AnnyNguyen commited on
Commit
5f742b3
·
verified ·
1 Parent(s): 3c18a1a

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +227 -0
models.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module định nghĩa các mô hình cho spam review detection
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification
8
+ from .custom_models import TextCNN, BiLSTM, RoBERTaGRU, SPhoBERT
9
+
10
+ class TransformerForSpamDetection(nn.Module):
11
+ """
12
+ Base transformer model cho spam review detection
13
+ """
14
+ def __init__(self, model_name: str, num_labels: int):
15
+ super().__init__()
16
+ config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
17
+ self.encoder = AutoModel.from_pretrained(model_name, config=config)
18
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
19
+ self.dropout = nn.Dropout(0.1)
20
+
21
+ def forward(self, input_ids, attention_mask, labels=None, **kwargs):
22
+ # Filter out arguments that BertModel doesn't expect
23
+ filtered_kwargs = {k: v for k, v in kwargs.items()
24
+ if k not in ['num_items_in_batch', 'position_ids']}
25
+
26
+ # Pass filtered arguments to encoder (including token_type_ids for BERT)
27
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
28
+ pooled = out.last_hidden_state[:, 0] # CLS token
29
+ pooled = self.dropout(pooled)
30
+ logits = self.classifier(pooled)
31
+ loss = None
32
+ if labels is not None:
33
+ loss_fn = nn.CrossEntropyLoss()
34
+ loss = loss_fn(logits, labels)
35
+ return {"loss": loss, "logits": logits}
36
+
37
+ class ViT5ForSpamDetection(nn.Module):
38
+ """
39
+ ViT5 model cho spam review detection - sử dụng encoder-only approach
40
+ """
41
+ def __init__(self, model_name: str, num_labels: int):
42
+ super().__init__()
43
+ from transformers import T5EncoderModel, T5Config
44
+
45
+ # Load T5 encoder only
46
+ config = T5Config.from_pretrained(model_name)
47
+ self.t5_encoder = T5EncoderModel.from_pretrained(model_name, config=config)
48
+
49
+ # Classification head
50
+ self.classifier = nn.Linear(config.d_model, num_labels)
51
+ self.dropout = nn.Dropout(0.1)
52
+
53
+ def forward(self, input_ids, attention_mask, labels=None, **kwargs):
54
+ # Filter out arguments that T5EncoderModel doesn't expect
55
+ filtered_kwargs = {k: v for k, v in kwargs.items()
56
+ if k not in ['num_items_in_batch', 'position_ids']}
57
+
58
+ # Chỉ sử dụng encoder của T5
59
+ encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
60
+
61
+ # Lấy pooled representation (first token)
62
+ pooled = encoder_outputs.last_hidden_state[:, 0]
63
+ pooled = self.dropout(pooled)
64
+ logits = self.classifier(pooled)
65
+
66
+ loss = None
67
+ if labels is not None:
68
+ loss_fn = nn.CrossEntropyLoss()
69
+ loss = loss_fn(logits, labels)
70
+
71
+ return {"loss": loss, "logits": logits}
72
+
73
+ def get_model(model_name: str, num_labels: int, vocab_size: int = None):
74
+ """
75
+ Factory function để tạo model dựa trên tên model
76
+
77
+ Args:
78
+ model_name: Tên model (phobert-v2, textcnn, bilstm, etc.)
79
+ num_labels: Số lượng classes
80
+ vocab_size: Kích thước vocabulary (chỉ cần cho BiLSTM-CRF)
81
+
82
+ Returns:
83
+ Model instance
84
+ """
85
+ # Mapping từ model name đến base model
86
+ model_mapping = {
87
+ "phobert-v1": "vinai/phobert-base",
88
+ "phobert-v2": "vinai/phobert-base-v2",
89
+ "bartpho": "vinai/bartpho-syllable",
90
+ "visobert": "uitnlp/visobert",
91
+ "xlm-r": "xlm-roberta-large",
92
+ "mbert": "bert-base-multilingual-cased",
93
+ "vit5": "VietAI/vit5-base"
94
+ }
95
+
96
+ if model_name == "vit5":
97
+ # Sử dụng ViT5ForSpamDetection cho T5 model
98
+ base_model_name = model_mapping[model_name]
99
+ return ViT5ForSpamDetection(base_model_name, num_labels)
100
+ elif model_name in model_mapping:
101
+ # Sử dụng standard transformer model
102
+ base_model_name = model_mapping[model_name]
103
+ return TransformerForSpamDetection(base_model_name, num_labels)
104
+
105
+ elif model_name == "textcnn":
106
+ # TextCNN custom model
107
+ base_model_name = "vinai/phobert-base-v2" # Sử dụng PhoBERT embeddings
108
+ return TextCNN(base_model_name, num_labels)
109
+
110
+ elif model_name == "bilstm":
111
+ # BiLSTM custom model
112
+ base_model_name = "vinai/phobert-base-v2"
113
+ return BiLSTM(base_model_name, num_labels)
114
+
115
+ elif model_name == "roberta-gru":
116
+ # RoBERTa-GRU hybrid model
117
+ base_model_name = "vinai/phobert-base-v2"
118
+ return RoBERTaGRU(base_model_name, num_labels)
119
+
120
+ elif model_name == "sphobert":
121
+ # SPhoBERT fusion model
122
+ base_model_name = "vinai/phobert-base-v2"
123
+ return SPhoBERT(base_model_name, num_labels)
124
+
125
+ elif model_name == "bilstm-crf":
126
+ # BiLSTM-CRF model (placeholder implementation)
127
+ # Trong thực tế cần implement CRF layer
128
+ base_model_name = "vinai/phobert-base-v2"
129
+ return BiLSTM(base_model_name, num_labels)
130
+
131
+ else:
132
+ raise ValueError(f"Unknown model name: {model_name}. Available models: {list(model_mapping.keys()) + ['textcnn', 'bilstm', 'roberta-gru', 'sphobert', 'bilstm-crf']}")
133
+
134
+ def get_model_config(model_name: str):
135
+ """
136
+ Lấy cấu hình cho model
137
+
138
+ Args:
139
+ model_name: Tên model
140
+
141
+ Returns:
142
+ Dict chứa cấu hình model
143
+ """
144
+ configs = {
145
+ "phobert-v1": {
146
+ "model_name": "vinai/phobert-base",
147
+ "description": "PhoBERT v1 - Pre-trained BERT for Vietnamese",
148
+ "max_length": 256,
149
+ "learning_rate": 5e-5
150
+ },
151
+ "phobert-v2": {
152
+ "model_name": "vinai/phobert-base-v2",
153
+ "description": "PhoBERT v2 - Improved PhoBERT for Vietnamese",
154
+ "max_length": 256,
155
+ "learning_rate": 5e-5
156
+ },
157
+ "bartpho": {
158
+ "model_name": "vinai/bartpho-syllable",
159
+ "description": "BART Pho - Vietnamese BART model",
160
+ "max_length": 256,
161
+ "learning_rate": 5e-5
162
+ },
163
+ "visobert": {
164
+ "model_name": "uitnlp/visobert",
165
+ "description": "ViSoBERT - Vietnamese Social BERT",
166
+ "max_length": 256,
167
+ "learning_rate": 5e-5
168
+ },
169
+ "xlm-r": {
170
+ "model_name": "xlm-roberta-large",
171
+ "description": "XLM-RoBERTa Large - Multilingual model",
172
+ "max_length": 256,
173
+ "learning_rate": 3e-5
174
+ },
175
+ "mbert": {
176
+ "model_name": "bert-base-multilingual-cased",
177
+ "description": "mBERT - Multilingual BERT model",
178
+ "max_length": 256,
179
+ "learning_rate": 5e-5
180
+ },
181
+ "vit5": {
182
+ "model_name": "VietAI/vit5-base",
183
+ "description": "ViT5 - Vietnamese T5",
184
+ "max_length": 256,
185
+ "learning_rate": 5e-5
186
+ },
187
+ "textcnn": {
188
+ "model_name": "vinai/phobert-base-v2",
189
+ "description": "TextCNN - Convolutional Neural Network for text",
190
+ "max_length": 256,
191
+ "learning_rate": 1e-3,
192
+ "custom_model": True
193
+ },
194
+ "bilstm": {
195
+ "model_name": "vinai/phobert-base-v2",
196
+ "description": "BiLSTM - Bidirectional LSTM for text classification",
197
+ "max_length": 256,
198
+ "learning_rate": 1e-3,
199
+ "custom_model": True
200
+ },
201
+ "roberta-gru": {
202
+ "model_name": "vinai/phobert-base-v2",
203
+ "description": "RoBERTa-GRU - Hybrid RoBERTa + GRU model",
204
+ "max_length": 256,
205
+ "learning_rate": 5e-5,
206
+ "custom_model": True
207
+ },
208
+ "sphobert": {
209
+ "model_name": "vinai/phobert-base-v2",
210
+ "description": "SPhoBERT - PhoBERT + SentenceBERT embedding fusion",
211
+ "max_length": 256,
212
+ "learning_rate": 5e-5,
213
+ "custom_model": True
214
+ },
215
+ "bilstm-crf": {
216
+ "model_name": "vinai/phobert-base-v2",
217
+ "description": "BiLSTM-CRF - Bidirectional LSTM with CRF",
218
+ "max_length": 256,
219
+ "learning_rate": 1e-3,
220
+ "custom_model": True
221
+ }
222
+ }
223
+
224
+ if model_name not in configs:
225
+ raise ValueError(f"Model {model_name} not found. Available models: {list(configs.keys())}")
226
+
227
+ return configs[model_name]