Spaces:
Sleeping
Sleeping
| """ | |
| Visualization Module for CropDoctor-Semantic | |
| ============================================= | |
| This module provides visualization functions for: | |
| - Segmentation masks overlay | |
| - Severity heatmaps | |
| - Diagnostic dashboards | |
| - Comparison views | |
| """ | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from typing import List, Optional, Tuple, Union, Dict | |
| from pathlib import Path | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Color schemes | |
| SEVERITY_COLORS = { | |
| 'healthy': '#2ECC71', # Green | |
| 'mild': '#F1C40F', # Yellow | |
| 'moderate': '#E67E22', # Orange | |
| 'severe': '#E74C3C' # Red | |
| } | |
| SYMPTOM_COLORS = [ | |
| '#E74C3C', # Red | |
| '#9B59B6', # Purple | |
| '#3498DB', # Blue | |
| '#E67E22', # Orange | |
| '#1ABC9C', # Teal | |
| '#F39C12', # Yellow | |
| '#D35400', # Dark Orange | |
| '#8E44AD', # Dark Purple | |
| ] | |
| def create_diagnostic_visualization( | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| masks: Optional[np.ndarray] = None, | |
| severity_label: str = "unknown", | |
| disease_name: str = "Unknown", | |
| affected_percent: float = 0.0, | |
| prompt_labels: Optional[List[str]] = None, | |
| figsize: Tuple[int, int] = (16, 6) | |
| ) -> plt.Figure: | |
| """ | |
| Create a comprehensive diagnostic visualization. | |
| Args: | |
| image: Input image | |
| masks: Segmentation masks array (N, H, W) | |
| severity_label: Severity classification result | |
| disease_name: Identified disease name | |
| affected_percent: Percentage of affected area | |
| prompt_labels: Labels for each mask | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| # Load image | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| img_array = np.array(image) | |
| # Create figure with subplots | |
| fig, axes = plt.subplots(1, 3, figsize=figsize) | |
| fig.suptitle(f'CropDoctor Diagnostic Report', fontsize=14, fontweight='bold') | |
| # Panel 1: Original Image | |
| axes[0].imshow(img_array) | |
| axes[0].set_title('Original Image', fontsize=12) | |
| axes[0].axis('off') | |
| # Panel 2: Segmentation Overlay | |
| if masks is not None and len(masks) > 0: | |
| overlay = create_mask_overlay(img_array, masks, alpha=0.5) | |
| axes[1].imshow(overlay) | |
| # Create legend | |
| if prompt_labels: | |
| patches = [] | |
| for i, label in enumerate(prompt_labels[:len(SYMPTOM_COLORS)]): | |
| color = SYMPTOM_COLORS[i % len(SYMPTOM_COLORS)] | |
| patches.append(mpatches.Patch(color=color, label=label, alpha=0.7)) | |
| axes[1].legend(handles=patches, loc='upper right', fontsize=8) | |
| else: | |
| axes[1].imshow(img_array) | |
| axes[1].text(0.5, 0.5, 'No disease regions detected', | |
| transform=axes[1].transAxes, ha='center', va='center', | |
| fontsize=12, color='green') | |
| axes[1].set_title('Disease Detection', fontsize=12) | |
| axes[1].axis('off') | |
| # Panel 3: Diagnostic Summary | |
| axes[2].axis('off') | |
| # Create summary text | |
| severity_color = SEVERITY_COLORS.get(severity_label.lower(), '#95A5A6') | |
| summary_text = f""" | |
| ββββββββββββββββββββββββββββββββββ | |
| β DIAGNOSTIC SUMMARY β | |
| ββββββββββββββββββββββββββββββββββ | |
| π Disease: {disease_name} | |
| β οΈ Severity: {severity_label.upper()} | |
| π Affected Area: {affected_percent:.1f}% | |
| """ | |
| # Add severity indicator | |
| axes[2].text(0.5, 0.65, summary_text, transform=axes[2].transAxes, | |
| fontsize=11, fontfamily='monospace', | |
| verticalalignment='top', horizontalalignment='center') | |
| # Add severity color bar | |
| severity_bar = plt.Rectangle((0.15, 0.25), 0.7, 0.1, | |
| facecolor=severity_color, | |
| edgecolor='black', | |
| transform=axes[2].transAxes) | |
| axes[2].add_patch(severity_bar) | |
| axes[2].text(0.5, 0.30, severity_label.upper(), | |
| transform=axes[2].transAxes, ha='center', va='center', | |
| fontsize=12, fontweight='bold', color='white') | |
| # Add affected area progress bar | |
| bar_width = 0.7 * (affected_percent / 100) | |
| bg_bar = plt.Rectangle((0.15, 0.12), 0.7, 0.06, | |
| facecolor='#EEEEEE', edgecolor='black', | |
| transform=axes[2].transAxes) | |
| progress_bar = plt.Rectangle((0.15, 0.12), max(0.01, bar_width), 0.06, | |
| facecolor='#E74C3C', | |
| transform=axes[2].transAxes) | |
| axes[2].add_patch(bg_bar) | |
| axes[2].add_patch(progress_bar) | |
| axes[2].text(0.5, 0.08, f'Affected Area: {affected_percent:.1f}%', | |
| transform=axes[2].transAxes, ha='center', fontsize=10) | |
| plt.tight_layout() | |
| return fig | |
| def create_mask_overlay( | |
| image: np.ndarray, | |
| masks: np.ndarray, | |
| alpha: float = 0.5, | |
| colors: Optional[List[str]] = None | |
| ) -> np.ndarray: | |
| """ | |
| Create an overlay of segmentation masks on an image. | |
| Args: | |
| image: RGB image array (H, W, 3) | |
| masks: Binary masks (N, H, W) | |
| alpha: Transparency of overlay | |
| colors: Optional list of colors for masks | |
| Returns: | |
| Image array with mask overlay | |
| """ | |
| if colors is None: | |
| colors = SYMPTOM_COLORS | |
| # Start with the original image | |
| overlay = image.copy().astype(np.float32) | |
| for i, mask in enumerate(masks): | |
| if mask.any(): | |
| # Get color for this mask | |
| color_hex = colors[i % len(colors)] | |
| color_rgb = hex_to_rgb(color_hex) | |
| # Create colored mask | |
| colored_mask = np.zeros_like(overlay) | |
| colored_mask[mask] = color_rgb | |
| # Blend with overlay | |
| mask_3d = np.stack([mask] * 3, axis=-1) | |
| overlay = np.where( | |
| mask_3d, | |
| overlay * (1 - alpha) + colored_mask * alpha, | |
| overlay | |
| ) | |
| return overlay.astype(np.uint8) | |
| def create_severity_heatmap( | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| severity_map: np.ndarray, | |
| figsize: Tuple[int, int] = (12, 5) | |
| ) -> plt.Figure: | |
| """ | |
| Create a heatmap showing severity distribution across the image. | |
| Args: | |
| image: Input image | |
| severity_map: Array of severity values (H, W) with values 0-3 | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| # Load image | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| img_array = np.array(image) | |
| # Create custom colormap | |
| colors = ['#2ECC71', '#F1C40F', '#E67E22', '#E74C3C'] # Green to Red | |
| cmap = LinearSegmentedColormap.from_list('severity', colors, N=256) | |
| fig, axes = plt.subplots(1, 2, figsize=figsize) | |
| # Original image | |
| axes[0].imshow(img_array) | |
| axes[0].set_title('Original Image') | |
| axes[0].axis('off') | |
| # Heatmap overlay | |
| axes[1].imshow(img_array) | |
| heatmap = axes[1].imshow(severity_map, cmap=cmap, alpha=0.6, vmin=0, vmax=3) | |
| axes[1].set_title('Severity Heatmap') | |
| axes[1].axis('off') | |
| # Add colorbar | |
| cbar = plt.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04) | |
| cbar.set_ticks([0, 1, 2, 3]) | |
| cbar.set_ticklabels(['Healthy', 'Mild', 'Moderate', 'Severe']) | |
| plt.tight_layout() | |
| return fig | |
| def create_comparison_view( | |
| images: List[Union[str, Path, Image.Image]], | |
| results: List[Dict], | |
| cols: int = 4, | |
| figsize_per_image: Tuple[float, float] = (4, 5) | |
| ) -> plt.Figure: | |
| """ | |
| Create a grid comparison view of multiple diagnoses. | |
| Args: | |
| images: List of images | |
| results: List of diagnostic results (dicts with 'severity_label', 'disease_name', etc.) | |
| cols: Number of columns in grid | |
| figsize_per_image: Size per image in the grid | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| n_images = len(images) | |
| rows = (n_images + cols - 1) // cols | |
| fig, axes = plt.subplots( | |
| rows, cols, | |
| figsize=(figsize_per_image[0] * cols, figsize_per_image[1] * rows) | |
| ) | |
| if rows == 1: | |
| axes = [axes] | |
| if cols == 1: | |
| axes = [[ax] for ax in axes] | |
| for i, (img, result) in enumerate(zip(images, results)): | |
| row = i // cols | |
| col = i % cols | |
| ax = axes[row][col] if rows > 1 else axes[col] | |
| # Load image | |
| if isinstance(img, (str, Path)): | |
| img = Image.open(img).convert("RGB") | |
| ax.imshow(img) | |
| ax.axis('off') | |
| # Add colored border based on severity | |
| severity = result.get('severity_label', 'unknown') | |
| color = SEVERITY_COLORS.get(severity.lower(), '#95A5A6') | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor(color) | |
| spine.set_linewidth(4) | |
| spine.set_visible(True) | |
| # Add label | |
| ax.set_title( | |
| f"{result.get('disease_name', 'Unknown')}\n{severity.upper()}", | |
| fontsize=10, | |
| color=color | |
| ) | |
| # Hide empty subplots | |
| for i in range(n_images, rows * cols): | |
| row = i // cols | |
| col = i % cols | |
| ax = axes[row][col] if rows > 1 else axes[col] | |
| ax.axis('off') | |
| ax.set_visible(False) | |
| plt.tight_layout() | |
| return fig | |
| def create_treatment_card( | |
| result: Dict, | |
| figsize: Tuple[int, int] = (8, 10) | |
| ) -> plt.Figure: | |
| """ | |
| Create a treatment recommendation card. | |
| Args: | |
| result: Diagnostic result dictionary | |
| figsize: Figure size | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| ax.axis('off') | |
| severity_color = SEVERITY_COLORS.get( | |
| result.get('severity_label', 'unknown').lower(), | |
| '#95A5A6' | |
| ) | |
| # Title | |
| ax.text(0.5, 0.95, 'πΏ TREATMENT CARD', | |
| ha='center', va='top', fontsize=16, fontweight='bold', | |
| transform=ax.transAxes) | |
| # Disease info | |
| disease_text = f""" | |
| βββββββββββββββββββββββββββββββββββββββββββββ | |
| β Disease: {result.get('disease_name', 'Unknown'):<32}β | |
| β Type: {result.get('disease_type', 'unknown'):<35}β | |
| β Severity: {result.get('severity_label', 'unknown').upper():<31}β | |
| β Affected Area: {result.get('affected_area_percent', 0):.1f}%{' ' * 25}β | |
| βββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| ax.text(0.5, 0.85, disease_text, | |
| ha='center', va='top', fontfamily='monospace', fontsize=10, | |
| transform=ax.transAxes) | |
| # Treatments | |
| y_pos = 0.60 | |
| # Organic treatments | |
| ax.text(0.1, y_pos, 'π± ORGANIC TREATMENTS', fontweight='bold', fontsize=11, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| for treatment in result.get('organic_treatments', [])[:4]: | |
| ax.text(0.12, y_pos, f'β’ {treatment[:50]}', fontsize=9, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| y_pos -= 0.02 | |
| # Chemical treatments | |
| if result.get('chemical_treatments'): | |
| ax.text(0.1, y_pos, 'π§ͺ CHEMICAL TREATMENTS', fontweight='bold', fontsize=11, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| for treatment in result.get('chemical_treatments', [])[:3]: | |
| ax.text(0.12, y_pos, f'β’ {treatment[:50]}', fontsize=9, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| y_pos -= 0.02 | |
| # Prevention | |
| ax.text(0.1, y_pos, 'π‘οΈ PREVENTION', fontweight='bold', fontsize=11, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| for measure in result.get('preventive_measures', [])[:4]: | |
| ax.text(0.12, y_pos, f'β’ {measure[:50]}', fontsize=9, | |
| transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| # Timing | |
| y_pos -= 0.02 | |
| ax.text(0.1, y_pos, f"β° TIMING: {result.get('treatment_timing', 'Consult expert')[:60]}", | |
| fontsize=9, transform=ax.transAxes) | |
| # Add border | |
| rect = plt.Rectangle((0.05, 0.05), 0.9, 0.92, | |
| fill=False, edgecolor=severity_color, linewidth=3, | |
| transform=ax.transAxes) | |
| ax.add_patch(rect) | |
| return fig | |
| def hex_to_rgb(hex_color: str) -> List[int]: | |
| """Convert hex color to RGB.""" | |
| hex_color = hex_color.lstrip('#') | |
| return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)] | |
| def save_visualization( | |
| fig: plt.Figure, | |
| output_path: Union[str, Path], | |
| dpi: int = 150 | |
| ): | |
| """Save figure to file.""" | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white') | |
| plt.close(fig) | |
| logger.info(f"Visualization saved to {output_path}") | |
| if __name__ == "__main__": | |
| # Test visualizations | |
| import numpy as np | |
| # Create test image | |
| test_img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) | |
| test_img[:, :, 1] = 139 # Greenish tint | |
| # Create test masks | |
| test_masks = np.zeros((2, 480, 640), dtype=bool) | |
| test_masks[0, 100:200, 100:200] = True # Square mask | |
| test_masks[1, 300:400, 400:500] = True # Another square | |
| # Test diagnostic visualization | |
| fig = create_diagnostic_visualization( | |
| test_img, | |
| test_masks, | |
| severity_label="moderate", | |
| disease_name="Leaf Spot Disease", | |
| affected_percent=15.5, | |
| prompt_labels=["brown spots", "yellowing"] | |
| ) | |
| save_visualization(fig, "/tmp/test_diagnostic.png") | |
| print("Test visualization saved to /tmp/test_diagnostic.png") | |