File size: 3,758 Bytes
5c48b65
00d89b2
1624f80
5c48b65
62d55b4
05a9b7e
fad09d8
 
 
 
 
 
 
77464a5
 
 
 
 
fad09d8
 
 
 
 
 
 
 
 
 
 
 
00d89b2
 
 
 
a24a633
029c034
00d89b2
bb47474
 
77464a5
a24a633
 
00d89b2
bb47474
00d89b2
fad09d8
1624f80
 
fad09d8
 
 
 
 
 
1624f80
 
bb47474
1624f80
 
 
 
 
 
 
05a9b7e
5c48b65
 
77464a5
5c48b65
 
00d89b2
 
5c48b65
00d89b2
 
 
1e99ca6
 
00d89b2
 
7e877bc
e93544f
91e2402
 
1624f80
 
 
91e2402
1624f80
 
 
 
 
 
91e2402
1624f80
91e2402
1624f80
91e2402
e93544f
91e2402
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
from typing import List
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from transformers import AutoTokenizer as SummarizerTokenizer, AutoModelForSeq2SeqLM

device = "cuda" if torch.cuda.is_available() else "cpu"

# Summarization model
summarizer_model_id = "facebook/bart-large-cnn"
summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(
    summarizer_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)
summarizer_model.to(device)

def summarize_text(text: str, max_length=150) -> str:
    inputs = summarizer_tokenizer([text], return_tensors="pt", max_length=1024, truncation=True).to(device)
    summary_ids = summarizer_model.generate(
        inputs['input_ids'],
        num_beams=4,
        max_length=max_length,
        early_stopping=True
    )
    summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True,
    token=HF_TOKEN
)

# --- GIỮ LẠI CHỈ 1 HÀM build_prompt, ĐÃ BỔ SUNG SUMMARIZATION ---
def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
    prompt_text = prompt.strip() + "\n" if prompt else ""
    histories_text = "\n".join(histories) if histories else ""
    # Tóm tắt nếu quá dài (tùy chỉnh ngưỡng này)
    if len(histories_text) > 3000:
        histories_text = summarize_text(histories_text, max_length=180)
    if histories_text:
        prompt_text += histories_text + "\n"
    prompt_text += f"User: {new_message}\nAI:"
    return prompt_text

def chat(
    prompt: str,
    histories: List[str],
    new_message: str
) -> str:
    prompt_text = build_prompt(prompt, histories, new_message)
    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=256,
            do_sample=True,
            top_p=0.95,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    if "AI:" in output_text:
        response = output_text.split("AI:")[-1].strip()
        if "User:" in response:
            response = response.split("User:")[0].strip()
    else:
        response = output_text.strip()
    return response

with gr.Blocks() as demo:
    gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!")
    prompt_box = gr.Textbox(lines=2, label="Prompt (System Prompt, chỉ dẫn context cho AI, có thể bỏ trống)")
    histories_box = gr.Textbox(lines=8, label="Histories (mỗi dòng là một message, ví dụ: User: Xin chào)")
    new_message_box = gr.Textbox(label="New message")
    output = gr.Textbox(label="AI Response")

    def _chat_ui(prompt, histories, new_message):
        # histories nhập từ UI là multiline string -> chuyển thành list
        histories_list = [line.strip() for line in histories.split('\n') if line.strip()]
        return chat(prompt, histories_list, new_message)

    btn = gr.Button("Gửi")
    btn.click(_chat_ui, inputs=[prompt_box, histories_box, new_message_box], outputs=output)

    # API chuẩn RESTful với prompt, histories, new_message
    gr.api(chat, api_name="chat_ai")

demo.launch()