sunbv56's picture
Update app.py
12dfc47 verified
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
# Kiểm tra thiết bị (GPU nếu có)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # Thêm log để biết thiết bị đang sử dụng
# --- Tải mô hình và tokenizer ---
# Khởi tạo biến model và tokenizer là None
model_1, tokenizer_1 = None, None
model_2, tokenizer_2 = None, None
# model_3, tokenizer_3 = None, None # Không cần tải model_3 nữa
model_4, tokenizer_4 = None, None
# Sử dụng try-except để xử lý lỗi nếu không tải được mô hình
try:
model_name_1 = "sunbv56/ViLawT5_QAChatBot"
print(f"Loading model: {model_name_1}...")
model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device)
tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1)
print(f"Model {model_name_1} loaded successfully.")
except Exception as e:
print(f"Error loading model {model_name_1}: {e}")
try:
model_name_2 = "sunbv56/ViT5_QAChatBot"
print(f"Loading model: {model_name_2}...")
model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device)
tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2)
print(f"Model {model_name_2} loaded successfully.")
except Exception as e:
print(f"Error loading model {model_name_2}: {e}")
# Bỏ qua việc tải model_3 (ViLawT5_RL)
# ... (phần code tải model_3 bị comment như cũ) ...
try:
model_name_4 = "sunbv56/V-LegalQA"
print(f"Loading model: {model_name_4}...")
model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device)
tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4)
print(f"Model {model_name_4} loaded successfully.")
except Exception as e:
print(f"Error loading model {model_name_4}: {e}")
# --- Hàm sinh phản hồi ---
def chatbot_response(question, model_choice, max_new_tokens, temperature, top_k, top_p, repetition_penalty, use_early_stopping, use_do_sample):
model = None
tokenizer = None
# Chọn model dựa trên lựa chọn của người dùng (đã bỏ ViLawT5_RL)
if model_choice == "ViLawT5" and model_1 and tokenizer_1:
model = model_1
tokenizer = tokenizer_1
elif model_choice == "ViT5" and model_2 and tokenizer_2:
model = model_2
tokenizer = tokenizer_2
# Bỏ điều kiện kiểm tra ViLawT5_RL
# elif model_choice == "ViLawT5_RL" and model_3 and tokenizer_3:
# model = model_3
# tokenizer = tokenizer_3
elif model_choice == "V-LegalQA" and model_4 and tokenizer_4:
model = model_4
tokenizer = tokenizer_4
else:
# Kiểm tra xem model có được tải không
available_models = []
if model_1: available_models.append("ViLawT5")
if model_2: available_models.append("ViT5")
# Không thêm ViLawT5_RL vào danh sách kiểm tra
if model_4: available_models.append("V-LegalQA")
if not available_models:
return "Error: No models were loaded successfully. Please check the logs."
if model_choice not in available_models:
return f"Error: Model '{model_choice}' was not loaded successfully or is invalid. Available models: {', '.join(available_models)}"
else: # Trường hợp model_choice hợp lệ nhưng model/tokenizer là None (lỗi không mong muốn)
return f"Error: An unexpected issue occurred with model '{model_choice}'. Please check the logs."
print(f"Generating response using {model_choice} with params: max_new_tokens={max_new_tokens}, temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty}, early_stop={use_early_stopping}, do_sample={use_do_sample}")
input_text = f"câu_hỏi: {question}"
try:
data = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
padding="max_length",
max_length=256 # Cân nhắc tăng max_length nếu câu hỏi/context dài
)
input_ids = data.input_ids.to(device)
attention_mask = data.attention_mask.to(device)
# Suy luận với mô hình
with torch.no_grad():
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=int(max_new_tokens),
early_stopping=use_early_stopping,
do_sample=use_do_sample,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
# Thêm pad_token_id nếu cần (thường không cần cho T5)
# pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Raw output shape: {outputs[0].shape}") # Log thêm shape
print(f"Decoded response: {response}")
return response
except Exception as e:
print(f"Error during generation: {e}")
# In thêm traceback để debug
import traceback
traceback.print_exc()
return f"An error occurred during response generation: {e}"
# --- Tạo danh sách các model đã tải thành công (bỏ ViLawT5_RL) ---
loaded_models = []
if model_1 and tokenizer_1: loaded_models.append("ViLawT5")
if model_2 and tokenizer_2: loaded_models.append("ViT5")
if model_4 and tokenizer_4: loaded_models.append("V-LegalQA")
# Chọn model mặc định
default_model = "V-LegalQA" if "V-LegalQA" in loaded_models else (loaded_models[0] if loaded_models else "No models available")
# --- Tạo giao diện với Gradio ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🤖 AI Chatbot Pháp luật Việt Nam (Demo)
Chọn mô hình và đặt câu hỏi liên quan đến pháp luật.
Nhấn **Shift + Enter** để gửi câu hỏi, **Enter** để xuống dòng.
"""
)
with gr.Row():
model_choice = gr.Dropdown(
choices=loaded_models,
label="Chọn Mô hình AI",
value=default_model,
interactive=bool(loaded_models) # Chỉ cho phép tương tác nếu có model
)
# Đảm bảo 'lines' >= 2 để Shift+Enter có tác dụng rõ ràng
question_input = gr.Textbox(
label="Nhập câu hỏi của bạn (Shift+Enter để gửi)",
placeholder="Ví dụ: Thế nào là tội cố ý gây thương tích?",
lines=3, # Giữ nguyên hoặc tăng nếu muốn ô nhập cao hơn
# scale=7 # Ví dụ: làm cho ô nhập rộng hơn nếu cần
)
# --- Cập nhật giá trị mặc định trong Accordion ---
with gr.Accordion("Tùy chọn Nâng cao (Generation Parameters)", open=False):
with gr.Row():
early_stopping_checkbox = gr.Checkbox(label="Enable Early Stopping", value=False, info="Dừng sớm khi gặp token EOS.")
do_sample_checkbox = gr.Checkbox(label="Enable Sampling (do_sample)", value=False, info="Sử dụng sampling (cần thiết cho temperature, top_k, top_p). Tắt nếu muốn greedy search.")
with gr.Row():
max_new_tokens_slider = gr.Slider(minimum=10, maximum=1024, value=512, step=10, label="Max New Tokens", info="Số lượng token tối đa được sinh ra.")
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Độ 'sáng tạo' của câu trả lời (thấp hơn = bảo thủ hơn). Cần bật do_sample.")
with gr.Row():
top_k_slider = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Top-K", info="Chỉ xem xét K token có xác suất cao nhất. Cần bật do_sample.")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-P (Nucleus Sampling)", info="Chỉ xem xét các token có tổng xác suất >= P. Cần bật do_sample.")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Repetition Penalty", info="Phạt các token đã xuất hiện (cao hơn = ít lặp lại hơn).")
response_output = gr.Textbox(label="Phản hồi của Chatbot", lines=5, interactive=False)
# Nút gửi vẫn giữ lại phòng trường hợp người dùng thích click hơn
submit_btn = gr.Button("Gửi câu hỏi", variant="primary")
# --- THAY ĐỔI QUAN TRỌNG ---
# Tạo một list các inputs để dùng chung cho cả nút bấm và nhấn Enter
chatbot_inputs = [
question_input,
model_choice,
max_new_tokens_slider,
temperature_slider,
top_k_slider,
top_p_slider,
repetition_penalty_slider,
early_stopping_checkbox,
do_sample_checkbox
]
# 1. Gửi khi nhấn nút
submit_btn.click(
fn=chatbot_response,
inputs=chatbot_inputs,
outputs=response_output
)
# 2. Gửi khi nhấn Enter trong Textbox question_input
# Shift+Enter sẽ tự động xuống dòng (hành vi mặc định khi lines > 1)
question_input.submit(
fn=chatbot_response,
inputs=chatbot_inputs,
outputs=response_output
)
# -----------------------------
gr.Examples(
examples=[
["Hợp đồng vô hiệu khi nào?", "V-LegalQA"],
["Quyền và nghĩa vụ của người lao động là gì?", "ViT5"],
["Người dưới 18 tuổi có được ký hợp đồng lao động không?\nThời gian làm việc tối đa là bao lâu?", "V-LegalQA"] # Ví dụ multi-line
],
inputs=[question_input, model_choice]
)
# --- Chạy Gradio ---
if __name__ == "__main__":
if not loaded_models:
print("WARNING: No models were loaded successfully. The application might not function correctly.")
# Cân nhắc thêm: gr.Info("Không có mô hình nào được tải thành công!") trong Blocks
# Bật share=True nếu muốn tạo link chia sẻ tạm thời
demo.launch(debug=True, share=False)