API / app.py
Trigger82's picture
Update app.py
e44d7d1 verified
raw
history blame
1.56 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import urllib.parse
# Load model and tokenizer
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Global memory for all users
chat_history = {}
# Format past messages
def format_context(history):
context = ""
for user, bot in history[-3:]: # Last 3 exchanges
context += f"You: {user}\n𝕴 𝖆𝖒 π–π–Žπ–’: {bot}\n"
return context
# Main chat function with memory per user
def chat_with_memory(query_string):
parsed = urllib.parse.parse_qs(query_string)
user_input = parsed.get("query", [""])[0]
user_id = parsed.get("user_id", ["default"])[0]
# Get or init user history
history = chat_history.get(user_id, [])
# Format prompt
context = format_context(history) + f"You: {user_input}\n𝕴 𝖆𝖒 π–π–Žπ–’:"
# Tokenize & generate
inputs = tokenizer(context, return_tensors="pt", return_attention_mask=True)
outputs = model.generate(**inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
reply = tokenizer.decode(outputs[0], skip_special_tokens=True).split("𝕴 𝖆𝖒 π–π–Žπ–’:")[-1].strip()
# Save memory
history.append((user_input, reply))
chat_history[user_id] = history[-10:]
return {"reply": reply}
# Create public /ai?query=&user_id=
iface = gr.Interface(
fn=chat_with_memory,
inputs="text", # URL query string
outputs="json"
)
iface.launch()