Trigger82 commited on
Commit
034af7a
·
verified ·
1 Parent(s): 0d3e157

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -50
app.py CHANGED
@@ -1,68 +1,113 @@
1
- from fastapi import FastAPI
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- from tinydb import TinyDB, Query
 
5
  import uvicorn
6
 
7
- app = FastAPI()
 
 
8
 
9
- # Load model
10
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
11
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
12
- model.eval()
13
 
14
- # Memory setup
15
- db = TinyDB("memory.json")
16
- User = Query()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.get("/")
19
- async def chat(query: str = "", number: str = "", prompt: str = ""):
20
- if not query or not number:
21
- return {"response": "Missing query or number."}
22
-
23
- record = db.get(User.number == number)
24
- if not record:
25
- record = {
26
- "number": number,
27
- "history": [],
28
- "prompt": "You are Trigger, a clever, emotional AI boy created by 'I am him'. Speak like a witty, chill human friend."
29
- }
30
- db.insert(record)
31
-
32
- if prompt:
33
- record["prompt"] = prompt
34
- db.update(record, User.number == number)
35
-
36
- history = record["history"]
37
- system_prompt = record.get("prompt", "")
38
-
39
- # Append new input
40
- history.append(f"User: {query}")
41
- trimmed = history[-6:]
42
-
43
- # Construct final prompt
44
- full_prompt = system_prompt + "\n" + "\n".join(trimmed) + "\nTrigger:"
45
- inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True)
46
-
47
- with torch.no_grad():
48
  output = model.generate(
49
- **inputs,
50
  max_new_tokens=100,
51
- temperature=0.9,
52
- top_p=0.9,
53
  do_sample=True,
 
 
 
54
  pad_token_id=tokenizer.eos_token_id
55
  )
56
 
57
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
58
- reply = decoded.split("Trigger:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Save memory
61
- history.append(f"Trigger: {reply}")
62
- db.update({"number": number, "history": history, "prompt": system_prompt}, User.number == number)
 
 
 
 
 
63
 
64
- return {"response": reply}
 
 
 
 
 
 
 
 
 
65
 
66
- # Only runs locally, HF handles it differently
67
  if __name__ == "__main__":
68
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import JSONResponse, HTMLResponse
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ import os
6
+ import logging
7
  import uvicorn
8
 
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
+ # Base path from Hugging Face Spaces
14
+ BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
15
+ logger.info(f"Using base path: '{BASE_PATH}'")
 
16
 
17
+ # Initialize app
18
+ app = FastAPI(title="Trigger AI", description="Lightning fast chatbot", version="1.0")
19
+
20
+ # Load lightweight fast model (phi-1.5)
21
+ try:
22
+ logger.info("Loading tokenizer and model...")
23
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
24
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
25
+ logger.info("Model loaded.")
26
+ except Exception as e:
27
+ logger.error(f"Model load error: {e}")
28
+ raise RuntimeError("Model failed to load")
29
+
30
+ # In-memory chat memory per user_id
31
+ chat_memory = {}
32
+
33
+ @app.middleware("http")
34
+ async def strip_base_path(request: Request, call_next):
35
+ path = request.scope["path"]
36
+ if BASE_PATH and path.startswith(BASE_PATH):
37
+ request.scope["path"] = path[len(BASE_PATH):]
38
+ return await call_next(request)
39
 
40
  @app.get("/")
41
+ async def root():
42
+ return {
43
+ "message": " Trigger AI is active",
44
+ "try": f"{BASE_PATH}/ai?query=Hello&user_id=233XXXXXXXXX"
45
+ }
46
+
47
+ @app.get("/ai")
48
+ async def ai(request: Request):
49
+ query = request.query_params.get("query", "").strip()
50
+ user_id = request.query_params.get("user_id", "").strip()
51
+
52
+ if not query or not user_id:
53
+ raise HTTPException(status_code=400, detail="Missing 'query' or 'user_id'")
54
+
55
+ try:
56
+ # Tokenize input
57
+ input_ids = tokenizer.encode(query, return_tensors="pt")
58
+
59
+ # Load history
60
+ history = chat_memory.get(user_id, [])
61
+ full_input = torch.cat(history + [input_ids], dim=-1) if history else input_ids
62
+
63
+ # Generate response
 
 
 
 
 
 
64
  output = model.generate(
65
+ full_input,
66
  max_new_tokens=100,
 
 
67
  do_sample=True,
68
+ top_k=40,
69
+ top_p=0.9,
70
+ temperature=0.8,
71
  pad_token_id=tokenizer.eos_token_id
72
  )
73
 
74
+ # Decode result
75
+ response = tokenizer.decode(output[:, full_input.shape[-1]:][0], skip_special_tokens=True)
76
+
77
+ # Save memory
78
+ chat_memory[user_id] = [full_input, output]
79
+
80
+ return {"reply": response}
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error: {e}")
84
+ raise HTTPException(status_code=500, detail=str(e))
85
+
86
+ @app.get("/reset")
87
+ async def reset(user_id: str = "default"):
88
+ if user_id in chat_memory:
89
+ del chat_memory[user_id]
90
+ return {"status": "cleared", "user_id": user_id}
91
 
92
+ @app.get("/health")
93
+ async def health():
94
+ return {
95
+ "status": "🟢 online",
96
+ "users": len(chat_memory),
97
+ "model": "phi-1.5",
98
+ "base_path": BASE_PATH
99
+ }
100
 
101
+ @app.get("/test", response_class=HTMLResponse)
102
+ async def test():
103
+ return f"""
104
+ <html>
105
+ <body>
106
+ <h2>Trigger AI Test</h2>
107
+ <a href="{BASE_PATH}/ai?query=Hello&user_id=tester">Talk to Trigger</a>
108
+ </body>
109
+ </html>
110
+ """
111
 
 
112
  if __name__ == "__main__":
113
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)