rangoli-classifier / models /classifier.py
shashidharak99's picture
Upload 16 files
0b3dd07 verified
"""
============================================================
Rangoli Classification Models
============================================================
6 architectures for comparative study in IEEE paper:
1. ResNet-50 (Baseline CNN)
2. EfficientNet-B3 (Efficient Scaling)
3. ViT-Base (Vision Transformer)
4. ConvNeXt-Small (Modern CNN)
5. MobileNetV3 (Lightweight/Mobile)
6. Swin-Base (Hierarchical Transformer)
All use ImageNet-pretrained weights with custom classification heads.
============================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance."""
def __init__(self, alpha=None, gamma=2.0, label_smoothing=0.0, reduction="mean"):
super().__init__()
self.alpha = alpha # class weights tensor
self.gamma = gamma
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, inputs, targets):
"""
Args:
inputs: (B, C) logits
targets: (B,) class indices OR (B, C) soft labels (for mixup/cutmix)
"""
if targets.dim() == 1:
# Standard cross-entropy with focal modulation
ce_loss = F.cross_entropy(
inputs, targets, weight=self.alpha,
label_smoothing=self.label_smoothing, reduction="none"
)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
else:
# Soft labels (MixUp/CutMix)
log_probs = F.log_softmax(inputs, dim=1)
ce_loss = -(targets * log_probs).sum(dim=1)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
if self.reduction == "mean":
return focal_loss.mean()
elif self.reduction == "sum":
return focal_loss.sum()
return focal_loss
class RangoliClassifier(nn.Module):
"""
Universal Rangoli Classifier wrapper.
Supports multiple backbone architectures via timm library.
"""
def __init__(self, architecture, num_classes=8, pretrained=True,
dropout=0.5, feature_dim=None):
super().__init__()
self.architecture = architecture
self.num_classes = num_classes
# Create backbone via timm
self.backbone = timm.create_model(
architecture,
pretrained=pretrained,
num_classes=0, # Remove original classifier
)
# Get feature dimension
if feature_dim is None:
with torch.no_grad():
dummy = torch.randn(1, 3, 224, 224)
feature_dim = self.backbone(dummy).shape[-1]
self.feature_dim = feature_dim
# Custom classification head
self.classifier = nn.Sequential(
nn.BatchNorm1d(feature_dim),
nn.Dropout(p=dropout),
nn.Linear(feature_dim, 512),
nn.GELU(),
nn.BatchNorm1d(512),
nn.Dropout(p=dropout * 0.5),
nn.Linear(512, num_classes),
)
# Initialize classifier weights
self._init_classifier()
# Track total params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f" [{architecture}] Total: {total_params:,} Trainable: {trainable_params:,}")
def _init_classifier(self):
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x, return_features=False):
features = self.backbone(x)
logits = self.classifier(features)
if return_features:
return logits, features
return logits
def freeze_backbone(self):
"""Freeze backbone for initial fine-tuning."""
for param in self.backbone.parameters():
param.requires_grad = False
print(f" [{self.architecture}] Backbone frozen")
def unfreeze_backbone(self, unfreeze_from=0.5):
"""
Gradually unfreeze backbone layers.
unfreeze_from: fraction of layers to keep frozen (0=unfreeze all, 0.5=unfreeze last half)
"""
params = list(self.backbone.parameters())
freeze_until = int(len(params) * unfreeze_from)
for i, param in enumerate(params):
param.requires_grad = i >= freeze_until
unfrozen = sum(1 for p in self.backbone.parameters() if p.requires_grad)
total = sum(1 for _ in self.backbone.parameters())
print(f" [{self.architecture}] Unfrozen {unfrozen}/{total} backbone layers")
def get_layer_groups(self):
"""Get parameter groups with different learning rates."""
backbone_params = list(self.backbone.parameters())
n = len(backbone_params)
# Split backbone into 3 groups (early, middle, late)
groups = [
{"params": backbone_params[:n//3], "lr_scale": 0.01},
{"params": backbone_params[n//3:2*n//3], "lr_scale": 0.1},
{"params": backbone_params[2*n//3:], "lr_scale": 0.5},
{"params": list(self.classifier.parameters()), "lr_scale": 1.0},
]
return groups
def build_model(model_name, config, num_classes=None):
"""Factory function to build a model from config."""
model_cfg = config["models"][model_name]
if num_classes is None:
num_classes = config["num_classes"]
model = RangoliClassifier(
architecture=model_cfg["architecture"],
num_classes=num_classes,
pretrained=model_cfg.get("pretrained", True),
dropout=model_cfg.get("dropout", 0.5),
feature_dim=model_cfg.get("feature_dim", None),
)
return model
def build_loss_function(config, class_weights=None, device="cpu"):
"""Build loss function based on config."""
training_cfg = config["training"]
alpha = None
if class_weights is not None and training_cfg.get("use_weighted_loss", False):
alpha = torch.tensor(list(class_weights.values()), dtype=torch.float32).to(device)
if training_cfg.get("use_focal_loss", False):
return FocalLoss(
alpha=alpha,
gamma=training_cfg.get("focal_loss_gamma", 2.0),
label_smoothing=training_cfg.get("label_smoothing", 0.1),
)
else:
return nn.CrossEntropyLoss(
weight=alpha,
label_smoothing=training_cfg.get("label_smoothing", 0.1),
)
def get_model_summary(model, input_size=(1, 3, 224, 224)):
"""Get detailed model summary for paper."""
from collections import OrderedDict
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Estimate FLOPs
try:
from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(model, torch.randn(input_size))
total_flops = flops.total()
except:
total_flops = "N/A (install fvcore for FLOP count)"
# Model size in MB
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
model_size_mb = (param_size + buffer_size) / (1024 ** 2)
summary = OrderedDict({
"Architecture": model.architecture,
"Total Parameters": f"{total_params:,}",
"Trainable Parameters": f"{trainable_params:,}",
"Model Size (MB)": f"{model_size_mb:.2f}",
"FLOPs": total_flops,
"Feature Dimension": model.feature_dim,
"Num Classes": model.num_classes,
})
return summary
# ---- Ensemble Model ----
class EnsembleModel(nn.Module):
"""Ensemble of multiple models for best accuracy."""
def __init__(self, models, weights=None):
super().__init__()
self.models = nn.ModuleList(models)
if weights is None:
weights = [1.0 / len(models)] * len(models)
self.weights = weights
def forward(self, x):
outputs = []
for model, w in zip(self.models, self.weights):
model.eval()
with torch.no_grad():
out = F.softmax(model(x), dim=1)
outputs.append(w * out)
return torch.stack(outputs).sum(dim=0)