MindVR commited on
Commit
00d89b2
·
verified ·
1 Parent(s): c85c1b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -22
app.py CHANGED
@@ -1,25 +1,43 @@
1
  import os
 
2
  from huggingface_hub import login
3
- login(token=os.environ["HF_TOKEN"]) # Dùng biến môi trường để lấy token
4
-
5
- import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import gradio as gr
8
-
9
- # Load model
10
- model_id = "MindVR/JohnTran_Fine-tune" # ⚠️ Đảm bảo đây là bản mới fine-tune không dùng 4bit
11
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ["HF_TOKEN"])
 
 
 
12
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
  device_map="auto",
16
  low_cpu_mem_usage=True,
17
- token=os.environ["HF_TOKEN"]
18
  )
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Hàm xử lý yêu cầu
21
- def chat(prompt):
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
23
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
24
  with torch.no_grad():
25
  output = model.generate(
@@ -27,18 +45,83 @@ def chat(prompt):
27
  max_new_tokens=200,
28
  do_sample=True,
29
  top_p=0.95,
30
- temperature=0.7
 
31
  )
32
- response = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
33
  return response
34
 
35
- # Giao diện Gradio
36
- demo = gr.Interface(
37
- fn=chat,
38
- inputs=gr.Textbox(label="Nhập câu hỏi"),
39
- outputs=gr.Textbox(label="Phản hồi từ AI"),
40
- title="MindVR Therapy Chatbot",
41
- allow_flagging="never"
42
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- demo.launch()
 
 
1
  import os
2
+ import torch
3
  from huggingface_hub import login
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
+ from fastapi import FastAPI, Request
7
+ from pydantic import BaseModel
8
+
9
+ # ---- Load Model ----
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+ if HF_TOKEN:
12
+ login(token=HF_TOKEN)
13
 
14
+ model_id = "MindVR/JohnTran_Fine-tune"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
  device_map="auto",
19
  low_cpu_mem_usage=True,
20
+ token=HF_TOKEN
21
  )
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model.to(device)
24
+
25
+ # ---- Chat Function ----
26
+ def build_prompt(history, new_message):
27
+ prompt = ""
28
+ if history:
29
+ prompt += "\n".join(history) + "\n"
30
+ prompt += f"User: {new_message}\nAI:"
31
+ return prompt
32
 
33
+ def chat_gradio(message, history):
34
+ history_text = []
35
+ if history:
36
+ # history là dạng list các cặp [msg, response]
37
+ for user_msg, ai_msg in history:
38
+ history_text.append(f"User: {user_msg}")
39
+ history_text.append(f"AI: {ai_msg}")
40
+ prompt = build_prompt(history_text, message)
41
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
42
  with torch.no_grad():
43
  output = model.generate(
 
45
  max_new_tokens=200,
46
  do_sample=True,
47
  top_p=0.95,
48
+ temperature=0.7,
49
+ pad_token_id=tokenizer.eos_token_id
50
  )
51
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
52
+ # Lấy đoạn trả lời AI cuối cùng
53
+ if "AI:" in output_text:
54
+ response = output_text.split("AI:")[-1].strip()
55
+ else:
56
+ response = output_text.strip()
57
  return response
58
 
59
+ # ---- Gradio Interface ----
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("# MindVR Therapy Chatbot")
62
+ chatbot = gr.Chatbot()
63
+ msg = gr.Textbox(label="Nhập câu hỏi")
64
+ send = gr.Button("Gửi")
65
+
66
+ def user_chat(message, history):
67
+ response = chat_gradio(message, history)
68
+ return response
69
+
70
+ send.click(
71
+ fn=user_chat,
72
+ inputs=[msg, chatbot],
73
+ outputs=chatbot,
74
+ queue=False
75
+ )
76
+ msg.submit(
77
+ fn=user_chat,
78
+ inputs=[msg, chatbot],
79
+ outputs=chatbot,
80
+ queue=False
81
+ )
82
+
83
+ # ---- REST API Endpoint ----
84
+ app = FastAPI()
85
+
86
+ class ChatRequest(BaseModel):
87
+ history: list
88
+ new_message: str
89
+
90
+ @app.post("/generate")
91
+ async def generate(data: ChatRequest):
92
+ # history dạng ["User: ...", "AI: ...", ...]
93
+ prompt = build_prompt(data.history, data.new_message)
94
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
95
+ with torch.no_grad():
96
+ output = model.generate(
97
+ input_ids,
98
+ max_new_tokens=200,
99
+ do_sample=True,
100
+ top_p=0.95,
101
+ temperature=0.7,
102
+ pad_token_id=tokenizer.eos_token_id
103
+ )
104
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
105
+ if "AI:" in output_text:
106
+ response = output_text.split("AI:")[-1].strip()
107
+ else:
108
+ response = output_text.strip()
109
+ return {"response": response}
110
+
111
+ # ---- Export both Gradio and API ----
112
+ import uvicorn
113
+
114
+ def main():
115
+ import threading
116
+ import time
117
+
118
+ # Run FastAPI on background
119
+ def run_api():
120
+ uvicorn.run(app, host="0.0.0.0", port=7861)
121
+ threading.Thread(target=run_api, daemon=True).start()
122
+
123
+ # Run Gradio interface
124
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
125
 
126
+ if __name__ == "__main__":
127
+ main()