File size: 965 Bytes
dae5c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | """
EfficientNet model for image classification. Using timm library for model definition.
"""
import timm
def crete_efficientnet_v2_model(model_name='efficientnetv2_m', num_classes=2, pretrained=True, in_22k=False):
"""
Create an EfficientNet model for image classification.
Args:
model_name (str): Name of the EfficientNet model variant to use.
num_classes (int): Number of output classes (e.g. 0 for not initializing head).
pretrained (bool): Whether to use pretrained weights.
Returns:
model: The EfficientNet model.
"""
if not model_name.startswith('tf_'):
model_name = 'tf_' + model_name
model_name += '.in21k' if in_22k else '.in21k_ft_in1k'
print(f"Creating EfficientNet model: {model_name}")
model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
num_features = model.classifier.in_features
return model, num_features |