Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| Rangoli Classification Models | |
| ============================================================ | |
| 6 architectures for comparative study in IEEE paper: | |
| 1. ResNet-50 (Baseline CNN) | |
| 2. EfficientNet-B3 (Efficient Scaling) | |
| 3. ViT-Base (Vision Transformer) | |
| 4. ConvNeXt-Small (Modern CNN) | |
| 5. MobileNetV3 (Lightweight/Mobile) | |
| 6. Swin-Base (Hierarchical Transformer) | |
| All use ImageNet-pretrained weights with custom classification heads. | |
| ============================================================ | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| class FocalLoss(nn.Module): | |
| """Focal Loss for handling class imbalance.""" | |
| def __init__(self, alpha=None, gamma=2.0, label_smoothing=0.0, reduction="mean"): | |
| super().__init__() | |
| self.alpha = alpha # class weights tensor | |
| self.gamma = gamma | |
| self.label_smoothing = label_smoothing | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| """ | |
| Args: | |
| inputs: (B, C) logits | |
| targets: (B,) class indices OR (B, C) soft labels (for mixup/cutmix) | |
| """ | |
| if targets.dim() == 1: | |
| # Standard cross-entropy with focal modulation | |
| ce_loss = F.cross_entropy( | |
| inputs, targets, weight=self.alpha, | |
| label_smoothing=self.label_smoothing, reduction="none" | |
| ) | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = ((1 - pt) ** self.gamma) * ce_loss | |
| else: | |
| # Soft labels (MixUp/CutMix) | |
| log_probs = F.log_softmax(inputs, dim=1) | |
| ce_loss = -(targets * log_probs).sum(dim=1) | |
| pt = torch.exp(-ce_loss) | |
| focal_loss = ((1 - pt) ** self.gamma) * ce_loss | |
| if self.reduction == "mean": | |
| return focal_loss.mean() | |
| elif self.reduction == "sum": | |
| return focal_loss.sum() | |
| return focal_loss | |
| class RangoliClassifier(nn.Module): | |
| """ | |
| Universal Rangoli Classifier wrapper. | |
| Supports multiple backbone architectures via timm library. | |
| """ | |
| def __init__(self, architecture, num_classes=8, pretrained=True, | |
| dropout=0.5, feature_dim=None): | |
| super().__init__() | |
| self.architecture = architecture | |
| self.num_classes = num_classes | |
| # Create backbone via timm | |
| self.backbone = timm.create_model( | |
| architecture, | |
| pretrained=pretrained, | |
| num_classes=0, # Remove original classifier | |
| ) | |
| # Get feature dimension | |
| if feature_dim is None: | |
| with torch.no_grad(): | |
| dummy = torch.randn(1, 3, 224, 224) | |
| feature_dim = self.backbone(dummy).shape[-1] | |
| self.feature_dim = feature_dim | |
| # Custom classification head | |
| self.classifier = nn.Sequential( | |
| nn.BatchNorm1d(feature_dim), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(feature_dim, 512), | |
| nn.GELU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(p=dropout * 0.5), | |
| nn.Linear(512, num_classes), | |
| ) | |
| # Initialize classifier weights | |
| self._init_classifier() | |
| # Track total params | |
| total_params = sum(p.numel() for p in self.parameters()) | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| print(f" [{architecture}] Total: {total_params:,} Trainable: {trainable_params:,}") | |
| def _init_classifier(self): | |
| for m in self.classifier.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x, return_features=False): | |
| features = self.backbone(x) | |
| logits = self.classifier(features) | |
| if return_features: | |
| return logits, features | |
| return logits | |
| def freeze_backbone(self): | |
| """Freeze backbone for initial fine-tuning.""" | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| print(f" [{self.architecture}] Backbone frozen") | |
| def unfreeze_backbone(self, unfreeze_from=0.5): | |
| """ | |
| Gradually unfreeze backbone layers. | |
| unfreeze_from: fraction of layers to keep frozen (0=unfreeze all, 0.5=unfreeze last half) | |
| """ | |
| params = list(self.backbone.parameters()) | |
| freeze_until = int(len(params) * unfreeze_from) | |
| for i, param in enumerate(params): | |
| param.requires_grad = i >= freeze_until | |
| unfrozen = sum(1 for p in self.backbone.parameters() if p.requires_grad) | |
| total = sum(1 for _ in self.backbone.parameters()) | |
| print(f" [{self.architecture}] Unfrozen {unfrozen}/{total} backbone layers") | |
| def get_layer_groups(self): | |
| """Get parameter groups with different learning rates.""" | |
| backbone_params = list(self.backbone.parameters()) | |
| n = len(backbone_params) | |
| # Split backbone into 3 groups (early, middle, late) | |
| groups = [ | |
| {"params": backbone_params[:n//3], "lr_scale": 0.01}, | |
| {"params": backbone_params[n//3:2*n//3], "lr_scale": 0.1}, | |
| {"params": backbone_params[2*n//3:], "lr_scale": 0.5}, | |
| {"params": list(self.classifier.parameters()), "lr_scale": 1.0}, | |
| ] | |
| return groups | |
| def build_model(model_name, config, num_classes=None): | |
| """Factory function to build a model from config.""" | |
| model_cfg = config["models"][model_name] | |
| if num_classes is None: | |
| num_classes = config["num_classes"] | |
| model = RangoliClassifier( | |
| architecture=model_cfg["architecture"], | |
| num_classes=num_classes, | |
| pretrained=model_cfg.get("pretrained", True), | |
| dropout=model_cfg.get("dropout", 0.5), | |
| feature_dim=model_cfg.get("feature_dim", None), | |
| ) | |
| return model | |
| def build_loss_function(config, class_weights=None, device="cpu"): | |
| """Build loss function based on config.""" | |
| training_cfg = config["training"] | |
| alpha = None | |
| if class_weights is not None and training_cfg.get("use_weighted_loss", False): | |
| alpha = torch.tensor(list(class_weights.values()), dtype=torch.float32).to(device) | |
| if training_cfg.get("use_focal_loss", False): | |
| return FocalLoss( | |
| alpha=alpha, | |
| gamma=training_cfg.get("focal_loss_gamma", 2.0), | |
| label_smoothing=training_cfg.get("label_smoothing", 0.1), | |
| ) | |
| else: | |
| return nn.CrossEntropyLoss( | |
| weight=alpha, | |
| label_smoothing=training_cfg.get("label_smoothing", 0.1), | |
| ) | |
| def get_model_summary(model, input_size=(1, 3, 224, 224)): | |
| """Get detailed model summary for paper.""" | |
| from collections import OrderedDict | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # Estimate FLOPs | |
| try: | |
| from fvcore.nn import FlopCountAnalysis | |
| flops = FlopCountAnalysis(model, torch.randn(input_size)) | |
| total_flops = flops.total() | |
| except: | |
| total_flops = "N/A (install fvcore for FLOP count)" | |
| # Model size in MB | |
| param_size = sum(p.nelement() * p.element_size() for p in model.parameters()) | |
| buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers()) | |
| model_size_mb = (param_size + buffer_size) / (1024 ** 2) | |
| summary = OrderedDict({ | |
| "Architecture": model.architecture, | |
| "Total Parameters": f"{total_params:,}", | |
| "Trainable Parameters": f"{trainable_params:,}", | |
| "Model Size (MB)": f"{model_size_mb:.2f}", | |
| "FLOPs": total_flops, | |
| "Feature Dimension": model.feature_dim, | |
| "Num Classes": model.num_classes, | |
| }) | |
| return summary | |
| # ---- Ensemble Model ---- | |
| class EnsembleModel(nn.Module): | |
| """Ensemble of multiple models for best accuracy.""" | |
| def __init__(self, models, weights=None): | |
| super().__init__() | |
| self.models = nn.ModuleList(models) | |
| if weights is None: | |
| weights = [1.0 / len(models)] * len(models) | |
| self.weights = weights | |
| def forward(self, x): | |
| outputs = [] | |
| for model, w in zip(self.models, self.weights): | |
| model.eval() | |
| with torch.no_grad(): | |
| out = F.softmax(model(x), dim=1) | |
| outputs.append(w * out) | |
| return torch.stack(outputs).sum(dim=0) | |