Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ app = FastAPI()
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
| 9 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
| 10 |
|
| 11 |
-
# Memory store per user
|
| 12 |
chat_history = {}
|
| 13 |
|
| 14 |
@app.get("/ai")
|
|
@@ -17,20 +16,17 @@ async def chat(request: Request):
|
|
| 17 |
user_input = query_params.get("query", "")
|
| 18 |
user_id = query_params.get("user_id", "default")
|
| 19 |
|
| 20 |
-
# Pull user history
|
| 21 |
-
user_history = chat_history.get(user_id, [])
|
| 22 |
-
|
| 23 |
-
# Tokenize with context
|
| 24 |
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
|
| 25 |
-
|
| 26 |
-
# Concatenate history if available
|
| 27 |
bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
|
| 28 |
|
| 29 |
-
# Generate response
|
| 30 |
output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
|
| 31 |
response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
| 32 |
|
| 33 |
-
# Save history
|
| 34 |
chat_history[user_id] = [bot_input_ids, output_ids]
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
| 9 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
| 10 |
|
|
|
|
| 11 |
chat_history = {}
|
| 12 |
|
| 13 |
@app.get("/ai")
|
|
|
|
| 16 |
user_input = query_params.get("query", "")
|
| 17 |
user_id = query_params.get("user_id", "default")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
|
| 20 |
+
user_history = chat_history.get(user_id, [])
|
|
|
|
| 21 |
bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
|
| 22 |
|
|
|
|
| 23 |
output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
|
| 24 |
response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
| 25 |
|
|
|
|
| 26 |
chat_history[user_id] = [bot_input_ids, output_ids]
|
| 27 |
+
return JSONResponse({"reply": response})
|
| 28 |
|
| 29 |
+
# ✅ Add this to launch properly on Hugging Face Spaces
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
import uvicorn
|
| 32 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|