Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| class MediSimFusionModel(nn.Module): | |
| def __init__(self, vocab_size, num_classes): | |
| super().__init__() | |
| # Image Encoder (ResNet-18) | |
| # Keep initialization lightweight for hosted runtimes; custom checkpoint | |
| # loading in main.py will populate final weights. | |
| self.resnet = models.resnet18(weights=None) | |
| self.resnet.fc = nn.Linear(512, 128) | |
| # Text Encoder (biLSTM) | |
| self.embedding = nn.Embedding(vocab_size, 64) | |
| self.lstm = nn.LSTM(64, 64, bidirectional=True, batch_first=True) | |
| self.text_fc = nn.Linear(128, 128) | |
| # Fusion & Classification | |
| self.classifier = nn.Sequential( | |
| nn.Linear(128 + 128, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, num_classes) | |
| ) | |
| def forward(self, img, text): | |
| # image: (batch, 3, 224, 224) | |
| v_feat = self.resnet(img) | |
| # text: (batch, seq_len) | |
| embedded = self.embedding(text) | |
| _, (hn, _) = self.lstm(embedded) | |
| # Concatenate forward and backward hidden states | |
| t_feat = torch.cat((hn[-2], hn[-1]), dim=1) | |
| t_feat = self.text_fc(t_feat) | |
| # Fusion | |
| fused = torch.cat((v_feat, t_feat), dim=1) | |
| return self.classifier(fused) | |
| def get_model(vocab_size, num_classes, device="cpu"): | |
| model = MediSimFusionModel(vocab_size=vocab_size, num_classes=num_classes) | |
| return model.to(device) | |