ELIAS-epiblepharon / model.py
cahsu's picture
Upload 5 files
fa50b6c verified
"""
ELIAS — Eyelid Lesion Intelligent Analysis System
model.py
Frozen ResNet-18 classifier for epiblepharon detection.
Compatible with Hugging Face model loading.
"""
import torch
import torch.nn as nn
from torchvision import models
def build_elias_model(num_classes: int = 2, freeze_backbone: bool = True) -> nn.Module:
"""
Build ELIAS classifier.
Args:
num_classes: 2 for binary (CrossEntropyLoss)
freeze_backbone: Freeze all layers except the final FC head.
Returns:
ResNet-18 model with task-specific classification head.
"""
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
if freeze_backbone:
for param in model.parameters():
param.requires_grad = False
# Replace final FC with task-specific head
in_features = model.fc.in_features # 512
model.fc = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(in_features, num_classes),
)
return model
def load_elias_model(checkpoint_path: str, device: str = "cpu") -> nn.Module:
"""
Load a trained ELIAS model from checkpoint.
Usage:
model = load_elias_model("pytorch_model.pt")
"""
model = build_elias_model()
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
return model
if __name__ == "__main__":
model = build_elias_model()
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable:,} / {total:,}")
# Sanity check
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = model(x)
print(f"Output shape: {out.shape}") # (2, 2)