cochi1706 commited on
Commit
12c8e7c
·
1 Parent(s): 9cef669

Refactor text classification logic to dynamically set max_length based on model configuration and streamline tokenization process, enhancing error handling with detailed traceback.

Browse files
Files changed (1) hide show
  1. app.py +48 -53
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
- from torch.utils.data import Dataset, DataLoader
5
 
6
  # Định nghĩa các nhãn
7
  LABELS = ['Thế giới', 'Văn hóa', 'Chính trị Xã hội', 'Vi tính', 'Đời sống',
@@ -25,31 +24,18 @@ except Exception as e:
25
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
26
  model.to(device)
27
  model.eval()
28
- print("Model đã được tải thành công!")
29
 
30
- # Dataset class cho inference
31
- class TextDataset(Dataset):
32
- def __init__(self, texts, tokenizer, max_length=512):
33
- self.texts = texts
34
- self.tokenizer = tokenizer
35
- self.max_length = max_length
36
-
37
- def __len__(self):
38
- return len(self.texts)
39
-
40
- def __getitem__(self, idx):
41
- text = str(self.texts[idx])
42
- encoding = self.tokenizer(
43
- text,
44
- truncation=True,
45
- padding='max_length',
46
- max_length=self.max_length,
47
- return_tensors='pt'
48
- )
49
- return {
50
- 'input_ids': encoding['input_ids'].flatten(),
51
- 'attention_mask': encoding['attention_mask'].flatten()
52
- }
53
 
54
  def classify_text(text):
55
  """
@@ -59,39 +45,48 @@ def classify_text(text):
59
  return "Vui lòng nhập văn bản cần phân loại!"
60
 
61
  try:
62
- # Tạo dataset và dataloader
63
- dataset = TextDataset([text], tokenizer)
64
- dataloader = DataLoader(dataset, batch_size=1)
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Dự đoán
67
  with torch.no_grad():
68
- for batch in dataloader:
69
- batch = {k: v.to(device) for k, v in batch.items()}
70
- outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
71
- pred_label_id = torch.argmax(outputs.logits, dim=1).item()
72
-
73
- # Lấy xác suất cho tất cả các lớp
74
- probabilities = torch.softmax(outputs.logits, dim=1)[0]
75
-
76
- # Tạo kết quả
77
- predicted_label = LABELS[pred_label_id]
78
- confidence = probabilities[pred_label_id].item() * 100
79
-
80
- # Tạo danh sách xác suất cho tất cả các nhãn
81
- results = []
82
- for i, label in enumerate(LABELS):
83
- prob = probabilities[i].item() * 100
84
- results.append(f"{label}: {prob:.2f}%")
85
-
86
- result_text = f"**Nhãn dự đoán: {predicted_label}**\n"
87
- result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n"
88
- result_text += "**Xác suất cho tất cả các nhãn:**\n"
89
- result_text += "\n".join(results)
90
-
91
- return result_text
92
 
93
  except Exception as e:
94
- return f"Lỗi khi phân loại: {str(e)}"
 
95
 
96
  # Tạo giao diện Gradio
97
  with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo:
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
4
 
5
  # Định nghĩa các nhãn
6
  LABELS = ['Thế giới', 'Văn hóa', 'Chính trị Xã hội', 'Vi tính', 'Đời sống',
 
24
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
25
  model.to(device)
26
  model.eval()
 
27
 
28
+ # Lấy max_length từ model config (nếu có) hoặc dùng giá trị mặc định
29
+ # Dựa trên lỗi, model có vẻ được train với max_length=258
30
+ try:
31
+ if hasattr(model.config, 'max_position_embeddings'):
32
+ max_length = min(model.config.max_position_embeddings, 258)
33
+ else:
34
+ max_length = 258 # Giá trị dựa trên lỗi
35
+ except:
36
+ max_length = 258 # Giá trị mặc định dựa trên lỗi
37
+
38
+ print(f"Model đã được tải thành công! Max length: {max_length}")
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def classify_text(text):
41
  """
 
45
  return "Vui lòng nhập văn bản cần phân loại!"
46
 
47
  try:
48
+ # Tokenize văn bản
49
+ # Model có vẻ được train với max_length=258, nên cần pad đến đúng độ dài này
50
+ encoding = tokenizer(
51
+ text,
52
+ truncation=True,
53
+ padding='max_length',
54
+ max_length=max_length,
55
+ return_tensors='pt'
56
+ )
57
+
58
+ # Chuyển sang device
59
+ input_ids = encoding['input_ids'].to(device)
60
+ attention_mask = encoding['attention_mask'].to(device)
61
 
62
  # Dự đoán
63
  with torch.no_grad():
64
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
65
+ pred_label_id = torch.argmax(outputs.logits, dim=1).item()
66
+
67
+ # Lấy xác suất cho tất cả các lớp
68
+ probabilities = torch.softmax(outputs.logits, dim=1)[0]
69
+
70
+ # Tạo kết quả
71
+ predicted_label = LABELS[pred_label_id]
72
+ confidence = probabilities[pred_label_id].item() * 100
73
+
74
+ # Tạo danh sách xác suất cho tất cả các nhãn
75
+ results = []
76
+ for i, label in enumerate(LABELS):
77
+ prob = probabilities[i].item() * 100
78
+ results.append(f"{label}: {prob:.2f}%")
79
+
80
+ result_text = f"**Nhãn dự đoán: {predicted_label}**\n"
81
+ result_text += f"**Độ tin cậy: {confidence:.2f}%**\n\n"
82
+ result_text += "**Xác suất cho tất cả các nhãn:**\n"
83
+ result_text += "\n".join(results)
84
+
85
+ return result_text
 
 
86
 
87
  except Exception as e:
88
+ import traceback
89
+ return f"Lỗi khi phân loại: {str(e)}\n\nTraceback: {traceback.format_exc()}"
90
 
91
  # Tạo giao diện Gradio
92
  with gr.Blocks(title="Phân loại văn bản tiếng Việt", theme=gr.themes.Soft()) as demo: