Trigger82 commited on
Commit
81cbe70
·
verified ·
1 Parent(s): c7ceace

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -58
app.py CHANGED
@@ -1,37 +1,43 @@
1
  from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.responses import JSONResponse, HTMLResponse
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import os
6
  import logging
7
  import uvicorn
8
 
9
- # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Base path from Hugging Face Spaces
 
 
 
 
 
 
 
14
  BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
15
  logger.info(f"Using base path: '{BASE_PATH}'")
16
 
17
- # Initialize app
18
- app = FastAPI(title="Trigger AI", description="Lightning fast chatbot", version="1.0")
19
-
20
- # Load lightweight fast model (phi-1.5)
21
  try:
22
  logger.info("Loading tokenizer and model...")
23
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
24
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
25
- logger.info("Model loaded.")
 
26
  except Exception as e:
27
- logger.error(f"Model load error: {e}")
28
- raise RuntimeError("Model failed to load")
29
 
30
- # In-memory chat memory per user_id
31
- chat_memory = {}
32
 
 
33
  @app.middleware("http")
34
- async def strip_base_path(request: Request, call_next):
35
  path = request.scope["path"]
36
  if BASE_PATH and path.startswith(BASE_PATH):
37
  request.scope["path"] = path[len(BASE_PATH):]
@@ -40,74 +46,93 @@ async def strip_base_path(request: Request, call_next):
40
  @app.get("/")
41
  async def root():
42
  return {
43
- "message": " Trigger AI is active",
44
- "try": f"{BASE_PATH}/ai?query=Hello&user_id=233XXXXXXXXX"
 
 
 
 
 
 
45
  }
46
 
47
  @app.get("/ai")
48
- async def ai(request: Request):
49
- query = request.query_params.get("query", "").strip()
50
- user_id = request.query_params.get("user_id", "").strip()
51
-
52
- if not query or not user_id:
53
- raise HTTPException(status_code=400, detail="Missing 'query' or 'user_id'")
54
-
55
  try:
56
- # Tokenize input
57
- input_ids = tokenizer.encode(query, return_tensors="pt")
58
-
59
- # Load history
60
- history = chat_memory.get(user_id, [])
61
- full_input = torch.cat(history + [input_ids], dim=-1) if history else input_ids
62
-
63
- # Generate response
 
 
 
 
 
 
 
 
64
  output = model.generate(
65
- full_input,
66
- max_new_tokens=100,
67
- do_sample=True,
68
- top_k=40,
69
- top_p=0.9,
70
  temperature=0.8,
71
- pad_token_id=tokenizer.eos_token_id
 
72
  )
 
 
73
 
74
- # Decode result
75
- response = tokenizer.decode(output[:, full_input.shape[-1]:][0], skip_special_tokens=True)
76
-
77
- # Save memory
78
- chat_memory[user_id] = [full_input, output]
79
 
80
  return {"reply": response}
81
 
 
 
 
 
 
82
  except Exception as e:
83
- logger.error(f"Error: {e}")
84
- raise HTTPException(status_code=500, detail=str(e))
85
-
86
- @app.get("/reset")
87
- async def reset(user_id: str = "default"):
88
- if user_id in chat_memory:
89
- del chat_memory[user_id]
90
- return {"status": "cleared", "user_id": user_id}
91
 
92
  @app.get("/health")
93
  async def health():
94
  return {
95
- "status": "🟢 online",
96
- "users": len(chat_memory),
97
- "model": "phi-1.5",
98
  "base_path": BASE_PATH
99
  }
100
 
 
 
 
 
 
 
101
  @app.get("/test", response_class=HTMLResponse)
102
- async def test():
103
  return f"""
104
  <html>
105
  <body>
106
- <h2>Trigger AI Test</h2>
107
- <a href="{BASE_PATH}/ai?query=Hello&user_id=tester">Talk to Trigger</a>
 
 
 
 
 
 
108
  </body>
109
  </html>
110
  """
111
 
 
112
  if __name__ == "__main__":
113
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.responses import JSONResponse, HTMLResponse
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import os
6
  import logging
7
  import uvicorn
8
 
9
+ # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Initialize FastAPI
14
+ app = FastAPI(
15
+ title="PHI Chatbot API",
16
+ description="Chatbot API using Microsoft's Phi-2 model",
17
+ version="1.0",
18
+ )
19
+
20
+ # Get base path from environment (for Hugging Face Spaces)
21
  BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
22
  logger.info(f"Using base path: '{BASE_PATH}'")
23
 
24
+ # Load model and tokenizer
 
 
 
25
  try:
26
  logger.info("Loading tokenizer and model...")
27
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
28
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
29
+ model.eval()
30
+ logger.info("Model loaded successfully!")
31
  except Exception as e:
32
+ logger.error(f"Model loading failed: {str(e)}")
33
+ raise RuntimeError("Model initialization failed") from e
34
 
35
+ # In-memory chat memory
36
+ chat_history = {}
37
 
38
+ # Middleware for base path
39
  @app.middleware("http")
40
+ async def add_base_path(request: Request, call_next):
41
  path = request.scope["path"]
42
  if BASE_PATH and path.startswith(BASE_PATH):
43
  request.scope["path"] = path[len(BASE_PATH):]
 
46
  @app.get("/")
47
  async def root():
48
  return {
49
+ "message": "🟢 PHI API is running",
50
+ "endpoints": {
51
+ "chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
52
+ "health": f"{BASE_PATH}/health",
53
+ "reset": f"{BASE_PATH}/reset?user_id=yourname",
54
+ "test": f"{BASE_PATH}/test",
55
+ "docs": f"{BASE_PATH}/docs"
56
+ }
57
  }
58
 
59
  @app.get("/ai")
60
+ async def chat(request: Request):
 
 
 
 
 
 
61
  try:
62
+ user_input = request.query_params.get("query", "").strip()
63
+ user_id = request.query_params.get("user_id", "default").strip()
64
+
65
+ if not user_input:
66
+ raise HTTPException(status_code=400, detail="Missing 'query'")
67
+ if len(user_input) > 200:
68
+ raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")
69
+
70
+ # Prompt style: phi models work best with natural instructions
71
+ memory = chat_history.get(user_id, [])
72
+ prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
73
+ for q, a in memory:
74
+ prompt += f"User: {q}\nTrigger: {a}\n"
75
+ prompt += f"User: {user_input}\nTrigger:"
76
+
77
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
78
  output = model.generate(
79
+ input_ids,
80
+ max_new_tokens=128,
81
+ pad_token_id=tokenizer.eos_token_id,
 
 
82
  temperature=0.8,
83
+ top_k=50,
84
+ top_p=0.95,
85
  )
86
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
87
+ response = generated_text[len(prompt):].strip().split("\n")[0]
88
 
89
+ # Save history (limit to last 5 exchanges)
90
+ memory.append((user_input, response))
91
+ chat_history[user_id] = memory[-5:]
 
 
92
 
93
  return {"reply": response}
94
 
95
+ except torch.cuda.OutOfMemoryError:
96
+ logger.error("CUDA out of memory error")
97
+ if user_id in chat_history:
98
+ del chat_history[user_id]
99
+ raise HTTPException(status_code=500, detail="Memory error. Try again.")
100
  except Exception as e:
101
+ logger.error(f"Processing error: {str(e)}")
102
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
 
 
 
 
 
 
103
 
104
  @app.get("/health")
105
  async def health():
106
  return {
107
+ "status": "healthy",
108
+ "model": "microsoft/phi-2",
109
+ "users": len(chat_history),
110
  "base_path": BASE_PATH
111
  }
112
 
113
+ @app.get("/reset")
114
+ async def reset_history(user_id: str = "default"):
115
+ if user_id in chat_history:
116
+ del chat_history[user_id]
117
+ return {"status": "success", "message": f"History cleared for user {user_id}"}
118
+
119
  @app.get("/test", response_class=HTMLResponse)
120
+ async def test_page():
121
  return f"""
122
  <html>
123
  <body>
124
+ <h1>PHI Chatbot Test</h1>
125
+ <p>Base path: {BASE_PATH}</p>
126
+ <ul>
127
+ <li><a href="{BASE_PATH}/">Root endpoint</a></li>
128
+ <li><a href="{BASE_PATH}/ai?query=Hello&user_id=test">Chat endpoint</a></li>
129
+ <li><a href="{BASE_PATH}/health">Health check</a></li>
130
+ <li><a href="{BASE_PATH}/docs">API Docs</a></li>
131
+ </ul>
132
  </body>
133
  </html>
134
  """
135
 
136
+ # Run locally
137
  if __name__ == "__main__":
138
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True)