cochi1706 commited on
Commit
45c9bd2
·
1 Parent(s): 240e159

Initial implementation of the project structure and core functionality.

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from torch.utils.data import Dataset, DataLoader
5
+
6
+ # Định nghĩa các nhãn
7
+ LABELS = ['Thế giới', 'Văn hóa', 'Chính trị Xã hội', 'Vi tính', 'Đời sống',
8
+ 'Thể thao', 'Sức khỏe', 'Kinh doanh', 'Pháp luật', 'Khoa học']
9
+
10
+ # Khởi tạo device
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load model và tokenizer
14
+ print("Đang tải model...")
15
+ model_name = "cochi1706/phobert-vntc-chunk1"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
18
+ model.to(device)
19
+ model.eval()
20
+ print("Model đã được tải thành công!")
21
+
22
+ # Dataset class cho inference
23
+ class TextDataset(Dataset):
24
+ def __init__(self, texts, tokenizer, max_length=512):
25
+ self.texts = texts
26
+ self.tokenizer = tokenizer
27
+ self.max_length = max_length
28
+
29
+ def __len__(self):
30
+ return len(self.texts)
31
+
32
+ def __getitem__(self, idx):
33
+ text = str(self.texts[idx])
34
+ encoding = self.tokenizer(
35
+ text,
36
+ truncation=True,
37
+ padding='max_length',
38
+ max_length=self.max_length,
39
+ return_tensors='pt'
40
+ )
41
+ return {
42
+ 'input_ids': encoding['input_ids'].flatten(),
43
+ 'attention_mask': encoding['attention_mask'].flatten()
44
+ }
45
+
46
+ def classify_text(text):
47
+ """
48
+ Phân loại văn bản tiếng Việt
49
+ """
50
+ if not text or text.strip() == "":
51
+ return "Vui lòng nhập văn bản cần phân loại!"
52
+
53
+ try:
54
+ # Tạo dataset và dataloader
55
+ dataset = TextDataset([text], tokenizer)
56
+ dataloader = DataLoader(dataset, batch_size=1)
57
+
58
+ # Dự đoán
59
+ with torch.no_grad():
60
+ for batch in dataloader:
61
+ batch = {k: v.to(device) for k, v in batch.items()}
62
+ outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
63
+ pred_label_id = torch.argmax(outputs.logits, dim=1).item()
64
+
65
+ # Lấy xác suất cho tất cả các lớp
66
+ probabilities = torch.softmax(outputs.logits, dim=1)[0]
67
+
68
+ # Tạo kết quả
69
+ predicted_label = LABELS[pred_label_id]
70
+ confidence = probabilities[pred_label_id].item() * 100
71
+
72
+ # Tạo danh sách xác suất cho tất cả các nhãn
73
+ results = []
74
+ for i, label in enumerate(LABELS):
75
+ prob = probabilities[i].item() * 100
76
+ results.append(f"{label}: {prob:.2f}%")
77
+
78
+ result_text = f"**Nhãn dự đoán: {predicted_label}**\n"
79
+ result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n"
80
+ result_text += "**Xác suất cho tất cả các nhãn:**\n"
81
+ result_text += "\n".join(results)
82
+
83
+ return result_text
84
+
85
+ except Exception as e:
86
+ return f"Lỗi khi phân loại: {str(e)}"
87
+
88
+ # Tạo giao diện Gradio
89
+ with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown(
91
+ """
92
+ # 📰 Ứng dụng Phân loại Văn bản Tiếng Việt
93
+
94
+ Ứng dụng này sử dụng mô hình PhoBERT để phân loại văn bản tiếng Việt vào 10 danh mục:
95
+ - Thế giới
96
+ - Văn hóa
97
+ - Chính trị Xã hội
98
+ - Vi tính
99
+ - Đời sống
100
+ - Thể thao
101
+ - Sức khỏe
102
+ - Kinh doanh
103
+ - Pháp luật
104
+ - Khoa học
105
+
106
+ Nhập văn bản vào ô bên dưới và nhấn nút "Phân loại" để xem kết quả!
107
+ """
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ text_input = gr.Textbox(
113
+ label="Nhập văn bản cần phân loại",
114
+ placeholder="Ví dụ: Hôm nay thị trường chứng khoán tăng điểm mạnh...",
115
+ lines=5,
116
+ max_lines=10
117
+ )
118
+ classify_btn = gr.Button("Phân loại", variant="primary", size="lg")
119
+
120
+ with gr.Column():
121
+ output = gr.Markdown(label="Kết quả phân loại")
122
+
123
+ # Ví dụ
124
+ gr.Markdown("### 📝 Ví dụ:")
125
+ examples = gr.Examples(
126
+ examples=[
127
+ ["Hôm nay thị trường chứng khoán tăng điểm mạnh, nhiều mã cổ phiếu đạt trần."],
128
+ ["Đội tuyển bóng đá Việt Nam giành chiến thắng trong trận đấu tối qua."],
129
+ ["Các nhà khoa học phát hiện ra phương pháp mới trong điều trị ung thư."],
130
+ ["Chính phủ ban hành luật mới về bảo vệ môi trường."],
131
+ ["Công nghệ AI đang phát triển mạnh mẽ trong lĩnh vực y tế."]
132
+ ],
133
+ inputs=text_input
134
+ )
135
+
136
+ # Xử lý sự kiện
137
+ classify_btn.click(
138
+ fn=classify_text,
139
+ inputs=text_input,
140
+ outputs=output
141
+ )
142
+
143
+ text_input.submit(
144
+ fn=classify_text,
145
+ inputs=text_input,
146
+ outputs=output
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.launch(share=False)
151
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=5.49.1
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ accelerate>=0.20.0
5
+