""" Module định nghĩa các mô hình cho spam review detection """ import torch import torch.nn as nn from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification from .custom_models import TextCNN, BiLSTM, RoBERTaGRU, SPhoBERT class TransformerForSpamDetection(nn.Module): """ Base transformer model cho spam review detection """ def __init__(self, model_name: str, num_labels: int): super().__init__() config = AutoConfig.from_pretrained(model_name, num_labels=num_labels) self.encoder = AutoModel.from_pretrained(model_name, config=config) self.classifier = nn.Linear(config.hidden_size, num_labels) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, attention_mask, labels=None, **kwargs): # Filter out arguments that BertModel doesn't expect filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ['num_items_in_batch', 'position_ids']} # Pass filtered arguments to encoder (including token_type_ids for BERT) out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs) pooled = out.last_hidden_state[:, 0] # CLS token pooled = self.dropout(pooled) logits = self.classifier(pooled) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits} class ViT5ForSpamDetection(nn.Module): """ ViT5 model cho spam review detection - sử dụng encoder-only approach """ def __init__(self, model_name: str, num_labels: int): super().__init__() from transformers import T5EncoderModel, T5Config # Load T5 encoder only config = T5Config.from_pretrained(model_name) self.t5_encoder = T5EncoderModel.from_pretrained(model_name, config=config) # Classification head self.classifier = nn.Linear(config.d_model, num_labels) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, attention_mask, labels=None, **kwargs): # Filter out arguments that T5EncoderModel doesn't expect filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ['num_items_in_batch', 'position_ids']} # Chỉ sử dụng encoder của T5 encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs) # Lấy pooled representation (first token) pooled = encoder_outputs.last_hidden_state[:, 0] pooled = self.dropout(pooled) logits = self.classifier(pooled) loss = None if labels is not None: loss_fn = nn.CrossEntropyLoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits} def get_model(model_name: str, num_labels: int, vocab_size: int = None): """ Factory function để tạo model dựa trên tên model Args: model_name: Tên model (phobert-v2, textcnn, bilstm, etc.) num_labels: Số lượng classes vocab_size: Kích thước vocabulary (chỉ cần cho BiLSTM-CRF) Returns: Model instance """ # Mapping từ model name đến base model model_mapping = { "phobert-v1": "vinai/phobert-base", "phobert-v2": "vinai/phobert-base-v2", "bartpho": "vinai/bartpho-syllable", "visobert": "uitnlp/visobert", "xlm-r": "xlm-roberta-large", "mbert": "bert-base-multilingual-cased", "vit5": "VietAI/vit5-base" } if model_name == "vit5": # Sử dụng ViT5ForSpamDetection cho T5 model base_model_name = model_mapping[model_name] return ViT5ForSpamDetection(base_model_name, num_labels) elif model_name in model_mapping: # Sử dụng standard transformer model base_model_name = model_mapping[model_name] return TransformerForSpamDetection(base_model_name, num_labels) elif model_name == "textcnn": # TextCNN custom model base_model_name = "vinai/phobert-base-v2" # Sử dụng PhoBERT embeddings return TextCNN(base_model_name, num_labels) elif model_name == "bilstm": # BiLSTM custom model base_model_name = "vinai/phobert-base-v2" return BiLSTM(base_model_name, num_labels) elif model_name == "roberta-gru": # RoBERTa-GRU hybrid model base_model_name = "vinai/phobert-base-v2" return RoBERTaGRU(base_model_name, num_labels) elif model_name == "sphobert": # SPhoBERT fusion model base_model_name = "vinai/phobert-base-v2" return SPhoBERT(base_model_name, num_labels) elif model_name == "bilstm-crf": # BiLSTM-CRF model (placeholder implementation) # Trong thực tế cần implement CRF layer base_model_name = "vinai/phobert-base-v2" return BiLSTM(base_model_name, num_labels) else: raise ValueError(f"Unknown model name: {model_name}. Available models: {list(model_mapping.keys()) + ['textcnn', 'bilstm', 'roberta-gru', 'sphobert', 'bilstm-crf']}") def get_model_config(model_name: str): """ Lấy cấu hình cho model Args: model_name: Tên model Returns: Dict chứa cấu hình model """ configs = { "phobert-v1": { "model_name": "vinai/phobert-base", "description": "PhoBERT v1 - Pre-trained BERT for Vietnamese", "max_length": 256, "learning_rate": 5e-5 }, "phobert-v2": { "model_name": "vinai/phobert-base-v2", "description": "PhoBERT v2 - Improved PhoBERT for Vietnamese", "max_length": 256, "learning_rate": 5e-5 }, "bartpho": { "model_name": "vinai/bartpho-syllable", "description": "BART Pho - Vietnamese BART model", "max_length": 256, "learning_rate": 5e-5 }, "visobert": { "model_name": "uitnlp/visobert", "description": "ViSoBERT - Vietnamese Social BERT", "max_length": 256, "learning_rate": 5e-5 }, "xlm-r": { "model_name": "xlm-roberta-large", "description": "XLM-RoBERTa Large - Multilingual model", "max_length": 256, "learning_rate": 3e-5 }, "mbert": { "model_name": "bert-base-multilingual-cased", "description": "mBERT - Multilingual BERT model", "max_length": 256, "learning_rate": 5e-5 }, "vit5": { "model_name": "VietAI/vit5-base", "description": "ViT5 - Vietnamese T5", "max_length": 256, "learning_rate": 5e-5 }, "textcnn": { "model_name": "vinai/phobert-base-v2", "description": "TextCNN - Convolutional Neural Network for text", "max_length": 256, "learning_rate": 1e-3, "custom_model": True }, "bilstm": { "model_name": "vinai/phobert-base-v2", "description": "BiLSTM - Bidirectional LSTM for text classification", "max_length": 256, "learning_rate": 1e-3, "custom_model": True }, "roberta-gru": { "model_name": "vinai/phobert-base-v2", "description": "RoBERTa-GRU - Hybrid RoBERTa + GRU model", "max_length": 256, "learning_rate": 5e-5, "custom_model": True }, "sphobert": { "model_name": "vinai/phobert-base-v2", "description": "SPhoBERT - PhoBERT + SentenceBERT embedding fusion", "max_length": 256, "learning_rate": 5e-5, "custom_model": True }, "bilstm-crf": { "model_name": "vinai/phobert-base-v2", "description": "BiLSTM-CRF - Bidirectional LSTM with CRF", "max_length": 256, "learning_rate": 1e-3, "custom_model": True } } if model_name not in configs: raise ValueError(f"Model {model_name} not found. Available models: {list(configs.keys())}") return configs[model_name]