import os import torch import torchvision.models as models from PIL import Image import numpy as np from io import BytesIO BASE_DIR = os.path.dirname(__file__) # ========== CLASS NAMES ========== VARIETY_CLASSES = [ 'Anaji1', 'Banana Lady Finger ( Señorita )', 'Banana Red', 'Bichi', 'Canvendish(Bungulan)', 'Lakatan', 'Saba', 'Sabri Kola', 'Unknow Data' ] DISEASE_CLASSES = [ 'Banana Anthracnose Disease', 'Banana Bract Mosaic Virus Disease', 'Banana Cordana Leaf Disease', 'Banana Healthy', 'Banana Naturally Leaf Dead', 'Banana Panama Leaf Disease', 'Banana Panama Tree Disease', 'Banana Pestalotiopsis Disease', 'Banana Rhizome Root Tree Disease', 'Banana Sigatoka Leaf Disease' ] # ========== MODEL LOADER ========== def load_variety_model(): model = models.resnet18(weights=None) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, len(VARIETY_CLASSES)) model_path = os.path.join(BASE_DIR, "variety_model.bin") model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() return model def load_disease_model(): model = models.resnet18(weights=None) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, len(DISEASE_CLASSES)) model_path = os.path.join(BASE_DIR, "disease_model.bin") model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() return model # Load models once variety_model = load_variety_model() disease_model = load_disease_model() # ========== IMAGE PREPROCESSING ========== MEAN = np.array([0.485, 0.456, 0.406]) STD = np.array([0.229, 0.224, 0.225]) def preprocess_image(image_bytes): img = Image.open(BytesIO(image_bytes)).convert('RGB') img = img.resize((256, 256)) left = (256 - 224) // 2 top = (256 - 224) // 2 img = img.crop((left, top, left + 224, top + 224)) arr = np.asarray(img).astype(np.float32) / 255.0 arr = (arr - MEAN) / STD arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW arr = np.expand_dims(arr, 0) return torch.tensor(arr, dtype=torch.float32) # ========== PREDICTION FUNCTION ========== def predict(image_bytes, model_type="variety"): img_tensor = preprocess_image(image_bytes) if model_type == "variety": model = variety_model class_names = VARIETY_CLASSES elif model_type == "disease": model = disease_model class_names = DISEASE_CLASSES else: raise ValueError("model_type must be 'variety' or 'disease'") with torch.no_grad(): output = model(img_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) confidence, predicted_idx = torch.max(probabilities, 0) return { "prediction": class_names[predicted_idx.item()], "confidence": float(confidence) * 100 }