Spaces:
Sleeping
Sleeping
File size: 3,088 Bytes
505fc99 27a811d 505fc99 39b825d 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 27a811d e346658 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 e346658 505fc99 8bd7feb 505fc99 e346658 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 8bd7feb 505fc99 e346658 bcc26c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
Utility functions for the Plant Disease Detection UI
"""
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os
IMAGE_SIZE = (256, 256)
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
TOP_K_PREDICTIONS = 5
CONFIDENCE_THRESHOLD = 0.01
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Path to classNames.txt relative to this file
CLASS_NAMES_FILE = os.path.join(BASE_DIR, "classNames.txt")
with open(CLASS_NAMES_FILE, "r") as f:
CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()]
def preprocess_image(image):
"""
Preprocess image for model input
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
if image.mode == 'RGBA':
image = image.convert('RGB')
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
])
tensor = transform(image)
return tensor.unsqueeze(0)
def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS):
"""
Convert logits to formatted predictions
"""
if class_names is None:
class_names = CLASS_NAMES
probs = torch.nn.functional.softmax(logits, dim=1)
probs = probs.cpu().detach().numpy()[0]
predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
return dict(top_predictions), predictions
def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD):
"""
Filter predictions for Gradio display
"""
return {k: v for k, v in predictions.items() if v >= confidence_threshold}
def format_class_name(class_name):
"""
Format class name into readable form
"""
parts = class_name.split("___")
if len(parts) == 2:
plant, disease = parts
plant = plant.replace("_", " ")
disease = disease.replace("_", " ")
return f"{plant} - {disease}"
return class_name.replace("_", " ")
def get_disease_info(class_name):
"""
Extract structured disease info from class name
"""
parts = class_name.split("___")
return {
"plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
"disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
"is_healthy": "healthy" in class_name.lower(),
"formatted_name": format_class_name(class_name)
}
def create_confidence_label(predictions, top_k=5):
"""
Render a formatted multiline prediction list
"""
top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
lines = [
f"{i}. {format_class_name(name)}: {prob*100:.2f}%"
for i, (name, prob) in enumerate(top_preds, 1)
]
return "\n".join(lines)
def get_class_names():
"""Return the loaded class names from the txt file."""
return CLASS_NAMES |