# import torch # import torch.nn as nn # import torch.nn.functional as F # from torchvision.models import swin_t # from torchvision import transforms # from PIL import Image # import os # # --- MMIM model class --- # class MMIM(nn.Module): # def __init__(self, num_classes): # super(MMIM, self).__init__() # self.backbone = swin_t(weights='IMAGENET1K_V1') # self.backbone.head = nn.Identity() # self.classifier = nn.Sequential( # nn.Linear(768, 512), # nn.ReLU(), # nn.Dropout(0.3), # nn.Linear(512, num_classes) # ) # def forward(self, x): # x = self.backbone(x) # return self.classifier(x) # # --- Load models with offsets --- # def load_all_models(): # model_defs = [ # ("MMIM_best1.pth", 9), # ("MMIM_best3.pth", 4), # ("MMIM_best2.pth", 12) # ] # device = 'cuda' if torch.cuda.is_available() else 'cpu' # models = [] # offsets = [] # total_classes = 0 # for path, num_classes in model_defs: # model = MMIM(num_classes) # state_dict = torch.load(path, map_location=device) # model.load_state_dict(state_dict) # model.to(device) # model.eval() # models.append(model) # offsets.append(total_classes) # total_classes += num_classes # # Generate dummy class names like class0, class1, ... # idx_to_class = {i: f"class{i}" for i in range(total_classes)} # return models, offsets, idx_to_class # # --- Inference on one image --- # def predict_image(image, models, offsets, idx_to_class): # device = 'cuda' if torch.cuda.is_available() else 'cpu' # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize([0.5]*3, [0.5]*3) # ]) # image_tensor = transform(image).unsqueeze(0).to(device) # temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence # max_score = float('-inf') # final_pred = -1 # probs_combined = {} # for model, offset, temp in zip(models, offsets, temperatures): # with torch.no_grad(): # logits = model(image_tensor) / temp # probs = F.softmax(logits, dim=1).squeeze(0) # top_score, top_class = torch.max(probs, dim=0) # if top_score.item() > max_score: # max_score = top_score.item() # final_pred = top_class.item() + offset # # Also collect probabilities for all classes # for i, p in enumerate(probs): # probs_combined[offset + i] = p.item() # # Sort top 3 # top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3] # return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3} import torch import torch.nn as nn from torchvision.models.swin_transformer import swin_t, Swin_T_Weights import torch.nn.functional as F # ✅ Define MMIM architecture (same as used during training) class MMIM(nn.Module): def __init__(self, num_classes): super(MMIM, self).__init__() self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT) self.backbone.head = nn.Identity() self.classifier = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) ) def forward(self, x): features = self.backbone(x) return self.classifier(features) # ✅ Load all 3 models def load_all_models(): model1 = MMIM(num_classes=9) # class1–9 model2 = MMIM(num_classes=12) # class14–25 model3 = MMIM(num_classes=4) # class10–13 model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu')) model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu')) model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu')) model1.eval() model2.eval() model3.eval() return model1, model2, model3 # ✅ Inference combining raw logits before softmax def predict_image(image, model1, model2, model3, transform, class_names): image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224] with torch.no_grad(): logit1 = model1(image_tensor) # [1, 9] logit3 = model3(image_tensor) # [1, 4] logit2 = model2(image_tensor) # [1, 12] # ✅ Combine logits (not softmax) → then apply softmax combined_logits = torch.cat([logit1, logit3, logit2], dim=1) # [1, 25] combined_probs = F.softmax(combined_logits, dim=1) # unified softmax pred_idx = combined_probs.argmax(dim=1).item() confidence = combined_probs[0, pred_idx].item() return class_names[pred_idx], confidence