Spaces:
Sleeping
Sleeping
| """ | |
| CIFAR-10 ViT-B/16 Model Handler | |
| Handles prediction, Grad-CAM visualization, and calibration | |
| for the ViT-B/16 model trained on CIFAR-10. | |
| """ | |
| import os | |
| import types | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Dict, List, Optional, Any | |
| import torchvision.transforms as transforms | |
| from torchvision.models import vit_b_16 | |
| from app.shared.model_registry import ( | |
| BaseModelHandler, | |
| PredictionResult, | |
| CalibrationResult, | |
| ) | |
| from app.shared.artifact_utils import ( | |
| get_best_accuracy_from_history, | |
| load_precomputed_calibration_result, | |
| ) | |
| from app.image.data import create_cifar10_test_dataset | |
| # CIFAR-10 class labels | |
| CIFAR10_LABELS = [ | |
| 'airplane', 'automobile', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck' | |
| ] | |
| # CIFAR-10 normalization values | |
| CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) | |
| CIFAR10_STD = (0.2470, 0.2435, 0.2616) | |
| # Image size ViT expects | |
| IMAGE_SIZE = 224 | |
| def create_vit_model(num_classes=10): | |
| """Create ViT-B/16 with modified classifier for CIFAR-10.""" | |
| model = vit_b_16(weights=None) | |
| # Replace classifier head | |
| model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) | |
| return model | |
| class ViTAttentionVisualizer: | |
| """ | |
| Attention visualization for ViT. | |
| Shows which patches the model attends to. | |
| """ | |
| def __init__(self, model): | |
| self.model = model | |
| self.attentions = None | |
| self._patch_last_encoder_block() | |
| def _patch_last_encoder_block(self): | |
| """ | |
| Torchvision's ViT encoder block calls MultiheadAttention with | |
| need_weights=False, so a normal forward hook never receives attention | |
| maps. We patch only the last block to request weights during inference. | |
| """ | |
| last_block = self.model.encoder.layers[-1] | |
| visualizer = self | |
| def forward_with_attention(block, input_tensor): | |
| torch._assert( | |
| input_tensor.dim() == 3, | |
| f"Expected (batch_size, seq_length, hidden_dim) got {input_tensor.shape}", | |
| ) | |
| x = block.ln_1(input_tensor) | |
| attn_output, attn_weights = block.self_attention( | |
| x, | |
| x, | |
| x, | |
| need_weights=True, | |
| average_attn_weights=False, | |
| ) | |
| visualizer.attentions = attn_weights.detach() | |
| x = block.dropout(attn_output) | |
| x = x + input_tensor | |
| y = block.ln_2(x) | |
| y = block.mlp(y) | |
| return x + y | |
| last_block.forward = types.MethodType(forward_with_attention, last_block) | |
| def generate_attention_map(self, input_tensor): | |
| """Generate attention map from input tensor.""" | |
| self.model.eval() | |
| # Forward pass | |
| with torch.no_grad(): | |
| _ = self.model(input_tensor) | |
| if self.attentions is None: | |
| return None | |
| # Get the [CLS] token attention across all heads | |
| # Shape: (batch, heads, seq_len, seq_len) -> take cls token row | |
| cls_attention = self.attentions[0, :, 0, 1:].mean(dim=0) # Average over heads | |
| # Reshape to patch grid (assuming 16x16 patches for 224x224 image) | |
| num_patches = int(cls_attention.shape[0] ** 0.5) | |
| if num_patches * num_patches != cls_attention.shape[0]: | |
| # Fallback: just return raw attention | |
| return cls_attention.cpu().numpy() | |
| # Reshape to 2D grid | |
| attention_map = cls_attention.reshape(num_patches, num_patches).cpu().numpy() | |
| # Normalize | |
| attention_map = attention_map - attention_map.min() | |
| if attention_map.max() > 0: | |
| attention_map = attention_map / attention_map.max() | |
| return attention_map | |
| def create_attention_overlay(image_np, attention_map, alpha=0.5): | |
| """Create overlay of attention map on original image.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| if attention_map is None: | |
| return image_np | |
| # Resize attention map to image size | |
| from PIL import Image as PILImage | |
| attention_uint8 = (attention_map * 255).astype(np.uint8) | |
| attention_resized = PILImage.fromarray(attention_uint8).resize( | |
| (IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR | |
| ) | |
| attention_resized = np.array(attention_resized).astype(np.float32) / 255.0 | |
| if image_np.shape[:2] != (IMAGE_SIZE, IMAGE_SIZE): | |
| image_np = np.array( | |
| PILImage.fromarray(image_np).resize((IMAGE_SIZE, IMAGE_SIZE), PILImage.BILINEAR) | |
| ) | |
| # Apply colormap | |
| colormap = cm.jet(attention_resized)[:, :, :3] | |
| colormap = (colormap * 255).astype(np.uint8) | |
| # Create overlay | |
| overlay = (alpha * colormap + (1 - alpha) * image_np).astype(np.uint8) | |
| # Create figure | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| fig.patch.set_facecolor('#0d1117') | |
| titles = ['Original Image', 'Attention Map', 'Overlay'] | |
| images = [image_np, colormap, overlay] | |
| for ax, img, title in zip(axes, images, titles): | |
| ax.imshow(img) | |
| ax.set_title(title, color='white', fontsize=14, fontweight='bold', pad=10) | |
| ax.axis('off') | |
| ax.set_facecolor('#0d1117') | |
| plt.tight_layout(pad=2) | |
| fig.canvas.draw() | |
| rgba_buffer = fig.canvas.buffer_rgba() | |
| result = np.array(rgba_buffer)[:, :, :3] | |
| plt.close(fig) | |
| return result | |
| class Cifar10ViTHandler(BaseModelHandler): | |
| """Model handler for CIFAR-10 ViT-B/16.""" | |
| def __init__(self, model_path: str): | |
| self.model_path = model_path | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self.attention_viz = None | |
| self.history = {} | |
| self.best_accuracy = None | |
| self._calibration_cache = {} | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD), | |
| ]) | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the trained model.""" | |
| self.model = create_vit_model(num_classes=10) | |
| if os.path.exists(self.model_path): | |
| checkpoint = torch.load(self.model_path, map_location=self.device, | |
| weights_only=True) | |
| if isinstance(checkpoint, dict): | |
| self.history = checkpoint.get('history', {}) or {} | |
| self.best_accuracy = get_best_accuracy_from_history(self.history) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Initialize attention visualizer | |
| self.attention_viz = ViTAttentionVisualizer(self.model) | |
| precomputed_full = load_precomputed_calibration_result("vit_b16") | |
| if precomputed_full is not None: | |
| self._calibration_cache["full"] = precomputed_full | |
| def get_model_name(self) -> str: | |
| return "ViT-B/16" | |
| def get_dataset_name(self) -> str: | |
| return "CIFAR-10" | |
| def get_data_type(self) -> str: | |
| return "image" | |
| def get_class_labels(self) -> List[str]: | |
| return CIFAR10_LABELS | |
| def get_model_info(self) -> Dict[str, str]: | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| best_accuracy = ( | |
| f"{self.best_accuracy:.2f}%" | |
| if self.best_accuracy is not None | |
| else "N/A" | |
| ) | |
| info = { | |
| "Architecture": "ViT-B/16 (Transfer Learning from ImageNet)", | |
| "Dataset": "CIFAR-10 (10 classes, 60,000 images)", | |
| "Parameters": f"{total_params:,}", | |
| "Input Size": f"{IMAGE_SIZE}×{IMAGE_SIZE}×3", | |
| "Training": "Full fine-tune, AdamW, Cosine Annealing LR", | |
| "Best Accuracy": best_accuracy, | |
| "Device": str(self.device), | |
| } | |
| if self.history: | |
| info["Epochs"] = str(len(self.history.get("val_acc", []))) | |
| full_result = self._calibration_cache.get("full") | |
| if full_result is not None: | |
| info["Full-Test ECE"] = f"{full_result.ece:.6f}" | |
| return info | |
| def predict(self, input_data) -> PredictionResult: | |
| """Run prediction with attention visualization.""" | |
| if input_data is None: | |
| raise ValueError("No input image provided") | |
| # Convert to PIL Image if numpy array | |
| if isinstance(input_data, np.ndarray): | |
| original_image = input_data.copy() | |
| pil_image = Image.fromarray(input_data).convert('RGB') | |
| else: | |
| pil_image = input_data.convert('RGB') | |
| original_image = np.array(pil_image) | |
| # Preprocess | |
| input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| output = self.model(input_tensor) | |
| probabilities = torch.softmax(output, dim=1)[0] | |
| probs = probabilities.cpu().numpy() | |
| pred_idx = probs.argmax() | |
| pred_label = CIFAR10_LABELS[pred_idx] | |
| pred_conf = float(probs[pred_idx]) | |
| # Generate attention visualization | |
| attention_map = self.attention_viz.generate_attention_map(input_tensor) | |
| explanation_image = create_attention_overlay(original_image, attention_map) | |
| return PredictionResult( | |
| label=pred_label, | |
| confidence=pred_conf, | |
| all_labels=CIFAR10_LABELS, | |
| all_confidences=probs.tolist(), | |
| explanation_image=explanation_image, | |
| ) | |
| def get_example_inputs(self) -> List[Any]: | |
| return [] | |
| def get_calibration_data( | |
| self, max_samples: Optional[int] = None | |
| ) -> Optional[CalibrationResult]: | |
| """Compute calibration metrics on test set.""" | |
| cache_key = "full" if max_samples is None else f"subset:{max_samples}" | |
| if cache_key in self._calibration_cache: | |
| return self._calibration_cache[cache_key] | |
| try: | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| test_dataset = create_cifar10_test_dataset(transform=self.transform) | |
| if max_samples is not None and 0 < max_samples < len(test_dataset): | |
| indices = np.linspace( | |
| 0, len(test_dataset) - 1, num=max_samples, dtype=int | |
| ).tolist() | |
| test_dataset = torch.utils.data.Subset(test_dataset, indices) | |
| test_loader = torch.utils.data.DataLoader( | |
| test_dataset, batch_size=128, shuffle=False, num_workers=0 | |
| ) | |
| all_probs = [] | |
| all_preds = [] | |
| all_targets = [] | |
| self.model.eval() | |
| with torch.inference_mode(): | |
| for inputs, targets in test_loader: | |
| inputs = inputs.to(self.device) | |
| outputs = self.model(inputs) | |
| probs = torch.softmax(outputs, dim=1) | |
| preds = outputs.argmax(1) | |
| all_probs.extend(probs.cpu().numpy()) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_targets.extend(targets.numpy()) | |
| all_probs = np.array(all_probs) | |
| all_preds = np.array(all_preds) | |
| all_targets = np.array(all_targets) | |
| # Compute ECE | |
| n_bins = 10 | |
| max_probs = np.max(all_probs, axis=1) | |
| correctness = (all_preds == all_targets).astype(float) | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_accuracies = [] | |
| bin_confidences = [] | |
| bin_counts = [] | |
| for i in range(n_bins): | |
| lower = bin_boundaries[i] | |
| upper = bin_boundaries[i + 1] | |
| mask = (max_probs > lower) & (max_probs <= upper) | |
| count = mask.sum() | |
| bin_counts.append(int(count)) | |
| if count > 0: | |
| bin_acc = correctness[mask].mean() | |
| bin_conf = max_probs[mask].mean() | |
| else: | |
| bin_acc = 0.0 | |
| bin_conf = 0.0 | |
| bin_accuracies.append(float(bin_acc)) | |
| bin_confidences.append(float(bin_conf)) | |
| # Compute ECE | |
| total = len(all_preds) | |
| ece = sum( | |
| (count / total) * abs(acc - conf) | |
| for count, acc, conf in zip(bin_counts, bin_accuracies, bin_confidences) | |
| ) | |
| # Create reliability diagram | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) | |
| fig.patch.set_facecolor('#0d1117') | |
| # Reliability Diagram | |
| ax1.set_facecolor('#161b22') | |
| bin_centers = [(bin_boundaries[i] + bin_boundaries[i + 1]) / 2 for i in range(n_bins)] | |
| width = 0.08 | |
| ax1.bar([c - width/2 for c in bin_centers], bin_accuracies, width, | |
| label='Accuracy', color='#58a6ff', alpha=0.9, edgecolor='#58a6ff') | |
| ax1.bar([c + width/2 for c in bin_centers], bin_confidences, width, | |
| label='Avg Confidence', color='#f97583', alpha=0.9, edgecolor='#f97583') | |
| ax1.plot([0, 1], [0, 1], '--', color='#8b949e', linewidth=2, | |
| label='Perfect Calibration') | |
| ax1.set_xlim(0, 1) | |
| ax1.set_ylim(0, 1) | |
| ax1.set_xlabel('Confidence', color='white', fontsize=12) | |
| ax1.set_ylabel('Accuracy / Confidence', color='white', fontsize=12) | |
| ax1.set_title(f'Reliability Diagram (ECE: {ece:.4f})', | |
| color='white', fontsize=14, fontweight='bold', pad=15) | |
| ax1.legend(facecolor='#161b22', edgecolor='#30363d', labelcolor='white', fontsize=10) | |
| ax1.tick_params(colors='white') | |
| for spine in ax1.spines.values(): | |
| spine.set_edgecolor('#30363d') | |
| ax1.grid(True, alpha=0.1, color='white') | |
| # Confidence histogram | |
| ax2.set_facecolor('#161b22') | |
| ax2.bar(bin_centers, [c / total for c in bin_counts], 0.08, | |
| color='#56d364', alpha=0.9, edgecolor='#56d364') | |
| ax2.set_xlim(0, 1) | |
| ax2.set_xlabel('Confidence', color='white', fontsize=12) | |
| ax2.set_ylabel('Fraction of Samples', color='white', fontsize=12) | |
| ax2.set_title('Confidence Distribution', | |
| color='white', fontsize=14, fontweight='bold', pad=15) | |
| ax2.tick_params(colors='white') | |
| for spine in ax2.spines.values(): | |
| spine.set_edgecolor('#30363d') | |
| ax2.grid(True, alpha=0.1, color='white') | |
| plt.tight_layout(pad=3) | |
| fig.canvas.draw() | |
| rgba_buffer = fig.canvas.buffer_rgba() | |
| diagram = np.array(rgba_buffer)[:, :, :3] | |
| plt.close(fig) | |
| self._calibration_cache[cache_key] = CalibrationResult( | |
| ece=ece, | |
| bin_accuracies=bin_accuracies, | |
| bin_confidences=bin_confidences, | |
| bin_counts=bin_counts, | |
| reliability_diagram=diagram, | |
| source="Live computation", | |
| ) | |
| return self._calibration_cache[cache_key] | |
| except Exception as e: | |
| print(f"Error computing calibration: {e}") | |
| return None | |