Spaces:
Sleeping
Sleeping
| """ | |
| Gradient Service | |
| ================ | |
| Gradient computation using Integrated Gradients (Single Responsibility) | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import Tuple | |
| from app.models.phobert_model import PhoBERTFineTuned | |
| from app.core.config import settings | |
| class GradientService: | |
| """ | |
| Gradient computation service | |
| Responsibilities: | |
| - Compute integrated gradients | |
| - Calculate importance scores | |
| """ | |
| def compute_integrated_gradients( | |
| model: PhoBERTFineTuned, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| device: torch.device, | |
| target_class: int | None = None, | |
| steps: int | None = None | |
| ) -> Tuple[np.ndarray, int, float]: | |
| """ | |
| Compute integrated gradients | |
| Args: | |
| model: Model to analyze | |
| input_ids: Input token IDs | |
| attention_mask: Attention mask | |
| device: Device | |
| target_class: Target class (optional) | |
| steps: Number of integration steps | |
| Returns: | |
| importance_scores: Token importance scores | |
| predicted_class: Predicted class | |
| confidence: Prediction confidence | |
| """ | |
| if steps is None: | |
| steps = settings.GRADIENT_STEPS | |
| model.eval() | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| # Get original embeddings | |
| with torch.no_grad(): | |
| outputs = model.embedding(input_ids, attention_mask=attention_mask) | |
| original_hidden = outputs.last_hidden_state | |
| baseline_hidden = torch.zeros_like(original_hidden) | |
| integrated_grads = torch.zeros_like(original_hidden) | |
| # Integrate gradients | |
| for step in range(steps): | |
| alpha = (step + 1) / steps | |
| interpolated = baseline_hidden + alpha * (original_hidden - baseline_hidden) | |
| interpolated = interpolated.detach().clone() | |
| interpolated.requires_grad = True | |
| # Forward pass through classification head | |
| if model.pooling == 'cls': | |
| pooled = interpolated[:, 0, :] | |
| else: | |
| mask_expanded = attention_mask.unsqueeze(-1).expand(interpolated.size()).float() | |
| sum_embeddings = torch.sum(interpolated * mask_expanded, 1) | |
| sum_mask = mask_expanded.sum(1) | |
| pooled = sum_embeddings / sum_mask | |
| pooled = model.layer_norm(pooled) | |
| out = model.dropout(pooled) | |
| out = model.relu(model.fc1(out)) | |
| out = model.dropout(out) | |
| logits = model.fc2(out) | |
| # Get prediction on first step | |
| if step == 0: | |
| probs = F.softmax(logits, dim=1) | |
| predicted_class = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0, predicted_class].item() | |
| if target_class is None: | |
| target_class = predicted_class | |
| # Backward pass | |
| model.zero_grad() | |
| logits[0, target_class].backward() | |
| integrated_grads += interpolated.grad | |
| # Average and scale | |
| integrated_grads = integrated_grads / steps | |
| integrated_grads = integrated_grads * (original_hidden - baseline_hidden) | |
| # Compute importance scores | |
| importance_scores = torch.sum(torch.abs(integrated_grads), dim=-1) | |
| importance_scores = importance_scores[0].cpu().detach().numpy() | |
| valid_length = attention_mask[0].sum().item() | |
| importance_scores = importance_scores[:valid_length] | |
| return importance_scores, predicted_class, confidence | |
| def normalize_scores(scores: np.ndarray) -> np.ndarray: | |
| """ | |
| Normalize scores to [0, 1] | |
| Args: | |
| scores: Raw scores | |
| Returns: | |
| Normalized scores | |
| """ | |
| min_score = scores.min() | |
| max_score = scores.max() | |
| if max_score - min_score < 1e-8: | |
| return np.ones_like(scores) * 0.5 | |
| return (scores - min_score) / (max_score - min_score) | |
| def compute_threshold( | |
| scores: np.ndarray, | |
| is_toxic: bool, | |
| percentile: int | None = None | |
| ) -> float: | |
| """ | |
| Compute threshold for toxicity | |
| Args: | |
| scores: Word scores | |
| is_toxic: Whether text is toxic | |
| percentile: Percentile for threshold | |
| Returns: | |
| Threshold value | |
| """ | |
| if percentile is None: | |
| percentile = settings.PERCENTILE_THRESHOLD | |
| if len(scores) == 0: | |
| return 0.6 | |
| mean_score = np.mean(scores) | |
| percentile_score = np.percentile(scores, percentile) | |
| threshold = 0.6 * percentile_score + 0.4 * mean_score | |
| if is_toxic: | |
| threshold = max(threshold, 0.55) | |
| else: | |
| threshold = max(threshold, 0.75) | |
| threshold = np.clip(threshold, 0.45, 0.90) | |
| return float(threshold) | |