toxic-api / app /services /gradient_service.py
handrix
Initial deployment - Toxic Detection API
ae4e2a6
"""
Gradient Service
================
Gradient computation using Integrated Gradients (Single Responsibility)
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from app.models.phobert_model import PhoBERTFineTuned
from app.core.config import settings
class GradientService:
"""
Gradient computation service
Responsibilities:
- Compute integrated gradients
- Calculate importance scores
"""
@staticmethod
def compute_integrated_gradients(
model: PhoBERTFineTuned,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
device: torch.device,
target_class: int | None = None,
steps: int | None = None
) -> Tuple[np.ndarray, int, float]:
"""
Compute integrated gradients
Args:
model: Model to analyze
input_ids: Input token IDs
attention_mask: Attention mask
device: Device
target_class: Target class (optional)
steps: Number of integration steps
Returns:
importance_scores: Token importance scores
predicted_class: Predicted class
confidence: Prediction confidence
"""
if steps is None:
steps = settings.GRADIENT_STEPS
model.eval()
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Get original embeddings
with torch.no_grad():
outputs = model.embedding(input_ids, attention_mask=attention_mask)
original_hidden = outputs.last_hidden_state
baseline_hidden = torch.zeros_like(original_hidden)
integrated_grads = torch.zeros_like(original_hidden)
# Integrate gradients
for step in range(steps):
alpha = (step + 1) / steps
interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden)
interpolated = interpolated.detach().clone()
interpolated.requires_grad = True
# Forward pass through classification head
if model.pooling == 'cls':
pooled = interpolated[:, 0, :]
else:
mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float()
sum_embeddings = torch.sum(interpolated * mask_expanded, 1)
sum_mask = mask_expanded.sum(1)
pooled = sum_embeddings / sum_mask
pooled = model.layer_norm(pooled)
out = model.dropout(pooled)
out = model.relu(model.fc1(out))
out = model.dropout(out)
logits = model.fc2(out)
# Get prediction on first step
if step == 0:
probs = F.softmax(logits, dim=1)
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0, predicted_class].item()
if target_class is None:
target_class = predicted_class
# Backward pass
model.zero_grad()
logits[0, target_class].backward()
integrated_grads += interpolated.grad
# Average and scale
integrated_grads = integrated_grads / steps
integrated_grads = integrated_grads * (original_hidden - baseline_hidden)
# Compute importance scores
importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1)
importance_scores = importance_scores[0].cpu().detach().numpy()
valid_length = attention_mask[0].sum().item()
importance_scores = importance_scores[:valid_length]
return importance_scores, predicted_class, confidence
@staticmethod
def normalize_scores(scores: np.ndarray) -> np.ndarray:
"""
Normalize scores to [0, 1]
Args:
scores: Raw scores
Returns:
Normalized scores
"""
min_score = scores.min()
max_score = scores.max()
if max_score - min_score < 1e-8:
return np.ones_like(scores) * 0.5
return (scores - min_score) / (max_score - min_score)
@staticmethod
def compute_threshold(
scores: np.ndarray,
is_toxic: bool,
percentile: int | None = None
) -> float:
"""
Compute threshold for toxicity
Args:
scores: Word scores
is_toxic: Whether text is toxic
percentile: Percentile for threshold
Returns:
Threshold value
"""
if percentile is None:
percentile = settings.PERCENTILE_THRESHOLD
if len(scores) == 0:
return 0.6
mean_score = np.mean(scores)
percentile_score = np.percentile(scores, percentile)
threshold = 0.6 * percentile_score + 0.4 * mean_score
if is_toxic:
threshold = max(threshold, 0.55)
else:
threshold = max(threshold, 0.75)
threshold = np.clip(threshold, 0.45, 0.90)
return float(threshold)