API / app.py
Trigger82's picture
Update app.py
08aad81 verified
raw
history blame
1.53 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import urllib.parse
# Load model
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Memory for users
chat_history = {}
# Format history
def format_context(history):
return "".join([f"You: {u}\n𝕴 𝖆𝖒 π–π–Žπ–’: {b}\n" for u, b in history[-3:]])
# FastAPI app
app = FastAPI()
@app.get("/ai")
async def ai_chat(request: Request):
query_params = dict(request.query_params)
user_input = query_params.get("query", "")
user_id = query_params.get("user_id", "default")
# Get user history
history = chat_history.get(user_id, [])
prompt = format_context(history) + f"You: {user_input}\n𝕴 𝖆𝖒 π–π–Žπ–’:"
# Tokenize & run model
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True)
outputs = model.generate(**inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
reply = tokenizer.decode(outputs[0], skip_special_tokens=True).split("𝕴 𝖆𝖒 π–π–Žπ–’:")[-1].strip()
# Save memory
history.append((user_input, reply))
chat_history[user_id] = history[-10:]
return JSONResponse({"reply": reply})
# Wrap with Gradio to serve
app = gr.mount_gradio_app(app, gr.Interface(lambda x: x, "textbox", "textbox"))
# Launch it
gradio_app = gr.FastAPI(app)