Spaces:
Running
Running
| """ | |
| PhoBERT Model | |
| ============= | |
| Model architecture definition (Single Responsibility) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Optional | |
| class PhoBERTFineTuned(nn.Module): | |
| """ | |
| Fine-tuned PhoBERT model for toxic text classification | |
| Responsibilities: | |
| - Define model architecture | |
| - Forward pass computation | |
| """ | |
| def __init__( | |
| self, | |
| embedding_model: nn.Module, | |
| hidden_dim: int = 768, | |
| dropout: float = 0.3, | |
| num_classes: int = 2, | |
| num_layers_to_finetune: int = 4, | |
| pooling: str = 'mean' | |
| ): | |
| super(PhoBERTFineTuned, self).__init__() | |
| self.embedding = embedding_model | |
| self.pooling = pooling | |
| self.num_layers_to_finetune = num_layers_to_finetune | |
| # Freeze all parameters | |
| for param in self.embedding.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze last N layers | |
| if num_layers_to_finetune > 0: | |
| total_layers = len(self.embedding.encoder.layer) | |
| layers_to_train = list(range( | |
| total_layers - num_layers_to_finetune, | |
| total_layers | |
| )) | |
| for layer_idx in layers_to_train: | |
| for param in self.embedding.encoder.layer[layer_idx].parameters(): | |
| param.requires_grad = True | |
| if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None: | |
| for param in self.embedding.pooler.parameters(): | |
| param.requires_grad = True | |
| # Classification head | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc1 = nn.Linear(hidden_dim, 256) | |
| self.fc2 = nn.Linear(256, num_classes) | |
| self.relu = nn.ReLU() | |
| self.layer_norm = nn.LayerNorm(hidden_dim) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| return_embeddings: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Forward pass | |
| Args: | |
| input_ids: Input token IDs | |
| attention_mask: Attention mask | |
| return_embeddings: Whether to return embeddings | |
| Returns: | |
| logits: Classification logits | |
| embeddings: Hidden states (if return_embeddings=True) | |
| """ | |
| # Get embeddings | |
| outputs = self.embedding(input_ids, attention_mask=attention_mask) | |
| embeddings = outputs.last_hidden_state | |
| # Pooling | |
| if self.pooling == 'cls': | |
| pooled = embeddings[:, 0, :] | |
| elif self.pooling == 'mean': | |
| mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() | |
| sum_embeddings = torch.sum(embeddings * mask_expanded, 1) | |
| sum_mask = mask_expanded.sum(1) | |
| pooled = sum_embeddings / sum_mask | |
| else: | |
| raise ValueError(f"Unknown pooling method: {self.pooling}") | |
| # Classification | |
| pooled = self.layer_norm(pooled) | |
| out = self.dropout(pooled) | |
| out = self.relu(self.fc1(out)) | |
| out = self.dropout(out) | |
| logits = self.fc2(out) | |
| if return_embeddings: | |
| return logits, embeddings | |
| return logits, None | |