Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| # Định nghĩa các nhãn theo label_id (0-9) | |
| # Mapping: label_id -> label | |
| LABELS = [ | |
| 'Chính trị Xã hội', # label_id 0 | |
| 'Khoa học', # label_id 1 | |
| 'Kinh doanh', # label_id 2 | |
| 'Pháp luật', # label_id 3 | |
| 'Sức khỏe', # label_id 4 | |
| 'Thế giới', # label_id 5 | |
| 'Thể thao', # label_id 6 | |
| 'Vi tính', # label_id 7 | |
| 'Văn hóa', # label_id 8 | |
| 'Đời sống', # label_id 9 | |
| ] | |
| # Khởi tạo device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model và tokenizer | |
| print("Đang tải model...") | |
| model_name = "cochi1706/phobert-vntc-chunk1" | |
| # Sử dụng tokenizer từ model PhoBERT gốc vì model fine-tuned có thể không có tokenizer config đầy đủ | |
| tokenizer_name = "vinai/phobert-base" # Hoặc "vinai/phobert-large" nếu model dùng large | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| print(f"Đã tải tokenizer từ {tokenizer_name}") | |
| except Exception as e: | |
| print(f"Không thể tải tokenizer từ {tokenizer_name}, thử từ model fine-tuned...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| model.eval() | |
| # Lấy max_length từ model config (nếu có) hoặc dùng giá trị mặc định | |
| # Dựa trên lỗi, model có vẻ được train với max_length=258 | |
| try: | |
| if hasattr(model.config, 'max_position_embeddings'): | |
| max_length = min(model.config.max_position_embeddings, 258) | |
| else: | |
| max_length = 258 # Giá trị dựa trên lỗi | |
| except: | |
| max_length = 258 # Giá trị mặc định dựa trên lỗi | |
| print(f"Model đã được tải thành công! Max length: {max_length}") | |
| def classify_text(text): | |
| """ | |
| Phân loại văn bản tiếng Việt | |
| """ | |
| if not text or text.strip() == "": | |
| return "Vui lòng nhập văn bản cần phân loại!" | |
| try: | |
| # Tokenize văn bản | |
| # Model có vẻ được train với max_length=258, nên cần pad đến đúng độ dài này | |
| encoding = tokenizer( | |
| text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ) | |
| # Chuyển sang device | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| # Dự đoán | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| pred_label_id = torch.argmax(outputs.logits, dim=1).item() | |
| # Lấy xác suất cho tất cả các lớp | |
| probabilities = torch.softmax(outputs.logits, dim=1)[0] | |
| # Tạo kết quả | |
| predicted_label = LABELS[pred_label_id] | |
| confidence = probabilities[pred_label_id].item() * 100 | |
| # Tạo danh sách xác suất cho tất cả các nhãn | |
| results = [] | |
| for i, label in enumerate(LABELS): | |
| prob = probabilities[i].item() * 100 | |
| results.append(f"{label}: {prob:.2f}%") | |
| result_text = f"**Nhãn dự đoán: {predicted_label}**\n" | |
| result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n" | |
| result_text += "**Xác suất cho tất cả các nhãn:**\n" | |
| result_text += "\n".join(results) | |
| return result_text | |
| except Exception as e: | |
| import traceback | |
| return f"Lỗi khi phân loại: {str(e)}\n\nTraceback: {traceback.format_exc()}" | |
| # Tạo giao diện Gradio | |
| with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 📰 Ứng dụng Phân loại Văn bản Tiếng Việt | |
| Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục: | |
| - Thế giới | |
| - Văn hóa | |
| - Chính trị Xã hội | |
| - Vi tính | |
| - Đời sống | |
| - Thể thao | |
| - Sức khỏe | |
| - Kinh doanh | |
| - Pháp luật | |
| - Khoa học | |
| Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Nhập văn bản cần phân loại", | |
| placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...", | |
| lines=5, | |
| max_lines=10 | |
| ) | |
| classify_btn = gr.Button("Phân loại", variant="primary", size="lg") | |
| with gr.Column(): | |
| output = gr.Markdown(label="Kết quả phân loại") | |
| # Ví dụ | |
| gr.Markdown("### 📝 Ví dụ:") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."], | |
| ["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."], | |
| ["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."], | |
| ["Chính phủ ban hành luật mới về bảo vệ môi trường."], | |
| ["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."] | |
| ], | |
| inputs=text_input | |
| ) | |
| # Xử lý sự kiện | |
| classify_btn.click( | |
| fn=classify_text, | |
| inputs=text_input, | |
| outputs=output | |
| ) | |
| text_input.submit( | |
| fn=classify_text, | |
| inputs=text_input, | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |