API / app.py
Trigger82's picture
Update app.py
2fd3a49 verified
raw
history blame
2.45 kB
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import requests
import threading
app = FastAPI()
# Load model and tokenizer once
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
# In-memory chat history by user
chat_history = {}
@app.get("/")
async def root():
return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"}
@app.get("/ai")
async def chat(request: Request):
query_params = dict(request.query_params)
user_input = query_params.get("query", "")
user_id = query_params.get("user_id", "default")
if not user_input:
return JSONResponse({"error": "Missing 'query' parameter"}, status_code=400)
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
user_history = chat_history.get(user_id, [])
bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
chat_history[user_id] = [bot_input_ids, output_ids]
return JSONResponse({"reply": response})
# Gradio UI to call your /ai endpoint easily via browser
def gradio_chat(user_input, user_id="default"):
if not user_input:
return "Please enter some text."
url = f"https://Trigger82--API.hf.space/ai?query={user_input}&user_id={user_id}"
try:
res = requests.get(url)
if res.status_code == 200:
return res.json().get("reply", "No reply")
return f"Error: {res.status_code}"
except Exception as e:
return f"Exception: {e}"
iface = gr.Interface(
fn=gradio_chat,
inputs=[gr.Textbox(label="Your Message"), gr.Textbox(label="User ID", value="default")],
outputs="text",
title="Chat with DialoGPT API",
description="Type your message and user id to chat with the model."
)
# Launch Gradio app in a thread alongside FastAPI
def run_gradio():
iface.launch(server_name="0.0.0.0", server_port=7861, share=False)
threading.Thread(target=run_gradio).start()
# No need for uvicorn.run here on Spaces; it manages startup automatically