Spaces:
Running
Running
File size: 2,141 Bytes
604f286 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """Model factory for binary classification: ResNet-50, DenseNet-121, ViT-B/16."""
import timm
import torch.nn as nn
SUPPORTED_MODELS = {
"resnet50": "resnet50",
"densenet121": "densenet121",
"vit_base_patch16_224": "vit_base_patch16_224",
}
def _replace_relu_with_gelu(module):
for name, child in module.named_children():
if isinstance(child, nn.ReLU):
setattr(module, name, nn.GELU())
else:
_replace_relu_with_gelu(child)
def create_model(model_name="resnet50", pretrained=True, dropout=0.3, modified=False):
"""Create a binary classification model.
Args:
model_name: One of 'resnet50', 'densenet121', 'vit_base_patch16_224'.
pretrained: Use ImageNet-pretrained weights.
dropout: Dropout rate before the final classifier.
modified: If True, replace ReLU with GELU in ResNet-50.
Returns:
model: nn.Module with a single-output (sigmoid) head.
"""
if model_name not in SUPPORTED_MODELS:
raise ValueError(f"Unknown model: {model_name}. Choose from {list(SUPPORTED_MODELS)}")
model = timm.create_model(
SUPPORTED_MODELS[model_name],
pretrained=pretrained,
num_classes=1,
drop_rate=dropout,
)
if modified and model_name == "resnet50":
_replace_relu_with_gelu(model)
in_features = model.get_classifier().in_features
hidden = in_features // 2
model.fc = nn.Sequential(
nn.Linear(in_features, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
return model
def freeze_backbone(model):
"""Freeze all parameters except the classification head."""
classifier_params = set(id(p) for p in model.get_classifier().parameters())
for param in model.parameters():
if id(param) not in classifier_params:
param.requires_grad = False
def unfreeze_backbone(model):
"""Unfreeze all parameters."""
for param in model.parameters():
param.requires_grad = True
|