Trigger82 commited on
Commit
a7c32b2
Β·
verified Β·
1 Parent(s): b82cab2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -1,40 +1,36 @@
1
- import torch
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- # Load model
7
- model_id = "microsoft/phi-2"
8
- tokenizer = AutoTokenizer.from_pretrained(model_id)
9
- model = AutoModelForCausalLM.from_pretrained(model_id)
10
 
11
- # Memory dict
12
- chat_history = {}
13
 
14
- # History formatter
15
- def format_context(history):
16
- return "".join([f"You: {u}\n𝕴 𝖆𝖒 π–π–Žπ–’: {b}\n" for u, b in history[-3:]])
17
 
18
- # Create FastAPI app
19
- app = FastAPI()
20
 
21
  @app.get("/ai")
22
- async def ai_chat(request: Request):
23
  query_params = dict(request.query_params)
24
  user_input = query_params.get("query", "")
25
  user_id = query_params.get("user_id", "default")
26
 
27
- # Pull history
28
- history = chat_history.get(user_id, [])
29
- prompt = format_context(history) + f"You: {user_input}\n𝕴 𝖆𝖒 π–π–Žπ–’:"
 
 
 
 
 
30
 
31
- # Run model
32
- inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True)
33
- outputs = model.generate(**inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
34
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True).split("𝕴 𝖆𝖒 π–π–Žπ–’:")[-1].strip()
35
 
36
- # Store memory
37
- history.append((user_input, reply))
38
- chat_history[user_id] = history[-10:]
39
 
40
- return JSONResponse({"reply": reply})
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
 
 
 
 
5
 
6
+ app = FastAPI()
 
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
9
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
10
 
11
+ # Memory store per user
12
+ chat_history = {}
13
 
14
  @app.get("/ai")
15
+ async def chat(request: Request):
16
  query_params = dict(request.query_params)
17
  user_input = query_params.get("query", "")
18
  user_id = query_params.get("user_id", "default")
19
 
20
+ # Pull user history
21
+ user_history = chat_history.get(user_id, [])
22
+
23
+ # Tokenize with context
24
+ new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
25
+
26
+ # Concatenate history if available
27
+ bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
28
 
29
+ # Generate response
30
+ output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
31
+ response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
32
 
33
+ # Save history
34
+ chat_history[user_id] = [bot_input_ids, output_ids]
 
35
 
36
+ return JSONResponse({"reply": response})