import os import json import torch import numpy as np import tensorflow as tf from PIL import Image from torchvision import models, transforms # ================= PATHS ================= BASE_DIR = os.path.dirname(__file__) MODELS_DIR = os.path.join(BASE_DIR, "models") LABELS_DIR = os.path.join(BASE_DIR, "labels") # ================= TF CUSTOM OBJECTS ================= @tf.keras.utils.register_keras_serializable() class FixedDropout(tf.keras.layers.Dropout): def __init__(self, rate, noise_shape=None, seed=None, **kwargs): super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs) # Dummy class to satisfy EfficientNet deserialization @tf.keras.utils.register_keras_serializable() class EfficientNetB3(tf.keras.Model): pass # ================= INPUT SIZE PER MODEL ================= KERAS_INPUT_SIZES = { "corn": 300, } # ================= IMAGE PREPROCESS ================= def preprocess_pytorch(img, size=224): transform = transforms.Compose([ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) return transform(img).unsqueeze(0) def preprocess_keras(img, crop_name): img = img.convert("RGB") size = KERAS_INPUT_SIZES.get(crop_name, 224) img = img.resize((size, size)) arr = np.array(img).astype("float32") / 255.0 return np.expand_dims(arr, axis=0) # ================= MODEL REGISTRIES ================= PYTORCH_MODELS = {} KERAS_MODELS = {} LABELS = {} # ================= LOAD MODELS ================= def load_models(): for file in os.listdir(MODELS_DIR): name, ext = os.path.splitext(file) crop_name = name.replace("_model", "").lower() model_path = os.path.join(MODELS_DIR, file) label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json") if not os.path.exists(label_path): raise FileNotFoundError(f"Missing label file: {label_path}") with open(label_path, "r") as f: LABELS[crop_name] = json.load(f) # ---------- PyTorch ---------- if ext == ".pth": num_classes = len(LABELS[crop_name]) model = models.resnet18(weights=None) model.fc = torch.nn.Linear(model.fc.in_features, num_classes) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() PYTORCH_MODELS[crop_name] = model # ---------- Keras ---------- elif ext in [".keras", ".h5"]: KERAS_MODELS[crop_name] = tf.keras.models.load_model( model_path, custom_objects={ "swish": tf.keras.activations.swish, "FixedDropout": FixedDropout, "EfficientNetB3": EfficientNetB3, }, compile=False ) # Load models at startup load_models() # ================= PREDICTION ================= def predict(image, crop_name): crop_name = crop_name.strip().lower() if crop_name in PYTORCH_MODELS: model = PYTORCH_MODELS[crop_name] labels = LABELS[crop_name] tensor = preprocess_pytorch(image) with torch.no_grad(): output = model(tensor) probs = torch.softmax(output[0], dim=0) idx = probs.argmax().item() return labels[idx], float(probs[idx]) elif crop_name in KERAS_MODELS: model = KERAS_MODELS[crop_name] labels = LABELS[crop_name] arr = preprocess_keras(image, crop_name) preds = model.predict(arr, verbose=0)[0] idx = int(np.argmax(preds)) return labels[idx], float(preds[idx]) else: raise ValueError(f"No model found for crop: {crop_name}")