toxic-api / app /models /phobert_model.py
handrix
Initial deployment - Toxic Detection API
ae4e2a6
"""
PhoBERT Model
=============
Model architecture definition (Single Responsibility)
"""
import torch
import torch.nn as nn
from typing import Tuple, Optional
class PhoBERTFineTuned(nn.Module):
"""
Fine-tuned PhoBERT model for toxic text classification
Responsibilities:
- Define model architecture
- Forward pass computation
"""
def __init__(
self,
embedding_model: nn.Module,
hidden_dim: int = 768,
dropout: float = 0.3,
num_classes: int = 2,
num_layers_to_finetune: int = 4,
pooling: str = 'mean'
):
super(PhoBERTFineTuned, self).__init__()
self.embedding = embedding_model
self.pooling = pooling
self.num_layers_to_finetune = num_layers_to_finetune
# Freeze all parameters
for param in self.embedding.parameters():
param.requires_grad = False
# Unfreeze last N layers
if num_layers_to_finetune > 0:
total_layers = len(self.embedding.encoder.layer)
layers_to_train = list(range(
total_layers - num_layers_to_finetune,
total_layers
))
for layer_idx in layers_to_train:
for param in self.embedding.encoder.layer[layer_idx].parameters():
param.requires_grad = True
if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None:
for param in self.embedding.pooler.parameters():
param.requires_grad = True
# Classification head
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(hidden_dim, 256)
self.fc2 = nn.Linear(256, num_classes)
self.relu = nn.ReLU()
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
return_embeddings: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass
Args:
input_ids: Input token IDs
attention_mask: Attention mask
return_embeddings: Whether to return embeddings
Returns:
logits: Classification logits
embeddings: Hidden states (if return_embeddings=True)
"""
# Get embeddings
outputs = self.embedding(input_ids, attention_mask=attention_mask)
embeddings = outputs.last_hidden_state
# Pooling
if self.pooling == 'cls':
pooled = embeddings[:, 0, :]
elif self.pooling == 'mean':
mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
sum_mask = mask_expanded.sum(1)
pooled = sum_embeddings / sum_mask
else:
raise ValueError(f"Unknown pooling method: {self.pooling}")
# Classification
pooled = self.layer_norm(pooled)
out = self.dropout(pooled)
out = self.relu(self.fc1(out))
out = self.dropout(out)
logits = self.fc2(out)
if return_embeddings:
return logits, embeddings
return logits, None