banaescan / model.py
Inoue1's picture
Update model.py
62768c6 verified
import os
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
from io import BytesIO
BASE_DIR = os.path.dirname(__file__)
# ========== CLASS NAMES ==========
VARIETY_CLASSES = [
'Anaji1', 'Banana Lady Finger ( Señorita )', 'Banana Red', 'Bichi',
'Canvendish(Bungulan)', 'Lakatan', 'Saba', 'Sabri Kola', 'Unknow Data'
]
DISEASE_CLASSES = [
'Banana Anthracnose Disease', 'Banana Bract Mosaic Virus Disease',
'Banana Cordana Leaf Disease', 'Banana Healthy', 'Banana Naturally Leaf Dead',
'Banana Panama Leaf Disease', 'Banana Panama Tree Disease',
'Banana Pestalotiopsis Disease', 'Banana Rhizome Root Tree Disease',
'Banana Sigatoka Leaf Disease'
]
# ========== MODEL LOADER ==========
def load_variety_model():
model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(VARIETY_CLASSES))
model_path = os.path.join(BASE_DIR, "variety_model.bin")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
return model
def load_disease_model():
model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(DISEASE_CLASSES))
model_path = os.path.join(BASE_DIR, "disease_model.bin")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
return model
# Load models once
variety_model = load_variety_model()
disease_model = load_disease_model()
# ========== IMAGE PREPROCESSING ==========
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])
def preprocess_image(image_bytes):
img = Image.open(BytesIO(image_bytes)).convert('RGB')
img = img.resize((256, 256))
left = (256 - 224) // 2
top = (256 - 224) // 2
img = img.crop((left, top, left + 224, top + 224))
arr = np.asarray(img).astype(np.float32) / 255.0
arr = (arr - MEAN) / STD
arr = np.transpose(arr, (2, 0, 1)) # HWC → CHW
arr = np.expand_dims(arr, 0)
return torch.tensor(arr, dtype=torch.float32)
# ========== PREDICTION FUNCTION ==========
def predict(image_bytes, model_type="variety"):
img_tensor = preprocess_image(image_bytes)
if model_type == "variety":
model = variety_model
class_names = VARIETY_CLASSES
elif model_type == "disease":
model = disease_model
class_names = DISEASE_CLASSES
else:
raise ValueError("model_type must be 'variety' or 'disease'")
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, predicted_idx = torch.max(probabilities, 0)
return {
"prediction": class_names[predicted_idx.item()],
"confidence": float(confidence) * 100
}