File size: 965 Bytes
dae5c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
""" 

   EfficientNet model for image classification. Using timm library for model definition.

"""

import timm

def crete_efficientnet_v2_model(model_name='efficientnetv2_m', num_classes=2, pretrained=True, in_22k=False):
   """

   Create an EfficientNet model for image classification.



   Args:

      model_name (str): Name of the EfficientNet model variant to use.

      num_classes (int): Number of output classes (e.g. 0 for not initializing head).

      pretrained (bool): Whether to use pretrained weights.



   Returns:

      model: The EfficientNet model.

   """

   if not model_name.startswith('tf_'):
      model_name = 'tf_' + model_name
      
   model_name += '.in21k' if in_22k else '.in21k_ft_in1k'

   print(f"Creating EfficientNet model: {model_name}")
   model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
   num_features = model.classifier.in_features

   return model, num_features