Trigger82 commited on
Commit
2fd3a49
·
verified ·
1 Parent(s): bf28dd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -2,14 +2,17 @@ from fastapi import FastAPI, Request
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
 
 
5
 
6
  app = FastAPI()
7
 
8
- # Load model and tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
11
 
12
- # In-memory history per user
13
  chat_history = {}
14
 
15
  @app.get("/")
@@ -35,7 +38,31 @@ async def chat(request: Request):
35
  chat_history[user_id] = [bot_input_ids, output_ids]
36
  return JSONResponse({"reply": response})
37
 
38
- # Only needed if running locally, not in Hugging Face Space
39
- if __name__ == "__main__":
40
- import uvicorn
41
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ import gradio as gr
6
+ import requests
7
+ import threading
8
 
9
  app = FastAPI()
10
 
11
+ # Load model and tokenizer once
12
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
13
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
14
 
15
+ # In-memory chat history by user
16
  chat_history = {}
17
 
18
  @app.get("/")
 
38
  chat_history[user_id] = [bot_input_ids, output_ids]
39
  return JSONResponse({"reply": response})
40
 
41
+ # Gradio UI to call your /ai endpoint easily via browser
42
+ def gradio_chat(user_input, user_id="default"):
43
+ if not user_input:
44
+ return "Please enter some text."
45
+ url = f"https://Trigger82--API.hf.space/ai?query={user_input}&user_id={user_id}"
46
+ try:
47
+ res = requests.get(url)
48
+ if res.status_code == 200:
49
+ return res.json().get("reply", "No reply")
50
+ return f"Error: {res.status_code}"
51
+ except Exception as e:
52
+ return f"Exception: {e}"
53
+
54
+ iface = gr.Interface(
55
+ fn=gradio_chat,
56
+ inputs=[gr.Textbox(label="Your Message"), gr.Textbox(label="User ID", value="default")],
57
+ outputs="text",
58
+ title="Chat with DialoGPT API",
59
+ description="Type your message and user id to chat with the model."
60
+ )
61
+
62
+ # Launch Gradio app in a thread alongside FastAPI
63
+ def run_gradio():
64
+ iface.launch(server_name="0.0.0.0", server_port=7861, share=False)
65
+
66
+ threading.Thread(target=run_gradio).start()
67
+
68
+ # No need for uvicorn.run here on Spaces; it manages startup automatically