leafbuddy / model.py
Inoue1's picture
Update model.py
010286a verified
import os
import json
import torch
import numpy as np
import tensorflow as tf
from PIL import Image
from torchvision import models, transforms
# ================= PATHS =================
BASE_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(BASE_DIR, "models")
LABELS_DIR = os.path.join(BASE_DIR, "labels")
# ================= TF CUSTOM OBJECTS =================
@tf.keras.utils.register_keras_serializable()
class FixedDropout(tf.keras.layers.Dropout):
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
# Dummy class to satisfy EfficientNet deserialization
@tf.keras.utils.register_keras_serializable()
class EfficientNetB3(tf.keras.Model):
pass
# ================= INPUT SIZE PER MODEL =================
KERAS_INPUT_SIZES = {
"corn": 300,
}
# ================= IMAGE PREPROCESS =================
def preprocess_pytorch(img, size=224):
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform(img).unsqueeze(0)
def preprocess_keras(img, crop_name):
img = img.convert("RGB")
size = KERAS_INPUT_SIZES.get(crop_name, 224)
img = img.resize((size, size))
arr = np.array(img).astype("float32") / 255.0
return np.expand_dims(arr, axis=0)
# ================= MODEL REGISTRIES =================
PYTORCH_MODELS = {}
KERAS_MODELS = {}
LABELS = {}
# ================= LOAD MODELS =================
def load_models():
for file in os.listdir(MODELS_DIR):
name, ext = os.path.splitext(file)
crop_name = name.replace("_model", "").lower()
model_path = os.path.join(MODELS_DIR, file)
label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
if not os.path.exists(label_path):
raise FileNotFoundError(f"Missing label file: {label_path}")
with open(label_path, "r") as f:
LABELS[crop_name] = json.load(f)
# ---------- PyTorch ----------
if ext == ".pth":
num_classes = len(LABELS[crop_name])
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
PYTORCH_MODELS[crop_name] = model
# ---------- Keras ----------
elif ext in [".keras", ".h5"]:
KERAS_MODELS[crop_name] = tf.keras.models.load_model(
model_path,
custom_objects={
"swish": tf.keras.activations.swish,
"FixedDropout": FixedDropout,
"EfficientNetB3": EfficientNetB3,
},
compile=False
)
# Load models at startup
load_models()
# ================= PREDICTION =================
def predict(image, crop_name):
crop_name = crop_name.strip().lower()
if crop_name in PYTORCH_MODELS:
model = PYTORCH_MODELS[crop_name]
labels = LABELS[crop_name]
tensor = preprocess_pytorch(image)
with torch.no_grad():
output = model(tensor)
probs = torch.softmax(output[0], dim=0)
idx = probs.argmax().item()
return labels[idx], float(probs[idx])
elif crop_name in KERAS_MODELS:
model = KERAS_MODELS[crop_name]
labels = LABELS[crop_name]
arr = preprocess_keras(image, crop_name)
preds = model.predict(arr, verbose=0)[0]
idx = int(np.argmax(preds))
return labels[idx], float(preds[idx])
else:
raise ValueError(f"No model found for crop: {crop_name}")