AnnyNguyen commited on
Commit
aacd51d
·
verified ·
1 Parent(s): 3fc674d

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +559 -0
models.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import (
5
+ AutoModel, AutoConfig, AutoTokenizer,
6
+ T5ForConditionalGeneration, T5Config,
7
+ AutoModelForSequenceClassification,
8
+ PreTrainedModel, PretrainedConfig
9
+ )
10
+ from transformers.modeling_utils import (
11
+ load_state_dict,
12
+ WEIGHTS_NAME,
13
+ SAFE_WEIGHTS_NAME,
14
+ SAFE_WEIGHTS_INDEX_NAME,
15
+ WEIGHTS_INDEX_NAME
16
+ )
17
+ from transformers.utils import (
18
+ is_safetensors_available,
19
+ is_torch_available,
20
+ logging,
21
+ EntryNotFoundError,
22
+ PushToHubMixin
23
+ )
24
+ import os
25
+ import json
26
+ import numpy as np
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ class BaseHateSpeechModel(nn.Module):
31
+ """Base class cho tất cả các mô hình hate speech detection"""
32
+ def __init__(self, model_name: str, num_labels: int = 3):
33
+ super().__init__()
34
+ self.num_labels = num_labels
35
+ self.model_name = model_name
36
+
37
+ def forward(self, input_ids, attention_mask, labels=None):
38
+ raise NotImplementedError
39
+
40
+ def load_state_dict(self, state_dict, strict=True):
41
+ """
42
+ Override load_state_dict để bypass transformers' key renaming.
43
+ Load trực tiếp state_dict vào model mà không qua key mapping.
44
+ """
45
+ # Load trực tiếp, không qua transformers' key renaming
46
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
47
+ if missing_keys and strict:
48
+ logger.warning(f"Missing keys when loading state_dict: {missing_keys}")
49
+ if unexpected_keys:
50
+ logger.warning(f"Unexpected keys when loading state_dict: {unexpected_keys}")
51
+ return missing_keys, unexpected_keys
52
+
53
+ @classmethod
54
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
55
+ """
56
+ Load model từ pretrained checkpoint.
57
+ Transformers sẽ tự động load state_dict sau khi khởi tạo model.
58
+ """
59
+ # Extract config từ kwargs (transformers sẽ pass config vào đây)
60
+ config = kwargs.pop("config", None)
61
+
62
+ # Load config nếu chưa có
63
+ if config is None:
64
+ try:
65
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
66
+ except Exception:
67
+ config = {}
68
+
69
+ # Get num_labels từ config hoặc kwargs
70
+ num_labels = kwargs.pop("num_labels", None)
71
+ if num_labels is None:
72
+ if hasattr(config, "num_labels"):
73
+ num_labels = config.num_labels
74
+ elif isinstance(config, dict) and "num_labels" in config:
75
+ num_labels = config["num_labels"]
76
+ else:
77
+ num_labels = 3
78
+
79
+ # Lấy base model name từ config
80
+ base_model_name = None
81
+ if hasattr(config, "_name_or_path"):
82
+ base_model_name = config._name_or_path
83
+ elif isinstance(config, dict) and "_name_or_path" in config:
84
+ base_model_name = config["_name_or_path"]
85
+
86
+ # Khởi tạo model với base model name
87
+ if base_model_name:
88
+ model = cls(model_name=base_model_name, num_labels=num_labels, **kwargs)
89
+ else:
90
+ # Fallback: dùng default model_name từ class
91
+ model = cls(num_labels=num_labels, **kwargs)
92
+
93
+ return model
94
+
95
+ class PhoBERTV2Model(BaseHateSpeechModel):
96
+ """PhoBERT-V2 cho hate speech detection"""
97
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3):
98
+ super().__init__(model_name, num_labels)
99
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
100
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
101
+ self.dropout = nn.Dropout(0.1)
102
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
103
+
104
+ def forward(self, input_ids, attention_mask, labels=None):
105
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
106
+ pooled_output = outputs.pooler_output
107
+ pooled_output = self.dropout(pooled_output)
108
+ logits = self.classifier(pooled_output)
109
+
110
+ loss = None
111
+ if labels is not None:
112
+ loss_fn = nn.CrossEntropyLoss()
113
+ loss = loss_fn(logits, labels)
114
+ return {"loss": loss, "logits": logits}
115
+
116
+ class BartPhoModel(BaseHateSpeechModel):
117
+ """BART Pho cho hate speech detection"""
118
+ def __init__(self, model_name: str = "vinai/bartpho-syllable-base", num_labels: int = 3):
119
+ super().__init__(model_name, num_labels)
120
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
121
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
122
+ self.dropout = nn.Dropout(0.1)
123
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
124
+
125
+ def forward(self, input_ids, attention_mask, labels=None):
126
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
127
+ # Sử dụng hidden state của token cuối cùng
128
+ last_hidden_states = outputs.last_hidden_state
129
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
130
+ pooled_output = self.dropout(pooled_output)
131
+ logits = self.classifier(pooled_output)
132
+
133
+ loss = None
134
+ if labels is not None:
135
+ loss_fn = nn.CrossEntropyLoss()
136
+ loss = loss_fn(logits, labels)
137
+ return {"loss": loss, "logits": logits}
138
+
139
+ class ViSoBERTModel(BaseHateSpeechModel):
140
+ """ViSoBERT cho hate speech detection"""
141
+ def __init__(self, model_name: str = "uitnlp/visobert", num_labels: int = 3):
142
+ super().__init__(model_name, num_labels)
143
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
144
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
145
+ self.dropout = nn.Dropout(0.1)
146
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
147
+
148
+ def forward(self, input_ids, attention_mask, labels=None):
149
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
150
+
151
+ # Kiểm tra xem có pooler_output không, nếu không thì dùng last_hidden_state
152
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
153
+ pooled_output = outputs.pooler_output
154
+ else:
155
+ # Fallback: sử dụng mean pooling của last_hidden_state
156
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
157
+
158
+ pooled_output = self.dropout(pooled_output)
159
+ logits = self.classifier(pooled_output)
160
+
161
+ loss = None
162
+ if labels is not None:
163
+ loss_fn = nn.CrossEntropyLoss()
164
+ loss = loss_fn(logits, labels)
165
+ return {"loss": loss, "logits": logits}
166
+
167
+ class PhoBERTV1Model(BaseHateSpeechModel):
168
+ """PhoBERT-V1 cho hate speech detection"""
169
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
170
+ super().__init__(model_name, num_labels)
171
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
172
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
173
+ self.dropout = nn.Dropout(0.1)
174
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
175
+
176
+ def forward(self, input_ids, attention_mask, labels=None):
177
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
178
+ # Một số encoder không có pooler_output
179
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
180
+ pooled_output = outputs.pooler_output
181
+ else:
182
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
183
+ pooled_output = self.dropout(pooled_output)
184
+ logits = self.classifier(pooled_output)
185
+
186
+ loss = None
187
+ if labels is not None:
188
+ loss_fn = nn.CrossEntropyLoss()
189
+ loss = loss_fn(logits, labels)
190
+ return {"loss": loss, "logits": logits}
191
+
192
+ class MBERTModel(BaseHateSpeechModel):
193
+ """mBERT (bert-base-multilingual-cased) cho hate speech detection"""
194
+ def __init__(self, model_name: str = "bert-base-multilingual-cased", num_labels: int = 3):
195
+ super().__init__(model_name, num_labels)
196
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
197
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
198
+ self.dropout = nn.Dropout(0.1)
199
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
200
+
201
+ def forward(self, input_ids, attention_mask, labels=None):
202
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
203
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
204
+ pooled_output = outputs.pooler_output
205
+ else:
206
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
207
+ pooled_output = self.dropout(pooled_output)
208
+ logits = self.classifier(pooled_output)
209
+
210
+ loss = None
211
+ if labels is not None:
212
+ loss_fn = nn.CrossEntropyLoss()
213
+ loss = loss_fn(logits, labels)
214
+ return {"loss": loss, "logits": logits}
215
+
216
+ class SPhoBERTModel(BaseHateSpeechModel):
217
+ """SPhoBERT (biến thể PhoBERT syllable-level) cho hate speech detection"""
218
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
219
+ super().__init__(model_name, num_labels)
220
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
221
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
222
+ self.dropout = nn.Dropout(0.1)
223
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
224
+
225
+ def forward(self, input_ids, attention_mask, labels=None):
226
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
227
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
228
+ pooled_output = outputs.pooler_output
229
+ else:
230
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
231
+ pooled_output = self.dropout(pooled_output)
232
+ logits = self.classifier(pooled_output)
233
+
234
+ loss = None
235
+ if labels is not None:
236
+ loss_fn = nn.CrossEntropyLoss()
237
+ loss = loss_fn(logits, labels)
238
+ return {"loss": loss, "logits": logits}
239
+
240
+ class ViHateT5Model(BaseHateSpeechModel):
241
+ """ViHateT5 cho hate speech detection"""
242
+ def __init__(self, model_name: str = "VietAI/vit5-base", num_labels: int = 3):
243
+ super().__init__(model_name, num_labels)
244
+ self.config = T5Config.from_pretrained(model_name)
245
+ self.encoder = T5ForConditionalGeneration.from_pretrained(model_name, config=self.config)
246
+ self.dropout = nn.Dropout(0.1)
247
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
248
+
249
+ def forward(self, input_ids, attention_mask, labels=None):
250
+ outputs = self.encoder.encoder(input_ids=input_ids, attention_mask=attention_mask)
251
+ # Sử dụng hidden state của token cuối cùng
252
+ last_hidden_states = outputs.last_hidden_state
253
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
254
+ pooled_output = self.dropout(pooled_output)
255
+ logits = self.classifier(pooled_output)
256
+
257
+ loss = None
258
+ if labels is not None:
259
+ loss_fn = nn.CrossEntropyLoss()
260
+ loss = loss_fn(logits, labels)
261
+ return {"loss": loss, "logits": logits}
262
+
263
+ class XLMRModel(BaseHateSpeechModel):
264
+ """XLM-R Large cho hate speech detection"""
265
+ def __init__(self, model_name: str = "xlm-roberta-large", num_labels: int = 3):
266
+ super().__init__(model_name, num_labels)
267
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
268
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
269
+ self.dropout = nn.Dropout(0.1)
270
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
271
+
272
+ def forward(self, input_ids, attention_mask, labels=None):
273
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
274
+ pooled_output = outputs.pooler_output
275
+ pooled_output = self.dropout(pooled_output)
276
+ logits = self.classifier(pooled_output)
277
+
278
+ loss = None
279
+ if labels is not None:
280
+ loss_fn = nn.CrossEntropyLoss()
281
+ loss = loss_fn(logits, labels)
282
+ return {"loss": loss, "logits": logits}
283
+
284
+ class RoBERTaGRUModel(BaseHateSpeechModel):
285
+ """RoBERTa + GRU Hybrid model"""
286
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3, hidden_size: int = 256):
287
+ super().__init__(model_name, num_labels)
288
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
289
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
290
+ self.gru = nn.GRU(
291
+ input_size=self.config.hidden_size,
292
+ hidden_size=hidden_size,
293
+ num_layers=2,
294
+ batch_first=True,
295
+ dropout=0.1,
296
+ bidirectional=True
297
+ )
298
+ self.dropout = nn.Dropout(0.1)
299
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
300
+
301
+ def forward(self, input_ids, attention_mask, labels=None):
302
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
303
+ hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
304
+
305
+ # GRU processing
306
+ gru_output, _ = self.gru(hidden_states) # [batch_size, seq_len, hidden_size*2]
307
+
308
+ # Global average pooling
309
+ pooled_output = gru_output.mean(dim=1) # [batch_size, hidden_size*2]
310
+ pooled_output = self.dropout(pooled_output)
311
+ logits = self.classifier(pooled_output)
312
+
313
+ loss = None
314
+ if labels is not None:
315
+ loss_fn = nn.CrossEntropyLoss()
316
+ loss = loss_fn(logits, labels)
317
+ return {"loss": loss, "logits": logits}
318
+
319
+ class TextCNNModel(BaseHateSpeechModel):
320
+ """TextCNN cho hate speech detection"""
321
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, num_labels: int = 3,
322
+ num_filters: int = 100, filter_sizes: list = [3, 4, 5], dropout: float = 0.5):
323
+ super().__init__("textcnn", num_labels)
324
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
325
+ self.convs = nn.ModuleList([
326
+ nn.Conv2d(1, num_filters, (filter_size, embedding_dim))
327
+ for filter_size in filter_sizes
328
+ ])
329
+ self.dropout = nn.Dropout(dropout)
330
+ self.classifier = nn.Linear(num_filters * len(filter_sizes), num_labels)
331
+
332
+ @classmethod
333
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
334
+ """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
335
+ # Get vocab_size từ kwargs hoặc config
336
+ vocab_size = kwargs.pop("vocab_size", None)
337
+ config = kwargs.pop("config", None)
338
+
339
+ # Nếu chưa có vocab_size, thử detect từ checkpoint file
340
+ if vocab_size is None:
341
+ import os
342
+ state_dict = None
343
+ # Try to load state_dict từ local path để detect vocab_size
344
+ if os.path.isdir(pretrained_model_name_or_path):
345
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
346
+ try:
347
+ from safetensors.torch import load_file
348
+ state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
349
+ except Exception:
350
+ pass
351
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
352
+ try:
353
+ state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
354
+ except Exception:
355
+ pass
356
+
357
+ # Detect vocab_size từ embedding.weight
358
+ if state_dict is not None and "embedding.weight" in state_dict:
359
+ vocab_size = state_dict["embedding.weight"].shape[0]
360
+ else:
361
+ vocab_size = 30000 # Default
362
+
363
+ # Get num_labels
364
+ num_labels = kwargs.pop("num_labels", None)
365
+ if num_labels is None:
366
+ if config and hasattr(config, "num_labels"):
367
+ num_labels = config.num_labels
368
+ elif config and isinstance(config, dict) and "num_labels" in config:
369
+ num_labels = config["num_labels"]
370
+ else:
371
+ num_labels = 3
372
+
373
+ # Khởi tạo model
374
+ model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
375
+
376
+ return model
377
+
378
+ def forward(self, input_ids, attention_mask, labels=None):
379
+ # Embedding
380
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
381
+
382
+ # Add channel dimension for Conv2d
383
+ embedded = embedded.unsqueeze(1) # [batch_size, 1, seq_len, embedding_dim]
384
+
385
+ # Convolutional layers
386
+ conv_outputs = []
387
+ for conv in self.convs:
388
+ conv_out = F.relu(conv(embedded)) # [batch_size, num_filters, seq_len', 1]
389
+ conv_out = conv_out.squeeze(3) # [batch_size, num_filters, seq_len']
390
+ pooled = F.max_pool1d(conv_out, conv_out.size(2)) # [batch_size, num_filters, 1]
391
+ pooled = pooled.squeeze(2) # [batch_size, num_filters]
392
+ conv_outputs.append(pooled)
393
+
394
+ # Concatenate all conv outputs
395
+ concatenated = torch.cat(conv_outputs, dim=1) # [batch_size, num_filters * len(filter_sizes)]
396
+
397
+ # Classification
398
+ concatenated = self.dropout(concatenated)
399
+ logits = self.classifier(concatenated)
400
+
401
+ loss = None
402
+ if labels is not None:
403
+ loss_fn = nn.CrossEntropyLoss()
404
+ loss = loss_fn(logits, labels)
405
+ return {"loss": loss, "logits": logits}
406
+
407
+ class BiLSTMModel(BaseHateSpeechModel):
408
+ """BiLSTM cho hate speech detection"""
409
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, hidden_size: int = 256,
410
+ num_labels: int = 3, num_layers: int = 2, dropout: float = 0.5):
411
+ super().__init__("bilstm", num_labels)
412
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
413
+ self.lstm = nn.LSTM(
414
+ input_size=embedding_dim,
415
+ hidden_size=hidden_size,
416
+ num_layers=num_layers,
417
+ batch_first=True,
418
+ dropout=dropout if num_layers > 1 else 0,
419
+ bidirectional=True
420
+ )
421
+ self.dropout = nn.Dropout(dropout)
422
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
423
+
424
+ @classmethod
425
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
426
+ """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
427
+ # Get vocab_size từ kwargs hoặc config
428
+ vocab_size = kwargs.pop("vocab_size", None)
429
+ config = kwargs.pop("config", None)
430
+
431
+ # Nếu chưa có vocab_size, thử detect từ checkpoint file
432
+ if vocab_size is None:
433
+ import os
434
+ state_dict = None
435
+ # Try to load state_dict từ local path để detect vocab_size
436
+ if os.path.isdir(pretrained_model_name_or_path):
437
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
438
+ try:
439
+ from safetensors.torch import load_file
440
+ state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
441
+ except Exception:
442
+ pass
443
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
444
+ try:
445
+ state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
446
+ except Exception:
447
+ pass
448
+
449
+ # Detect vocab_size từ embedding.weight
450
+ if state_dict is not None and "embedding.weight" in state_dict:
451
+ vocab_size = state_dict["embedding.weight"].shape[0]
452
+ else:
453
+ vocab_size = 30000 # Default
454
+
455
+ # Get num_labels
456
+ num_labels = kwargs.pop("num_labels", None)
457
+ if num_labels is None:
458
+ if config and hasattr(config, "num_labels"):
459
+ num_labels = config.num_labels
460
+ elif config and isinstance(config, dict) and "num_labels" in config:
461
+ num_labels = config["num_labels"]
462
+ else:
463
+ num_labels = 3
464
+
465
+ # Khởi tạo model
466
+ model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
467
+
468
+ return model
469
+
470
+ def forward(self, input_ids, attention_mask, labels=None):
471
+ # Embedding
472
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
473
+
474
+ # BiLSTM
475
+ lstm_output, (hidden, cell) = self.lstm(embedded) # [batch_size, seq_len, hidden_size*2]
476
+
477
+ # Global average pooling (có thể thay bằng max pooling hoặc last hidden state)
478
+ # Option 1: Global average pooling
479
+ pooled_output = lstm_output.mean(dim=1) # [batch_size, hidden_size*2]
480
+
481
+ # Option 2: Last hidden state (uncomment if preferred)
482
+ # pooled_output = lstm_output[:, -1, :] # [batch_size, hidden_size*2]
483
+
484
+ # Option 3: Max pooling (uncomment if preferred)
485
+ # pooled_output = torch.max(lstm_output, dim=1)[0] # [batch_size, hidden_size*2]
486
+
487
+ pooled_output = self.dropout(pooled_output)
488
+ logits = self.classifier(pooled_output)
489
+
490
+ loss = None
491
+ if labels is not None:
492
+ loss_fn = nn.CrossEntropyLoss()
493
+ loss = loss_fn(logits, labels)
494
+ return {"loss": loss, "logits": logits}
495
+
496
+ class EnsembleModel(BaseHateSpeechModel):
497
+ """Ensemble model kết hợp các mô hình deep learning"""
498
+ def __init__(self, models: list, num_labels: int = 3, weights: list = None):
499
+ super().__init__("ensemble", num_labels)
500
+ self.models = nn.ModuleList(models)
501
+ self.num_models = len(models)
502
+ self.weights = weights if weights else [1.0] * self.num_models
503
+ self.weights = torch.tensor(self.weights, dtype=torch.float32)
504
+ self.weights = self.weights / self.weights.sum() # Normalize weights
505
+
506
+ def forward(self, input_ids, attention_mask, labels=None):
507
+ all_logits = []
508
+ total_loss = 0
509
+
510
+ for i, model in enumerate(self.models):
511
+ model_output = model(input_ids, attention_mask, labels)
512
+ all_logits.append(model_output["logits"])
513
+
514
+ if model_output["loss"] is not None:
515
+ total_loss += self.weights[i] * model_output["loss"]
516
+
517
+ # Weighted average of logits
518
+ ensemble_logits = torch.zeros_like(all_logits[0])
519
+ for i, logits in enumerate(all_logits):
520
+ ensemble_logits += self.weights[i] * logits
521
+
522
+ return {
523
+ "loss": total_loss if total_loss > 0 else None,
524
+ "logits": ensemble_logits
525
+ }
526
+
527
+ def get_model(model_name: str, num_labels: int = 3, **kwargs):
528
+ """
529
+ Factory function để tạo model dựa trên tên
530
+
531
+ Args:
532
+ model_name: Tên model ("phobert-v2", "bartpho", "visobert", "vihate-t5",
533
+ "xlm-r", "roberta-gru", "textcnn", "bilstm", "bilstm-crf", "ensemble")
534
+ num_labels: Số lượng nhãn (3 cho hate speech: CLEAN, OFFENSIVE, HATE)
535
+ **kwargs: Các tham số bổ sung cho model
536
+
537
+ Returns:
538
+ Model instance
539
+ """
540
+ model_mapping = {
541
+ "phobert-v1": PhoBERTV1Model,
542
+ "phobert-v2": PhoBERTV2Model,
543
+ "bartpho": BartPhoModel,
544
+ "visobert": ViSoBERTModel,
545
+ "vihate-t5": ViHateT5Model,
546
+ "xlm-r": XLMRModel,
547
+ "mbert": MBERTModel,
548
+ "sphobert": SPhoBERTModel,
549
+ "roberta-gru": RoBERTaGRUModel,
550
+ "textcnn": TextCNNModel,
551
+ "bilstm": BiLSTMModel,
552
+ "ensemble": EnsembleModel
553
+ }
554
+
555
+ if model_name not in model_mapping:
556
+ raise ValueError(f"Unknown model: {model_name}. Available models: {list(model_mapping.keys())}")
557
+
558
+ model_class = model_mapping[model_name]
559
+ return model_class(num_labels=num_labels, **kwargs)