Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| import torch.nn.functional as F | |
| import torch | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.feature_extractor = resnet18(weights=ResNet18_Weights) | |
| in_channels = self.feature_extractor.fc.in_features | |
| self.feature_extractor.fc = nn.Identity() | |
| # Output is a vector of dimension 1 + 4 | |
| # 1 for probability of belonging to any class | |
| # 4 for bounding box of object that is presented (if no object is presented i. e. the probability < a threshold, any 4 numbers) | |
| self.fc_prob = nn.Sequential( | |
| nn.Linear(in_channels, 512), | |
| nn.Linear(512, 1) | |
| ) | |
| self.fc_bbox = nn.Sequential( | |
| nn.Linear(in_channels, 512), | |
| nn.Linear(512, 4) | |
| ) | |
| def forward(self, x): | |
| pred_prob = torch.sigmoid(self.fc_prob(self.feature_extractor(x))) | |
| pred_bbox = self.fc_bbox(self.feature_extractor(x)) | |
| return (pred_prob, pred_bbox) | |