Spaces:
Sleeping
Sleeping
File size: 3,758 Bytes
5c48b65 00d89b2 1624f80 5c48b65 62d55b4 05a9b7e fad09d8 77464a5 fad09d8 00d89b2 a24a633 029c034 00d89b2 bb47474 77464a5 a24a633 00d89b2 bb47474 00d89b2 fad09d8 1624f80 fad09d8 1624f80 bb47474 1624f80 05a9b7e 5c48b65 77464a5 5c48b65 00d89b2 5c48b65 00d89b2 1e99ca6 00d89b2 7e877bc e93544f 91e2402 1624f80 91e2402 1624f80 91e2402 1624f80 91e2402 1624f80 91e2402 e93544f 91e2402 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import torch
from typing import List
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from transformers import AutoTokenizer as SummarizerTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# Summarization model
summarizer_model_id = "facebook/bart-large-cnn"
summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(
summarizer_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
summarizer_model.to(device)
def summarize_text(text: str, max_length=150) -> str:
inputs = summarizer_tokenizer([text], return_tensors="pt", max_length=1024, truncation=True).to(device)
summary_ids = summarizer_model.generate(
inputs['input_ids'],
num_beams=4,
max_length=max_length,
early_stopping=True
)
summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
token=HF_TOKEN
)
# --- GIỮ LẠI CHỈ 1 HÀM build_prompt, ĐÃ BỔ SUNG SUMMARIZATION ---
def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
prompt_text = prompt.strip() + "\n" if prompt else ""
histories_text = "\n".join(histories) if histories else ""
# Tóm tắt nếu quá dài (tùy chỉnh ngưỡng này)
if len(histories_text) > 3000:
histories_text = summarize_text(histories_text, max_length=180)
if histories_text:
prompt_text += histories_text + "\n"
prompt_text += f"User: {new_message}\nAI:"
return prompt_text
def chat(
prompt: str,
histories: List[str],
new_message: str
) -> str:
prompt_text = build_prompt(prompt, histories, new_message)
input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "AI:" in output_text:
response = output_text.split("AI:")[-1].strip()
if "User:" in response:
response = response.split("User:")[0].strip()
else:
response = output_text.strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# MindVR Therapy Chatbot\n\nDùng thử UI hoặc gọi API!")
prompt_box = gr.Textbox(lines=2, label="Prompt (System Prompt, chỉ dẫn context cho AI, có thể bỏ trống)")
histories_box = gr.Textbox(lines=8, label="Histories (mỗi dòng là một message, ví dụ: User: Xin chào)")
new_message_box = gr.Textbox(label="New message")
output = gr.Textbox(label="AI Response")
def _chat_ui(prompt, histories, new_message):
# histories nhập từ UI là multiline string -> chuyển thành list
histories_list = [line.strip() for line in histories.split('\n') if line.strip()]
return chat(prompt, histories_list, new_message)
btn = gr.Button("Gửi")
btn.click(_chat_ui, inputs=[prompt_box, histories_box, new_message_box], outputs=output)
# API chuẩn RESTful với prompt, histories, new_message
gr.api(chat, api_name="chat_ai")
demo.launch()
|