Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for the Plant Disease Detection UI | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import config | |
| def preprocess_image(image, image_size=config.IMAGE_SIZE): | |
| """ | |
| Preprocess image for model input | |
| Args: | |
| image: PIL Image or numpy array | |
| image_size: Target size (height, width) | |
| Returns: | |
| Preprocessed tensor ready for model | |
| """ | |
| # Convert to PIL Image if numpy array | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8')) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Define preprocessing transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD) | |
| ]) | |
| # Apply transforms | |
| tensor = transform(image) | |
| # Add batch dimension | |
| tensor = tensor.unsqueeze(0) | |
| return tensor | |
| def postprocess_predictions(logits, class_names=config.CLASS_NAMES, top_k=config.TOP_K_PREDICTIONS): | |
| """ | |
| Convert model logits to human-readable predictions | |
| Args: | |
| logits: Raw model output | |
| class_names: List of class names | |
| top_k: Number of top predictions to return | |
| Returns: | |
| Dictionary of predictions with confidences | |
| """ | |
| # Convert logits to probabilities using softmax | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| # Convert to numpy | |
| probs = probs.cpu().detach().numpy()[0] | |
| # Create predictions dictionary | |
| predictions = {name: float(prob) for name, prob in zip(class_names, probs)} | |
| # Get top-k predictions | |
| top_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| return dict(top_predictions), predictions | |
| def format_prediction_for_display(predictions): | |
| """ | |
| Format predictions for Gradio display | |
| Args: | |
| predictions: Dictionary of class names and probabilities | |
| Returns: | |
| Dictionary formatted for Gradio Label component | |
| """ | |
| # Filter out very low confidence predictions | |
| filtered = {k: v for k, v in predictions.items() if v >= config.CONFIDENCE_THRESHOLD} | |
| return filtered | |
| def format_class_name(class_name): | |
| """ | |
| Format class name for better readability | |
| Args: | |
| class_name: Original class name (e.g., "Tomato___Late_blight") | |
| Returns: | |
| Formatted class name (e.g., "Tomato - Late blight") | |
| """ | |
| # Replace underscores with spaces and split on ___ | |
| parts = class_name.split("___") | |
| if len(parts) == 2: | |
| plant, disease = parts | |
| plant = plant.replace("_", " ") | |
| disease = disease.replace("_", " ") | |
| return f"{plant} - {disease}" | |
| else: | |
| return class_name.replace("_", " ") | |
| def get_disease_info(class_name): | |
| """ | |
| Get information about a disease (for future enhancement) | |
| Args: | |
| class_name: Disease class name | |
| Returns: | |
| Dictionary with disease information | |
| """ | |
| # This is a placeholder - you could expand this with actual disease information | |
| parts = class_name.split("___") | |
| info = { | |
| "plant": parts[0].replace("_", " ") if len(parts) > 0 else "Unknown", | |
| "disease": parts[1].replace("_", " ") if len(parts) > 1 else "Unknown", | |
| "is_healthy": "healthy" in class_name.lower(), | |
| "formatted_name": format_class_name(class_name) | |
| } | |
| return info | |
| def batch_preprocess_images(images): | |
| """ | |
| Preprocess multiple images for batch prediction | |
| Args: | |
| images: List of PIL Images or numpy arrays | |
| Returns: | |
| Batched tensor ready for model | |
| """ | |
| tensors = [preprocess_image(img) for img in images] | |
| batch = torch.cat(tensors, dim=0) | |
| return batch | |
| def create_confidence_label(predictions, top_k=5): | |
| """ | |
| Create a formatted string showing top predictions | |
| Args: | |
| predictions: Dictionary of predictions | |
| top_k: Number of top predictions to show | |
| Returns: | |
| Formatted string | |
| """ | |
| top_preds = sorted(predictions.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| lines = [] | |
| for i, (class_name, prob) in enumerate(top_preds, 1): | |
| formatted_name = format_class_name(class_name) | |
| lines.append(f"{i}. {formatted_name}: {prob*100:.2f}%") | |
| return "\n".join(lines) | |
| if __name__ == "__main__": | |
| # Test utilities | |
| print("Testing utility functions...") | |
| # Test class name formatting | |
| test_names = [ | |
| "Tomato___Late_blight", | |
| "Apple___healthy", | |
| "Corn_(maize)___Common_rust_" | |
| ] | |
| print("\nClass name formatting:") | |
| for name in test_names: | |
| print(f" {name} -> {format_class_name(name)}") | |
| # Test disease info | |
| print("\nDisease info:") | |
| for name in test_names: | |
| info = get_disease_info(name) | |
| print(f" {name}:") | |
| print(f" Plant: {info['plant']}") | |
| print(f" Disease: {info['disease']}") | |
| print(f" Healthy: {info['is_healthy']}") | |
| # Test image preprocessing | |
| print("\nImage preprocessing:") | |
| dummy_image = Image.new('RGB', (512, 512), color='red') | |
| tensor = preprocess_image(dummy_image) | |
| print(f" Input size: {dummy_image.size}") | |
| print(f" Output tensor shape: {tensor.shape}") | |
| # Test mock predictions | |
| print("\nMock predictions:") | |
| from models.mock_model import create_mock_predictions | |
| preds = create_mock_predictions(config.CLASS_NAMES) | |
| top_preds, all_preds = postprocess_predictions( | |
| torch.tensor([list(preds.values())]), | |
| config.CLASS_NAMES | |
| ) | |
| print(create_confidence_label(top_preds)) | |