Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
965 Bytes
"""
EfficientNet model for image classification. Using timm library for model definition.
"""
import timm
def crete_efficientnet_v2_model(model_name='efficientnetv2_m', num_classes=2, pretrained=True, in_22k=False):
"""
Create an EfficientNet model for image classification.
Args:
model_name (str): Name of the EfficientNet model variant to use.
num_classes (int): Number of output classes (e.g. 0 for not initializing head).
pretrained (bool): Whether to use pretrained weights.
Returns:
model: The EfficientNet model.
"""
if not model_name.startswith('tf_'):
model_name = 'tf_' + model_name
model_name += '.in21k' if in_22k else '.in21k_ft_in1k'
print(f"Creating EfficientNet model: {model_name}")
model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
num_features = model.classifier.in_features
return model, num_features