Spaces:
Sleeping
Sleeping
File size: 1,426 Bytes
d581b00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
import torch.nn as nn
from torchvision import models
def build_model(pretrained=True):
model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None)
# Freeze all layers first
for param in model.parameters():
param.requires_grad = False
# Replace final layer for binary classification
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.6),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 1)
)
return model
def build_efficientnet(pretrained=True):
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace classifier for binary classification
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(0.6),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 1)
)
return model
if __name__ == "__main__":
model = build_model()
print(model.fc)
# Quick shape test
dummy = torch.randn(4, 3, 224, 224)
out = model(dummy)
print(f"Output shape: {out.shape}") # should be [4, 1] |