File size: 2,858 Bytes
7751e3d
288ff31
 
 
 
 
 
62768c6
7751e3d
288ff31
 
 
4c002d5
288ff31
 
 
 
 
 
 
976bda1
288ff31
 
 
 
 
 
 
7751e3d
 
288ff31
 
 
 
 
 
 
7751e3d
 
288ff31
 
 
 
 
 
62768c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
from io import BytesIO

BASE_DIR = os.path.dirname(__file__)

# ========== CLASS NAMES ==========
VARIETY_CLASSES = [
    'Anaji1', 'Banana Lady Finger ( Señorita )', 'Banana Red', 'Bichi',
    'Canvendish(Bungulan)', 'Lakatan', 'Saba', 'Sabri Kola', 'Unknow Data'
]

DISEASE_CLASSES = [
    'Banana Anthracnose Disease', 'Banana Bract Mosaic Virus Disease', 
    'Banana Cordana Leaf Disease', 'Banana Healthy', 'Banana Naturally Leaf Dead',
    'Banana Panama Leaf Disease', 'Banana Panama Tree Disease',
    'Banana Pestalotiopsis Disease', 'Banana Rhizome Root Tree Disease',
    'Banana Sigatoka Leaf Disease'
]

# ========== MODEL LOADER ==========
def load_variety_model():
    model = models.resnet18(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, len(VARIETY_CLASSES))
    model_path = os.path.join(BASE_DIR, "variety_model.bin")
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    return model

def load_disease_model():
    model = models.resnet18(weights=None)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, len(DISEASE_CLASSES))
    model_path = os.path.join(BASE_DIR, "disease_model.bin")
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    return model

# Load models once
variety_model = load_variety_model()
disease_model = load_disease_model()

# ========== IMAGE PREPROCESSING ==========
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

def preprocess_image(image_bytes):
    img = Image.open(BytesIO(image_bytes)).convert('RGB')
    img = img.resize((256, 256))
    left = (256 - 224) // 2
    top = (256 - 224) // 2
    img = img.crop((left, top, left + 224, top + 224))
    arr = np.asarray(img).astype(np.float32) / 255.0
    arr = (arr - MEAN) / STD
    arr = np.transpose(arr, (2, 0, 1))  # HWC → CHW
    arr = np.expand_dims(arr, 0)
    return torch.tensor(arr, dtype=torch.float32)

# ========== PREDICTION FUNCTION ==========
def predict(image_bytes, model_type="variety"):
    img_tensor = preprocess_image(image_bytes)
    
    if model_type == "variety":
        model = variety_model
        class_names = VARIETY_CLASSES
    elif model_type == "disease":
        model = disease_model
        class_names = DISEASE_CLASSES
    else:
        raise ValueError("model_type must be 'variety' or 'disease'")
    
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        confidence, predicted_idx = torch.max(probabilities, 0)
        return {
            "prediction": class_names[predicted_idx.item()],
            "confidence": float(confidence) * 100
        }