Spaces:
Running
Running
File size: 5,326 Bytes
ae4e2a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
Gradient Service
================
Gradient computation using Integrated Gradients (Single Responsibility)
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from app.models.phobert_model import PhoBERTFineTuned
from app.core.config import settings
class GradientService:
"""
Gradient computation service
Responsibilities:
- Compute integrated gradients
- Calculate importance scores
"""
@staticmethod
def compute_integrated_gradients(
model: PhoBERTFineTuned,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
device: torch.device,
target_class: int | None = None,
steps: int | None = None
) -> Tuple[np.ndarray, int, float]:
"""
Compute integrated gradients
Args:
model: Model to analyze
input_ids: Input token IDs
attention_mask: Attention mask
device: Device
target_class: Target class (optional)
steps: Number of integration steps
Returns:
importance_scores: Token importance scores
predicted_class: Predicted class
confidence: Prediction confidence
"""
if steps is None:
steps = settings.GRADIENT_STEPS
model.eval()
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Get original embeddings
with torch.no_grad():
outputs = model.embedding(input_ids, attention_mask=attention_mask)
original_hidden = outputs.last_hidden_state
baseline_hidden = torch.zeros_like(original_hidden)
integrated_grads = torch.zeros_like(original_hidden)
# Integrate gradients
for step in range(steps):
alpha = (step + 1) / steps
interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden)
interpolated = interpolated.detach().clone()
interpolated.requires_grad = True
# Forward pass through classification head
if model.pooling == 'cls':
pooled = interpolated[:, 0, :]
else:
mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float()
sum_embeddings = torch.sum(interpolated * mask_expanded, 1)
sum_mask = mask_expanded.sum(1)
pooled = sum_embeddings / sum_mask
pooled = model.layer_norm(pooled)
out = model.dropout(pooled)
out = model.relu(model.fc1(out))
out = model.dropout(out)
logits = model.fc2(out)
# Get prediction on first step
if step == 0:
probs = F.softmax(logits, dim=1)
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0, predicted_class].item()
if target_class is None:
target_class = predicted_class
# Backward pass
model.zero_grad()
logits[0, target_class].backward()
integrated_grads += interpolated.grad
# Average and scale
integrated_grads = integrated_grads / steps
integrated_grads = integrated_grads * (original_hidden - baseline_hidden)
# Compute importance scores
importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1)
importance_scores = importance_scores[0].cpu().detach().numpy()
valid_length = attention_mask[0].sum().item()
importance_scores = importance_scores[:valid_length]
return importance_scores, predicted_class, confidence
@staticmethod
def normalize_scores(scores: np.ndarray) -> np.ndarray:
"""
Normalize scores to [0, 1]
Args:
scores: Raw scores
Returns:
Normalized scores
"""
min_score = scores.min()
max_score = scores.max()
if max_score - min_score < 1e-8:
return np.ones_like(scores) * 0.5
return (scores - min_score) / (max_score - min_score)
@staticmethod
def compute_threshold(
scores: np.ndarray,
is_toxic: bool,
percentile: int | None = None
) -> float:
"""
Compute threshold for toxicity
Args:
scores: Word scores
is_toxic: Whether text is toxic
percentile: Percentile for threshold
Returns:
Threshold value
"""
if percentile is None:
percentile = settings.PERCENTILE_THRESHOLD
if len(scores) == 0:
return 0.6
mean_score = np.mean(scores)
percentile_score = np.percentile(scores, percentile)
threshold = 0.6 * percentile_score + 0.4 * mean_score
if is_toxic:
threshold = max(threshold, 0.55)
else:
threshold = max(threshold, 0.75)
threshold = np.clip(threshold, 0.45, 0.90)
return float(threshold)
|