Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import gradio as gr | |
| # Kiểm tra thiết bị (GPU nếu có) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") # Thêm log để biết thiết bị đang sử dụng | |
| # --- Tải mô hình và tokenizer --- | |
| # Khởi tạo biến model và tokenizer là None | |
| model_1, tokenizer_1 = None, None | |
| model_2, tokenizer_2 = None, None | |
| # model_3, tokenizer_3 = None, None # Không cần tải model_3 nữa | |
| model_4, tokenizer_4 = None, None | |
| # Sử dụng try-except để xử lý lỗi nếu không tải được mô hình | |
| try: | |
| model_name_1 = "sunbv56/ViLawT5_QAChatBot" | |
| print(f"Loading model: {model_name_1}...") | |
| model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device) | |
| tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1) | |
| print(f"Model {model_name_1} loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model {model_name_1}: {e}") | |
| try: | |
| model_name_2 = "sunbv56/ViT5_QAChatBot" | |
| print(f"Loading model: {model_name_2}...") | |
| model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device) | |
| tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2) | |
| print(f"Model {model_name_2} loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model {model_name_2}: {e}") | |
| # Bỏ qua việc tải model_3 (ViLawT5_RL) | |
| # ... (phần code tải model_3 bị comment như cũ) ... | |
| try: | |
| model_name_4 = "sunbv56/V-LegalQA" | |
| print(f"Loading model: {model_name_4}...") | |
| model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device) | |
| tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4) | |
| print(f"Model {model_name_4} loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model {model_name_4}: {e}") | |
| # --- Hàm sinh phản hồi --- | |
| def chatbot_response(question, model_choice, max_new_tokens, temperature, top_k, top_p, repetition_penalty, use_early_stopping, use_do_sample): | |
| model = None | |
| tokenizer = None | |
| # Chọn model dựa trên lựa chọn của người dùng (đã bỏ ViLawT5_RL) | |
| if model_choice == "ViLawT5" and model_1 and tokenizer_1: | |
| model = model_1 | |
| tokenizer = tokenizer_1 | |
| elif model_choice == "ViT5" and model_2 and tokenizer_2: | |
| model = model_2 | |
| tokenizer = tokenizer_2 | |
| # Bỏ điều kiện kiểm tra ViLawT5_RL | |
| # elif model_choice == "ViLawT5_RL" and model_3 and tokenizer_3: | |
| # model = model_3 | |
| # tokenizer = tokenizer_3 | |
| elif model_choice == "V-LegalQA" and model_4 and tokenizer_4: | |
| model = model_4 | |
| tokenizer = tokenizer_4 | |
| else: | |
| # Kiểm tra xem model có được tải không | |
| available_models = [] | |
| if model_1: available_models.append("ViLawT5") | |
| if model_2: available_models.append("ViT5") | |
| # Không thêm ViLawT5_RL vào danh sách kiểm tra | |
| if model_4: available_models.append("V-LegalQA") | |
| if not available_models: | |
| return "Error: No models were loaded successfully. Please check the logs." | |
| if model_choice not in available_models: | |
| return f"Error: Model '{model_choice}' was not loaded successfully or is invalid. Available models: {', '.join(available_models)}" | |
| else: # Trường hợp model_choice hợp lệ nhưng model/tokenizer là None (lỗi không mong muốn) | |
| return f"Error: An unexpected issue occurred with model '{model_choice}'. Please check the logs." | |
| print(f"Generating response using {model_choice} with params: max_new_tokens={max_new_tokens}, temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty}, early_stop={use_early_stopping}, do_sample={use_do_sample}") | |
| input_text = f"câu_hỏi: {question}" | |
| try: | |
| data = tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| return_attention_mask=True, | |
| add_special_tokens=True, | |
| padding="max_length", | |
| max_length=256 # Cân nhắc tăng max_length nếu câu hỏi/context dài | |
| ) | |
| input_ids = data.input_ids.to(device) | |
| attention_mask = data.attention_mask.to(device) | |
| # Suy luận với mô hình | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=int(max_new_tokens), | |
| early_stopping=use_early_stopping, | |
| do_sample=use_do_sample, | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| # Thêm pad_token_id nếu cần (thường không cần cho T5) | |
| # pad_token_id=tokenizer.pad_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"Raw output shape: {outputs[0].shape}") # Log thêm shape | |
| print(f"Decoded response: {response}") | |
| return response | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| # In thêm traceback để debug | |
| import traceback | |
| traceback.print_exc() | |
| return f"An error occurred during response generation: {e}" | |
| # --- Tạo danh sách các model đã tải thành công (bỏ ViLawT5_RL) --- | |
| loaded_models = [] | |
| if model_1 and tokenizer_1: loaded_models.append("ViLawT5") | |
| if model_2 and tokenizer_2: loaded_models.append("ViT5") | |
| if model_4 and tokenizer_4: loaded_models.append("V-LegalQA") | |
| # Chọn model mặc định | |
| default_model = "V-LegalQA" if "V-LegalQA" in loaded_models else (loaded_models[0] if loaded_models else "No models available") | |
| # --- Tạo giao diện với Gradio --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 AI Chatbot Pháp luật Việt Nam (Demo) | |
| Chọn mô hình và đặt câu hỏi liên quan đến pháp luật. | |
| Nhấn **Shift + Enter** để gửi câu hỏi, **Enter** để xuống dòng. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=loaded_models, | |
| label="Chọn Mô hình AI", | |
| value=default_model, | |
| interactive=bool(loaded_models) # Chỉ cho phép tương tác nếu có model | |
| ) | |
| # Đảm bảo 'lines' >= 2 để Shift+Enter có tác dụng rõ ràng | |
| question_input = gr.Textbox( | |
| label="Nhập câu hỏi của bạn (Shift+Enter để gửi)", | |
| placeholder="Ví dụ: Thế nào là tội cố ý gây thương tích?", | |
| lines=3, # Giữ nguyên hoặc tăng nếu muốn ô nhập cao hơn | |
| # scale=7 # Ví dụ: làm cho ô nhập rộng hơn nếu cần | |
| ) | |
| # --- Cập nhật giá trị mặc định trong Accordion --- | |
| with gr.Accordion("Tùy chọn Nâng cao (Generation Parameters)", open=False): | |
| with gr.Row(): | |
| early_stopping_checkbox = gr.Checkbox(label="Enable Early Stopping", value=False, info="Dừng sớm khi gặp token EOS.") | |
| do_sample_checkbox = gr.Checkbox(label="Enable Sampling (do_sample)", value=False, info="Sử dụng sampling (cần thiết cho temperature, top_k, top_p). Tắt nếu muốn greedy search.") | |
| with gr.Row(): | |
| max_new_tokens_slider = gr.Slider(minimum=10, maximum=1024, value=512, step=10, label="Max New Tokens", info="Số lượng token tối đa được sinh ra.") | |
| temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Độ 'sáng tạo' của câu trả lời (thấp hơn = bảo thủ hơn). Cần bật do_sample.") | |
| with gr.Row(): | |
| top_k_slider = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Top-K", info="Chỉ xem xét K token có xác suất cao nhất. Cần bật do_sample.") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-P (Nucleus Sampling)", info="Chỉ xem xét các token có tổng xác suất >= P. Cần bật do_sample.") | |
| repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Repetition Penalty", info="Phạt các token đã xuất hiện (cao hơn = ít lặp lại hơn).") | |
| response_output = gr.Textbox(label="Phản hồi của Chatbot", lines=5, interactive=False) | |
| # Nút gửi vẫn giữ lại phòng trường hợp người dùng thích click hơn | |
| submit_btn = gr.Button("Gửi câu hỏi", variant="primary") | |
| # --- THAY ĐỔI QUAN TRỌNG --- | |
| # Tạo một list các inputs để dùng chung cho cả nút bấm và nhấn Enter | |
| chatbot_inputs = [ | |
| question_input, | |
| model_choice, | |
| max_new_tokens_slider, | |
| temperature_slider, | |
| top_k_slider, | |
| top_p_slider, | |
| repetition_penalty_slider, | |
| early_stopping_checkbox, | |
| do_sample_checkbox | |
| ] | |
| # 1. Gửi khi nhấn nút | |
| submit_btn.click( | |
| fn=chatbot_response, | |
| inputs=chatbot_inputs, | |
| outputs=response_output | |
| ) | |
| # 2. Gửi khi nhấn Enter trong Textbox question_input | |
| # Shift+Enter sẽ tự động xuống dòng (hành vi mặc định khi lines > 1) | |
| question_input.submit( | |
| fn=chatbot_response, | |
| inputs=chatbot_inputs, | |
| outputs=response_output | |
| ) | |
| # ----------------------------- | |
| gr.Examples( | |
| examples=[ | |
| ["Hợp đồng vô hiệu khi nào?", "V-LegalQA"], | |
| ["Quyền và nghĩa vụ của người lao động là gì?", "ViT5"], | |
| ["Người dưới 18 tuổi có được ký hợp đồng lao động không?\nThời gian làm việc tối đa là bao lâu?", "V-LegalQA"] # Ví dụ multi-line | |
| ], | |
| inputs=[question_input, model_choice] | |
| ) | |
| # --- Chạy Gradio --- | |
| if __name__ == "__main__": | |
| if not loaded_models: | |
| print("WARNING: No models were loaded successfully. The application might not function correctly.") | |
| # Cân nhắc thêm: gr.Info("Không có mô hình nào được tải thành công!") trong Blocks | |
| # Bật share=True nếu muốn tạo link chia sẻ tạm thời | |
| demo.launch(debug=True, share=False) |