File size: 3,359 Bytes
ae4e2a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
PhoBERT Model
=============
Model architecture definition (Single Responsibility)
"""

import torch
import torch.nn as nn
from typing import Tuple, Optional


class PhoBERTFineTuned(nn.Module):
    """
    Fine-tuned PhoBERT model for toxic text classification
    
    Responsibilities:
    - Define model architecture
    - Forward pass computation
    """
    
    def __init__(
        self,
        embedding_model: nn.Module,
        hidden_dim: int = 768,
        dropout: float = 0.3,
        num_classes: int = 2,
        num_layers_to_finetune: int = 4,
        pooling: str = 'mean'
    ):
        super(PhoBERTFineTuned, self).__init__()
        
        self.embedding = embedding_model
        self.pooling = pooling
        self.num_layers_to_finetune = num_layers_to_finetune
        
        # Freeze all parameters
        for param in self.embedding.parameters():
            param.requires_grad = False
        
        # Unfreeze last N layers
        if num_layers_to_finetune > 0:
            total_layers = len(self.embedding.encoder.layer)
            layers_to_train = list(range(
                total_layers - num_layers_to_finetune, 
                total_layers
            ))
            
            for layer_idx in layers_to_train:
                for param in self.embedding.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True
            
            if hasattr(self.embedding, 'pooler') and self.embedding.pooler is not None:
                for param in self.embedding.pooler.parameters():
                    param.requires_grad = True
        
        # Classification head
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        return_embeddings: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass
        
        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask
            return_embeddings: Whether to return embeddings
            
        Returns:
            logits: Classification logits
            embeddings: Hidden states (if return_embeddings=True)
        """
        # Get embeddings
        outputs = self.embedding(input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state
        
        # Pooling
        if self.pooling == 'cls':
            pooled = embeddings[:, 0, :]
        elif self.pooling == 'mean':
            mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
            sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
            sum_mask = mask_expanded.sum(1)
            pooled = sum_embeddings / sum_mask
        else:
            raise ValueError(f"Unknown pooling method: {self.pooling}")
        
        # Classification
        pooled = self.layer_norm(pooled)
        out = self.dropout(pooled)
        out = self.relu(self.fc1(out))
        out = self.dropout(out)
        logits = self.fc2(out)
        
        if return_embeddings:
            return logits, embeddings
        return logits, None