Final / utils.py
NagashreePai's picture
Update utils.py
84b609f verified
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision.models import swin_t
# from torchvision import transforms
# from PIL import Image
# import os
# # --- MMIM model class ---
# class MMIM(nn.Module):
# def __init__(self, num_classes):
# super(MMIM, self).__init__()
# self.backbone = swin_t(weights='IMAGENET1K_V1')
# self.backbone.head = nn.Identity()
# self.classifier = nn.Sequential(
# nn.Linear(768, 512),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(512, num_classes)
# )
# def forward(self, x):
# x = self.backbone(x)
# return self.classifier(x)
# # --- Load models with offsets ---
# def load_all_models():
# model_defs = [
# ("MMIM_best1.pth", 9),
# ("MMIM_best3.pth", 4),
# ("MMIM_best2.pth", 12)
# ]
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# models = []
# offsets = []
# total_classes = 0
# for path, num_classes in model_defs:
# model = MMIM(num_classes)
# state_dict = torch.load(path, map_location=device)
# model.load_state_dict(state_dict)
# model.to(device)
# model.eval()
# models.append(model)
# offsets.append(total_classes)
# total_classes += num_classes
# # Generate dummy class names like class0, class1, ...
# idx_to_class = {i: f"class{i}" for i in range(total_classes)}
# return models, offsets, idx_to_class
# # --- Inference on one image ---
# def predict_image(image, models, offsets, idx_to_class):
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize([0.5]*3, [0.5]*3)
# ])
# image_tensor = transform(image).unsqueeze(0).to(device)
# temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence
# max_score = float('-inf')
# final_pred = -1
# probs_combined = {}
# for model, offset, temp in zip(models, offsets, temperatures):
# with torch.no_grad():
# logits = model(image_tensor) / temp
# probs = F.softmax(logits, dim=1).squeeze(0)
# top_score, top_class = torch.max(probs, dim=0)
# if top_score.item() > max_score:
# max_score = top_score.item()
# final_pred = top_class.item() + offset
# # Also collect probabilities for all classes
# for i, p in enumerate(probs):
# probs_combined[offset + i] = p.item()
# # Sort top 3
# top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
# return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
import torch
import torch.nn as nn
from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
import torch.nn.functional as F
# βœ… Define MMIM architecture (same as used during training)
class MMIM(nn.Module):
def __init__(self, num_classes):
super(MMIM, self).__init__()
self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
self.backbone.head = nn.Identity()
self.classifier = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# βœ… Load all 3 models
def load_all_models():
model1 = MMIM(num_classes=9) # class1–9
model2 = MMIM(num_classes=12) # class14–25
model3 = MMIM(num_classes=4) # class10–13
model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
model1.eval()
model2.eval()
model3.eval()
return model1, model2, model3
# βœ… Inference combining raw logits before softmax
def predict_image(image, model1, model2, model3, transform, class_names):
image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
logit1 = model1(image_tensor) # [1, 9]
logit3 = model3(image_tensor) # [1, 4]
logit2 = model2(image_tensor) # [1, 12]
# βœ… Combine logits (not softmax) β†’ then apply softmax
combined_logits = torch.cat([logit1, logit3, logit2], dim=1) # [1, 25]
combined_probs = F.softmax(combined_logits, dim=1) # unified softmax
pred_idx = combined_probs.argmax(dim=1).item()
confidence = combined_probs[0, pred_idx].item()
return class_names[pred_idx], confidence