Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from torch.utils.data import Dataset, DataLoader | |
| # Định nghĩa các nhãn | |
| LABELS = ['Thế giới', 'Văn hóa', 'Chính trị Xã hội', 'Vi tính', 'Đời sống', | |
| 'Thể thao', 'Sức khỏe', 'Kinh doanh', 'Pháp luật', 'Khoa học'] | |
| # Khởi tạo device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model và tokenizer | |
| print("Đang tải model...") | |
| model_name = "cochi1706/phobert-vntc-chunk1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| model.eval() | |
| print("Model đã được tải thành công!") | |
| # Dataset class cho inference | |
| class TextDataset(Dataset): | |
| def __init__(self, texts, tokenizer, max_length=512): | |
| self.texts = texts | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = str(self.texts[idx]) | |
| encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten() | |
| } | |
| def classify_text(text): | |
| """ | |
| Phân loại văn bản tiếng Việt | |
| """ | |
| if not text or text.strip() == "": | |
| return "Vui lòng nhập văn bản cần phân loại!" | |
| try: | |
| # Tạo dataset và dataloader | |
| dataset = TextDataset([text], tokenizer) | |
| dataloader = DataLoader(dataset, batch_size=1) | |
| # Dự đoán | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) | |
| pred_label_id = torch.argmax(outputs.logits, dim=1).item() | |
| # Lấy xác suất cho tất cả các lớp | |
| probabilities = torch.softmax(outputs.logits, dim=1)[0] | |
| # Tạo kết quả | |
| predicted_label = LABELS[pred_label_id] | |
| confidence = probabilities[pred_label_id].item() * 100 | |
| # Tạo danh sách xác suất cho tất cả các nhãn | |
| results = [] | |
| for i, label in enumerate(LABELS): | |
| prob = probabilities[i].item() * 100 | |
| results.append(f"{label}: {prob:.2f}%") | |
| result_text = f"**Nhãn dự đoán: {predicted_label}**\n" | |
| result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n" | |
| result_text += "**Xác suất cho tất cả các nhãn:**\n" | |
| result_text += "\n".join(results) | |
| return result_text | |
| except Exception as e: | |
| return f"Lỗi khi phân loại: {str(e)}" | |
| # Tạo giao diện Gradio | |
| with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 📰 Ứng dụng Phân loại Văn bản Tiếng Việt | |
| Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục: | |
| - Thế giới | |
| - Văn hóa | |
| - Chính trị Xã hội | |
| - Vi tính | |
| - Đời sống | |
| - Thể thao | |
| - Sức khỏe | |
| - Kinh doanh | |
| - Pháp luật | |
| - Khoa học | |
| Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Nhập văn bản cần phân loại", | |
| placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...", | |
| lines=5, | |
| max_lines=10 | |
| ) | |
| classify_btn = gr.Button("Phân loại", variant="primary", size="lg") | |
| with gr.Column(): | |
| output = gr.Markdown(label="Kết quả phân loại") | |
| # Ví dụ | |
| gr.Markdown("### 📝 Ví dụ:") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."], | |
| ["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."], | |
| ["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."], | |
| ["Chính phủ ban hành luật mới về bảo vệ môi trường."], | |
| ["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."] | |
| ], | |
| inputs=text_input | |
| ) | |
| # Xử lý sự kiện | |
| classify_btn.click( | |
| fn=classify_text, | |
| inputs=text_input, | |
| outputs=output | |
| ) | |
| text_input.submit( | |
| fn=classify_text, | |
| inputs=text_input, | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |