Spaces:
Runtime error
Runtime error
File size: 2,172 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch.nn as nn
import timm
import torchvision.models as tvmodels
def get_model(name: str, num_classes: int, pretrained: bool = True):
"""Loads and adapts model architecture."""
name = name.lower()
if name.startswith('swin'):
model = timm.create_model('swin_small_patch4_window7_224', pretrained=pretrained)
if hasattr(model, 'reset_classifier'):
model.reset_classifier(num_classes=num_classes)
else:
model.head = nn.Linear(model.head.in_features, num_classes)
return model
if name.startswith('convnext'):
model = timm.create_model('convnext_tiny', pretrained=pretrained)
if hasattr(model, 'reset_classifier'):
model.reset_classifier(num_classes=num_classes)
else:
model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
return model
if name.startswith('densenet'):
model = tvmodels.densenet169(pretrained=pretrained)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
return model
if name.startswith('mobilenet'):
model = timm.create_model('mobilenetv2_100', pretrained=pretrained)
if hasattr(model, 'reset_classifier'):
model.reset_classifier(num_classes=num_classes)
else:
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
return model
if name.startswith('efficientnet'):
model = timm.create_model('efficientnet_b0', pretrained=pretrained)
if hasattr(model, 'reset_classifier'):
model.reset_classifier(num_classes=num_classes)
else:
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
return model
if name.startswith('maxvit'):
model = timm.create_model('maxvit_tiny_tf_224', pretrained=pretrained)
if hasattr(model, 'reset_classifier'):
model.reset_classifier(num_classes=num_classes)
else:
model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
return model
raise ValueError(f'Unknown model: {name}')
|