Spaces:
Sleeping
Sleeping
| # import torch | |
| # import torch.nn as nn | |
| # import torch.nn.functional as F | |
| # from torchvision.models import swin_t | |
| # from torchvision import transforms | |
| # from PIL import Image | |
| # import os | |
| # # --- MMIM model class --- | |
| # class MMIM(nn.Module): | |
| # def __init__(self, num_classes): | |
| # super(MMIM, self).__init__() | |
| # self.backbone = swin_t(weights='IMAGENET1K_V1') | |
| # self.backbone.head = nn.Identity() | |
| # self.classifier = nn.Sequential( | |
| # nn.Linear(768, 512), | |
| # nn.ReLU(), | |
| # nn.Dropout(0.3), | |
| # nn.Linear(512, num_classes) | |
| # ) | |
| # def forward(self, x): | |
| # x = self.backbone(x) | |
| # return self.classifier(x) | |
| # # --- Load models with offsets --- | |
| # def load_all_models(): | |
| # model_defs = [ | |
| # ("MMIM_best1.pth", 9), | |
| # ("MMIM_best3.pth", 4), | |
| # ("MMIM_best2.pth", 12) | |
| # ] | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # models = [] | |
| # offsets = [] | |
| # total_classes = 0 | |
| # for path, num_classes in model_defs: | |
| # model = MMIM(num_classes) | |
| # state_dict = torch.load(path, map_location=device) | |
| # model.load_state_dict(state_dict) | |
| # model.to(device) | |
| # model.eval() | |
| # models.append(model) | |
| # offsets.append(total_classes) | |
| # total_classes += num_classes | |
| # # Generate dummy class names like class0, class1, ... | |
| # idx_to_class = {i: f"class{i}" for i in range(total_classes)} | |
| # return models, offsets, idx_to_class | |
| # # --- Inference on one image --- | |
| # def predict_image(image, models, offsets, idx_to_class): | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # transform = transforms.Compose([ | |
| # transforms.Resize((224, 224)), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize([0.5]*3, [0.5]*3) | |
| # ]) | |
| # image_tensor = transform(image).unsqueeze(0).to(device) | |
| # temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence | |
| # max_score = float('-inf') | |
| # final_pred = -1 | |
| # probs_combined = {} | |
| # for model, offset, temp in zip(models, offsets, temperatures): | |
| # with torch.no_grad(): | |
| # logits = model(image_tensor) / temp | |
| # probs = F.softmax(logits, dim=1).squeeze(0) | |
| # top_score, top_class = torch.max(probs, dim=0) | |
| # if top_score.item() > max_score: | |
| # max_score = top_score.item() | |
| # final_pred = top_class.item() + offset | |
| # # Also collect probabilities for all classes | |
| # for i, p in enumerate(probs): | |
| # probs_combined[offset + i] = p.item() | |
| # # Sort top 3 | |
| # top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3] | |
| # return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3} | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models.swin_transformer import swin_t, Swin_T_Weights | |
| import torch.nn.functional as F | |
| # β Define MMIM architecture (same as used during training) | |
| class MMIM(nn.Module): | |
| def __init__(self, num_classes): | |
| super(MMIM, self).__init__() | |
| self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT) | |
| self.backbone.head = nn.Identity() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return self.classifier(features) | |
| # β Load all 3 models | |
| def load_all_models(): | |
| model1 = MMIM(num_classes=9) # class1β9 | |
| model2 = MMIM(num_classes=12) # class14β25 | |
| model3 = MMIM(num_classes=4) # class10β13 | |
| model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu')) | |
| model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu')) | |
| model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu')) | |
| model1.eval() | |
| model2.eval() | |
| model3.eval() | |
| return model1, model2, model3 | |
| # β Inference combining raw logits before softmax | |
| def predict_image(image, model1, model2, model3, transform, class_names): | |
| image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224] | |
| with torch.no_grad(): | |
| logit1 = model1(image_tensor) # [1, 9] | |
| logit3 = model3(image_tensor) # [1, 4] | |
| logit2 = model2(image_tensor) # [1, 12] | |
| # β Combine logits (not softmax) β then apply softmax | |
| combined_logits = torch.cat([logit1, logit3, logit2], dim=1) # [1, 25] | |
| combined_probs = F.softmax(combined_logits, dim=1) # unified softmax | |
| pred_idx = combined_probs.argmax(dim=1).item() | |
| confidence = combined_probs[0, pred_idx].item() | |
| return class_names[pred_idx], confidence | |