cochi1706's picture
Initial implementation of the project structure and core functionality.
45c9bd2
raw
history blame
5.39 kB
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
# Định nghĩa các nhãn
LABELS = ['Thế giới', 'Văn hóa', 'Chính trị Xã hội', 'Vi tính', 'Đời sống',
'Thể thao', 'Sức khỏe', 'Kinh doanh', 'Pháp luật', 'Khoa học']
# Khởi tạo device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model và tokenizer
print("Đang tải model...")
model_name = "cochi1706/phobert-vntc-chunk1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
model.eval()
print("Model đã được tải thành công!")
# Dataset class cho inference
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
def classify_text(text):
"""
Phân loại văn bản tiếng Việt
"""
if not text or text.strip() == "":
return "Vui lòng nhập văn bản cần phân loại!"
try:
# Tạo dataset và dataloader
dataset = TextDataset([text], tokenizer)
dataloader = DataLoader(dataset, batch_size=1)
# Dự đoán
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
pred_label_id = torch.argmax(outputs.logits, dim=1).item()
# Lấy xác suất cho tất cả các lớp
probabilities = torch.softmax(outputs.logits, dim=1)[0]
# Tạo kết quả
predicted_label = LABELS[pred_label_id]
confidence = probabilities[pred_label_id].item() * 100
# Tạo danh sách xác suất cho tất cả các nhãn
results = []
for i, label in enumerate(LABELS):
prob = probabilities[i].item() * 100
results.append(f"{label}: {prob:.2f}%")
result_text = f"**Nhãn dự đoán: {predicted_label}**\n"
result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n"
result_text += "**Xác suất cho tất cả các nhãn:**\n"
result_text += "\n".join(results)
return result_text
except Exception as e:
return f"Lỗi khi phân loại: {str(e)}"
# Tạo giao diện Gradio
with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📰 Ứng dụng Phân loại Văn bản Tiếng Việt
Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục:
- Thế giới
- Văn hóa
- Chính trị Xã hội
- Vi tính
- Đời sống
- Thể thao
- Sức khỏe
- Kinh doanh
- Pháp luật
- Khoa học
Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả!
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Nhập văn bản cần phân loại",
placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...",
lines=5,
max_lines=10
)
classify_btn = gr.Button("Phân loại", variant="primary", size="lg")
with gr.Column():
output = gr.Markdown(label="Kết quả phân loại")
# Ví dụ
gr.Markdown("### 📝 Ví dụ:")
examples = gr.Examples(
examples=[
["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."],
["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."],
["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."],
["Chính phủ ban hành luật mới về bảo vệ môi trường."],
["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."]
],
inputs=text_input
)
# Xử lý sự kiện
classify_btn.click(
fn=classify_text,
inputs=text_input,
outputs=output
)
text_input.submit(
fn=classify_text,
inputs=text_input,
outputs=output
)
if __name__ == "__main__":
demo.launch(share=False)