File size: 3,831 Bytes
4418807 f3591b4 4418807 f3591b4 4418807 4e13890 f3591b4 010286a f3591b4 010286a 4e13890 010286a 94a728b 010286a 94a728b 4418807 010286a 4418807 010286a 4418807 010286a f3591b4 4418807 010286a 4418807 4e13890 f3591b4 4418807 010286a 4418807 94a728b 010286a 94a728b f2bbf2c 4418807 010286a 4418807 010286a 4418807 010286a 4418807 f2bbf2c 65af2fd 010286a 4418807 010286a 4418807 010286a 4418807 010286a 4418807 010286a f3591b4 010286a 4418807 010286a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import os
import json
import torch
import numpy as np
import tensorflow as tf
from PIL import Image
from torchvision import models, transforms
# ================= PATHS =================
BASE_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(BASE_DIR, "models")
LABELS_DIR = os.path.join(BASE_DIR, "labels")
# ================= TF CUSTOM OBJECTS =================
@tf.keras.utils.register_keras_serializable()
class FixedDropout(tf.keras.layers.Dropout):
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
# Dummy class to satisfy EfficientNet deserialization
@tf.keras.utils.register_keras_serializable()
class EfficientNetB3(tf.keras.Model):
pass
# ================= INPUT SIZE PER MODEL =================
KERAS_INPUT_SIZES = {
"corn": 300,
}
# ================= IMAGE PREPROCESS =================
def preprocess_pytorch(img, size=224):
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform(img).unsqueeze(0)
def preprocess_keras(img, crop_name):
img = img.convert("RGB")
size = KERAS_INPUT_SIZES.get(crop_name, 224)
img = img.resize((size, size))
arr = np.array(img).astype("float32") / 255.0
return np.expand_dims(arr, axis=0)
# ================= MODEL REGISTRIES =================
PYTORCH_MODELS = {}
KERAS_MODELS = {}
LABELS = {}
# ================= LOAD MODELS =================
def load_models():
for file in os.listdir(MODELS_DIR):
name, ext = os.path.splitext(file)
crop_name = name.replace("_model", "").lower()
model_path = os.path.join(MODELS_DIR, file)
label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
if not os.path.exists(label_path):
raise FileNotFoundError(f"Missing label file: {label_path}")
with open(label_path, "r") as f:
LABELS[crop_name] = json.load(f)
# ---------- PyTorch ----------
if ext == ".pth":
num_classes = len(LABELS[crop_name])
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
PYTORCH_MODELS[crop_name] = model
# ---------- Keras ----------
elif ext in [".keras", ".h5"]:
KERAS_MODELS[crop_name] = tf.keras.models.load_model(
model_path,
custom_objects={
"swish": tf.keras.activations.swish,
"FixedDropout": FixedDropout,
"EfficientNetB3": EfficientNetB3,
},
compile=False
)
# Load models at startup
load_models()
# ================= PREDICTION =================
def predict(image, crop_name):
crop_name = crop_name.strip().lower()
if crop_name in PYTORCH_MODELS:
model = PYTORCH_MODELS[crop_name]
labels = LABELS[crop_name]
tensor = preprocess_pytorch(image)
with torch.no_grad():
output = model(tensor)
probs = torch.softmax(output[0], dim=0)
idx = probs.argmax().item()
return labels[idx], float(probs[idx])
elif crop_name in KERAS_MODELS:
model = KERAS_MODELS[crop_name]
labels = LABELS[crop_name]
arr = preprocess_keras(image, crop_name)
preds = model.predict(arr, verbose=0)[0]
idx = int(np.argmax(preds))
return labels[idx], float(preds[idx])
else:
raise ValueError(f"No model found for crop: {crop_name}")
|