""" Gradient Service ================ Gradient computation using Integrated Gradients (Single Responsibility) """ import torch import torch.nn.functional as F import numpy as np from typing import Tuple from app.models.phobert_model import PhoBERTFineTuned from app.core.config import settings class GradientService: """ Gradient computation service Responsibilities: - Compute integrated gradients - Calculate importance scores """ @staticmethod def compute_integrated_gradients( model: PhoBERTFineTuned, input_ids: torch.Tensor, attention_mask: torch.Tensor, device: torch.device, target_class: int | None = None, steps: int | None = None ) -> Tuple[np.ndarray, int, float]: """ Compute integrated gradients Args: model: Model to analyze input_ids: Input token IDs attention_mask: Attention mask device: Device target_class: Target class (optional) steps: Number of integration steps Returns: importance_scores: Token importance scores predicted_class: Predicted class confidence: Prediction confidence """ if steps is None: steps = settings.GRADIENT_STEPS model.eval() input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) # Get original embeddings with torch.no_grad(): outputs = model.embedding(input_ids, attention_mask=attention_mask) original_hidden = outputs.last_hidden_state baseline_hidden = torch.zeros_like(original_hidden) integrated_grads = torch.zeros_like(original_hidden) # Integrate gradients for step in range(steps): alpha = (step + 1) / steps interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden) interpolated = interpolated.detach().clone() interpolated.requires_grad = True # Forward pass through classification head if model.pooling == 'cls': pooled = interpolated[:, 0, :] else: mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float() sum_embeddings = torch.sum(interpolated * mask_expanded, 1) sum_mask = mask_expanded.sum(1) pooled = sum_embeddings / sum_mask pooled = model.layer_norm(pooled) out = model.dropout(pooled) out = model.relu(model.fc1(out)) out = model.dropout(out) logits = model.fc2(out) # Get prediction on first step if step == 0: probs = F.softmax(logits, dim=1) predicted_class = torch.argmax(probs, dim=1).item() confidence = probs[0, predicted_class].item() if target_class is None: target_class = predicted_class # Backward pass model.zero_grad() logits[0, target_class].backward() integrated_grads += interpolated.grad # Average and scale integrated_grads = integrated_grads / steps integrated_grads = integrated_grads * (original_hidden - baseline_hidden) # Compute importance scores importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1) importance_scores = importance_scores[0].cpu().detach().numpy() valid_length = attention_mask[0].sum().item() importance_scores = importance_scores[:valid_length] return importance_scores, predicted_class, confidence @staticmethod def normalize_scores(scores: np.ndarray) -> np.ndarray: """ Normalize scores to [0, 1] Args: scores: Raw scores Returns: Normalized scores """ min_score = scores.min() max_score = scores.max() if max_score - min_score < 1e-8: return np.ones_like(scores) * 0.5 return (scores - min_score) / (max_score - min_score) @staticmethod def compute_threshold( scores: np.ndarray, is_toxic: bool, percentile: int | None = None ) -> float: """ Compute threshold for toxicity Args: scores: Word scores is_toxic: Whether text is toxic percentile: Percentile for threshold Returns: Threshold value """ if percentile is None: percentile = settings.PERCENTILE_THRESHOLD if len(scores) == 0: return 0.6 mean_score = np.mean(scores) percentile_score = np.percentile(scores, percentile) threshold = 0.6 * percentile_score + 0.4 * mean_score if is_toxic: threshold = max(threshold, 0.55) else: threshold = max(threshold, 0.75) threshold = np.clip(threshold, 0.45, 0.90) return float(threshold)