Spaces:
Sleeping
Sleeping
File size: 4,785 Bytes
4b4e2d5 5703104 4b4e2d5 84b609f 4b4e2d5 84b609f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | # import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision.models import swin_t
# from torchvision import transforms
# from PIL import Image
# import os
# # --- MMIM model class ---
# class MMIM(nn.Module):
# def __init__(self, num_classes):
# super(MMIM, self).__init__()
# self.backbone = swin_t(weights='IMAGENET1K_V1')
# self.backbone.head = nn.Identity()
# self.classifier = nn.Sequential(
# nn.Linear(768, 512),
# nn.ReLU(),
# nn.Dropout(0.3),
# nn.Linear(512, num_classes)
# )
# def forward(self, x):
# x = self.backbone(x)
# return self.classifier(x)
# # --- Load models with offsets ---
# def load_all_models():
# model_defs = [
# ("MMIM_best1.pth", 9),
# ("MMIM_best3.pth", 4),
# ("MMIM_best2.pth", 12)
# ]
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# models = []
# offsets = []
# total_classes = 0
# for path, num_classes in model_defs:
# model = MMIM(num_classes)
# state_dict = torch.load(path, map_location=device)
# model.load_state_dict(state_dict)
# model.to(device)
# model.eval()
# models.append(model)
# offsets.append(total_classes)
# total_classes += num_classes
# # Generate dummy class names like class0, class1, ...
# idx_to_class = {i: f"class{i}" for i in range(total_classes)}
# return models, offsets, idx_to_class
# # --- Inference on one image ---
# def predict_image(image, models, offsets, idx_to_class):
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize([0.5]*3, [0.5]*3)
# ])
# image_tensor = transform(image).unsqueeze(0).to(device)
# temperatures = [1.2, 1.0, 0.8] # Adjust for balancing confidence
# max_score = float('-inf')
# final_pred = -1
# probs_combined = {}
# for model, offset, temp in zip(models, offsets, temperatures):
# with torch.no_grad():
# logits = model(image_tensor) / temp
# probs = F.softmax(logits, dim=1).squeeze(0)
# top_score, top_class = torch.max(probs, dim=0)
# if top_score.item() > max_score:
# max_score = top_score.item()
# final_pred = top_class.item() + offset
# # Also collect probabilities for all classes
# for i, p in enumerate(probs):
# probs_combined[offset + i] = p.item()
# # Sort top 3
# top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
# return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}
import torch
import torch.nn as nn
from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
import torch.nn.functional as F
# β
Define MMIM architecture (same as used during training)
class MMIM(nn.Module):
def __init__(self, num_classes):
super(MMIM, self).__init__()
self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
self.backbone.head = nn.Identity()
self.classifier = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# β
Load all 3 models
def load_all_models():
model1 = MMIM(num_classes=9) # class1β9
model2 = MMIM(num_classes=12) # class14β25
model3 = MMIM(num_classes=4) # class10β13
model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))
model1.eval()
model2.eval()
model3.eval()
return model1, model2, model3
# β
Inference combining raw logits before softmax
def predict_image(image, model1, model2, model3, transform, class_names):
image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224]
with torch.no_grad():
logit1 = model1(image_tensor) # [1, 9]
logit3 = model3(image_tensor) # [1, 4]
logit2 = model2(image_tensor) # [1, 12]
# β
Combine logits (not softmax) β then apply softmax
combined_logits = torch.cat([logit1, logit3, logit2], dim=1) # [1, 25]
combined_probs = F.softmax(combined_logits, dim=1) # unified softmax
pred_idx = combined_probs.argmax(dim=1).item()
confidence = combined_probs[0, pred_idx].item()
return class_names[pred_idx], confidence
|