EyeDiseaseDetection / utils /ModelCreator.py
GilbertKrantz
FIX : ty Type Error
6e8e8fb
import torch
from torch import device
import torch.nn as nn
import timm
# Set device
DEVICE = device("cuda" if torch.cuda.is_available() else "cpu")
class EyeDetectionModels(object):
"""
A class to create and configure various deep learning models for eye detection tasks.
"""
def __init__(
self,
num_classes: int,
freeze_layers: bool = True,
device: device = DEVICE,
):
"""
Initialize the EyeDetectionModels class.
This class provides methods to create and configure various deep learning models for eye detection.
"""
# Initialize the model creator
self.num_classes = num_classes
self.freeze_layers = freeze_layers
self.device = device
self.models = {
"mobilenetv4": self.get_model_mobilenetv4,
"levit": self.get_model_levit,
"efficientvit": self.get_model_efficientvit,
"gernet": self.get_model_gernet,
"regnetx": self.get_model_regnetx,
}
# Model architecture functions
@staticmethod
def _get_feature_blocks(model: nn.Module) -> nn.ModuleList:
"""
Utility: locate the main feature blocks container in a timm model.
Returns a list-like module of blocks.
"""
for attr in ("features", "blocks", "layers", "stem"): # common container names
if hasattr(model, attr):
return nn.ModuleList(getattr(model, attr))
# fallback: collect all children except classifier/head
return nn.ModuleList(model.children())[:-1]
@staticmethod
def _freeze_except_last_n(blocks: nn.ModuleList, n: int) -> None:
total = len(blocks)
for idx, block in enumerate(blocks):
requires = idx >= total - n
for p in block.parameters():
p.requires_grad = requires
def get_model_mobilenetv4(self) -> nn.Module:
model = timm.create_model(
"mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
)
if self.freeze_layers:
blocks = self._get_feature_blocks(model)
self._freeze_except_last_n(blocks, 2)
# replace classifier
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
return model.to(self.device)
def get_model_levit(self) -> nn.Module:
model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
if self.freeze_layers:
blocks = self._get_feature_blocks(model)
self._freeze_except_last_n(blocks, 2)
# Attempt to extract in_features from model.head or classifier
in_features = 384
model.head = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
model.head_dist = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
return model.to(self.device)
def get_model_efficientvit(self) -> nn.Module:
model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
if self.freeze_layers:
blocks = self._get_feature_blocks(model)
self._freeze_except_last_n(blocks, 2)
# handle different head naming
in_features = 192
model.head.linear = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
return model.to(self.device)
def get_model_gernet(self) -> nn.Module:
"""
Load and configure a GENet (General and Efficient Network) model with customizable classifier.
Returns:
Configured GENet model
"""
model = timm.create_model("gernet_s.idstcv_in1k", pretrained=True)
if self.freeze_layers:
# For GENet, we need to specifically handle its structure
# It typically has a 'stem' and 'stages' structure
if hasattr(model, "stem") and hasattr(model, "stages"):
# Freeze stem completely
for param in model.stem.parameters():
param.requires_grad = False
# Freeze all stages except the last two
stages = list(model.stages.children())
total_stages = len(stages)
for i, stage in enumerate(stages):
requires_grad = i >= total_stages - 2
for param in stage.parameters():
param.requires_grad = requires_grad
else:
# Fallback to generic approach
blocks = self._get_feature_blocks(model)
self._freeze_except_last_n(blocks, 2)
# Replace classifier
in_features = model.head.fc.in_features
model.head.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
return model.to(self.device)
def get_model_regnetx(self) -> nn.Module:
"""
Load and configure a RegNetX model with customizable classifier.
Returns:
Configured RegNetX model
"""
model = timm.create_model("regnetx_008.tv2_in1k", pretrained=True)
if self.freeze_layers:
for param in model.parameters():
param.requires_grad = False
# RegNetX typically has 'stem' + 'trunk' structure in timm
if hasattr(model, "trunk"):
# Unfreeze final stages of the trunk
trunk_blocks = list(model.trunk.children())
# Unfreeze approximately last 25% of trunk blocks
unfreeze_from = max(0, int(len(trunk_blocks) * 0.75))
for i in range(unfreeze_from, len(trunk_blocks)):
for param in trunk_blocks[i].parameters():
param.requires_grad = True
# Always unfreeze the classifier/head for fine-tuning
for param in model.head.parameters():
param.requires_grad = True
# Replace classifier
in_features = model.head.fc.in_features
model.head.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(512, self.num_classes),
)
return model.to(self.device)