AnnyNguyen commited on
Commit
0dcc5d0
·
verified ·
1 Parent(s): b8c313f

Delete models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +0 -546
models.py DELETED
@@ -1,546 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import (
5
- AutoModel, AutoConfig, AutoTokenizer,
6
- T5ForConditionalGeneration, T5Config,
7
- AutoModelForSequenceClassification,
8
- PreTrainedModel, PretrainedConfig
9
- )
10
- from transformers.modeling_utils import (
11
- load_state_dict,
12
- WEIGHTS_NAME,
13
- SAFE_WEIGHTS_NAME,
14
- SAFE_WEIGHTS_INDEX_NAME,
15
- WEIGHTS_INDEX_NAME
16
- )
17
- from transformers.utils import (
18
- is_safetensors_available,
19
- is_torch_available,
20
- logging,
21
- EntryNotFoundError,
22
- PushToHubMixin
23
- )
24
- import os
25
- import json
26
- import numpy as np
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
- class BaseHateSpeechModel(nn.Module):
31
- """Base class cho tất cả các mô hình hate speech detection"""
32
- def __init__(self, model_name: str, num_labels: int = 3):
33
- super().__init__()
34
- self.num_labels = num_labels
35
- self.model_name = model_name
36
-
37
- def forward(self, input_ids, attention_mask, labels=None):
38
- raise NotImplementedError
39
-
40
- @classmethod
41
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
42
- """
43
- Load model từ pretrained checkpoint.
44
- Transformers sẽ tự động load state_dict sau khi khởi tạo model.
45
- """
46
- # Extract config từ kwargs (transformers sẽ pass config vào đây)
47
- config = kwargs.pop("config", None)
48
-
49
- # Load config nếu chưa có
50
- if config is None:
51
- try:
52
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
53
- except Exception:
54
- config = {}
55
-
56
- # Get num_labels từ config hoặc kwargs
57
- num_labels = kwargs.pop("num_labels", None)
58
- if num_labels is None:
59
- if hasattr(config, "num_labels"):
60
- num_labels = config.num_labels
61
- elif isinstance(config, dict) and "num_labels" in config:
62
- num_labels = config["num_labels"]
63
- else:
64
- num_labels = 3
65
-
66
- # Lấy base model name từ config
67
- base_model_name = None
68
- if hasattr(config, "_name_or_path"):
69
- base_model_name = config._name_or_path
70
- elif isinstance(config, dict) and "_name_or_path" in config:
71
- base_model_name = config["_name_or_path"]
72
-
73
- # Khởi tạo model với base model name
74
- if base_model_name:
75
- model = cls(model_name=base_model_name, num_labels=num_labels, **kwargs)
76
- else:
77
- # Fallback: dùng default model_name từ class
78
- model = cls(num_labels=num_labels, **kwargs)
79
-
80
- return model
81
-
82
- class PhoBERTV2Model(BaseHateSpeechModel):
83
- """PhoBERT-V2 cho hate speech detection"""
84
- def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3):
85
- super().__init__(model_name, num_labels)
86
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
87
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
88
- self.dropout = nn.Dropout(0.1)
89
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
90
-
91
- def forward(self, input_ids, attention_mask, labels=None):
92
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
93
- pooled_output = outputs.pooler_output
94
- pooled_output = self.dropout(pooled_output)
95
- logits = self.classifier(pooled_output)
96
-
97
- loss = None
98
- if labels is not None:
99
- loss_fn = nn.CrossEntropyLoss()
100
- loss = loss_fn(logits, labels)
101
- return {"loss": loss, "logits": logits}
102
-
103
- class BartPhoModel(BaseHateSpeechModel):
104
- """BART Pho cho hate speech detection"""
105
- def __init__(self, model_name: str = "vinai/bartpho-syllable-base", num_labels: int = 3):
106
- super().__init__(model_name, num_labels)
107
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
108
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
109
- self.dropout = nn.Dropout(0.1)
110
- self.classifier = nn.Linear(self.config.d_model, num_labels)
111
-
112
- def forward(self, input_ids, attention_mask, labels=None):
113
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
114
- # Sử dụng hidden state của token cuối cùng
115
- last_hidden_states = outputs.last_hidden_state
116
- pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
117
- pooled_output = self.dropout(pooled_output)
118
- logits = self.classifier(pooled_output)
119
-
120
- loss = None
121
- if labels is not None:
122
- loss_fn = nn.CrossEntropyLoss()
123
- loss = loss_fn(logits, labels)
124
- return {"loss": loss, "logits": logits}
125
-
126
- class ViSoBERTModel(BaseHateSpeechModel):
127
- """ViSoBERT cho hate speech detection"""
128
- def __init__(self, model_name: str = "uitnlp/visobert", num_labels: int = 3):
129
- super().__init__(model_name, num_labels)
130
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
131
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
132
- self.dropout = nn.Dropout(0.1)
133
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
134
-
135
- def forward(self, input_ids, attention_mask, labels=None):
136
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
137
-
138
- # Kiểm tra xem có pooler_output không, nếu không thì dùng last_hidden_state
139
- if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
140
- pooled_output = outputs.pooler_output
141
- else:
142
- # Fallback: sử dụng mean pooling của last_hidden_state
143
- pooled_output = outputs.last_hidden_state.mean(dim=1)
144
-
145
- pooled_output = self.dropout(pooled_output)
146
- logits = self.classifier(pooled_output)
147
-
148
- loss = None
149
- if labels is not None:
150
- loss_fn = nn.CrossEntropyLoss()
151
- loss = loss_fn(logits, labels)
152
- return {"loss": loss, "logits": logits}
153
-
154
- class PhoBERTV1Model(BaseHateSpeechModel):
155
- """PhoBERT-V1 cho hate speech detection"""
156
- def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
157
- super().__init__(model_name, num_labels)
158
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
159
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
160
- self.dropout = nn.Dropout(0.1)
161
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
162
-
163
- def forward(self, input_ids, attention_mask, labels=None):
164
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
165
- # Một số encoder không có pooler_output
166
- if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
167
- pooled_output = outputs.pooler_output
168
- else:
169
- pooled_output = outputs.last_hidden_state.mean(dim=1)
170
- pooled_output = self.dropout(pooled_output)
171
- logits = self.classifier(pooled_output)
172
-
173
- loss = None
174
- if labels is not None:
175
- loss_fn = nn.CrossEntropyLoss()
176
- loss = loss_fn(logits, labels)
177
- return {"loss": loss, "logits": logits}
178
-
179
- class MBERTModel(BaseHateSpeechModel):
180
- """mBERT (bert-base-multilingual-cased) cho hate speech detection"""
181
- def __init__(self, model_name: str = "bert-base-multilingual-cased", num_labels: int = 3):
182
- super().__init__(model_name, num_labels)
183
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
184
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
185
- self.dropout = nn.Dropout(0.1)
186
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
187
-
188
- def forward(self, input_ids, attention_mask, labels=None):
189
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
190
- if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
191
- pooled_output = outputs.pooler_output
192
- else:
193
- pooled_output = outputs.last_hidden_state.mean(dim=1)
194
- pooled_output = self.dropout(pooled_output)
195
- logits = self.classifier(pooled_output)
196
-
197
- loss = None
198
- if labels is not None:
199
- loss_fn = nn.CrossEntropyLoss()
200
- loss = loss_fn(logits, labels)
201
- return {"loss": loss, "logits": logits}
202
-
203
- class SPhoBERTModel(BaseHateSpeechModel):
204
- """SPhoBERT (biến thể PhoBERT syllable-level) cho hate speech detection"""
205
- def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
206
- super().__init__(model_name, num_labels)
207
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
208
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
209
- self.dropout = nn.Dropout(0.1)
210
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
211
-
212
- def forward(self, input_ids, attention_mask, labels=None):
213
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
214
- if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
215
- pooled_output = outputs.pooler_output
216
- else:
217
- pooled_output = outputs.last_hidden_state.mean(dim=1)
218
- pooled_output = self.dropout(pooled_output)
219
- logits = self.classifier(pooled_output)
220
-
221
- loss = None
222
- if labels is not None:
223
- loss_fn = nn.CrossEntropyLoss()
224
- loss = loss_fn(logits, labels)
225
- return {"loss": loss, "logits": logits}
226
-
227
- class ViHateT5Model(BaseHateSpeechModel):
228
- """ViHateT5 cho hate speech detection"""
229
- def __init__(self, model_name: str = "VietAI/vit5-base", num_labels: int = 3):
230
- super().__init__(model_name, num_labels)
231
- self.config = T5Config.from_pretrained(model_name)
232
- self.encoder = T5ForConditionalGeneration.from_pretrained(model_name, config=self.config)
233
- self.dropout = nn.Dropout(0.1)
234
- self.classifier = nn.Linear(self.config.d_model, num_labels)
235
-
236
- def forward(self, input_ids, attention_mask, labels=None):
237
- outputs = self.encoder.encoder(input_ids=input_ids, attention_mask=attention_mask)
238
- # Sử dụng hidden state của token cuối cùng
239
- last_hidden_states = outputs.last_hidden_state
240
- pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
241
- pooled_output = self.dropout(pooled_output)
242
- logits = self.classifier(pooled_output)
243
-
244
- loss = None
245
- if labels is not None:
246
- loss_fn = nn.CrossEntropyLoss()
247
- loss = loss_fn(logits, labels)
248
- return {"loss": loss, "logits": logits}
249
-
250
- class XLMRModel(BaseHateSpeechModel):
251
- """XLM-R Large cho hate speech detection"""
252
- def __init__(self, model_name: str = "xlm-roberta-large", num_labels: int = 3):
253
- super().__init__(model_name, num_labels)
254
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
255
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
256
- self.dropout = nn.Dropout(0.1)
257
- self.classifier = nn.Linear(self.config.hidden_size, num_labels)
258
-
259
- def forward(self, input_ids, attention_mask, labels=None):
260
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
261
- pooled_output = outputs.pooler_output
262
- pooled_output = self.dropout(pooled_output)
263
- logits = self.classifier(pooled_output)
264
-
265
- loss = None
266
- if labels is not None:
267
- loss_fn = nn.CrossEntropyLoss()
268
- loss = loss_fn(logits, labels)
269
- return {"loss": loss, "logits": logits}
270
-
271
- class RoBERTaGRUModel(BaseHateSpeechModel):
272
- """RoBERTa + GRU Hybrid model"""
273
- def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3, hidden_size: int = 256):
274
- super().__init__(model_name, num_labels)
275
- self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
276
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
277
- self.gru = nn.GRU(
278
- input_size=self.config.hidden_size,
279
- hidden_size=hidden_size,
280
- num_layers=2,
281
- batch_first=True,
282
- dropout=0.1,
283
- bidirectional=True
284
- )
285
- self.dropout = nn.Dropout(0.1)
286
- self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
287
-
288
- def forward(self, input_ids, attention_mask, labels=None):
289
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
290
- hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
291
-
292
- # GRU processing
293
- gru_output, _ = self.gru(hidden_states) # [batch_size, seq_len, hidden_size*2]
294
-
295
- # Global average pooling
296
- pooled_output = gru_output.mean(dim=1) # [batch_size, hidden_size*2]
297
- pooled_output = self.dropout(pooled_output)
298
- logits = self.classifier(pooled_output)
299
-
300
- loss = None
301
- if labels is not None:
302
- loss_fn = nn.CrossEntropyLoss()
303
- loss = loss_fn(logits, labels)
304
- return {"loss": loss, "logits": logits}
305
-
306
- class TextCNNModel(BaseHateSpeechModel):
307
- """TextCNN cho hate speech detection"""
308
- def __init__(self, vocab_size: int, embedding_dim: int = 128, num_labels: int = 3,
309
- num_filters: int = 100, filter_sizes: list = [3, 4, 5], dropout: float = 0.5):
310
- super().__init__("textcnn", num_labels)
311
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
312
- self.convs = nn.ModuleList([
313
- nn.Conv2d(1, num_filters, (filter_size, embedding_dim))
314
- for filter_size in filter_sizes
315
- ])
316
- self.dropout = nn.Dropout(dropout)
317
- self.classifier = nn.Linear(num_filters * len(filter_sizes), num_labels)
318
-
319
- @classmethod
320
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
321
- """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
322
- # Get vocab_size từ kwargs hoặc config
323
- vocab_size = kwargs.pop("vocab_size", None)
324
- config = kwargs.pop("config", None)
325
-
326
- # Nếu chưa có vocab_size, thử detect từ checkpoint file
327
- if vocab_size is None:
328
- import os
329
- state_dict = None
330
- # Try to load state_dict từ local path để detect vocab_size
331
- if os.path.isdir(pretrained_model_name_or_path):
332
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
333
- try:
334
- from safetensors.torch import load_file
335
- state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
336
- except Exception:
337
- pass
338
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
339
- try:
340
- state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
341
- except Exception:
342
- pass
343
-
344
- # Detect vocab_size từ embedding.weight
345
- if state_dict is not None and "embedding.weight" in state_dict:
346
- vocab_size = state_dict["embedding.weight"].shape[0]
347
- else:
348
- vocab_size = 30000 # Default
349
-
350
- # Get num_labels
351
- num_labels = kwargs.pop("num_labels", None)
352
- if num_labels is None:
353
- if config and hasattr(config, "num_labels"):
354
- num_labels = config.num_labels
355
- elif config and isinstance(config, dict) and "num_labels" in config:
356
- num_labels = config["num_labels"]
357
- else:
358
- num_labels = 3
359
-
360
- # Khởi tạo model
361
- model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
362
-
363
- return model
364
-
365
- def forward(self, input_ids, attention_mask, labels=None):
366
- # Embedding
367
- embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
368
-
369
- # Add channel dimension for Conv2d
370
- embedded = embedded.unsqueeze(1) # [batch_size, 1, seq_len, embedding_dim]
371
-
372
- # Convolutional layers
373
- conv_outputs = []
374
- for conv in self.convs:
375
- conv_out = F.relu(conv(embedded)) # [batch_size, num_filters, seq_len', 1]
376
- conv_out = conv_out.squeeze(3) # [batch_size, num_filters, seq_len']
377
- pooled = F.max_pool1d(conv_out, conv_out.size(2)) # [batch_size, num_filters, 1]
378
- pooled = pooled.squeeze(2) # [batch_size, num_filters]
379
- conv_outputs.append(pooled)
380
-
381
- # Concatenate all conv outputs
382
- concatenated = torch.cat(conv_outputs, dim=1) # [batch_size, num_filters * len(filter_sizes)]
383
-
384
- # Classification
385
- concatenated = self.dropout(concatenated)
386
- logits = self.classifier(concatenated)
387
-
388
- loss = None
389
- if labels is not None:
390
- loss_fn = nn.CrossEntropyLoss()
391
- loss = loss_fn(logits, labels)
392
- return {"loss": loss, "logits": logits}
393
-
394
- class BiLSTMModel(BaseHateSpeechModel):
395
- """BiLSTM cho hate speech detection"""
396
- def __init__(self, vocab_size: int, embedding_dim: int = 128, hidden_size: int = 256,
397
- num_labels: int = 3, num_layers: int = 2, dropout: float = 0.5):
398
- super().__init__("bilstm", num_labels)
399
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
400
- self.lstm = nn.LSTM(
401
- input_size=embedding_dim,
402
- hidden_size=hidden_size,
403
- num_layers=num_layers,
404
- batch_first=True,
405
- dropout=dropout if num_layers > 1 else 0,
406
- bidirectional=True
407
- )
408
- self.dropout = nn.Dropout(dropout)
409
- self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
410
-
411
- @classmethod
412
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
413
- """Override để detect vocab_size từ state_dict hoặc checkpoint file"""
414
- # Get vocab_size từ kwargs hoặc config
415
- vocab_size = kwargs.pop("vocab_size", None)
416
- config = kwargs.pop("config", None)
417
-
418
- # Nếu chưa có vocab_size, thử detect từ checkpoint file
419
- if vocab_size is None:
420
- import os
421
- state_dict = None
422
- # Try to load state_dict từ local path để detect vocab_size
423
- if os.path.isdir(pretrained_model_name_or_path):
424
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)):
425
- try:
426
- from safetensors.torch import load_file
427
- state_dict = load_file(os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME))
428
- except Exception:
429
- pass
430
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
431
- try:
432
- state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
433
- except Exception:
434
- pass
435
-
436
- # Detect vocab_size từ embedding.weight
437
- if state_dict is not None and "embedding.weight" in state_dict:
438
- vocab_size = state_dict["embedding.weight"].shape[0]
439
- else:
440
- vocab_size = 30000 # Default
441
-
442
- # Get num_labels
443
- num_labels = kwargs.pop("num_labels", None)
444
- if num_labels is None:
445
- if config and hasattr(config, "num_labels"):
446
- num_labels = config.num_labels
447
- elif config and isinstance(config, dict) and "num_labels" in config:
448
- num_labels = config["num_labels"]
449
- else:
450
- num_labels = 3
451
-
452
- # Khởi tạo model
453
- model = cls(vocab_size=vocab_size, num_labels=num_labels, **kwargs)
454
-
455
- return model
456
-
457
- def forward(self, input_ids, attention_mask, labels=None):
458
- # Embedding
459
- embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
460
-
461
- # BiLSTM
462
- lstm_output, (hidden, cell) = self.lstm(embedded) # [batch_size, seq_len, hidden_size*2]
463
-
464
- # Global average pooling (có thể thay bằng max pooling hoặc last hidden state)
465
- # Option 1: Global average pooling
466
- pooled_output = lstm_output.mean(dim=1) # [batch_size, hidden_size*2]
467
-
468
- # Option 2: Last hidden state (uncomment if preferred)
469
- # pooled_output = lstm_output[:, -1, :] # [batch_size, hidden_size*2]
470
-
471
- # Option 3: Max pooling (uncomment if preferred)
472
- # pooled_output = torch.max(lstm_output, dim=1)[0] # [batch_size, hidden_size*2]
473
-
474
- pooled_output = self.dropout(pooled_output)
475
- logits = self.classifier(pooled_output)
476
-
477
- loss = None
478
- if labels is not None:
479
- loss_fn = nn.CrossEntropyLoss()
480
- loss = loss_fn(logits, labels)
481
- return {"loss": loss, "logits": logits}
482
-
483
- class EnsembleModel(BaseHateSpeechModel):
484
- """Ensemble model kết hợp các mô hình deep learning"""
485
- def __init__(self, models: list, num_labels: int = 3, weights: list = None):
486
- super().__init__("ensemble", num_labels)
487
- self.models = nn.ModuleList(models)
488
- self.num_models = len(models)
489
- self.weights = weights if weights else [1.0] * self.num_models
490
- self.weights = torch.tensor(self.weights, dtype=torch.float32)
491
- self.weights = self.weights / self.weights.sum() # Normalize weights
492
-
493
- def forward(self, input_ids, attention_mask, labels=None):
494
- all_logits = []
495
- total_loss = 0
496
-
497
- for i, model in enumerate(self.models):
498
- model_output = model(input_ids, attention_mask, labels)
499
- all_logits.append(model_output["logits"])
500
-
501
- if model_output["loss"] is not None:
502
- total_loss += self.weights[i] * model_output["loss"]
503
-
504
- # Weighted average of logits
505
- ensemble_logits = torch.zeros_like(all_logits[0])
506
- for i, logits in enumerate(all_logits):
507
- ensemble_logits += self.weights[i] * logits
508
-
509
- return {
510
- "loss": total_loss if total_loss > 0 else None,
511
- "logits": ensemble_logits
512
- }
513
-
514
- def get_model(model_name: str, num_labels: int = 3, **kwargs):
515
- """
516
- Factory function để tạo model dựa trên tên
517
-
518
- Args:
519
- model_name: Tên model ("phobert-v2", "bartpho", "visobert", "vihate-t5",
520
- "xlm-r", "roberta-gru", "textcnn", "bilstm", "bilstm-crf", "ensemble")
521
- num_labels: Số lượng nhãn (3 cho hate speech: CLEAN, OFFENSIVE, HATE)
522
- **kwargs: Các tham số bổ sung cho model
523
-
524
- Returns:
525
- Model instance
526
- """
527
- model_mapping = {
528
- "phobert-v1": PhoBERTV1Model,
529
- "phobert-v2": PhoBERTV2Model,
530
- "bartpho": BartPhoModel,
531
- "visobert": ViSoBERTModel,
532
- "vihate-t5": ViHateT5Model,
533
- "xlm-r": XLMRModel,
534
- "mbert": MBERTModel,
535
- "sphobert": SPhoBERTModel,
536
- "roberta-gru": RoBERTaGRUModel,
537
- "textcnn": TextCNNModel,
538
- "bilstm": BiLSTMModel,
539
- "ensemble": EnsembleModel
540
- }
541
-
542
- if model_name not in model_mapping:
543
- raise ValueError(f"Unknown model: {model_name}. Available models: {list(model_mapping.keys())}")
544
-
545
- model_class = model_mapping[model_name]
546
- return model_class(num_labels=num_labels, **kwargs)