Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for the Plant Disease Detection UI | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import os | |
| IMAGE_SIZE = (256, 256) | |
| NORMALIZE_MEAN = [0.485, 0.456, 0.406] | |
| NORMALIZE_STD = [0.229, 0.224, 0.225] | |
| TOP_K_PREDICTIONS = 5 | |
| CONFIDENCE_THRESHOLD = 0.01 | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Path to classNames.txt relative to this file | |
| CLASS_NAMES_FILE = os.path.join(BASE_DIR, "classNames.txt") | |
| with open(CLASS_NAMES_FILE, "r") as f: | |
| CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()] | |
| def preprocess_image(image): | |
| """ | |
| Preprocess image for model input | |
| """ | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8')) | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| transform = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD) | |
| ]) | |
| tensor = transform(image) | |
| return tensor.unsqueeze(0) | |
| def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS): | |
| """ | |
| Convert logits to formatted predictions | |
| """ | |
| if class_names is None: | |
| class_names = CLASS_NAMES | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| probs = probs.cpu().detach().numpy()[0] | |
| predictions = {name: float(prob) for name, prob in zip(class_names, probs)} | |
| top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| return dict(top_predictions), predictions | |
| def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD): | |
| """ | |
| Filter predictions for Gradio display | |
| """ | |
| return {k: v for k, v in predictions.items() if v >= confidence_threshold} | |
| def format_class_name(class_name): | |
| """ | |
| Format class name into readable form | |
| """ | |
| parts = class_name.split("___") | |
| if len(parts) == 2: | |
| plant, disease = parts | |
| plant = plant.replace("_", " ") | |
| disease = disease.replace("_", " ") | |
| return f"{plant} - {disease}" | |
| return class_name.replace("_", " ") | |
| def get_disease_info(class_name): | |
| """ | |
| Extract structured disease info from class name | |
| """ | |
| parts = class_name.split("___") | |
| return { | |
| "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown", | |
| "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown", | |
| "is_healthy": "healthy" in class_name.lower(), | |
| "formatted_name": format_class_name(class_name) | |
| } | |
| def create_confidence_label(predictions, top_k=5): | |
| """ | |
| Render a formatted multiline prediction list | |
| """ | |
| top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| lines = [ | |
| f"{i}. {format_class_name(name)}: {prob*100:.2f}%" | |
| for i, (name, prob) in enumerate(top_preds, 1) | |
| ] | |
| return "\n".join(lines) | |
| def get_class_names(): | |
| """Return the loaded class names from the txt file.""" | |
| return CLASS_NAMES |