cochi1706's picture
Update label definitions in app.py to include label_id mapping for improved clarity and organization.
9a2ab2b
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Định nghĩa các nhãn theo label_id (0-9)
# Mapping: label_id -> label
LABELS = [
'Chính trị Xã hội', # label_id 0
'Khoa học', # label_id 1
'Kinh doanh', # label_id 2
'Pháp luật', # label_id 3
'Sức khỏe', # label_id 4
'Thế giới', # label_id 5
'Thể thao', # label_id 6
'Vi tính', # label_id 7
'Văn hóa', # label_id 8
'Đời sống', # label_id 9
]
# Khởi tạo device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model và tokenizer
print("Đang tải model...")
model_name = "cochi1706/phobert-vntc-chunk1"
# Sử dụng tokenizer từ model PhoBERT gốc vì model fine-tuned có thể không có tokenizer config đầy đủ
tokenizer_name = "vinai/phobert-base" # Hoặc "vinai/phobert-large" nếu model dùng large
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
print(f"Đã tải tokenizer từ {tokenizer_name}")
except Exception as e:
print(f"Không thể tải tokenizer từ {tokenizer_name}, thử từ model fine-tuned...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.to(device)
model.eval()
# Lấy max_length từ model config (nếu có) hoặc dùng giá trị mặc định
# Dựa trên lỗi, model có vẻ được train với max_length=258
try:
if hasattr(model.config, 'max_position_embeddings'):
max_length = min(model.config.max_position_embeddings, 258)
else:
max_length = 258 # Giá trị dựa trên lỗi
except:
max_length = 258 # Giá trị mặc định dựa trên lỗi
print(f"Model đã được tải thành công! Max length: {max_length}")
def classify_text(text):
"""
Phân loại văn bản tiếng Việt
"""
if not text or text.strip() == "":
return "Vui lòng nhập văn bản cần phân loại!"
try:
# Tokenize văn bản
# Model có vẻ được train với max_length=258, nên cần pad đến đúng độ dài này
encoding = tokenizer(
text,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
# Chuyển sang device
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Dự đoán
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred_label_id = torch.argmax(outputs.logits, dim=1).item()
# Lấy xác suất cho tất cả các lớp
probabilities = torch.softmax(outputs.logits, dim=1)[0]
# Tạo kết quả
predicted_label = LABELS[pred_label_id]
confidence = probabilities[pred_label_id].item() * 100
# Tạo danh sách xác suất cho tất cả các nhãn
results = []
for i, label in enumerate(LABELS):
prob = probabilities[i].item() * 100
results.append(f"{label}: {prob:.2f}%")
result_text = f"**Nhãn dự đoán: {predicted_label}**\n"
result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n"
result_text += "**Xác suất cho tất cả các nhãn:**\n"
result_text += "\n".join(results)
return result_text
except Exception as e:
import traceback
return f"Lỗi khi phân loại: {str(e)}\n\nTraceback: {traceback.format_exc()}"
# Tạo giao diện Gradio
with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📰 Ứng dụng Phân loại Văn bản Tiếng Việt
Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục:
- Thế giới
- Văn hóa
- Chính trị Xã hội
- Vi tính
- Đời sống
- Thể thao
- Sức khỏe
- Kinh doanh
- Pháp luật
- Khoa học
Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả!
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Nhập văn bản cần phân loại",
placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...",
lines=5,
max_lines=10
)
classify_btn = gr.Button("Phân loại", variant="primary", size="lg")
with gr.Column():
output = gr.Markdown(label="Kết quả phân loại")
# Ví dụ
gr.Markdown("### 📝 Ví dụ:")
examples = gr.Examples(
examples=[
["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."],
["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."],
["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."],
["Chính phủ ban hành luật mới về bảo vệ môi trường."],
["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."]
],
inputs=text_input
)
# Xử lý sự kiện
classify_btn.click(
fn=classify_text,
inputs=text_input,
outputs=output
)
text_input.submit(
fn=classify_text,
inputs=text_input,
outputs=output
)
if __name__ == "__main__":
demo.launch(share=False)