AnnyNguyen commited on
Commit
42a1fa7
·
verified ·
1 Parent(s): 6195332

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +436 -0
models.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import (
5
+ AutoModel, AutoConfig, AutoTokenizer,
6
+ T5ForConditionalGeneration, T5Config,
7
+ AutoModelForSequenceClassification
8
+ )
9
+ from torchcrf import CRF
10
+ import numpy as np
11
+
12
+ class BaseHateSpeechModel(nn.Module):
13
+ """Base class cho tất cả các mô hình hate speech detection"""
14
+ def __init__(self, model_name: str, num_labels: int = 3):
15
+ super().__init__()
16
+ self.num_labels = num_labels
17
+ self.model_name = model_name
18
+
19
+ def forward(self, input_ids, attention_mask, labels=None):
20
+ raise NotImplementedError
21
+
22
+ class PhoBERTV2Model(BaseHateSpeechModel):
23
+ """PhoBERT-V2 cho hate speech detection"""
24
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3):
25
+ super().__init__(model_name, num_labels)
26
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
27
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
28
+ self.dropout = nn.Dropout(0.1)
29
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
30
+
31
+ def forward(self, input_ids, attention_mask, labels=None):
32
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled_output = outputs.pooler_output
34
+ pooled_output = self.dropout(pooled_output)
35
+ logits = self.classifier(pooled_output)
36
+
37
+ loss = None
38
+ if labels is not None:
39
+ loss_fn = nn.CrossEntropyLoss()
40
+ loss = loss_fn(logits, labels)
41
+ return {"loss": loss, "logits": logits}
42
+
43
+ class BartPhoModel(BaseHateSpeechModel):
44
+ """BART Pho cho hate speech detection"""
45
+ def __init__(self, model_name: str = "vinai/bartpho-syllable-base", num_labels: int = 3):
46
+ super().__init__(model_name, num_labels)
47
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
48
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
49
+ self.dropout = nn.Dropout(0.1)
50
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
51
+
52
+ def forward(self, input_ids, attention_mask, labels=None):
53
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
54
+ # Sử dụng hidden state của token cuối cùng
55
+ last_hidden_states = outputs.last_hidden_state
56
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
57
+ pooled_output = self.dropout(pooled_output)
58
+ logits = self.classifier(pooled_output)
59
+
60
+ loss = None
61
+ if labels is not None:
62
+ loss_fn = nn.CrossEntropyLoss()
63
+ loss = loss_fn(logits, labels)
64
+ return {"loss": loss, "logits": logits}
65
+
66
+ class ViSoBERTModel(BaseHateSpeechModel):
67
+ """ViSoBERT cho hate speech detection"""
68
+ def __init__(self, model_name: str = "uitnlp/visobert", num_labels: int = 3):
69
+ super().__init__(model_name, num_labels)
70
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
71
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
72
+ self.dropout = nn.Dropout(0.1)
73
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
74
+
75
+ def forward(self, input_ids, attention_mask, labels=None):
76
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
77
+
78
+ # Kiểm tra xem có pooler_output không, nếu không thì dùng last_hidden_state
79
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
80
+ pooled_output = outputs.pooler_output
81
+ else:
82
+ # Fallback: sử dụng mean pooling của last_hidden_state
83
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
84
+
85
+ pooled_output = self.dropout(pooled_output)
86
+ logits = self.classifier(pooled_output)
87
+
88
+ loss = None
89
+ if labels is not None:
90
+ loss_fn = nn.CrossEntropyLoss()
91
+ loss = loss_fn(logits, labels)
92
+ return {"loss": loss, "logits": logits}
93
+
94
+ class PhoBERTV1Model(BaseHateSpeechModel):
95
+ """PhoBERT-V1 cho hate speech detection"""
96
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
97
+ super().__init__(model_name, num_labels)
98
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
99
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
100
+ self.dropout = nn.Dropout(0.1)
101
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
102
+
103
+ def forward(self, input_ids, attention_mask, labels=None):
104
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
105
+ # Một số encoder không có pooler_output
106
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
107
+ pooled_output = outputs.pooler_output
108
+ else:
109
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
110
+ pooled_output = self.dropout(pooled_output)
111
+ logits = self.classifier(pooled_output)
112
+
113
+ loss = None
114
+ if labels is not None:
115
+ loss_fn = nn.CrossEntropyLoss()
116
+ loss = loss_fn(logits, labels)
117
+ return {"loss": loss, "logits": logits}
118
+
119
+ class MBERTModel(BaseHateSpeechModel):
120
+ """mBERT (bert-base-multilingual-cased) cho hate speech detection"""
121
+ def __init__(self, model_name: str = "bert-base-multilingual-cased", num_labels: int = 3):
122
+ super().__init__(model_name, num_labels)
123
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
124
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
125
+ self.dropout = nn.Dropout(0.1)
126
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
127
+
128
+ def forward(self, input_ids, attention_mask, labels=None):
129
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
130
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
131
+ pooled_output = outputs.pooler_output
132
+ else:
133
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
134
+ pooled_output = self.dropout(pooled_output)
135
+ logits = self.classifier(pooled_output)
136
+
137
+ loss = None
138
+ if labels is not None:
139
+ loss_fn = nn.CrossEntropyLoss()
140
+ loss = loss_fn(logits, labels)
141
+ return {"loss": loss, "logits": logits}
142
+
143
+ class SPhoBERTModel(BaseHateSpeechModel):
144
+ """SPhoBERT (biến thể PhoBERT syllable-level) cho hate speech detection"""
145
+ def __init__(self, model_name: str = "vinai/phobert-base", num_labels: int = 3):
146
+ super().__init__(model_name, num_labels)
147
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
148
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
149
+ self.dropout = nn.Dropout(0.1)
150
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
151
+
152
+ def forward(self, input_ids, attention_mask, labels=None):
153
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
154
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
155
+ pooled_output = outputs.pooler_output
156
+ else:
157
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
158
+ pooled_output = self.dropout(pooled_output)
159
+ logits = self.classifier(pooled_output)
160
+
161
+ loss = None
162
+ if labels is not None:
163
+ loss_fn = nn.CrossEntropyLoss()
164
+ loss = loss_fn(logits, labels)
165
+ return {"loss": loss, "logits": logits}
166
+
167
+ class ViHateT5Model(BaseHateSpeechModel):
168
+ """ViHateT5 cho hate speech detection"""
169
+ def __init__(self, model_name: str = "VietAI/vit5-base", num_labels: int = 3):
170
+ super().__init__(model_name, num_labels)
171
+ self.config = T5Config.from_pretrained(model_name)
172
+ self.encoder = T5ForConditionalGeneration.from_pretrained(model_name, config=self.config)
173
+ self.dropout = nn.Dropout(0.1)
174
+ self.classifier = nn.Linear(self.config.d_model, num_labels)
175
+
176
+ def forward(self, input_ids, attention_mask, labels=None):
177
+ outputs = self.encoder.encoder(input_ids=input_ids, attention_mask=attention_mask)
178
+ # Sử dụng hidden state của token cuối cùng
179
+ last_hidden_states = outputs.last_hidden_state
180
+ pooled_output = last_hidden_states.mean(dim=1) # Mean pooling
181
+ pooled_output = self.dropout(pooled_output)
182
+ logits = self.classifier(pooled_output)
183
+
184
+ loss = None
185
+ if labels is not None:
186
+ loss_fn = nn.CrossEntropyLoss()
187
+ loss = loss_fn(logits, labels)
188
+ return {"loss": loss, "logits": logits}
189
+
190
+ class XLMRModel(BaseHateSpeechModel):
191
+ """XLM-R Large cho hate speech detection"""
192
+ def __init__(self, model_name: str = "xlm-roberta-large", num_labels: int = 3):
193
+ super().__init__(model_name, num_labels)
194
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
195
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
196
+ self.dropout = nn.Dropout(0.1)
197
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
198
+
199
+ def forward(self, input_ids, attention_mask, labels=None):
200
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
201
+ pooled_output = outputs.pooler_output
202
+ pooled_output = self.dropout(pooled_output)
203
+ logits = self.classifier(pooled_output)
204
+
205
+ loss = None
206
+ if labels is not None:
207
+ loss_fn = nn.CrossEntropyLoss()
208
+ loss = loss_fn(logits, labels)
209
+ return {"loss": loss, "logits": logits}
210
+
211
+ class RoBERTaGRUModel(BaseHateSpeechModel):
212
+ """RoBERTa + GRU Hybrid model"""
213
+ def __init__(self, model_name: str = "vinai/phobert-base-v2", num_labels: int = 3, hidden_size: int = 256):
214
+ super().__init__(model_name, num_labels)
215
+ self.config = AutoConfig.from_pretrained(model_name, ignore_mismatched_sizes=True)
216
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes=True)
217
+ self.gru = nn.GRU(
218
+ input_size=self.config.hidden_size,
219
+ hidden_size=hidden_size,
220
+ num_layers=2,
221
+ batch_first=True,
222
+ dropout=0.1,
223
+ bidirectional=True
224
+ )
225
+ self.dropout = nn.Dropout(0.1)
226
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
227
+
228
+ def forward(self, input_ids, attention_mask, labels=None):
229
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
230
+ hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
231
+
232
+ # GRU processing
233
+ gru_output, _ = self.gru(hidden_states) # [batch_size, seq_len, hidden_size*2]
234
+
235
+ # Global average pooling
236
+ pooled_output = gru_output.mean(dim=1) # [batch_size, hidden_size*2]
237
+ pooled_output = self.dropout(pooled_output)
238
+ logits = self.classifier(pooled_output)
239
+
240
+ loss = None
241
+ if labels is not None:
242
+ loss_fn = nn.CrossEntropyLoss()
243
+ loss = loss_fn(logits, labels)
244
+ return {"loss": loss, "logits": logits}
245
+
246
+ class TextCNNModel(BaseHateSpeechModel):
247
+ """TextCNN cho hate speech detection"""
248
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, num_labels: int = 3,
249
+ num_filters: int = 100, filter_sizes: list = [3, 4, 5], dropout: float = 0.5):
250
+ super().__init__("textcnn", num_labels)
251
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
252
+ self.convs = nn.ModuleList([
253
+ nn.Conv2d(1, num_filters, (filter_size, embedding_dim))
254
+ for filter_size in filter_sizes
255
+ ])
256
+ self.dropout = nn.Dropout(dropout)
257
+ self.classifier = nn.Linear(num_filters * len(filter_sizes), num_labels)
258
+
259
+ def forward(self, input_ids, attention_mask, labels=None):
260
+ # Embedding
261
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
262
+
263
+ # Add channel dimension for Conv2d
264
+ embedded = embedded.unsqueeze(1) # [batch_size, 1, seq_len, embedding_dim]
265
+
266
+ # Convolutional layers
267
+ conv_outputs = []
268
+ for conv in self.convs:
269
+ conv_out = F.relu(conv(embedded)) # [batch_size, num_filters, seq_len', 1]
270
+ conv_out = conv_out.squeeze(3) # [batch_size, num_filters, seq_len']
271
+ pooled = F.max_pool1d(conv_out, conv_out.size(2)) # [batch_size, num_filters, 1]
272
+ pooled = pooled.squeeze(2) # [batch_size, num_filters]
273
+ conv_outputs.append(pooled)
274
+
275
+ # Concatenate all conv outputs
276
+ concatenated = torch.cat(conv_outputs, dim=1) # [batch_size, num_filters * len(filter_sizes)]
277
+
278
+ # Classification
279
+ concatenated = self.dropout(concatenated)
280
+ logits = self.classifier(concatenated)
281
+
282
+ loss = None
283
+ if labels is not None:
284
+ loss_fn = nn.CrossEntropyLoss()
285
+ loss = loss_fn(logits, labels)
286
+ return {"loss": loss, "logits": logits}
287
+
288
+ class BiLSTMModel(BaseHateSpeechModel):
289
+ """BiLSTM cho hate speech detection"""
290
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, hidden_size: int = 256,
291
+ num_labels: int = 3, num_layers: int = 2, dropout: float = 0.5):
292
+ super().__init__("bilstm", num_labels)
293
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
294
+ self.lstm = nn.LSTM(
295
+ input_size=embedding_dim,
296
+ hidden_size=hidden_size,
297
+ num_layers=num_layers,
298
+ batch_first=True,
299
+ dropout=dropout if num_layers > 1 else 0,
300
+ bidirectional=True
301
+ )
302
+ self.dropout = nn.Dropout(dropout)
303
+ self.classifier = nn.Linear(hidden_size * 2, num_labels) # *2 for bidirectional
304
+
305
+ def forward(self, input_ids, attention_mask, labels=None):
306
+ # Embedding
307
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
308
+
309
+ # BiLSTM
310
+ lstm_output, (hidden, cell) = self.lstm(embedded) # [batch_size, seq_len, hidden_size*2]
311
+
312
+ # Global average pooling (có thể thay bằng max pooling hoặc last hidden state)
313
+ # Option 1: Global average pooling
314
+ pooled_output = lstm_output.mean(dim=1) # [batch_size, hidden_size*2]
315
+
316
+ # Option 2: Last hidden state (uncomment if preferred)
317
+ # pooled_output = lstm_output[:, -1, :] # [batch_size, hidden_size*2]
318
+
319
+ # Option 3: Max pooling (uncomment if preferred)
320
+ # pooled_output = torch.max(lstm_output, dim=1)[0] # [batch_size, hidden_size*2]
321
+
322
+ pooled_output = self.dropout(pooled_output)
323
+ logits = self.classifier(pooled_output)
324
+
325
+ loss = None
326
+ if labels is not None:
327
+ loss_fn = nn.CrossEntropyLoss()
328
+ loss = loss_fn(logits, labels)
329
+ return {"loss": loss, "logits": logits}
330
+
331
+ class BiLSTMCRFModel(BaseHateSpeechModel):
332
+ """BiLSTM + CRF cho hate speech detection"""
333
+ def __init__(self, vocab_size: int, embedding_dim: int = 128, hidden_size: int = 256,
334
+ num_labels: int = 3, num_layers: int = 2):
335
+ super().__init__("bilstm-crf", num_labels)
336
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
337
+ self.lstm = nn.LSTM(
338
+ input_size=embedding_dim,
339
+ hidden_size=hidden_size,
340
+ num_layers=num_layers,
341
+ batch_first=True,
342
+ dropout=0.1,
343
+ bidirectional=True
344
+ )
345
+ self.dropout = nn.Dropout(0.1)
346
+ self.classifier = nn.Linear(hidden_size * 2, num_labels)
347
+ self.crf = CRF(num_labels, batch_first=True)
348
+
349
+ def forward(self, input_ids, attention_mask, labels=None):
350
+ # Embedding
351
+ embedded = self.embedding(input_ids) # [batch_size, seq_len, embedding_dim]
352
+
353
+ # BiLSTM
354
+ lstm_output, _ = self.lstm(embedded) # [batch_size, seq_len, hidden_size*2]
355
+ lstm_output = self.dropout(lstm_output)
356
+
357
+ # Classification
358
+ emissions = self.classifier(lstm_output) # [batch_size, seq_len, num_labels]
359
+
360
+ if labels is not None:
361
+ # CRF loss - labels phải có cùng shape với emissions
362
+ if labels.dim() == 1:
363
+ # Nếu labels là 1D, reshape thành 2D để match với emissions
364
+ labels = labels.unsqueeze(1).expand(-1, emissions.size(1))
365
+ loss = -self.crf(emissions, labels, mask=attention_mask.bool())
366
+ return {"loss": loss, "logits": emissions}
367
+ else:
368
+ # CRF prediction
369
+ predictions = self.crf.decode(emissions, mask=attention_mask.bool())
370
+ return {"loss": None, "logits": emissions, "predictions": predictions}
371
+
372
+ class EnsembleModel(BaseHateSpeechModel):
373
+ """Ensemble model kết hợp các mô hình deep learning"""
374
+ def __init__(self, models: list, num_labels: int = 3, weights: list = None):
375
+ super().__init__("ensemble", num_labels)
376
+ self.models = nn.ModuleList(models)
377
+ self.num_models = len(models)
378
+ self.weights = weights if weights else [1.0] * self.num_models
379
+ self.weights = torch.tensor(self.weights, dtype=torch.float32)
380
+ self.weights = self.weights / self.weights.sum() # Normalize weights
381
+
382
+ def forward(self, input_ids, attention_mask, labels=None):
383
+ all_logits = []
384
+ total_loss = 0
385
+
386
+ for i, model in enumerate(self.models):
387
+ model_output = model(input_ids, attention_mask, labels)
388
+ all_logits.append(model_output["logits"])
389
+
390
+ if model_output["loss"] is not None:
391
+ total_loss += self.weights[i] * model_output["loss"]
392
+
393
+ # Weighted average of logits
394
+ ensemble_logits = torch.zeros_like(all_logits[0])
395
+ for i, logits in enumerate(all_logits):
396
+ ensemble_logits += self.weights[i] * logits
397
+
398
+ return {
399
+ "loss": total_loss if total_loss > 0 else None,
400
+ "logits": ensemble_logits
401
+ }
402
+
403
+ def get_model(model_name: str, num_labels: int = 3, **kwargs):
404
+ """
405
+ Factory function để tạo model dựa trên tên
406
+
407
+ Args:
408
+ model_name: Tên model ("phobert-v2", "bartpho", "visobert", "vihate-t5",
409
+ "xlm-r", "roberta-gru", "textcnn", "bilstm", "bilstm-crf", "ensemble")
410
+ num_labels: Số lượng nhãn (3 cho hate speech: CLEAN, OFFENSIVE, HATE)
411
+ **kwargs: Các tham số bổ sung cho model
412
+
413
+ Returns:
414
+ Model instance
415
+ """
416
+ model_mapping = {
417
+ "phobert-v1": PhoBERTV1Model,
418
+ "phobert-v2": PhoBERTV2Model,
419
+ "bartpho": BartPhoModel,
420
+ "visobert": ViSoBERTModel,
421
+ "vihate-t5": ViHateT5Model,
422
+ "xlm-r": XLMRModel,
423
+ "mbert": MBERTModel,
424
+ "sphobert": SPhoBERTModel,
425
+ "roberta-gru": RoBERTaGRUModel,
426
+ "textcnn": TextCNNModel,
427
+ "bilstm": BiLSTMModel,
428
+ "bilstm-crf": BiLSTMCRFModel,
429
+ "ensemble": EnsembleModel
430
+ }
431
+
432
+ if model_name not in model_mapping:
433
+ raise ValueError(f"Unknown model: {model_name}. Available models: {list(model_mapping.keys())}")
434
+
435
+ model_class = model_mapping[model_name]
436
+ return model_class(num_labels=num_labels, **kwargs)