""" Utility functions for the Plant Disease Detection UI """ import torch import numpy as np from PIL import Image import torchvision.transforms as transforms import os IMAGE_SIZE = (256, 256) NORMALIZE_MEAN = [0.485, 0.456, 0.406] NORMALIZE_STD = [0.229, 0.224, 0.225] TOP_K_PREDICTIONS = 5 CONFIDENCE_THRESHOLD = 0.01 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Path to classNames.txt relative to this file CLASS_NAMES_FILE = os.path.join(BASE_DIR, "classNames.txt") with open(CLASS_NAMES_FILE, "r") as f: CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()] def preprocess_image(image): """ Preprocess image for model input """ if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) if image.mode == 'RGBA': image = image.convert('RGB') transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD) ]) tensor = transform(image) return tensor.unsqueeze(0) def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS): """ Convert logits to formatted predictions """ if class_names is None: class_names = CLASS_NAMES probs = torch.nn.functional.softmax(logits, dim=1) probs = probs.cpu().detach().numpy()[0] predictions = {name: float(prob) for name, prob in zip(class_names, probs)} top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] return dict(top_predictions), predictions def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD): """ Filter predictions for Gradio display """ return {k: v for k, v in predictions.items() if v >= confidence_threshold} def format_class_name(class_name): """ Format class name into readable form """ parts = class_name.split("___") if len(parts) == 2: plant, disease = parts plant = plant.replace("_", " ") disease = disease.replace("_", " ") return f"{plant} - {disease}" return class_name.replace("_", " ") def get_disease_info(class_name): """ Extract structured disease info from class name """ parts = class_name.split("___") return { "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown", "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown", "is_healthy": "healthy" in class_name.lower(), "formatted_name": format_class_name(class_name) } def create_confidence_label(predictions, top_k=5): """ Render a formatted multiline prediction list """ top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] lines = [ f"{i}. {format_class_name(name)}: {prob*100:.2f}%" for i, (name, prob) in enumerate(top_preds, 1) ] return "\n".join(lines) def get_class_names(): """Return the loaded class names from the txt file.""" return CLASS_NAMES