bartpho-spam-binary / models.py
AnnyNguyen's picture
Upload models.py with huggingface_hub
5f742b3 verified
"""
Module định nghĩa các mô hình cho spam review detection
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification
from .custom_models import TextCNN, BiLSTM, RoBERTaGRU, SPhoBERT
class TransformerForSpamDetection(nn.Module):
"""
Base transformer model cho spam review detection
"""
def __init__(self, model_name: str, num_labels: int):
super().__init__()
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
self.encoder = AutoModel.from_pretrained(model_name, config=config)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
# Filter out arguments that BertModel doesn't expect
filtered_kwargs = {k: v for k, v in kwargs.items()
if k not in ['num_items_in_batch', 'position_ids']}
# Pass filtered arguments to encoder (including token_type_ids for BERT)
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
pooled = out.last_hidden_state[:, 0] # CLS token
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
class ViT5ForSpamDetection(nn.Module):
"""
ViT5 model cho spam review detection - sử dụng encoder-only approach
"""
def __init__(self, model_name: str, num_labels: int):
super().__init__()
from transformers import T5EncoderModel, T5Config
# Load T5 encoder only
config = T5Config.from_pretrained(model_name)
self.t5_encoder = T5EncoderModel.from_pretrained(model_name, config=config)
# Classification head
self.classifier = nn.Linear(config.d_model, num_labels)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
# Filter out arguments that T5EncoderModel doesn't expect
filtered_kwargs = {k: v for k, v in kwargs.items()
if k not in ['num_items_in_batch', 'position_ids']}
# Chỉ sử dụng encoder của T5
encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
# Lấy pooled representation (first token)
pooled = encoder_outputs.last_hidden_state[:, 0]
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
def get_model(model_name: str, num_labels: int, vocab_size: int = None):
"""
Factory function để tạo model dựa trên tên model
Args:
model_name: Tên model (phobert-v2, textcnn, bilstm, etc.)
num_labels: Số lượng classes
vocab_size: Kích thước vocabulary (chỉ cần cho BiLSTM-CRF)
Returns:
Model instance
"""
# Mapping từ model name đến base model
model_mapping = {
"phobert-v1": "vinai/phobert-base",
"phobert-v2": "vinai/phobert-base-v2",
"bartpho": "vinai/bartpho-syllable",
"visobert": "uitnlp/visobert",
"xlm-r": "xlm-roberta-large",
"mbert": "bert-base-multilingual-cased",
"vit5": "VietAI/vit5-base"
}
if model_name == "vit5":
# Sử dụng ViT5ForSpamDetection cho T5 model
base_model_name = model_mapping[model_name]
return ViT5ForSpamDetection(base_model_name, num_labels)
elif model_name in model_mapping:
# Sử dụng standard transformer model
base_model_name = model_mapping[model_name]
return TransformerForSpamDetection(base_model_name, num_labels)
elif model_name == "textcnn":
# TextCNN custom model
base_model_name = "vinai/phobert-base-v2" # Sử dụng PhoBERT embeddings
return TextCNN(base_model_name, num_labels)
elif model_name == "bilstm":
# BiLSTM custom model
base_model_name = "vinai/phobert-base-v2"
return BiLSTM(base_model_name, num_labels)
elif model_name == "roberta-gru":
# RoBERTa-GRU hybrid model
base_model_name = "vinai/phobert-base-v2"
return RoBERTaGRU(base_model_name, num_labels)
elif model_name == "sphobert":
# SPhoBERT fusion model
base_model_name = "vinai/phobert-base-v2"
return SPhoBERT(base_model_name, num_labels)
elif model_name == "bilstm-crf":
# BiLSTM-CRF model (placeholder implementation)
# Trong thực tế cần implement CRF layer
base_model_name = "vinai/phobert-base-v2"
return BiLSTM(base_model_name, num_labels)
else:
raise ValueError(f"Unknown model name: {model_name}. Available models: {list(model_mapping.keys()) + ['textcnn', 'bilstm', 'roberta-gru', 'sphobert', 'bilstm-crf']}")
def get_model_config(model_name: str):
"""
Lấy cấu hình cho model
Args:
model_name: Tên model
Returns:
Dict chứa cấu hình model
"""
configs = {
"phobert-v1": {
"model_name": "vinai/phobert-base",
"description": "PhoBERT v1 - Pre-trained BERT for Vietnamese",
"max_length": 256,
"learning_rate": 5e-5
},
"phobert-v2": {
"model_name": "vinai/phobert-base-v2",
"description": "PhoBERT v2 - Improved PhoBERT for Vietnamese",
"max_length": 256,
"learning_rate": 5e-5
},
"bartpho": {
"model_name": "vinai/bartpho-syllable",
"description": "BART Pho - Vietnamese BART model",
"max_length": 256,
"learning_rate": 5e-5
},
"visobert": {
"model_name": "uitnlp/visobert",
"description": "ViSoBERT - Vietnamese Social BERT",
"max_length": 256,
"learning_rate": 5e-5
},
"xlm-r": {
"model_name": "xlm-roberta-large",
"description": "XLM-RoBERTa Large - Multilingual model",
"max_length": 256,
"learning_rate": 3e-5
},
"mbert": {
"model_name": "bert-base-multilingual-cased",
"description": "mBERT - Multilingual BERT model",
"max_length": 256,
"learning_rate": 5e-5
},
"vit5": {
"model_name": "VietAI/vit5-base",
"description": "ViT5 - Vietnamese T5",
"max_length": 256,
"learning_rate": 5e-5
},
"textcnn": {
"model_name": "vinai/phobert-base-v2",
"description": "TextCNN - Convolutional Neural Network for text",
"max_length": 256,
"learning_rate": 1e-3,
"custom_model": True
},
"bilstm": {
"model_name": "vinai/phobert-base-v2",
"description": "BiLSTM - Bidirectional LSTM for text classification",
"max_length": 256,
"learning_rate": 1e-3,
"custom_model": True
},
"roberta-gru": {
"model_name": "vinai/phobert-base-v2",
"description": "RoBERTa-GRU - Hybrid RoBERTa + GRU model",
"max_length": 256,
"learning_rate": 5e-5,
"custom_model": True
},
"sphobert": {
"model_name": "vinai/phobert-base-v2",
"description": "SPhoBERT - PhoBERT + SentenceBERT embedding fusion",
"max_length": 256,
"learning_rate": 5e-5,
"custom_model": True
},
"bilstm-crf": {
"model_name": "vinai/phobert-base-v2",
"description": "BiLSTM-CRF - Bidirectional LSTM with CRF",
"max_length": 256,
"learning_rate": 1e-3,
"custom_model": True
}
}
if model_name not in configs:
raise ValueError(f"Model {model_name} not found. Available models: {list(configs.keys())}")
return configs[model_name]