PlantDiseaseDetection / configuration.py
BrandonFors's picture
uploading files to space
8e1235e
import transformers
from transformers import PretrainedConfig
# create a config class
class EffNetPlantDiseaseConfig(PretrainedConfig):
# tells hugging face the model family type
# could be used to register class with AutoConfig
model_type = "effnetv2_s_plant_disease"
def __init__(self,
num_classes=38,
image_size=224,
dropout_rate = 0.2,
class_names = None,
**kwargs):
super().__init__(**kwargs)
# assign inputs
self.num_classes = num_classes
self.image_size = 224 # unused but will keep just in case
self.dropout_rate = dropout_rate
# create class dictionaries from the label (image class) names
if class_names:
self.id2label = {str(i): name for i, name in enumerate(class_names)}
self.label2id = {name: str(i) for i, name in enumerate(class_names)}
else:
self.id2label = {str(i): f"class_{i}" for i in range(num_classes)}
self.label2id = {f"class_{i}": str(i) for i in range(num_classes)}