from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch import gradio as gr import requests import threading app = FastAPI() # Load model and tokenizer once tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") # In-memory chat history by user chat_history = {} @app.get("/") async def root(): return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"} @app.get("/ai") async def chat(request: Request): query_params = dict(request.query_params) user_input = query_params.get("query", "") user_id = query_params.get("user_id", "default") if not user_input: return JSONResponse({"error": "Missing 'query' parameter"}, status_code=400) new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') user_history = chat_history.get(user_id, []) bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) chat_history[user_id] = [bot_input_ids, output_ids] return JSONResponse({"reply": response}) # Gradio UI to call your /ai endpoint easily via browser def gradio_chat(user_input, user_id="default"): if not user_input: return "Please enter some text." url = f"https://Trigger82--API.hf.space/ai?query={user_input}&user_id={user_id}" try: res = requests.get(url) if res.status_code == 200: return res.json().get("reply", "No reply") return f"Error: {res.status_code}" except Exception as e: return f"Exception: {e}" iface = gr.Interface( fn=gradio_chat, inputs=[gr.Textbox(label="Your Message"), gr.Textbox(label="User ID", value="default")], outputs="text", title="Chat with DialoGPT API", description="Type your message and user id to chat with the model." ) # Launch Gradio app in a thread alongside FastAPI def run_gradio(): iface.launch(server_name="0.0.0.0", server_port=7861, share=False) threading.Thread(target=run_gradio).start() # No need for uvicorn.run here on Spaces; it manages startup automatically