| | import os |
| | import torch |
| | import torchvision.models as models |
| | from PIL import Image |
| | import numpy as np |
| | from io import BytesIO |
| |
|
| | BASE_DIR = os.path.dirname(__file__) |
| |
|
| | |
| | VARIETY_CLASSES = [ |
| | 'Anaji1', 'Banana Lady Finger ( Señorita )', 'Banana Red', 'Bichi', |
| | 'Canvendish(Bungulan)', 'Lakatan', 'Saba', 'Sabri Kola', 'Unknow Data' |
| | ] |
| |
|
| | DISEASE_CLASSES = [ |
| | 'Banana Anthracnose Disease', 'Banana Bract Mosaic Virus Disease', |
| | 'Banana Cordana Leaf Disease', 'Banana Healthy', 'Banana Naturally Leaf Dead', |
| | 'Banana Panama Leaf Disease', 'Banana Panama Tree Disease', |
| | 'Banana Pestalotiopsis Disease', 'Banana Rhizome Root Tree Disease', |
| | 'Banana Sigatoka Leaf Disease' |
| | ] |
| |
|
| | |
| | def load_variety_model(): |
| | model = models.resnet18(weights=None) |
| | num_ftrs = model.fc.in_features |
| | model.fc = torch.nn.Linear(num_ftrs, len(VARIETY_CLASSES)) |
| | model_path = os.path.join(BASE_DIR, "variety_model.bin") |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| | model.eval() |
| | return model |
| |
|
| | def load_disease_model(): |
| | model = models.resnet18(weights=None) |
| | num_ftrs = model.fc.in_features |
| | model.fc = torch.nn.Linear(num_ftrs, len(DISEASE_CLASSES)) |
| | model_path = os.path.join(BASE_DIR, "disease_model.bin") |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| | model.eval() |
| | return model |
| |
|
| | |
| | variety_model = load_variety_model() |
| | disease_model = load_disease_model() |
| |
|
| | |
| | MEAN = np.array([0.485, 0.456, 0.406]) |
| | STD = np.array([0.229, 0.224, 0.225]) |
| |
|
| | def preprocess_image(image_bytes): |
| | img = Image.open(BytesIO(image_bytes)).convert('RGB') |
| | img = img.resize((256, 256)) |
| | left = (256 - 224) // 2 |
| | top = (256 - 224) // 2 |
| | img = img.crop((left, top, left + 224, top + 224)) |
| | arr = np.asarray(img).astype(np.float32) / 255.0 |
| | arr = (arr - MEAN) / STD |
| | arr = np.transpose(arr, (2, 0, 1)) |
| | arr = np.expand_dims(arr, 0) |
| | return torch.tensor(arr, dtype=torch.float32) |
| |
|
| | |
| | def predict(image_bytes, model_type="variety"): |
| | img_tensor = preprocess_image(image_bytes) |
| | |
| | if model_type == "variety": |
| | model = variety_model |
| | class_names = VARIETY_CLASSES |
| | elif model_type == "disease": |
| | model = disease_model |
| | class_names = DISEASE_CLASSES |
| | else: |
| | raise ValueError("model_type must be 'variety' or 'disease'") |
| | |
| | with torch.no_grad(): |
| | output = model(img_tensor) |
| | probabilities = torch.nn.functional.softmax(output[0], dim=0) |
| | confidence, predicted_idx = torch.max(probabilities, 0) |
| | return { |
| | "prediction": class_names[predicted_idx.item()], |
| | "confidence": float(confidence) * 100 |
| | } |
| |
|