import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # Định nghĩa các nhãn theo label_id (0-9) # Mapping: label_id -> label LABELS = [ 'Chính trị Xã hội', # label_id 0 'Khoa học', # label_id 1 'Kinh doanh', # label_id 2 'Pháp luật', # label_id 3 'Sức khỏe', # label_id 4 'Thế giới', # label_id 5 'Thể thao', # label_id 6 'Vi tính', # label_id 7 'Văn hóa', # label_id 8 'Đời sống', # label_id 9 ] # Khởi tạo device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model và tokenizer print("Đang tải model...") model_name = "cochi1706/phobert-vntc-chunk1" # Sử dụng tokenizer từ model PhoBERT gốc vì model fine-tuned có thể không có tokenizer config đầy đủ tokenizer_name = "vinai/phobert-base" # Hoặc "vinai/phobert-large" nếu model dùng large try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) print(f"Đã tải tokenizer từ {tokenizer_name}") except Exception as e: print(f"Không thể tải tokenizer từ {tokenizer_name}, thử từ model fine-tuned...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.to(device) model.eval() # Lấy max_length từ model config (nếu có) hoặc dùng giá trị mặc định # Dựa trên lỗi, model có vẻ được train với max_length=258 try: if hasattr(model.config, 'max_position_embeddings'): max_length = min(model.config.max_position_embeddings, 258) else: max_length = 258 # Giá trị dựa trên lỗi except: max_length = 258 # Giá trị mặc định dựa trên lỗi print(f"Model đã được tải thành công! Max length: {max_length}") def classify_text(text): """ Phân loại văn bản tiếng Việt """ if not text or text.strip() == "": return "Vui lòng nhập văn bản cần phân loại!" try: # Tokenize văn bản # Model có vẻ được train với max_length=258, nên cần pad đến đúng độ dài này encoding = tokenizer( text, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) # Chuyển sang device input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) # Dự đoán with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) pred_label_id = torch.argmax(outputs.logits, dim=1).item() # Lấy xác suất cho tất cả các lớp probabilities = torch.softmax(outputs.logits, dim=1)[0] # Tạo kết quả predicted_label = LABELS[pred_label_id] confidence = probabilities[pred_label_id].item() * 100 # Tạo danh sách xác suất cho tất cả các nhãn results = [] for i, label in enumerate(LABELS): prob = probabilities[i].item() * 100 results.append(f"{label}: {prob:.2f}%") result_text = f"**Nhãn dự đoán: {predicted_label}**\n" result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n" result_text += "**Xác suất cho tất cả các nhãn:**\n" result_text += "\n".join(results) return result_text except Exception as e: import traceback return f"Lỗi khi phân loại: {str(e)}\n\nTraceback: {traceback.format_exc()}" # Tạo giao diện Gradio with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 📰 Ứng dụng Phân loại Văn bản Tiếng Việt Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục: - Thế giới - Văn hóa - Chính trị Xã hội - Vi tính - Đời sống - Thể thao - Sức khỏe - Kinh doanh - Pháp luật - Khoa học Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả! """ ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Nhập văn bản cần phân loại", placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...", lines=5, max_lines=10 ) classify_btn = gr.Button("Phân loại", variant="primary", size="lg") with gr.Column(): output = gr.Markdown(label="Kết quả phân loại") # Ví dụ gr.Markdown("### 📝 Ví dụ:") examples = gr.Examples( examples=[ ["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."], ["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."], ["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."], ["Chính phủ ban hành luật mới về bảo vệ môi trường."], ["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."] ], inputs=text_input ) # Xử lý sự kiện classify_btn.click( fn=classify_text, inputs=text_input, outputs=output ) text_input.submit( fn=classify_text, inputs=text_input, outputs=output ) if __name__ == "__main__": demo.launch(share=False)