File size: 5,326 Bytes
ae4e2a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Gradient Service
================
Gradient computation using Integrated Gradients (Single Responsibility)
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple

from app.models.phobert_model import PhoBERTFineTuned
from app.core.config import settings


class GradientService:
    """
    Gradient computation service
    
    Responsibilities:
    - Compute integrated gradients
    - Calculate importance scores
    """
    
    @staticmethod
    def compute_integrated_gradients(
        model: PhoBERTFineTuned,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        device: torch.device,
        target_class: int | None = None,
        steps: int | None = None
    ) -> Tuple[np.ndarray, int, float]:
        """
        Compute integrated gradients
        
        Args:
            model: Model to analyze
            input_ids: Input token IDs
            attention_mask: Attention mask
            device: Device
            target_class: Target class (optional)
            steps: Number of integration steps
            
        Returns:
            importance_scores: Token importance scores
            predicted_class: Predicted class
            confidence: Prediction confidence
        """
        if steps is None:
            steps = settings.GRADIENT_STEPS
        
        model.eval()
        
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        # Get original embeddings
        with torch.no_grad():
            outputs = model.embedding(input_ids, attention_mask=attention_mask)
            original_hidden = outputs.last_hidden_state
        
        baseline_hidden = torch.zeros_like(original_hidden)
        integrated_grads = torch.zeros_like(original_hidden)
        
        # Integrate gradients
        for step in range(steps):
            alpha = (step + 1) / steps
            interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden)
            interpolated = interpolated.detach().clone()
            interpolated.requires_grad = True
            
            # Forward pass through classification head
            if model.pooling == 'cls':
                pooled = interpolated[:, 0, :]
            else:
                mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float()
                sum_embeddings = torch.sum(interpolated * mask_expanded, 1)
                sum_mask = mask_expanded.sum(1)
                pooled = sum_embeddings / sum_mask
            
            pooled = model.layer_norm(pooled)
            out = model.dropout(pooled)
            out = model.relu(model.fc1(out))
            out = model.dropout(out)
            logits = model.fc2(out)
            
            # Get prediction on first step
            if step == 0:
                probs = F.softmax(logits, dim=1)
                predicted_class = torch.argmax(probs, dim=1).item()
                confidence = probs[0, predicted_class].item()
                if target_class is None:
                    target_class = predicted_class
            
            # Backward pass
            model.zero_grad()
            logits[0, target_class].backward()
            integrated_grads += interpolated.grad
        
        # Average and scale
        integrated_grads = integrated_grads / steps
        integrated_grads = integrated_grads * (original_hidden - baseline_hidden)
        
        # Compute importance scores
        importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1)
        importance_scores = importance_scores[0].cpu().detach().numpy()
        
        valid_length = attention_mask[0].sum().item()
        importance_scores = importance_scores[:valid_length]
        
        return importance_scores, predicted_class, confidence
    
    @staticmethod
    def normalize_scores(scores: np.ndarray) -> np.ndarray:
        """
        Normalize scores to [0, 1]
        
        Args:
            scores: Raw scores
            
        Returns:
            Normalized scores
        """
        min_score = scores.min()
        max_score = scores.max()
        
        if max_score - min_score < 1e-8:
            return np.ones_like(scores) * 0.5
        
        return (scores - min_score) / (max_score - min_score)
    
    @staticmethod
    def compute_threshold(
        scores: np.ndarray,
        is_toxic: bool,
        percentile: int | None = None
    ) -> float:
        """
        Compute threshold for toxicity
        
        Args:
            scores: Word scores
            is_toxic: Whether text is toxic
            percentile: Percentile for threshold
            
        Returns:
            Threshold value
        """
        if percentile is None:
            percentile = settings.PERCENTILE_THRESHOLD
        
        if len(scores) == 0:
            return 0.6
        
        mean_score = np.mean(scores)
        percentile_score = np.percentile(scores, percentile)
        threshold = 0.6 * percentile_score + 0.4 * mean_score
        
        if is_toxic:
            threshold = max(threshold, 0.55)
        else:
            threshold = max(threshold, 0.75)
        
        threshold = np.clip(threshold, 0.45, 0.90)
        
        return float(threshold)