AnnyNguyen commited on
Commit
6985390
·
verified ·
1 Parent(s): cf3d50a

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +546 -0
models.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import (
5
+ AutoModel, AutoConfig, AutoTokenizer,
6
+ T5ForConditionalGeneration, T5Config,
7
+ AutoModelForSequenceClassification,
8
+ PreTrainedModel, PretrainedConfig
9
+ )
10
+ from transformers.modeling_utils import (
11
+ load_state_dict,
12
+ WEIGHTS_NAME,
13
+ SAFE_WEIGHTS_NAME,
14
+ SAFE_WEIGHTS_INDEX_NAME,
15
+ WEIGHTS_INDEX_NAME
16
+ )
17
+ from transformers.utils import (
18
+ is_safetensors_available,
19
+ is_torch_available,
20
+ logging,
21
+ EntryNotFoundError,
22
+ PushToHubMixin
23
+ )
24
+ import os
25
+ import json
26
+ import numpy as np
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ class BaseHateSpeechModel(nn.Module):
31
+ """Base class cho tất cả các mô hình hate speech detection"""
32
+ def __init__(self, model_name: str, num_labels: int = 3):
33
+ super().__init__()
34
+ self.num_labels = num_labels
35
+ self.model_name = model_name
36
+
37
+ def forward(self, input_ids, attention_mask, labels=None):
38
+ raise NotImplementedError
39
+
40
+ @classmethod
41
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
42
+ """
43
+ Load model từ pretrained checkpoint.
44
+ Transformers sẽ tự động load state_dict sau khi khởi tạo model.
45
+ """
46
+ # Extract config từ kwargs (transformers sẽ pass config vào đây)
47
+ config = kwargs.pop("config", None)
48
+
49
+ # Load config nếu chưa có
50
+ if config is None:
51
+ try:
52
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
53
+ except Exception:
54
+ config = {}
55
+
56
+ # Get num_labels từ config hoặc kwargs
57
+ num_labels = kwargs.pop("num_labels", None)
58
+ if num_labels is None:
59
+ if hasattr(config, "num_labels"):
60
+ num_labels = config.num_labels
61
+ elif isinstance(config, dict) and "num_labels" in config:
62
+ num_labels = config["num_labels"]
63
+ else:
64
+ num_labels = 3
65
+
66
+ # Lấy base model name từ config
67
+ base_model_name = None
68
+ if hasattr(config, "_name_or_path"):
69
+ base_model_name = config._name_or_path
70
+ elif isinstance(config, dict) and "_name_or_path" in config:
71
+ base_model_name = config["_name_or_path"]
72
+
73
+ # Khởi tạo model với base model name
74
+ if base_model_name:
75
+ model = cls(model_name=base_model_name, num_labels=num_labels, **kwargs)
76
+ else:
77
+ # Fallback: dùng default model_name từ class
78
+ model = cls(num_labels=num_labels, **kwargs)
79
+
80
+ return model
81
+
82
+ class PhoBERTV2Model(BaseHateSpeechModel):
83
+ """PhoBERT-V2 cho hate speech detection"""
84
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3):
85
+ super().__init__(model_name, num_labels)
86
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
87
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
88
+ self.dropout = nn.Dropout(0.1)
89
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
90
+
91
+ def forward(self, input_ids, attention_mask, labels=None):
92
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
93
+ pooled_output = outputs.pooler_output
94
+ pooled_output = self.dropout(pooled_output)
95
+ logits = self.classifier(pooled_output)
96
+
97
+ loss = None
98
+ if labels is not None:
99
+ loss_fn = nn.CrossEntropyLoss()
100
+ loss = loss_fn(logits, labels)
101
+ return {"loss": loss, "logits": logits}
102
+
103
+ class BartPhoModel(BaseHateSpeechModel):
104
+ """BART Pho cho hate speech detection"""
105
+ def __init__(self, model_name: str = "vinai/bartpho-syllable-base", num_labels: int = 3):
106
+ super().__init__(model_name, num_labels)
107
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
108
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
109
+ self.dropout = nn.Dropout(0.1)
110
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
111
+
112
+ def forward(self, input_ids, attention_mask, labels=None):
113
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
114
+ # Sử dụng hidden state của token cuối cùng
115
+ last_hidden_states = outputs.last_hidden_state
116
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
117
+ pooled_output = self.dropout(pooled_output)
118
+ logits = self.classifier(pooled_output)
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ loss_fn = nn.CrossEntropyLoss()
123
+ loss = loss_fn(logits, labels)
124
+ return {"loss": loss, "logits": logits}
125
+
126
+ class ViSoBERTModel(BaseHateSpeechModel):
127
+ """ViSoBERT cho hate speech detection"""
128
+ def __init__(self, model_name: str = "uitnlp/visobert", num_labels: int = 3):
129
+ super().__init__(model_name, num_labels)
130
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
131
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
132
+ self.dropout = nn.Dropout(0.1)
133
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
134
+
135
+ def forward(self, input_ids, attention_mask, labels=None):
136
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
137
+
138
+ # Kiểm tra xem có pooler_output không, nếu không thì dùng last_hidden_state
139
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
140
+ pooled_output = outputs.pooler_output
141
+ else:
142
+ # Fallback: sử dụng mean pooling của last_hidden_state
143
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
144
+
145
+ pooled_output = self.dropout(pooled_output)
146
+ logits = self.classifier(pooled_output)
147
+
148
+ loss = None
149
+ if labels is not None:
150
+ loss_fn = nn.CrossEntropyLoss()
151
+ loss = loss_fn(logits, labels)
152
+ return {"loss": loss, "logits": logits}
153
+
154
+ class PhoBERTV1Model(BaseHateSpeechModel):
155
+ """PhoBERT-V1 cho hate speech detection"""
156
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
157
+ super().__init__(model_name, num_labels)
158
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
159
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
160
+ self.dropout = nn.Dropout(0.1)
161
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
162
+
163
+ def forward(self, input_ids, attention_mask, labels=None):
164
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
165
+ # Một số encoder không có pooler_output
166
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
167
+ pooled_output = outputs.pooler_output
168
+ else:
169
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
170
+ pooled_output = self.dropout(pooled_output)
171
+ logits = self.classifier(pooled_output)
172
+
173
+ loss = None
174
+ if labels is not None:
175
+ loss_fn = nn.CrossEntropyLoss()
176
+ loss = loss_fn(logits, labels)
177
+ return {"loss": loss, "logits": logits}
178
+
179
+ class MBERTModel(BaseHateSpeechModel):
180
+ """mBERT (bert-base-multilingual-cased) cho hate speech detection"""
181
+ def __init__(self, model_name: str = "bert-base-multilingual-cased", num_labels: int = 3):
182
+ super().__init__(model_name, num_labels)
183
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
184
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
185
+ self.dropout = nn.Dropout(0.1)
186
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
187
+
188
+ def forward(self, input_ids, attention_mask, labels=None):
189
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
190
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
191
+ pooled_output = outputs.pooler_output
192
+ else:
193
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
194
+ pooled_output = self.dropout(pooled_output)
195
+ logits = self.classifier(pooled_output)
196
+
197
+ loss = None
198
+ if labels is not None:
199
+ loss_fn = nn.CrossEntropyLoss()
200
+ loss = loss_fn(logits, labels)
201
+ return {"loss": loss, "logits": logits}
202
+
203
+ class SPhoBERTModel(BaseHateSpeechModel):
204
+ """SPhoBERT (biến thể PhoBERT syllable-level) cho hate speech detection"""
205
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
206
+ super().__init__(model_name, num_labels)
207
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
208
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
209
+ self.dropout = nn.Dropout(0.1)
210
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
211
+
212
+ def forward(self, input_ids, attention_mask, labels=None):
213
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
214
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
215
+ pooled_output = outputs.pooler_output
216
+ else:
217
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
218
+ pooled_output = self.dropout(pooled_output)
219
+ logits = self.classifier(pooled_output)
220
+
221
+ loss = None
222
+ if labels is not None:
223
+ loss_fn = nn.CrossEntropyLoss()
224
+ loss = loss_fn(logits, labels)
225
+ return {"loss": loss, "logits": logits}
226
+
227
+ class ViHateT5Model(BaseHateSpeechModel):
228
+ """ViHateT5 cho hate speech detection"""
229
+ def __init__(self, model_name: str = "VietAI/vit5-base", num_labels: int = 3):
230
+ super().__init__(model_name, num_labels)
231
+ self.config = T5Config.from_pretrained(model_name)
232
+ self.encoder = T5ForConditionalGeneration.from_pretrained(model_name, config=self.config)
233
+ self.dropout = nn.Dropout(0.1)
234
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
235
+
236
+ def forward(self, input_ids, attention_mask, labels=None):
237
+ outputs = self.encoder.encoder(input_ids=input_ids, attention_mask=attention_mask)
238
+ # Sử dụng hidden state của token cuối cùng
239
+ last_hidden_states = outputs.last_hidden_state
240
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
241
+ pooled_output = self.dropout(pooled_output)
242
+ logits = self.classifier(pooled_output)
243
+
244
+ loss = None
245
+ if labels is not None:
246
+ loss_fn = nn.CrossEntropyLoss()
247
+ loss = loss_fn(logits, labels)
248
+ return {"loss": loss, "logits": logits}
249
+
250
+ class XLMRModel(BaseHateSpeechModel):
251
+ """XLM-R Large cho hate speech detection"""
252
+ def __init__(self, model_name: str = "xlm-roberta-large", num_labels: int = 3):
253
+ super().__init__(model_name, num_labels)
254
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
255
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
256
+ self.dropout = nn.Dropout(0.1)
257
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
258
+
259
+ def forward(self, input_ids, attention_mask, labels=None):
260
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
261
+ pooled_output = outputs.pooler_output
262
+ pooled_output = self.dropout(pooled_output)
263
+ logits = self.classifier(pooled_output)
264
+
265
+ loss = None
266
+ if labels is not None:
267
+ loss_fn = nn.CrossEntropyLoss()
268
+ loss = loss_fn(logits, labels)
269
+ return {"loss": loss, "logits": logits}
270
+
271
+ class RoBERTaGRUModel(BaseHateSpeechModel):
272
+ """RoBERTa + GRU Hybrid model"""
273
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3, hidden_size: int = 256):
274
+ super().__init__(model_name, num_labels)
275
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
276
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
277
+ self.gru = nn.GRU(
278
+ input_size=self.config.hidden_size,
279
+ hidden_size=hidden_size,
280
+ num_layers=2,
281
+ batch_first=True,
282
+ dropout=0.1,
283
+ bidirectional=True
284
+ )
285
+ self.dropout = nn.Dropout(0.1)
286
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
287
+
288
+ def forward(self, input_ids, attention_mask, labels=None):
289
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
290
+ hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
291
+
292
+ # GRU processing
293
+ gru_output, _ = self.gru(hidden_states) # [batch_size, seq_len, hidden_size*2]
294
+
295
+ # Global average pooling
296
+ pooled_output = gru_output.mean(dim=1) # [batch_size, hidden_size*2]
297
+ pooled_output = self.dropout(pooled_output)
298
+ logits = self.classifier(pooled_output)
299
+
300
+ loss = None
301
+ if labels is not None:
302
+ loss_fn = nn.CrossEntropyLoss()
303
+ loss = loss_fn(logits, labels)
304
+ return {"loss": loss, "logits": logits}
305
+
306
+ class TextCNNModel(BaseHateSpeechModel):
307
+ """TextCNN cho hate speech detection"""
308
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, num_labels: int = 3,
309
+ num_filters: int = 100, filter_sizes: list = [3, 4, 5], dropout: float = 0.5):
310
+ super().__init__("textcnn", num_labels)
311
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
312
+ self.convs = nn.ModuleList([
313
+ nn.Conv2d(1, num_filters, (filter_size, embedding_dim))
314
+ for filter_size in filter_sizes
315
+ ])
316
+ self.dropout = nn.Dropout(dropout)
317
+ self.classifier = nn.Linear(num_filters * len(filter_sizes), num_labels)
318
+
319
+ @classmethod
320
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
321
+ """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
322
+ # Get vocab_size từ kwargs hoặc config
323
+ vocab_size = kwargs.pop("vocab_size", None)
324
+ config = kwargs.pop("config", None)
325
+
326
+ # Nếu chưa có vocab_size, thử detect từ checkpoint file
327
+ if vocab_size is None:
328
+ import os
329
+ state_dict = None
330
+ # Try to load state_dict từ local path để detect vocab_size
331
+ if os.path.isdir(pretrained_model_name_or_path):
332
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
333
+ try:
334
+ from safetensors.torch import load_file
335
+ state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
336
+ except Exception:
337
+ pass
338
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
339
+ try:
340
+ state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
341
+ except Exception:
342
+ pass
343
+
344
+ # Detect vocab_size từ embedding.weight
345
+ if state_dict is not None and "embedding.weight" in state_dict:
346
+ vocab_size = state_dict["embedding.weight"].shape[0]
347
+ else:
348
+ vocab_size = 30000 # Default
349
+
350
+ # Get num_labels
351
+ num_labels = kwargs.pop("num_labels", None)
352
+ if num_labels is None:
353
+ if config and hasattr(config, "num_labels"):
354
+ num_labels = config.num_labels
355
+ elif config and isinstance(config, dict) and "num_labels" in config:
356
+ num_labels = config["num_labels"]
357
+ else:
358
+ num_labels = 3
359
+
360
+ # Khởi tạo model
361
+ model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
362
+
363
+ return model
364
+
365
+ def forward(self, input_ids, attention_mask, labels=None):
366
+ # Embedding
367
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
368
+
369
+ # Add channel dimension for Conv2d
370
+ embedded = embedded.unsqueeze(1) # [batch_size, 1, seq_len, embedding_dim]
371
+
372
+ # Convolutional layers
373
+ conv_outputs = []
374
+ for conv in self.convs:
375
+ conv_out = F.relu(conv(embedded)) # [batch_size, num_filters, seq_len', 1]
376
+ conv_out = conv_out.squeeze(3) # [batch_size, num_filters, seq_len']
377
+ pooled = F.max_pool1d(conv_out, conv_out.size(2)) # [batch_size, num_filters, 1]
378
+ pooled = pooled.squeeze(2) # [batch_size, num_filters]
379
+ conv_outputs.append(pooled)
380
+
381
+ # Concatenate all conv outputs
382
+ concatenated = torch.cat(conv_outputs, dim=1) # [batch_size, num_filters * len(filter_sizes)]
383
+
384
+ # Classification
385
+ concatenated = self.dropout(concatenated)
386
+ logits = self.classifier(concatenated)
387
+
388
+ loss = None
389
+ if labels is not None:
390
+ loss_fn = nn.CrossEntropyLoss()
391
+ loss = loss_fn(logits, labels)
392
+ return {"loss": loss, "logits": logits}
393
+
394
+ class BiLSTMModel(BaseHateSpeechModel):
395
+ """BiLSTM cho hate speech detection"""
396
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, hidden_size: int = 256,
397
+ num_labels: int = 3, num_layers: int = 2, dropout: float = 0.5):
398
+ super().__init__("bilstm", num_labels)
399
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
400
+ self.lstm = nn.LSTM(
401
+ input_size=embedding_dim,
402
+ hidden_size=hidden_size,
403
+ num_layers=num_layers,
404
+ batch_first=True,
405
+ dropout=dropout if num_layers > 1 else 0,
406
+ bidirectional=True
407
+ )
408
+ self.dropout = nn.Dropout(dropout)
409
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
410
+
411
+ @classmethod
412
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
413
+ """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
414
+ # Get vocab_size từ kwargs hoặc config
415
+ vocab_size = kwargs.pop("vocab_size", None)
416
+ config = kwargs.pop("config", None)
417
+
418
+ # Nếu chưa có vocab_size, thử detect từ checkpoint file
419
+ if vocab_size is None:
420
+ import os
421
+ state_dict = None
422
+ # Try to load state_dict từ local path để detect vocab_size
423
+ if os.path.isdir(pretrained_model_name_or_path):
424
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
425
+ try:
426
+ from safetensors.torch import load_file
427
+ state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
428
+ except Exception:
429
+ pass
430
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
431
+ try:
432
+ state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
433
+ except Exception:
434
+ pass
435
+
436
+ # Detect vocab_size từ embedding.weight
437
+ if state_dict is not None and "embedding.weight" in state_dict:
438
+ vocab_size = state_dict["embedding.weight"].shape[0]
439
+ else:
440
+ vocab_size = 30000 # Default
441
+
442
+ # Get num_labels
443
+ num_labels = kwargs.pop("num_labels", None)
444
+ if num_labels is None:
445
+ if config and hasattr(config, "num_labels"):
446
+ num_labels = config.num_labels
447
+ elif config and isinstance(config, dict) and "num_labels" in config:
448
+ num_labels = config["num_labels"]
449
+ else:
450
+ num_labels = 3
451
+
452
+ # Khởi tạo model
453
+ model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
454
+
455
+ return model
456
+
457
+ def forward(self, input_ids, attention_mask, labels=None):
458
+ # Embedding
459
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
460
+
461
+ # BiLSTM
462
+ lstm_output, (hidden, cell) = self.lstm(embedded) # [batch_size, seq_len, hidden_size*2]
463
+
464
+ # Global average pooling (có thể thay bằng max pooling hoặc last hidden state)
465
+ # Option 1: Global average pooling
466
+ pooled_output = lstm_output.mean(dim=1) # [batch_size, hidden_size*2]
467
+
468
+ # Option 2: Last hidden state (uncomment if preferred)
469
+ # pooled_output = lstm_output[:, -1, :] # [batch_size, hidden_size*2]
470
+
471
+ # Option 3: Max pooling (uncomment if preferred)
472
+ # pooled_output = torch.max(lstm_output, dim=1)[0] # [batch_size, hidden_size*2]
473
+
474
+ pooled_output = self.dropout(pooled_output)
475
+ logits = self.classifier(pooled_output)
476
+
477
+ loss = None
478
+ if labels is not None:
479
+ loss_fn = nn.CrossEntropyLoss()
480
+ loss = loss_fn(logits, labels)
481
+ return {"loss": loss, "logits": logits}
482
+
483
+ class EnsembleModel(BaseHateSpeechModel):
484
+ """Ensemble model kết hợp các mô hình deep learning"""
485
+ def __init__(self, models: list, num_labels: int = 3, weights: list = None):
486
+ super().__init__("ensemble", num_labels)
487
+ self.models = nn.ModuleList(models)
488
+ self.num_models = len(models)
489
+ self.weights = weights if weights else [1.0] * self.num_models
490
+ self.weights = torch.tensor(self.weights, dtype=torch.float32)
491
+ self.weights = self.weights / self.weights.sum() # Normalize weights
492
+
493
+ def forward(self, input_ids, attention_mask, labels=None):
494
+ all_logits = []
495
+ total_loss = 0
496
+
497
+ for i, model in enumerate(self.models):
498
+ model_output = model(input_ids, attention_mask, labels)
499
+ all_logits.append(model_output["logits"])
500
+
501
+ if model_output["loss"] is not None:
502
+ total_loss += self.weights[i] * model_output["loss"]
503
+
504
+ # Weighted average of logits
505
+ ensemble_logits = torch.zeros_like(all_logits[0])
506
+ for i, logits in enumerate(all_logits):
507
+ ensemble_logits += self.weights[i] * logits
508
+
509
+ return {
510
+ "loss": total_loss if total_loss > 0 else None,
511
+ "logits": ensemble_logits
512
+ }
513
+
514
+ def get_model(model_name: str, num_labels: int = 3, **kwargs):
515
+ """
516
+ Factory function để tạo model dựa trên tên
517
+
518
+ Args:
519
+ model_name: Tên model ("phobert-v2", "bartpho", "visobert", "vihate-t5",
520
+ "xlm-r", "roberta-gru", "textcnn", "bilstm", "bilstm-crf", "ensemble")
521
+ num_labels: Số lượng nhãn (3 cho hate speech: CLEAN, OFFENSIVE, HATE)
522
+ **kwargs: Các tham số bổ sung cho model
523
+
524
+ Returns:
525
+ Model instance
526
+ """
527
+ model_mapping = {
528
+ "phobert-v1": PhoBERTV1Model,
529
+ "phobert-v2": PhoBERTV2Model,
530
+ "bartpho": BartPhoModel,
531
+ "visobert": ViSoBERTModel,
532
+ "vihate-t5": ViHateT5Model,
533
+ "xlm-r": XLMRModel,
534
+ "mbert": MBERTModel,
535
+ "sphobert": SPhoBERTModel,
536
+ "roberta-gru": RoBERTaGRUModel,
537
+ "textcnn": TextCNNModel,
538
+ "bilstm": BiLSTMModel,
539
+ "ensemble": EnsembleModel
540
+ }
541
+
542
+ if model_name not in model_mapping:
543
+ raise ValueError(f"Unknown model: {model_name}. Available models: {list(model_mapping.keys())}")
544
+
545
+ model_class = model_mapping[model_name]
546
+ return model_class(num_labels=num_labels, **kwargs)