AnnyNguyen commited on
Commit
8fbdeda
·
verified ·
1 Parent(s): ab6110b

Upload models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models.py +978 -0
models.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model architectures cho Aspect-Based Sentiment Analysis
3
+ Hỗ trợ nhiều architectures: Transformer-based, CNN, LSTM, và hybrid models
4
+ """
5
+ import torch
6
+ import os
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import (
10
+ RobertaPreTrainedModel, RobertaModel,
11
+ BertPreTrainedModel, BertModel,
12
+ XLMRobertaPreTrainedModel, XLMRobertaModel,
13
+ BartPreTrainedModel, BartModel, BartForSequenceClassification,
14
+ T5PreTrainedModel, T5EncoderModel,
15
+ AutoConfig, AutoModel, AutoTokenizer,
16
+ PreTrainedModel
17
+ )
18
+ from transformers.modeling_outputs import SequenceClassifierOutput
19
+ from typing import Optional
20
+
21
+
22
+ class BaseABSA(PreTrainedModel):
23
+ """Base class cho tất cả ABSA models"""
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.num_aspects = config.num_aspects
27
+ self.num_sentiments = config.num_sentiments
28
+
29
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
30
+ raise NotImplementedError
31
+
32
+ def get_sentiment_classifiers(self, hidden_size):
33
+ """Create sentiment classifiers cho từng aspect"""
34
+ return nn.ModuleList([
35
+ nn.Linear(hidden_size, self.num_sentiments + 1) # +1 cho "none"
36
+ for _ in range(self.num_aspects)
37
+ ])
38
+
39
+
40
+ # ========== Transformer-based Models ==========
41
+
42
+ class TransformerForABSA(RobertaPreTrainedModel):
43
+ """RoBERTa-based model (cho PhoBERT, ViSoBERT, RoBERTa-GRU)"""
44
+ base_model_prefix = "roberta"
45
+
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+ self.roberta = RobertaModel(config)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ self.sentiment_classifiers = nn.ModuleList([
51
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
52
+ for _ in range(config.num_aspects)
53
+ ])
54
+ self.init_weights()
55
+
56
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
57
+ # RoBERTa-based models don't use token_type_ids, so we ignore it
58
+ kwargs.pop('token_type_ids', None)
59
+ # Filter kwargs to only include valid arguments for RobertaModel
60
+ model_kwargs = {
61
+ k: v for k, v in kwargs.items()
62
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
63
+ 'output_attentions', 'output_hidden_states']
64
+ }
65
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
66
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
67
+ pooled = self.dropout(outputs.pooler_output)
68
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
69
+
70
+ loss = None
71
+ if labels is not None:
72
+ B, A, _ = all_logits.size()
73
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
74
+ targets_flat = labels.view(-1)
75
+ loss_fct = nn.CrossEntropyLoss()
76
+ loss = loss_fct(logits_flat, targets_flat)
77
+
78
+ if not return_dict:
79
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
80
+
81
+ # T5 returns BaseModelOutput, which has hidden_states
82
+ # But we need to handle it properly
83
+ hidden_states = getattr(outputs, 'hidden_states', None)
84
+ attentions = getattr(outputs, 'attentions', None)
85
+
86
+ return SequenceClassifierOutput(
87
+ loss=loss, logits=all_logits,
88
+ hidden_states=hidden_states,
89
+ attentions=attentions,
90
+ )
91
+
92
+ def save_pretrained(self, save_directory: str, **kwargs):
93
+ # Ensure directory exists
94
+ os.makedirs(save_directory, exist_ok=True)
95
+
96
+ # Save backbone
97
+ self.roberta.save_pretrained(save_directory, **kwargs)
98
+
99
+ # Update and save config with custom attributes
100
+ config = self.roberta.config
101
+ config.num_aspects = len(self.sentiment_classifiers)
102
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1 # -1 vì không tính lớp "none"
103
+ # Auto map để AutoModel tự động load đúng class
104
+ # models.py sẽ được upload vào root của repo
105
+ config.auto_map = {
106
+ "AutoModel": "models.TransformerForABSA",
107
+ "AutoModelForSequenceClassification": "models.TransformerForABSA"
108
+ }
109
+ # Lưu thêm thông tin vào config để dễ dàng load lại
110
+ if not hasattr(config, 'custom_model_type'):
111
+ config.custom_model_type = 'TransformerForABSA'
112
+ config.save_pretrained(save_directory, **kwargs)
113
+
114
+ # Save full state_dict (bao gồm cả sentiment_classifiers)
115
+ sd = kwargs.get("state_dict", None) or self.state_dict()
116
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
117
+
118
+ @classmethod
119
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
120
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
121
+
122
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
123
+ if num_aspects is None:
124
+ num_aspects = getattr(config, 'num_aspects', None)
125
+ if num_aspects is None:
126
+ raise ValueError("num_aspects must be provided or present in config")
127
+
128
+ if num_sentiments is None:
129
+ num_sentiments = getattr(config, 'num_sentiments', None)
130
+ if num_sentiments is None:
131
+ raise ValueError("num_sentiments must be provided or present in config")
132
+
133
+ config.num_aspects = num_aspects
134
+ config.num_sentiments = num_sentiments
135
+
136
+ model = cls(config)
137
+
138
+ # Load backbone weights
139
+ model.roberta = RobertaModel.from_pretrained(
140
+ pretrained_model_name_or_path, config=config,
141
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
142
+ )
143
+
144
+ # Load full state_dict nếu có (bao gồm sentiment_classifiers)
145
+ try:
146
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
147
+ if os.path.exists(state_dict_path):
148
+ state_dict = torch.load(state_dict_path, map_location="cpu")
149
+ model.load_state_dict(state_dict, strict=False)
150
+ elif "state_dict" in kwargs:
151
+ model.load_state_dict(kwargs["state_dict"], strict=False)
152
+ except Exception as e:
153
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
154
+
155
+ return model
156
+
157
+
158
+ class BERTForABSA(BertPreTrainedModel):
159
+ """BERT-based model (cho mBERT)"""
160
+ def __init__(self, config):
161
+ super().__init__(config)
162
+ self.bert = BertModel(config)
163
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
164
+ self.sentiment_classifiers = nn.ModuleList([
165
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
166
+ for _ in range(config.num_aspects)
167
+ ])
168
+ self.init_weights()
169
+
170
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, token_type_ids=None, **kwargs):
171
+ # BERT models can use token_type_ids, but for single sentence tasks, it's usually all zeros
172
+ # Filter kwargs to only include valid arguments for BertModel
173
+ model_kwargs = {
174
+ k: v for k, v in kwargs.items()
175
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
176
+ 'output_attentions', 'output_hidden_states']
177
+ }
178
+ # BERT expects token_type_ids, but if not provided, it will default to all zeros
179
+ if token_type_ids is not None:
180
+ model_kwargs['token_type_ids'] = token_type_ids
181
+
182
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
183
+ outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
184
+ pooled = self.dropout(outputs.pooler_output)
185
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
186
+
187
+ loss = None
188
+ if labels is not None:
189
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
190
+ targets_flat = labels.view(-1)
191
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
192
+
193
+ if not return_dict:
194
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
195
+
196
+ # T5 returns BaseModelOutput, which has hidden_states
197
+ # But we need to handle it properly
198
+ hidden_states = getattr(outputs, 'hidden_states', None)
199
+ attentions = getattr(outputs, 'attentions', None)
200
+
201
+ return SequenceClassifierOutput(
202
+ loss=loss, logits=all_logits,
203
+ hidden_states=hidden_states,
204
+ attentions=attentions,
205
+ )
206
+
207
+ def save_pretrained(self, save_directory: str, **kwargs):
208
+ """Save model with custom attributes"""
209
+ os.makedirs(save_directory, exist_ok=True)
210
+ self.bert.save_pretrained(save_directory, **kwargs)
211
+ config = self.bert.config
212
+ config.num_aspects = len(self.sentiment_classifiers)
213
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
214
+ config.auto_map = {
215
+ "AutoModel": "models.BERTForABSA",
216
+ "AutoModelForSequenceClassification": "models.BERTForABSA"
217
+ }
218
+ if not hasattr(config, 'custom_model_type'):
219
+ config.custom_model_type = 'BERTForABSA'
220
+ config.save_pretrained(save_directory, **kwargs)
221
+ sd = kwargs.get("state_dict", None) or self.state_dict()
222
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
223
+
224
+ @classmethod
225
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
226
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
227
+
228
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
229
+ if num_aspects is None:
230
+ num_aspects = getattr(config, 'num_aspects', None)
231
+ if num_aspects is None:
232
+ raise ValueError("num_aspects must be provided or present in config")
233
+
234
+ if num_sentiments is None:
235
+ num_sentiments = getattr(config, 'num_sentiments', None)
236
+ if num_sentiments is None:
237
+ raise ValueError("num_sentiments must be provided or present in config")
238
+
239
+ config.num_aspects = num_aspects
240
+ config.num_sentiments = num_sentiments
241
+ model = cls(config)
242
+ model.bert = BertModel.from_pretrained(
243
+ pretrained_model_name_or_path, config=config,
244
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
245
+ )
246
+
247
+ # Load full state_dict if available
248
+ try:
249
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
250
+ if os.path.exists(state_dict_path):
251
+ state_dict = torch.load(state_dict_path, map_location="cpu")
252
+ model.load_state_dict(state_dict, strict=False)
253
+ elif "state_dict" in kwargs:
254
+ model.load_state_dict(kwargs["state_dict"], strict=False)
255
+ except Exception as e:
256
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
257
+
258
+ return model
259
+
260
+
261
+ class XLMRobertaForABSA(XLMRobertaPreTrainedModel):
262
+ """XLM-RoBERTa-based model"""
263
+ def __init__(self, config):
264
+ super().__init__(config)
265
+ self.roberta = XLMRobertaModel(config)
266
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
267
+ self.sentiment_classifiers = nn.ModuleList([
268
+ nn.Linear(config.hidden_size, config.num_sentiments + 1)
269
+ for _ in range(config.num_aspects)
270
+ ])
271
+ self.init_weights()
272
+
273
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
275
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict)
276
+ pooled = self.dropout(outputs.pooler_output)
277
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
278
+
279
+ loss = None
280
+ if labels is not None:
281
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
282
+ targets_flat = labels.view(-1)
283
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
284
+
285
+ if not return_dict:
286
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
287
+
288
+ # T5 returns BaseModelOutput, which has hidden_states
289
+ # But we need to handle it properly
290
+ hidden_states = getattr(outputs, 'hidden_states', None)
291
+ attentions = getattr(outputs, 'attentions', None)
292
+
293
+ return SequenceClassifierOutput(
294
+ loss=loss, logits=all_logits,
295
+ hidden_states=hidden_states,
296
+ attentions=attentions,
297
+ )
298
+
299
+ @classmethod
300
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
301
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
302
+
303
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
304
+ if num_aspects is None:
305
+ num_aspects = getattr(config, 'num_aspects', None)
306
+ if num_aspects is None:
307
+ raise ValueError("num_aspects must be provided or present in config")
308
+
309
+ if num_sentiments is None:
310
+ num_sentiments = getattr(config, 'num_sentiments', None)
311
+ if num_sentiments is None:
312
+ raise ValueError("num_sentiments must be provided or present in config")
313
+
314
+ config.num_aspects = num_aspects
315
+ config.num_sentiments = num_sentiments
316
+ model = cls(config)
317
+ model.roberta = XLMRobertaModel.from_pretrained(
318
+ pretrained_model_name_or_path, config=config,
319
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
320
+ )
321
+
322
+ # Load full state_dict nếu có (bao gồm sentiment_classifiers)
323
+ try:
324
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
325
+ if os.path.exists(state_dict_path):
326
+ state_dict = torch.load(state_dict_path, map_location="cpu")
327
+ model.load_state_dict(state_dict, strict=False)
328
+ elif "state_dict" in kwargs:
329
+ model.load_state_dict(kwargs["state_dict"], strict=False)
330
+ except Exception as e:
331
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
332
+
333
+ return model
334
+
335
+
336
+ class RoBERTaGRUForABSA(RobertaPreTrainedModel):
337
+ """RoBERTa + GRU hybrid model"""
338
+ base_model_prefix = "roberta"
339
+
340
+ def __init__(self, config):
341
+ super().__init__(config)
342
+ self.roberta = RobertaModel(config)
343
+ self.gru = nn.GRU(
344
+ config.hidden_size, config.hidden_size,
345
+ num_layers=2, batch_first=True, bidirectional=True, dropout=0.2
346
+ )
347
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
348
+ self.sentiment_classifiers = nn.ModuleList([
349
+ nn.Linear(config.hidden_size * 2, config.num_sentiments + 1) # *2 vì bidirectional
350
+ for _ in range(config.num_aspects)
351
+ ])
352
+ self.init_weights()
353
+
354
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None):
355
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
356
+ outputs = self.roberta(input_ids, attention_mask=attention_mask, return_dict=return_dict)
357
+
358
+ # Use last_hidden_state thay vì pooler_output
359
+ sequence_output = outputs.last_hidden_state # [B, L, H]
360
+
361
+ # GRU layer
362
+ gru_out, _ = self.gru(sequence_output) # [B, L, H*2]
363
+ # Take last timestep
364
+ pooled = self.dropout(gru_out[:, -1, :]) # [B, H*2]
365
+
366
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
367
+
368
+ loss = None
369
+ if labels is not None:
370
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
371
+ targets_flat = labels.view(-1)
372
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
373
+
374
+ if not return_dict:
375
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
376
+
377
+ # T5 returns BaseModelOutput, which has hidden_states
378
+ # But we need to handle it properly
379
+ hidden_states = getattr(outputs, 'hidden_states', None)
380
+ attentions = getattr(outputs, 'attentions', None)
381
+
382
+ return SequenceClassifierOutput(
383
+ loss=loss, logits=all_logits,
384
+ hidden_states=hidden_states,
385
+ attentions=attentions,
386
+ )
387
+
388
+ @classmethod
389
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
390
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
391
+
392
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
393
+ if num_aspects is None:
394
+ num_aspects = getattr(config, 'num_aspects', None)
395
+ if num_aspects is None:
396
+ raise ValueError("num_aspects must be provided or present in config")
397
+
398
+ if num_sentiments is None:
399
+ num_sentiments = getattr(config, 'num_sentiments', None)
400
+ if num_sentiments is None:
401
+ raise ValueError("num_sentiments must be provided or present in config")
402
+
403
+ config.num_aspects = num_aspects
404
+ config.num_sentiments = num_sentiments
405
+ model = cls(config)
406
+ model.roberta = RobertaModel.from_pretrained(
407
+ pretrained_model_name_or_path, config=config,
408
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
409
+ )
410
+
411
+ # Load full state_dict nếu có (bao gồm GRU và sentiment_classifiers)
412
+ try:
413
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
414
+ if os.path.exists(state_dict_path):
415
+ state_dict = torch.load(state_dict_path, map_location="cpu")
416
+ model.load_state_dict(state_dict, strict=False)
417
+ elif "state_dict" in kwargs:
418
+ model.load_state_dict(kwargs["state_dict"], strict=False)
419
+ except Exception as e:
420
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
421
+
422
+ return model
423
+
424
+
425
+ class BartForABSA(BartPreTrainedModel):
426
+ """BART-based model (cho BartPho)"""
427
+ def __init__(self, config):
428
+ super().__init__(config)
429
+ self.model = BartModel(config)
430
+ self.dropout = nn.Dropout(config.dropout)
431
+ self.sentiment_classifiers = nn.ModuleList([
432
+ nn.Linear(config.d_model, config.num_sentiments + 1)
433
+ for _ in range(config.num_aspects)
434
+ ])
435
+ self.init_weights()
436
+
437
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
438
+ # BART models don't use token_type_ids, so we ignore it
439
+ kwargs.pop('token_type_ids', None)
440
+ # Filter kwargs to only include valid arguments for BartModel
441
+ # Remove training-specific arguments that BartModel doesn't accept
442
+ model_kwargs = {
443
+ k: v for k, v in kwargs.items()
444
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
445
+ 'output_attentions', 'output_hidden_states']
446
+ }
447
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
448
+
449
+ # IMPORTANT: For BART, we need to access encoder output directly
450
+ # BartModel.forward() returns decoder output in last_hidden_state
451
+ # We need to call encoder separately to get encoder hidden states
452
+ # Only call encoder once (don't call full model.forward() to avoid double computation)
453
+ encoder_outputs = self.model.get_encoder()(
454
+ input_ids,
455
+ attention_mask=attention_mask,
456
+ return_dict=True,
457
+ **{k: v for k, v in model_kwargs.items()}
458
+ )
459
+ sequence_output = encoder_outputs.last_hidden_state # [B, L, H] - encoder output
460
+
461
+ # Mean pooling with attention mask (weighted mean to avoid padding tokens)
462
+ if attention_mask is not None:
463
+ # Expand attention mask to match sequence_output dimensions
464
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
465
+ # Sum over sequence length, divide by number of non-padding tokens
466
+ sum_embeddings = torch.sum(sequence_output * attention_mask_expanded, dim=1)
467
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
468
+ pooled = sum_embeddings / sum_mask # [B, H]
469
+ else:
470
+ pooled = sequence_output.mean(dim=1) # [B, H]
471
+
472
+ pooled = self.dropout(pooled)
473
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
474
+
475
+ loss = None
476
+ if labels is not None:
477
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
478
+ targets_flat = labels.view(-1)
479
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
480
+
481
+ if not return_dict:
482
+ return ((loss, all_logits) + (encoder_outputs.hidden_states, encoder_outputs.attentions)) if loss is not None else (all_logits,)
483
+
484
+ # Use encoder outputs for hidden_states and attentions
485
+ hidden_states = getattr(encoder_outputs, 'hidden_states', None)
486
+ attentions = getattr(encoder_outputs, 'attentions', None)
487
+
488
+ return SequenceClassifierOutput(
489
+ loss=loss, logits=all_logits,
490
+ hidden_states=hidden_states,
491
+ attentions=attentions,
492
+ )
493
+
494
+ def save_pretrained(self, save_directory: str, **kwargs):
495
+ """Save model with custom attributes"""
496
+ os.makedirs(save_directory, exist_ok=True)
497
+ self.model.save_pretrained(save_directory, **kwargs)
498
+ config = self.model.config
499
+ config.num_aspects = len(self.sentiment_classifiers)
500
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
501
+ config.auto_map = {
502
+ "AutoModel": "models.BartForABSA",
503
+ "AutoModelForSequenceClassification": "models.BartForABSA"
504
+ }
505
+ if not hasattr(config, 'custom_model_type'):
506
+ config.custom_model_type = 'BartForABSA'
507
+ config.save_pretrained(save_directory, **kwargs)
508
+ sd = kwargs.get("state_dict", None) or self.state_dict()
509
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
510
+
511
+ @classmethod
512
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
513
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
514
+
515
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
516
+ if num_aspects is None:
517
+ num_aspects = getattr(config, 'num_aspects', None)
518
+ if num_aspects is None:
519
+ raise ValueError("num_aspects must be provided or present in config")
520
+
521
+ if num_sentiments is None:
522
+ num_sentiments = getattr(config, 'num_sentiments', None)
523
+ if num_sentiments is None:
524
+ raise ValueError("num_sentiments must be provided or present in config")
525
+
526
+ config.num_aspects = num_aspects
527
+ config.num_sentiments = num_sentiments
528
+ model = cls(config)
529
+ model.model = BartModel.from_pretrained(
530
+ pretrained_model_name_or_path, config=config,
531
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
532
+ )
533
+
534
+ # Load full state_dict if available
535
+ try:
536
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
537
+ if os.path.exists(state_dict_path):
538
+ state_dict = torch.load(state_dict_path, map_location="cpu")
539
+ model.load_state_dict(state_dict, strict=False)
540
+ elif "state_dict" in kwargs:
541
+ model.load_state_dict(kwargs["state_dict"], strict=False)
542
+ except Exception as e:
543
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
544
+
545
+ return model
546
+
547
+
548
+ class T5ForABSA(T5PreTrainedModel):
549
+ """T5-based model (cho ViT5) - sử dụng encoder only"""
550
+ def __init__(self, config):
551
+ super().__init__(config)
552
+ self.encoder = T5EncoderModel(config)
553
+ self.dropout = nn.Dropout(config.dropout_rate)
554
+ self.sentiment_classifiers = nn.ModuleList([
555
+ nn.Linear(config.d_model, config.num_sentiments + 1)
556
+ for _ in range(config.num_aspects)
557
+ ])
558
+ self.init_weights()
559
+
560
+ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=None, **kwargs):
561
+ # T5 models don't use token_type_ids, so we ignore it
562
+ kwargs.pop('token_type_ids', None)
563
+ # Filter kwargs to only include valid arguments for T5EncoderModel
564
+ model_kwargs = {
565
+ k: v for k, v in kwargs.items()
566
+ if k in ['position_ids', 'head_mask', 'inputs_embeds',
567
+ 'output_attentions', 'output_hidden_states']
568
+ }
569
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
570
+ outputs = self.encoder(input_ids, attention_mask=attention_mask, return_dict=return_dict, **model_kwargs)
571
+
572
+ # Mean pooling with attention mask (weighted mean to avoid padding tokens)
573
+ sequence_output = outputs.last_hidden_state # [B, L, H]
574
+ if attention_mask is not None:
575
+ # Expand attention mask to match sequence_output dimensions
576
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
577
+ # Sum over sequence length, divide by number of non-padding tokens
578
+ sum_embeddings = torch.sum(sequence_output * attention_mask_expanded, dim=1)
579
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
580
+ pooled = sum_embeddings / sum_mask # [B, H]
581
+ else:
582
+ pooled = sequence_output.mean(dim=1) # [B, H]
583
+
584
+ pooled = self.dropout(pooled)
585
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
586
+
587
+ loss = None
588
+ if labels is not None:
589
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
590
+ targets_flat = labels.view(-1)
591
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
592
+
593
+ if not return_dict:
594
+ return ((loss, all_logits) + outputs[2:]) if loss is not None else (all_logits,) + outputs[2:]
595
+
596
+ # T5 returns BaseModelOutput, which has hidden_states
597
+ # But we need to handle it properly
598
+ hidden_states = getattr(outputs, 'hidden_states', None)
599
+ attentions = getattr(outputs, 'attentions', None)
600
+
601
+ return SequenceClassifierOutput(
602
+ loss=loss, logits=all_logits,
603
+ hidden_states=hidden_states,
604
+ attentions=attentions,
605
+ )
606
+
607
+ def save_pretrained(self, save_directory: str, **kwargs):
608
+ """Save model with custom attributes"""
609
+ os.makedirs(save_directory, exist_ok=True)
610
+ self.encoder.save_pretrained(save_directory, **kwargs)
611
+ config = self.encoder.config
612
+ config.num_aspects = len(self.sentiment_classifiers)
613
+ config.num_sentiments = self.sentiment_classifiers[0].out_features - 1
614
+ config.auto_map = {
615
+ "AutoModel": "models.T5ForABSA",
616
+ "AutoModelForSequenceClassification": "models.T5ForABSA"
617
+ }
618
+ if not hasattr(config, 'custom_model_type'):
619
+ config.custom_model_type = 'T5ForABSA'
620
+ config.save_pretrained(save_directory, **kwargs)
621
+ sd = kwargs.get("state_dict", None) or self.state_dict()
622
+ torch.save(sd, os.path.join(save_directory, "pytorch_model.bin"))
623
+
624
+ @classmethod
625
+ def from_pretrained(cls, pretrained_model_name_or_path: str, num_aspects: int = None, num_sentiments: int = None, **kwargs):
626
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
627
+
628
+ # Nếu num_aspects và num_sentiments không được truyền vào, đọc từ config
629
+ if num_aspects is None:
630
+ num_aspects = getattr(config, 'num_aspects', None)
631
+ if num_aspects is None:
632
+ raise ValueError("num_aspects must be provided or present in config")
633
+
634
+ if num_sentiments is None:
635
+ num_sentiments = getattr(config, 'num_sentiments', None)
636
+ if num_sentiments is None:
637
+ raise ValueError("num_sentiments must be provided or present in config")
638
+
639
+ config.num_aspects = num_aspects
640
+ config.num_sentiments = num_sentiments
641
+ model = cls(config)
642
+ model.encoder = T5EncoderModel.from_pretrained(
643
+ pretrained_model_name_or_path, config=config,
644
+ **{k: v for k, v in kwargs.items() if k not in ("config", "state_dict")},
645
+ )
646
+
647
+ # Load full state_dict if available
648
+ try:
649
+ state_dict_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
650
+ if os.path.exists(state_dict_path):
651
+ state_dict = torch.load(state_dict_path, map_location="cpu")
652
+ model.load_state_dict(state_dict, strict=False)
653
+ elif "state_dict" in kwargs:
654
+ model.load_state_dict(kwargs["state_dict"], strict=False)
655
+ except Exception as e:
656
+ print(f"⚠ Warning: Could not load full state_dict: {e}")
657
+
658
+ return model
659
+
660
+
661
+ # ========== Non-Transformer Models ==========
662
+
663
+ class TextCNNForABSA(nn.Module):
664
+ """TextCNN model - không dùng transformers"""
665
+ def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, num_aspects, num_sentiments, max_length=256):
666
+ super().__init__()
667
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
668
+ self.convs = nn.ModuleList([
669
+ nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
670
+ for fs in filter_sizes
671
+ ])
672
+ self.dropout = nn.Dropout(0.5)
673
+ self.sentiment_classifiers = nn.ModuleList([
674
+ nn.Linear(len(filter_sizes) * num_filters, num_sentiments + 1)
675
+ for _ in range(num_aspects)
676
+ ])
677
+
678
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True):
679
+ # input_ids: [B, L]
680
+ x = self.embedding(input_ids) # [B, L, E]
681
+ x = x.permute(0, 2, 1) # [B, E, L]
682
+
683
+ conv_outputs = []
684
+ for conv in self.convs:
685
+ conv_out = F.relu(conv(x)) # [B, F, L']
686
+ pooled = F.max_pool1d(conv_out, kernel_size=conv_out.size(2)) # [B, F, 1]
687
+ conv_outputs.append(pooled.squeeze(2)) # [B, F]
688
+
689
+ x = torch.cat(conv_outputs, dim=1) # [B, F*len(filter_sizes)]
690
+ x = self.dropout(x)
691
+
692
+ all_logits = torch.stack([cls(x) for cls in self.sentiment_classifiers], dim=1)
693
+
694
+ loss = None
695
+ if labels is not None:
696
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
697
+ targets_flat = labels.view(-1)
698
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
699
+
700
+ if return_dict:
701
+ return SequenceClassifierOutput(
702
+ loss=loss, logits=all_logits,
703
+ hidden_states=None, attentions=None
704
+ )
705
+ return (loss, all_logits) if loss is not None else (all_logits,)
706
+
707
+ @classmethod
708
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
709
+ """Load TextCNN model from pretrained path"""
710
+ import json
711
+
712
+ # Load config
713
+ config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
714
+ if os.path.exists(config_path):
715
+ with open(config_path, 'r', encoding='utf-8') as f:
716
+ config = json.load(f)
717
+ else:
718
+ # Fallback to model_config.json
719
+ config_path = os.path.join(pretrained_model_name_or_path, 'model_config.json')
720
+ if os.path.exists(config_path):
721
+ with open(config_path, 'r', encoding='utf-8') as f:
722
+ config = json.load(f)
723
+ else:
724
+ raise ValueError(f"Config file not found in {pretrained_model_name_or_path}")
725
+
726
+ # Create model with config
727
+ model = cls(
728
+ vocab_size=config.get('vocab_size', 30000),
729
+ embed_dim=kwargs.get('embed_dim', 300),
730
+ num_filters=kwargs.get('num_filters', 100),
731
+ filter_sizes=kwargs.get('filter_sizes', [3, 4, 5]),
732
+ num_aspects=config['num_aspects'],
733
+ num_sentiments=config['num_sentiments'],
734
+ max_length=kwargs.get('max_length', 256)
735
+ )
736
+
737
+ # Load weights
738
+ weights_path = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
739
+ if os.path.exists(weights_path):
740
+ state_dict = torch.load(weights_path, map_location='cpu')
741
+ model.load_state_dict(state_dict)
742
+
743
+ return model
744
+
745
+
746
+ class BiLSTMForABSA(nn.Module):
747
+ """BiLSTM model - không dùng transformers"""
748
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_aspects, num_sentiments, dropout=0.3):
749
+ super().__init__()
750
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
751
+ self.lstm = nn.LSTM(
752
+ embed_dim, hidden_dim, num_layers,
753
+ batch_first=True, bidirectional=True, dropout=dropout
754
+ )
755
+ self.dropout = nn.Dropout(dropout)
756
+ self.sentiment_classifiers = nn.ModuleList([
757
+ nn.Linear(hidden_dim * 2, num_sentiments + 1) # *2 vì bidirectional
758
+ for _ in range(num_aspects)
759
+ ])
760
+
761
+ def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True):
762
+ x = self.embedding(input_ids) # [B, L, E]
763
+ lstm_out, (h_n, c_n) = self.lstm(x) # [B, L, H*2]
764
+
765
+ # Use last non-padding hidden state instead of always using last timestep
766
+ # This is important because padding tokens can be at the end
767
+ if attention_mask is not None:
768
+ # Find the last non-padding token for each sequence
769
+ # attention_mask: [B, L] where 1 = real token, 0 = padding
770
+ seq_lengths = attention_mask.sum(dim=1) - 1 # -1 for 0-indexing
771
+ # Ensure seq_lengths are within valid range
772
+ seq_lengths = torch.clamp(seq_lengths, min=0, max=lstm_out.size(1) - 1)
773
+ # Get last hidden state for each sequence: [B, H*2]
774
+ batch_size = lstm_out.size(0)
775
+ pooled = lstm_out[torch.arange(batch_size, device=lstm_out.device), seq_lengths, :]
776
+ else:
777
+ # Fallback: use last timestep if no attention mask
778
+ pooled = lstm_out[:, -1, :] # [B, H*2]
779
+
780
+ pooled = self.dropout(pooled)
781
+ all_logits = torch.stack([cls(pooled) for cls in self.sentiment_classifiers], dim=1)
782
+
783
+ loss = None
784
+ if labels is not None:
785
+ logits_flat = all_logits.view(-1, all_logits.size(-1))
786
+ targets_flat = labels.view(-1)
787
+ loss = nn.CrossEntropyLoss()(logits_flat, targets_flat)
788
+
789
+ if return_dict:
790
+ return SequenceClassifierOutput(
791
+ loss=loss, logits=all_logits,
792
+ hidden_states=None, attentions=None
793
+ )
794
+ return (loss, all_logits) if loss is not None else (all_logits,)
795
+
796
+ @classmethod
797
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
798
+ """Load BiLSTM model from pretrained path"""
799
+ import json
800
+
801
+ # Load config
802
+ config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
803
+ if os.path.exists(config_path):
804
+ with open(config_path, 'r', encoding='utf-8') as f:
805
+ config = json.load(f)
806
+ else:
807
+ # Fallback to model_config.json
808
+ config_path = os.path.join(pretrained_model_name_or_path, 'model_config.json')
809
+ if os.path.exists(config_path):
810
+ with open(config_path, 'r', encoding='utf-8') as f:
811
+ config = json.load(f)
812
+ else:
813
+ raise ValueError(f"Config file not found in {pretrained_model_name_or_path}")
814
+
815
+ # Create model with config
816
+ model = cls(
817
+ vocab_size=config.get('vocab_size', 30000),
818
+ embed_dim=kwargs.get('embed_dim', 300),
819
+ hidden_dim=kwargs.get('hidden_dim', 256),
820
+ num_layers=kwargs.get('num_layers', 2),
821
+ num_aspects=config['num_aspects'],
822
+ num_sentiments=config['num_sentiments'],
823
+ dropout=kwargs.get('dropout', 0.3)
824
+ )
825
+
826
+ # Load weights
827
+ weights_path = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
828
+ if os.path.exists(weights_path):
829
+ state_dict = torch.load(weights_path, map_location='cpu')
830
+ model.load_state_dict(state_dict)
831
+
832
+ return model
833
+
834
+
835
+ # ========== Model Factory ==========
836
+
837
+ # Mapping từ model key sang Hugging Face model ID
838
+ # Các key này được dùng trong main.py, cần map sang Hugging Face model ID thực tế
839
+ MODEL_ID_MAPPING = {
840
+ "phobert-v1": "vinai/phobert-base", # PhoBERT v1 (có thể dùng phobert-base hoặc phobert-base-v1)
841
+ "phobert-v2": "vinai/phobert-base", # PhoBERT v2
842
+ "bartpho": "vinai/bartpho-syllable",
843
+ "visobert": "uitnlp/visobert",
844
+ "mbert": "bert-base-multilingual-cased",
845
+ "vit5": "VietAI/vit5-base",
846
+ "xlm-r": "xlm-roberta-base",
847
+ "xlm-roberta": "xlm-roberta-base",
848
+ "roberta-gru": "roberta-base",
849
+ "roberta-base-gru": "roberta-base",
850
+ }
851
+
852
+ def get_hf_model_id(model_name: str) -> str:
853
+ """
854
+ Chuyển đổi model key sang Hugging Face model ID
855
+
856
+ Args:
857
+ model_name: Model key (ví dụ: "phobert-v1") hoặc Hugging Face ID (ví dụ: "vinai/phobert-base")
858
+
859
+ Returns:
860
+ Hugging Face model ID
861
+ """
862
+ model_name_lower = model_name.lower()
863
+
864
+ # Nếu đã là Hugging Face ID (có dấu /), trả về nguyên
865
+ if '/' in model_name:
866
+ return model_name
867
+
868
+ # Nếu là key, map sang Hugging Face ID
869
+ if model_name_lower in MODEL_ID_MAPPING:
870
+ return MODEL_ID_MAPPING[model_name_lower]
871
+
872
+ # Nếu không có trong mapping, giả định đã là Hugging Face ID
873
+ return model_name
874
+
875
+ def get_model_class(model_name: str):
876
+ """Factory function để lấy model class dựa trên model name"""
877
+ model_name_lower = model_name.lower()
878
+
879
+ # RoBERTa-GRU (check first to avoid confusion)
880
+ if 'roberta' in model_name_lower and ('gru' in model_name_lower or 'roberta-base-gru' in model_name_lower):
881
+ return RoBERTaGRUForABSA
882
+
883
+ # Roberta-based (PhoBERT v1/v2, ViSoBERT)
884
+ if 'phobert' in model_name_lower or 'visobert' in model_name_lower or 'roberta' in model_name_lower:
885
+ return TransformerForABSA
886
+
887
+ # XLM-RoBERTa
888
+ elif 'xlm-roberta' in model_name_lower or 'xlm_roberta' in model_name_lower:
889
+ return XLMRobertaForABSA
890
+
891
+ # BERT
892
+ elif 'bert' in model_name_lower and 'roberta' not in model_name_lower:
893
+ return BERTForABSA
894
+
895
+ # BART
896
+ elif 'bart' in model_name_lower:
897
+ return BartForABSA
898
+
899
+ # T5
900
+ elif 't5' in model_name_lower or 'vit5' in model_name_lower:
901
+ return T5ForABSA
902
+
903
+ # TextCNN
904
+ elif 'textcnn' in model_name_lower or 'cnn' in model_name_lower:
905
+ return TextCNNForABSA
906
+
907
+ # BiLSTM
908
+ elif 'bilstm' in model_name_lower or 'lstm' in model_name_lower:
909
+ return BiLSTMForABSA
910
+
911
+ # Default: try Roberta
912
+ else:
913
+ return TransformerForABSA
914
+
915
+
916
+ def create_model(model_name: str, num_aspects: int, num_sentiments: int, vocab_size=None, **kwargs):
917
+ """
918
+ Create model instance dựa trên model name
919
+
920
+ Args:
921
+ model_name: Tên model hoặc model ID từ Hugging Face
922
+ num_aspects: Số lượng aspects
923
+ num_sentiments: Số lượng sentiment classes
924
+ vocab_size: Vocabulary size (chỉ cần cho TextCNN/BiLSTM)
925
+ **kwargs: Additional arguments
926
+ """
927
+ model_class = get_model_class(model_name)
928
+
929
+ # RoBERTa-GRU cần base model riêng
930
+ if model_class == RoBERTaGRUForABSA:
931
+ # Use roberta-base as base model for RoBERTa-GRU
932
+ base_model_name = 'roberta-base'
933
+ return model_class.from_pretrained(
934
+ base_model_name,
935
+ num_aspects=num_aspects,
936
+ num_sentiments=num_sentiments,
937
+ trust_remote_code=True,
938
+ **kwargs
939
+ )
940
+
941
+ # Non-transformer models
942
+ if model_class in [TextCNNForABSA, BiLSTMForABSA]:
943
+ if vocab_size is None:
944
+ raise ValueError(f"vocab_size is required for {model_class.__name__}")
945
+
946
+ if model_class == TextCNNForABSA:
947
+ return TextCNNForABSA(
948
+ vocab_size=vocab_size,
949
+ embed_dim=kwargs.get('embed_dim', 300),
950
+ num_filters=kwargs.get('num_filters', 100),
951
+ filter_sizes=kwargs.get('filter_sizes', [3, 4, 5]),
952
+ num_aspects=num_aspects,
953
+ num_sentiments=num_sentiments,
954
+ max_length=kwargs.get('max_length', 256)
955
+ )
956
+ elif model_class == BiLSTMForABSA:
957
+ return BiLSTMForABSA(
958
+ vocab_size=vocab_size,
959
+ embed_dim=kwargs.get('embed_dim', 300),
960
+ hidden_dim=kwargs.get('hidden_dim', 256),
961
+ num_layers=kwargs.get('num_layers', 2),
962
+ num_aspects=num_aspects,
963
+ num_sentiments=num_sentiments,
964
+ dropout=kwargs.get('dropout', 0.3)
965
+ )
966
+
967
+ # Transformer models
968
+ else:
969
+ # Chuyển đổi model key sang Hugging Face model ID
970
+ hf_model_id = get_hf_model_id(model_name)
971
+
972
+ return model_class.from_pretrained(
973
+ hf_model_id,
974
+ num_aspects=num_aspects,
975
+ num_sentiments=num_sentiments,
976
+ trust_remote_code=True,
977
+ **kwargs
978
+ )