MindVR commited on
Commit
c67107b
·
verified ·
1 Parent(s): 5fb8b67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -81
app.py CHANGED
@@ -3,10 +3,7 @@ import torch
3
  from huggingface_hub import login
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
- from fastapi import FastAPI, Request
7
- from pydantic import BaseModel
8
 
9
- # ---- Load Model ----
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  if HF_TOKEN:
12
  login(token=HF_TOKEN)
@@ -22,7 +19,6 @@ model = AutoModelForCausalLM.from_pretrained(
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
 
25
- # ---- Chat Function ----
26
  def build_prompt(history, new_message):
27
  prompt = ""
28
  if history:
@@ -30,14 +26,9 @@ def build_prompt(history, new_message):
30
  prompt += f"User: {new_message}\nAI:"
31
  return prompt
32
 
33
- def chat_gradio(message, history):
34
- history_text = []
35
- if history:
36
- # history là dạng list các cặp [msg, response]
37
- for user_msg, ai_msg in history:
38
- history_text.append(f"User: {user_msg}")
39
- history_text.append(f"AI: {ai_msg}")
40
- prompt = build_prompt(history_text, message)
41
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
42
  with torch.no_grad():
43
  output = model.generate(
@@ -49,79 +40,22 @@ def chat_gradio(message, history):
49
  pad_token_id=tokenizer.eos_token_id
50
  )
51
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
52
- # Lấy đoạn trả lời AI cuối cùng
53
  if "AI:" in output_text:
54
  response = output_text.split("AI:")[-1].strip()
55
  else:
56
  response = output_text.strip()
57
  return response
58
 
59
- # ---- Gradio Interface ----
60
- with gr.Blocks() as demo:
61
- gr.Markdown("# MindVR Therapy Chatbot")
62
- chatbot = gr.Chatbot()
63
- msg = gr.Textbox(label="Nhập câu hỏi")
64
- send = gr.Button("Gửi")
65
-
66
- def user_chat(message, history):
67
- response = chat_gradio(message, history)
68
- return response
69
-
70
- send.click(
71
- fn=user_chat,
72
- inputs=[msg, chatbot],
73
- outputs=chatbot,
74
- queue=False
75
- )
76
- msg.submit(
77
- fn=user_chat,
78
- inputs=[msg, chatbot],
79
- outputs=chatbot,
80
- queue=False
81
- )
82
-
83
- # ---- REST API Endpoint ----
84
- app = FastAPI()
85
-
86
- class ChatRequest(BaseModel):
87
- history: list
88
- new_message: str
89
-
90
- @app.post("/generate")
91
- async def generate(data: ChatRequest):
92
- # history dạng ["User: ...", "AI: ...", ...]
93
- prompt = build_prompt(data.history, data.new_message)
94
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
95
- with torch.no_grad():
96
- output = model.generate(
97
- input_ids,
98
- max_new_tokens=200,
99
- do_sample=True,
100
- top_p=0.95,
101
- temperature=0.7,
102
- pad_token_id=tokenizer.eos_token_id
103
- )
104
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
105
- if "AI:" in output_text:
106
- response = output_text.split("AI:")[-1].strip()
107
- else:
108
- response = output_text.strip()
109
- return {"response": response}
110
-
111
- # ---- Export both Gradio and API ----
112
- import uvicorn
113
-
114
- def main():
115
- import threading
116
- import time
117
-
118
- # Run FastAPI on background
119
- def run_api():
120
- uvicorn.run(app, host="0.0.0.0", port=7861)
121
- threading.Thread(target=run_api, daemon=True).start()
122
-
123
- # Run Gradio interface
124
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
125
 
126
- if __name__ == "__main__":
127
- main()
 
3
  from huggingface_hub import login
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
 
 
6
 
 
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
  if HF_TOKEN:
9
  login(token=HF_TOKEN)
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  model.to(device)
21
 
 
22
  def build_prompt(history, new_message):
23
  prompt = ""
24
  if history:
 
26
  prompt += f"User: {new_message}\nAI:"
27
  return prompt
28
 
29
+ def chat(history, new_message):
30
+ # history: list[str], new_message: str
31
+ prompt = build_prompt(history, new_message)
 
 
 
 
 
32
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
33
  with torch.no_grad():
34
  output = model.generate(
 
40
  pad_token_id=tokenizer.eos_token_id
41
  )
42
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
43
  if "AI:" in output_text:
44
  response = output_text.split("AI:")[-1].strip()
45
  else:
46
  response = output_text.strip()
47
  return response
48
 
49
+ # CHỈ DÙNG Gradio Interface, input là [history, new_message]
50
+ iface = gr.Interface(
51
+ fn=chat,
52
+ inputs=[
53
+ gr.inputs.Textbox(lines=8, label="History (JSON list, ví dụ: [\"User: Xin chào\"] )"),
54
+ gr.inputs.Textbox(label="New message")
55
+ ],
56
+ outputs="text",
57
+ title="MindVR Therapy Chatbot",
58
+ allow_flagging="never"
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ iface.launch()