""" Visualization Module for CropDoctor-Semantic ============================================= This module provides visualization functions for: - Segmentation masks overlay - Severity heatmaps - Diagnostic dashboards - Comparison views """ import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import LinearSegmentedColormap from typing import List, Optional, Tuple, Union, Dict from pathlib import Path import logging logger = logging.getLogger(__name__) # Color schemes SEVERITY_COLORS = { 'healthy': '#2ECC71', # Green 'mild': '#F1C40F', # Yellow 'moderate': '#E67E22', # Orange 'severe': '#E74C3C' # Red } SYMPTOM_COLORS = [ '#E74C3C', # Red '#9B59B6', # Purple '#3498DB', # Blue '#E67E22', # Orange '#1ABC9C', # Teal '#F39C12', # Yellow '#D35400', # Dark Orange '#8E44AD', # Dark Purple ] def create_diagnostic_visualization( image: Union[str, Path, Image.Image, np.ndarray], masks: Optional[np.ndarray] = None, severity_label: str = "unknown", disease_name: str = "Unknown", affected_percent: float = 0.0, prompt_labels: Optional[List[str]] = None, figsize: Tuple[int, int] = (16, 6) ) -> plt.Figure: """ Create a comprehensive diagnostic visualization. Args: image: Input image masks: Segmentation masks array (N, H, W) severity_label: Severity classification result disease_name: Identified disease name affected_percent: Percentage of affected area prompt_labels: Labels for each mask figsize: Figure size Returns: matplotlib Figure object """ # Load image if isinstance(image, (str, Path)): image = Image.open(image).convert("RGB") elif isinstance(image, np.ndarray): image = Image.fromarray(image) img_array = np.array(image) # Create figure with subplots fig, axes = plt.subplots(1, 3, figsize=figsize) fig.suptitle(f'CropDoctor Diagnostic Report', fontsize=14, fontweight='bold') # Panel 1: Original Image axes[0].imshow(img_array) axes[0].set_title('Original Image', fontsize=12) axes[0].axis('off') # Panel 2: Segmentation Overlay if masks is not None and len(masks) > 0: overlay = create_mask_overlay(img_array, masks, alpha=0.5) axes[1].imshow(overlay) # Create legend if prompt_labels: patches = [] for i, label in enumerate(prompt_labels[:len(SYMPTOM_COLORS)]): color = SYMPTOM_COLORS[i % len(SYMPTOM_COLORS)] patches.append(mpatches.Patch(color=color, label=label, alpha=0.7)) axes[1].legend(handles=patches, loc='upper right', fontsize=8) else: axes[1].imshow(img_array) axes[1].text(0.5, 0.5, 'No disease regions detected', transform=axes[1].transAxes, ha='center', va='center', fontsize=12, color='green') axes[1].set_title('Disease Detection', fontsize=12) axes[1].axis('off') # Panel 3: Diagnostic Summary axes[2].axis('off') # Create summary text severity_color = SEVERITY_COLORS.get(severity_label.lower(), '#95A5A6') summary_text = f""" ╔════════════════════════════════╗ ║ DIAGNOSTIC SUMMARY ║ ╚════════════════════════════════╝ 📋 Disease: {disease_name} ⚠️ Severity: {severity_label.upper()} 📊 Affected Area: {affected_percent:.1f}% """ # Add severity indicator axes[2].text(0.5, 0.65, summary_text, transform=axes[2].transAxes, fontsize=11, fontfamily='monospace', verticalalignment='top', horizontalalignment='center') # Add severity color bar severity_bar = plt.Rectangle((0.15, 0.25), 0.7, 0.1, facecolor=severity_color, edgecolor='black', transform=axes[2].transAxes) axes[2].add_patch(severity_bar) axes[2].text(0.5, 0.30, severity_label.upper(), transform=axes[2].transAxes, ha='center', va='center', fontsize=12, fontweight='bold', color='white') # Add affected area progress bar bar_width = 0.7 * (affected_percent / 100) bg_bar = plt.Rectangle((0.15, 0.12), 0.7, 0.06, facecolor='#EEEEEE', edgecolor='black', transform=axes[2].transAxes) progress_bar = plt.Rectangle((0.15, 0.12), max(0.01, bar_width), 0.06, facecolor='#E74C3C', transform=axes[2].transAxes) axes[2].add_patch(bg_bar) axes[2].add_patch(progress_bar) axes[2].text(0.5, 0.08, f'Affected Area: {affected_percent:.1f}%', transform=axes[2].transAxes, ha='center', fontsize=10) plt.tight_layout() return fig def create_mask_overlay( image: np.ndarray, masks: np.ndarray, alpha: float = 0.5, colors: Optional[List[str]] = None ) -> np.ndarray: """ Create an overlay of segmentation masks on an image. Args: image: RGB image array (H, W, 3) masks: Binary masks (N, H, W) alpha: Transparency of overlay colors: Optional list of colors for masks Returns: Image array with mask overlay """ if colors is None: colors = SYMPTOM_COLORS # Start with the original image overlay = image.copy().astype(np.float32) for i, mask in enumerate(masks): if mask.any(): # Get color for this mask color_hex = colors[i % len(colors)] color_rgb = hex_to_rgb(color_hex) # Create colored mask colored_mask = np.zeros_like(overlay) colored_mask[mask] = color_rgb # Blend with overlay mask_3d = np.stack([mask] * 3, axis=-1) overlay = np.where( mask_3d, overlay * (1 - alpha) + colored_mask * alpha, overlay ) return overlay.astype(np.uint8) def create_severity_heatmap( image: Union[str, Path, Image.Image, np.ndarray], severity_map: np.ndarray, figsize: Tuple[int, int] = (12, 5) ) -> plt.Figure: """ Create a heatmap showing severity distribution across the image. Args: image: Input image severity_map: Array of severity values (H, W) with values 0-3 figsize: Figure size Returns: matplotlib Figure object """ # Load image if isinstance(image, (str, Path)): image = Image.open(image).convert("RGB") elif isinstance(image, np.ndarray): image = Image.fromarray(image) img_array = np.array(image) # Create custom colormap colors = ['#2ECC71', '#F1C40F', '#E67E22', '#E74C3C'] # Green to Red cmap = LinearSegmentedColormap.from_list('severity', colors, N=256) fig, axes = plt.subplots(1, 2, figsize=figsize) # Original image axes[0].imshow(img_array) axes[0].set_title('Original Image') axes[0].axis('off') # Heatmap overlay axes[1].imshow(img_array) heatmap = axes[1].imshow(severity_map, cmap=cmap, alpha=0.6, vmin=0, vmax=3) axes[1].set_title('Severity Heatmap') axes[1].axis('off') # Add colorbar cbar = plt.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04) cbar.set_ticks([0, 1, 2, 3]) cbar.set_ticklabels(['Healthy', 'Mild', 'Moderate', 'Severe']) plt.tight_layout() return fig def create_comparison_view( images: List[Union[str, Path, Image.Image]], results: List[Dict], cols: int = 4, figsize_per_image: Tuple[float, float] = (4, 5) ) -> plt.Figure: """ Create a grid comparison view of multiple diagnoses. Args: images: List of images results: List of diagnostic results (dicts with 'severity_label', 'disease_name', etc.) cols: Number of columns in grid figsize_per_image: Size per image in the grid Returns: matplotlib Figure object """ n_images = len(images) rows = (n_images + cols - 1) // cols fig, axes = plt.subplots( rows, cols, figsize=(figsize_per_image[0] * cols, figsize_per_image[1] * rows) ) if rows == 1: axes = [axes] if cols == 1: axes = [[ax] for ax in axes] for i, (img, result) in enumerate(zip(images, results)): row = i // cols col = i % cols ax = axes[row][col] if rows > 1 else axes[col] # Load image if isinstance(img, (str, Path)): img = Image.open(img).convert("RGB") ax.imshow(img) ax.axis('off') # Add colored border based on severity severity = result.get('severity_label', 'unknown') color = SEVERITY_COLORS.get(severity.lower(), '#95A5A6') for spine in ax.spines.values(): spine.set_edgecolor(color) spine.set_linewidth(4) spine.set_visible(True) # Add label ax.set_title( f"{result.get('disease_name', 'Unknown')}\n{severity.upper()}", fontsize=10, color=color ) # Hide empty subplots for i in range(n_images, rows * cols): row = i // cols col = i % cols ax = axes[row][col] if rows > 1 else axes[col] ax.axis('off') ax.set_visible(False) plt.tight_layout() return fig def create_treatment_card( result: Dict, figsize: Tuple[int, int] = (8, 10) ) -> plt.Figure: """ Create a treatment recommendation card. Args: result: Diagnostic result dictionary figsize: Figure size Returns: matplotlib Figure object """ fig, ax = plt.subplots(figsize=figsize) ax.axis('off') severity_color = SEVERITY_COLORS.get( result.get('severity_label', 'unknown').lower(), '#95A5A6' ) # Title ax.text(0.5, 0.95, '🌿 TREATMENT CARD', ha='center', va='top', fontsize=16, fontweight='bold', transform=ax.transAxes) # Disease info disease_text = f""" ╔═══════════════════════════════════════════╗ ║ Disease: {result.get('disease_name', 'Unknown'):<32}║ ║ Type: {result.get('disease_type', 'unknown'):<35}║ ║ Severity: {result.get('severity_label', 'unknown').upper():<31}║ ║ Affected Area: {result.get('affected_area_percent', 0):.1f}%{' ' * 25}║ ╚═══════════════════════════════════════════╝ """ ax.text(0.5, 0.85, disease_text, ha='center', va='top', fontfamily='monospace', fontsize=10, transform=ax.transAxes) # Treatments y_pos = 0.60 # Organic treatments ax.text(0.1, y_pos, '🌱 ORGANIC TREATMENTS', fontweight='bold', fontsize=11, transform=ax.transAxes) y_pos -= 0.03 for treatment in result.get('organic_treatments', [])[:4]: ax.text(0.12, y_pos, f'• {treatment[:50]}', fontsize=9, transform=ax.transAxes) y_pos -= 0.03 y_pos -= 0.02 # Chemical treatments if result.get('chemical_treatments'): ax.text(0.1, y_pos, '🧪 CHEMICAL TREATMENTS', fontweight='bold', fontsize=11, transform=ax.transAxes) y_pos -= 0.03 for treatment in result.get('chemical_treatments', [])[:3]: ax.text(0.12, y_pos, f'• {treatment[:50]}', fontsize=9, transform=ax.transAxes) y_pos -= 0.03 y_pos -= 0.02 # Prevention ax.text(0.1, y_pos, '🛡️ PREVENTION', fontweight='bold', fontsize=11, transform=ax.transAxes) y_pos -= 0.03 for measure in result.get('preventive_measures', [])[:4]: ax.text(0.12, y_pos, f'• {measure[:50]}', fontsize=9, transform=ax.transAxes) y_pos -= 0.03 # Timing y_pos -= 0.02 ax.text(0.1, y_pos, f"⏰ TIMING: {result.get('treatment_timing', 'Consult expert')[:60]}", fontsize=9, transform=ax.transAxes) # Add border rect = plt.Rectangle((0.05, 0.05), 0.9, 0.92, fill=False, edgecolor=severity_color, linewidth=3, transform=ax.transAxes) ax.add_patch(rect) return fig def hex_to_rgb(hex_color: str) -> List[int]: """Convert hex color to RGB.""" hex_color = hex_color.lstrip('#') return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)] def save_visualization( fig: plt.Figure, output_path: Union[str, Path], dpi: int = 150 ): """Save figure to file.""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white') plt.close(fig) logger.info(f"Visualization saved to {output_path}") if __name__ == "__main__": # Test visualizations import numpy as np # Create test image test_img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) test_img[:, :, 1] = 139 # Greenish tint # Create test masks test_masks = np.zeros((2, 480, 640), dtype=bool) test_masks[0, 100:200, 100:200] = True # Square mask test_masks[1, 300:400, 400:500] = True # Another square # Test diagnostic visualization fig = create_diagnostic_visualization( test_img, test_masks, severity_label="moderate", disease_name="Leaf Spot Disease", affected_percent=15.5, prompt_labels=["brown spots", "yellowing"] ) save_visualization(fig, "/tmp/test_diagnostic.png") print("Test visualization saved to /tmp/test_diagnostic.png")