k23064919's picture
remove half implemented batch processing feature
bcc26c6
"""
Utility functions for the Plant Disease Detection UI
"""
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os
IMAGE_SIZE = (256, 256)
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
TOP_K_PREDICTIONS = 5
CONFIDENCE_THRESHOLD = 0.01
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Path to classNames.txt relative to this file
CLASS_NAMES_FILE = os.path.join(BASE_DIR, "classNames.txt")
with open(CLASS_NAMES_FILE, "r") as f:
CLASS_NAMES = [line.strip() for line in f.readlines() if line.strip()]
def preprocess_image(image):
"""
Preprocess image for model input
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
if image.mode == 'RGBA':
image = image.convert('RGB')
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
])
tensor = transform(image)
return tensor.unsqueeze(0)
def postprocess_predictions(logits, class_names=None, top_k=TOP_K_PREDICTIONS):
"""
Convert logits to formatted predictions
"""
if class_names is None:
class_names = CLASS_NAMES
probs = torch.nn.functional.softmax(logits, dim=1)
probs = probs.cpu().detach().numpy()[0]
predictions = {name: float(prob) for name, prob in zip(class_names, probs)}
top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
return dict(top_predictions), predictions
def format_prediction_for_display(predictions, confidence_threshold=CONFIDENCE_THRESHOLD):
"""
Filter predictions for Gradio display
"""
return {k: v for k, v in predictions.items() if v >= confidence_threshold}
def format_class_name(class_name):
"""
Format class name into readable form
"""
parts = class_name.split("___")
if len(parts) == 2:
plant, disease = parts
plant = plant.replace("_", " ")
disease = disease.replace("_", " ")
return f"{plant} - {disease}"
return class_name.replace("_", " ")
def get_disease_info(class_name):
"""
Extract structured disease info from class name
"""
parts = class_name.split("___")
return {
"plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown",
"disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown",
"is_healthy": "healthy" in class_name.lower(),
"formatted_name": format_class_name(class_name)
}
def create_confidence_label(predictions, top_k=5):
"""
Render a formatted multiline prediction list
"""
top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k]
lines = [
f"{i}. {format_class_name(name)}: {prob*100:.2f}%"
for i, (name, prob) in enumerate(top_preds, 1)
]
return "\n".join(lines)
def get_class_names():
"""Return the loaded class names from the txt file."""
return CLASS_NAMES