File size: 4,785 Bytes
4b4e2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5703104
4b4e2d5
84b609f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4e2d5
 
84b609f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision.models import swin_t
# from torchvision import transforms
# from PIL import Image
# import os

# # --- MMIM model class ---
# class MMIM(nn.Module):
#     def __init__(self, num_classes):
#         super(MMIM, self).__init__()
#         self.backbone = swin_t(weights='IMAGENET1K_V1')
#         self.backbone.head = nn.Identity()
#         self.classifier = nn.Sequential(
#             nn.Linear(768, 512),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(512, num_classes)
#         )

#     def forward(self, x):
#         x = self.backbone(x)
#         return self.classifier(x)

# # --- Load models with offsets ---
# def load_all_models():
#     model_defs = [
#         ("MMIM_best1.pth", 9),
#         ("MMIM_best3.pth", 4),
#         ("MMIM_best2.pth", 12)
#     ]
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'

#     models = []
#     offsets = []
#     total_classes = 0
#     for path, num_classes in model_defs:
#         model = MMIM(num_classes)
#         state_dict = torch.load(path, map_location=device)
#         model.load_state_dict(state_dict)
#         model.to(device)
#         model.eval()
#         models.append(model)
#         offsets.append(total_classes)
#         total_classes += num_classes

#     # Generate dummy class names like class0, class1, ...
#     idx_to_class = {i: f"class{i}" for i in range(total_classes)}
#     return models, offsets, idx_to_class

# # --- Inference on one image ---
# def predict_image(image, models, offsets, idx_to_class):
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     transform = transforms.Compose([
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5]*3, [0.5]*3)
#     ])
#     image_tensor = transform(image).unsqueeze(0).to(device)

#     temperatures = [1.2, 1.0, 0.8]  # Adjust for balancing confidence

#     max_score = float('-inf')
#     final_pred = -1
#     probs_combined = {}

#     for model, offset, temp in zip(models, offsets, temperatures):
#         with torch.no_grad():
#             logits = model(image_tensor) / temp
#             probs = F.softmax(logits, dim=1).squeeze(0)
#             top_score, top_class = torch.max(probs, dim=0)
#             if top_score.item() > max_score:
#                 max_score = top_score.item()
#                 final_pred = top_class.item() + offset

#             # Also collect probabilities for all classes
#             for i, p in enumerate(probs):
#                 probs_combined[offset + i] = p.item()

#     # Sort top 3
#     top3 = sorted(probs_combined.items(), key=lambda x: x[1], reverse=True)[:3]
#     return {idx_to_class[k]: float(f"{v:.4f}") for k, v in top3}

import torch
import torch.nn as nn
from torchvision.models.swin_transformer import swin_t, Swin_T_Weights
import torch.nn.functional as F

# βœ… Define MMIM architecture (same as used during training)
class MMIM(nn.Module):
    def __init__(self, num_classes):
        super(MMIM, self).__init__()
        self.backbone = swin_t(weights=Swin_T_Weights.DEFAULT)
        self.backbone.head = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# βœ… Load all 3 models
def load_all_models():
    model1 = MMIM(num_classes=9)     # class1–9
    model2 = MMIM(num_classes=12)    # class14–25
    model3 = MMIM(num_classes=4)     # class10–13

    model1.load_state_dict(torch.load("MMIM_best1.pth", map_location='cpu'))
    model2.load_state_dict(torch.load("MMIM_best2.pth", map_location='cpu'))
    model3.load_state_dict(torch.load("MMIM_best3.pth", map_location='cpu'))

    model1.eval()
    model2.eval()
    model3.eval()

    return model1, model2, model3

# βœ… Inference combining raw logits before softmax
def predict_image(image, model1, model2, model3, transform, class_names):
    image_tensor = transform(image).unsqueeze(0)  # [1, 3, 224, 224]

    with torch.no_grad():
        logit1 = model1(image_tensor)  # [1, 9]
        logit3 = model3(image_tensor)  # [1, 4]
        logit2 = model2(image_tensor)  # [1, 12]

    # βœ… Combine logits (not softmax) β†’ then apply softmax
    combined_logits = torch.cat([logit1, logit3, logit2], dim=1)  # [1, 25]
    combined_probs = F.softmax(combined_logits, dim=1)            # unified softmax

    pred_idx = combined_probs.argmax(dim=1).item()
    confidence = combined_probs[0, pred_idx].item()

    return class_names[pred_idx], confidence