MindVR commited on
Commit
1624f80
·
verified ·
1 Parent(s): 5667853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from typing import List, Union
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
@@ -20,22 +20,20 @@ model = AutoModelForCausalLM.from_pretrained(
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model.to(device)
22
 
23
- def build_prompt(history, new_message):
24
- prompt = ""
25
- if history:
26
- prompt += "\n".join(history) + "\n"
27
- prompt += f"User: {new_message}\nAI:"
28
- return prompt
29
 
30
- def chat(history: Union[str, List[str]], new_message: str) -> str:
31
- if isinstance(history, str):
32
- import ast
33
- try:
34
- history = ast.literal_eval(history)
35
- except:
36
- history = [history]
37
- prompt = build_prompt(history, new_message)
38
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
39
  with torch.no_grad():
40
  output = model.generate(
41
  input_ids,
@@ -54,15 +52,20 @@ def chat(history: Union[str, List[str]], new_message: str) -> str:
54
 
55
  with gr.Blocks() as demo:
56
  gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!")
57
- with gr.Row():
58
- history = gr.Textbox(lines=8, label="History (JSON list, ví dụ: [\"User: Xin chào\"] )")
59
- new_message = gr.Textbox(label="New message")
60
  output = gr.Textbox(label="AI Response")
61
- def _chat_ui(history, new_message):
62
- return chat(history, new_message)
 
 
 
 
63
  btn = gr.Button("Gửi")
64
- btn.click(_chat_ui, inputs=[history, new_message], outputs=output)
65
 
 
66
  gr.api(chat, api_name="chat_ai")
67
 
68
  demo.launch()
 
1
  import os
2
  import torch
3
+ from typing import List
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model.to(device)
22
 
23
+ def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
24
+ prompt_text = prompt.strip() + "\n" if prompt else ""
25
+ if histories:
26
+ prompt_text += "\n".join(histories) + "\n"
27
+ prompt_text += f"User: {new_message}\nAI:"
28
+ return prompt_text
29
 
30
+ def chat(
31
+ prompt: str,
32
+ histories: List[str],
33
+ new_message: str
34
+ ) -> str:
35
+ prompt_text = build_prompt(prompt, histories, new_message)
36
+ input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
 
 
37
  with torch.no_grad():
38
  output = model.generate(
39
  input_ids,
 
52
 
53
  with gr.Blocks() as demo:
54
  gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!")
55
+ prompt_box = gr.Textbox(lines=2, label="Prompt (System Prompt, chỉ dẫn context cho AI, có thể bỏ trống)")
56
+ histories_box = gr.Textbox(lines=8, label="Histories (mỗi dòng là một message, ví dụ: User: Xin chào)")
57
+ new_message_box = gr.Textbox(label="New message")
58
  output = gr.Textbox(label="AI Response")
59
+
60
+ def _chat_ui(prompt, histories, new_message):
61
+ # histories nhập từ UI là multiline string -> chuyển thành list
62
+ histories_list = [line.strip() for line in histories.split('\n') if line.strip()]
63
+ return chat(prompt, histories_list, new_message)
64
+
65
  btn = gr.Button("Gửi")
66
+ btn.click(_chat_ui, inputs=[prompt_box, histories_box, new_message_box], outputs=output)
67
 
68
+ # API chuẩn RESTful với prompt, histories, new_message
69
  gr.api(chat, api_name="chat_ai")
70
 
71
  demo.launch()