""" Utility functions for the Plant Disease Detection UI """ import torch import numpy as np from PIL import Image import torchvision.transforms as transforms import config def preprocess_image(image, image_size=config.IMAGE_SIZE): """ Preprocess image for model input Args: image: PIL Image or numpy array image_size: Target size (height, width) Returns: Preprocessed tensor ready for model """ # Convert to PIL Image if numpy array if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8')) # Convert RGBA to RGB if necessary if image.mode == 'RGBA': image = image.convert('RGB') # Define preprocessing transforms transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD) ]) # Apply transforms tensor = transform(image) # Add batch dimension tensor = tensor.unsqueeze(0) return tensor def postprocess_predictions(logits, class_names=config.CLASS_NAMES, top_k=config.TOP_K_PREDICTIONS): """ Convert model logits to human-readable predictions Args: logits: Raw model output class_names: List of class names top_k: Number of top predictions to return Returns: Dictionary of predictions with confidences """ # Convert logits to probabilities using softmax probs = torch.nn.functional.softmax(logits, dim=1) # Convert to numpy probs = probs.cpu().detach().numpy()[0] # Create predictions dictionary predictions = {name: float(prob) for name, prob in zip(class_names, probs)} # Get top-k predictions top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] return dict(top_predictions), predictions def format_prediction_for_display(predictions): """ Format predictions for Gradio display Args: predictions: Dictionary of class names and probabilities Returns: Dictionary formatted for Gradio Label component """ # Filter out very low confidence predictions filtered = {k: v for k, v in predictions.items() if v >= config.CONFIDENCE_THRESHOLD} return filtered def format_class_name(class_name): """ Format class name for better readability Args: class_name: Original class name (e.g., "Tomato___Late_blight") Returns: Formatted class name (e.g., "Tomato - Late blight") """ # Replace underscores with spaces and split on ___ parts = class_name.split("___") if len(parts) == 2: plant, disease = parts plant = plant.replace("_", " ") disease = disease.replace("_", " ") return f"{plant} - {disease}" else: return class_name.replace("_", " ") def get_disease_info(class_name): """ Get information about a disease (for future enhancement) Args: class_name: Disease class name Returns: Dictionary with disease information """ # This is a placeholder - you could expand this with actual disease information parts = class_name.split("___") info = { "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown", "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown", "is_healthy": "healthy" in class_name.lower(), "formatted_name": format_class_name(class_name) } return info def batch_preprocess_images(images): """ Preprocess multiple images for batch prediction Args: images: List of PIL Images or numpy arrays Returns: Batched tensor ready for model """ tensors = [preprocess_image(img) for img in images] batch = torch.cat(tensors, dim=0) return batch def create_confidence_label(predictions, top_k=5): """ Create a formatted string showing top predictions Args: predictions: Dictionary of predictions top_k: Number of top predictions to show Returns: Formatted string """ top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] lines = [] for i, (class_name, prob) in enumerate(top_preds, 1): formatted_name = format_class_name(class_name) lines.append(f"{i}. {formatted_name}: {prob*100:.2f}%") return "\n".join(lines) if __name__ == "__main__": # Test utilities print("Testing utility functions...") # Test class name formatting test_names = [ "Tomato___Late_blight", "Apple___healthy", "Corn_(maize)___Common_rust_" ] print("\nClass name formatting:") for name in test_names: print(f" {name} -> {format_class_name(name)}") # Test disease info print("\nDisease info:") for name in test_names: info = get_disease_info(name) print(f" {name}:") print(f" Plant: {info['plant']}") print(f" Disease: {info['disease']}") print(f" Healthy: {info['is_healthy']}") # Test image preprocessing print("\nImage preprocessing:") dummy_image = Image.new('RGB', (512, 512), color='red') tensor = preprocess_image(dummy_image) print(f" Input size: {dummy_image.size}") print(f" Output tensor shape: {tensor.shape}") # Test mock predictions print("\nMock predictions:") from models.mock_model import create_mock_predictions preds = create_mock_predictions(config.CLASS_NAMES) top_preds, all_preds = postprocess_predictions( torch.tensor([list(preds.values())]), config.CLASS_NAMES ) print(create_confidence_label(top_preds))