Minte commited on
Commit
3c52cba
·
1 Parent(s): c3e5dc7
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -2,9 +2,12 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from fastapi import FastAPI, Request
 
 
 
5
 
6
  # -------------------------------------------------
7
- # 1. Load model (same as your old code)
8
  # -------------------------------------------------
9
  print("Initializing DialoGPT-medium model...")
10
  model_name = "microsoft/DialoGPT-medium"
@@ -14,25 +17,23 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
14
  if tokenizer.pad_token is None:
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
17
- print("DialoGPT-medium loaded!")
18
 
19
  # -------------------------------------------------
20
- # 2. Generation helper (your old logic, cleaned up)
21
  # -------------------------------------------------
22
  def generate_response(message: str, chat_history: list):
23
  if not message.strip():
24
  return "Please enter a message."
25
 
26
- # Build conversation string
27
- conv = ""
28
  for user, bot in chat_history:
29
- conv += f"User: {user}\nBot: {bot}\n"
30
- conv += f"User: {message}\nBot:"
31
 
32
- # Encode
33
- inputs = tokenizer.encode(conv, return_tensors="pt", max_length=1024, truncation=True)
34
 
35
- # Generate
36
  with torch.no_grad():
37
  outputs = model.generate(
38
  inputs,
@@ -53,20 +54,20 @@ def generate_response(message: str, chat_history: list):
53
  return response
54
 
55
  # -------------------------------------------------
56
- # 3. Gradio chat function (used by /run/predict)
57
  # -------------------------------------------------
58
  def chat_fn(message: str, history: list):
59
  response = generate_response(message, history or [])
60
  history.append((message, response))
61
- return "", history # clear textbox, update chat
62
 
63
  # -------------------------------------------------
64
- # 4. Build the UI (your Blocks layout)
65
  # -------------------------------------------------
66
  example_questions = [
67
  "Hello! How are you today?",
68
  "What can you help me with?",
69
- "Tell me about artificial intelligence",
70
  "What's your favorite programming language?",
71
  "Can you explain machine learning?",
72
  "How does a neural network work?"
@@ -77,8 +78,8 @@ with gr.Blocks(
77
  title="GihonTech - AI Conversation Assistant"
78
  ) as demo:
79
 
80
- gr.Markdown("# GihonTech AI Conversation Assistant")
81
- gr.Markdown("Chat with an AI powered by **DialoGPT-medium**")
82
 
83
  with gr.Row():
84
  with gr.Column(scale=3):
@@ -104,7 +105,7 @@ with gr.Blocks(
104
  gr.Markdown("---")
105
  gr.Markdown("### Model Info")
106
  gr.Textbox(
107
- value="DialoGPT-medium: Loaded",
108
  label="Model Status",
109
  interactive=False,
110
  )
@@ -115,39 +116,41 @@ with gr.Blocks(
115
  - Conversation memory
116
 
117
  **Tips**
118
- - Ask clear questions
119
- - Use *Clear Chat* to start over
120
  """
121
  )
122
 
123
- # Event wiring
124
  send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
125
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
126
  clear.click(lambda: ([], ""), outputs=[chatbot, msg])
127
 
128
  # -------------------------------------------------
129
- # 5. OPTIONAL: expose /lambda (same JSON format)
130
  # -------------------------------------------------
131
  fastapi_app = FastAPI()
132
 
 
 
 
 
 
 
 
 
133
  @fastapi_app.post("/lambda")
134
  async def lambda_endpoint(req: Request):
135
  payload = await req.json()
136
- # Gradio sends {"data": [...]} ; we accept anything
137
  user_msg = payload.get("data", [""])[0]
138
- # Use the same generation logic (no history for this endpoint)
139
- resp = generate_response(user_msg, [])
140
- return {"data": [resp]}
141
 
142
- demo.mount_app(fastapi_app) # makes /lambda reachable
 
143
 
144
  # -------------------------------------------------
145
- # 6. Launch with queue (critical for API!)
146
  # -------------------------------------------------
147
  if __name__ == "__main__":
148
- demo.queue().launch(
149
- server_name="0.0.0.0",
150
- server_port=7860,
151
- share=False,
152
- show_error=True,
153
- )
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from fastapi import FastAPI, Request
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from gradio.routes import mount_gradio_app
7
+ import uvicorn
8
 
9
  # -------------------------------------------------
10
+ # 1. Load model
11
  # -------------------------------------------------
12
  print("Initializing DialoGPT-medium model...")
13
  model_name = "microsoft/DialoGPT-medium"
 
17
  if tokenizer.pad_token is None:
18
  tokenizer.pad_token = tokenizer.eos_token
19
 
20
+ print("DialoGPT-medium loaded successfully!")
21
 
22
  # -------------------------------------------------
23
+ # 2. Helper: Generate a response
24
  # -------------------------------------------------
25
  def generate_response(message: str, chat_history: list):
26
  if not message.strip():
27
  return "Please enter a message."
28
 
29
+ # Build the conversation context
30
+ conversation = ""
31
  for user, bot in chat_history:
32
+ conversation += f"User: {user}\nBot: {bot}\n"
33
+ conversation += f"User: {message}\nBot:"
34
 
35
+ inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=1024, truncation=True)
 
36
 
 
37
  with torch.no_grad():
38
  outputs = model.generate(
39
  inputs,
 
54
  return response
55
 
56
  # -------------------------------------------------
57
+ # 3. Gradio chat handler
58
  # -------------------------------------------------
59
  def chat_fn(message: str, history: list):
60
  response = generate_response(message, history or [])
61
  history.append((message, response))
62
+ return "", history # clear textbox, update chat
63
 
64
  # -------------------------------------------------
65
+ # 4. Build the Gradio UI
66
  # -------------------------------------------------
67
  example_questions = [
68
  "Hello! How are you today?",
69
  "What can you help me with?",
70
+ "Tell me about artificial intelligence.",
71
  "What's your favorite programming language?",
72
  "Can you explain machine learning?",
73
  "How does a neural network work?"
 
78
  title="GihonTech - AI Conversation Assistant"
79
  ) as demo:
80
 
81
+ gr.Markdown("# 🤖 GihonTech AI Conversation Assistant")
82
+ gr.Markdown("Chat with an AI powered by **DialoGPT-medium**.")
83
 
84
  with gr.Row():
85
  with gr.Column(scale=3):
 
105
  gr.Markdown("---")
106
  gr.Markdown("### Model Info")
107
  gr.Textbox(
108
+ value="DialoGPT-medium: Loaded",
109
  label="Model Status",
110
  interactive=False,
111
  )
 
116
  - Conversation memory
117
 
118
  **Tips**
119
+ - Ask clear, simple questions
120
+ - Use *Clear Chat* to start over
121
  """
122
  )
123
 
124
+ # Wire up events
125
  send.click(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
126
  msg.submit(chat_fn, inputs=[msg, chatbot], outputs=[msg, chatbot])
127
  clear.click(lambda: ([], ""), outputs=[chatbot, msg])
128
 
129
  # -------------------------------------------------
130
+ # 5. FastAPI app + Lambda route
131
  # -------------------------------------------------
132
  fastapi_app = FastAPI()
133
 
134
+ # Allow AnythingLLM / frontend CORS access
135
+ fastapi_app.add_middleware(
136
+ CORSMiddleware,
137
+ allow_origins=["*"],
138
+ allow_methods=["*"],
139
+ allow_headers=["*"],
140
+ )
141
+
142
  @fastapi_app.post("/lambda")
143
  async def lambda_endpoint(req: Request):
144
  payload = await req.json()
 
145
  user_msg = payload.get("data", [""])[0]
146
+ response = generate_response(user_msg, [])
147
+ return {"data": [response]}
 
148
 
149
+ # Mount Gradio app inside FastAPI
150
+ mount_gradio_app(fastapi_app, demo, path="/")
151
 
152
  # -------------------------------------------------
153
+ # 6. Run the combined FastAPI + Gradio app
154
  # -------------------------------------------------
155
  if __name__ == "__main__":
156
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)