cropscan-space / src /visualization.py
davidsv
Add disease detection app with RF-DETR and SAM2
f8eb07d
"""
Visualization Module for CropDoctor-Semantic
=============================================
This module provides visualization functions for:
- Segmentation masks overlay
- Severity heatmaps
- Diagnostic dashboards
- Comparison views
"""
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
from typing import List, Optional, Tuple, Union, Dict
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
# Color schemes
SEVERITY_COLORS = {
'healthy': '#2ECC71', # Green
'mild': '#F1C40F', # Yellow
'moderate': '#E67E22', # Orange
'severe': '#E74C3C' # Red
}
SYMPTOM_COLORS = [
'#E74C3C', # Red
'#9B59B6', # Purple
'#3498DB', # Blue
'#E67E22', # Orange
'#1ABC9C', # Teal
'#F39C12', # Yellow
'#D35400', # Dark Orange
'#8E44AD', # Dark Purple
]
def create_diagnostic_visualization(
image: Union[str, Path, Image.Image, np.ndarray],
masks: Optional[np.ndarray] = None,
severity_label: str = "unknown",
disease_name: str = "Unknown",
affected_percent: float = 0.0,
prompt_labels: Optional[List[str]] = None,
figsize: Tuple[int, int] = (16, 6)
) -> plt.Figure:
"""
Create a comprehensive diagnostic visualization.
Args:
image: Input image
masks: Segmentation masks array (N, H, W)
severity_label: Severity classification result
disease_name: Identified disease name
affected_percent: Percentage of affected area
prompt_labels: Labels for each mask
figsize: Figure size
Returns:
matplotlib Figure object
"""
# Load image
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_array = np.array(image)
# Create figure with subplots
fig, axes = plt.subplots(1, 3, figsize=figsize)
fig.suptitle(f'CropDoctor Diagnostic Report', fontsize=14, fontweight='bold')
# Panel 1: Original Image
axes[0].imshow(img_array)
axes[0].set_title('Original Image', fontsize=12)
axes[0].axis('off')
# Panel 2: Segmentation Overlay
if masks is not None and len(masks) > 0:
overlay = create_mask_overlay(img_array, masks, alpha=0.5)
axes[1].imshow(overlay)
# Create legend
if prompt_labels:
patches = []
for i, label in enumerate(prompt_labels[:len(SYMPTOM_COLORS)]):
color = SYMPTOM_COLORS[i % len(SYMPTOM_COLORS)]
patches.append(mpatches.Patch(color=color, label=label, alpha=0.7))
axes[1].legend(handles=patches, loc='upper right', fontsize=8)
else:
axes[1].imshow(img_array)
axes[1].text(0.5, 0.5, 'No disease regions detected',
transform=axes[1].transAxes, ha='center', va='center',
fontsize=12, color='green')
axes[1].set_title('Disease Detection', fontsize=12)
axes[1].axis('off')
# Panel 3: Diagnostic Summary
axes[2].axis('off')
# Create summary text
severity_color = SEVERITY_COLORS.get(severity_label.lower(), '#95A5A6')
summary_text = f"""
╔════════════════════════════════╗
β•‘ DIAGNOSTIC SUMMARY β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
πŸ“‹ Disease: {disease_name}
⚠️ Severity: {severity_label.upper()}
πŸ“Š Affected Area: {affected_percent:.1f}%
"""
# Add severity indicator
axes[2].text(0.5, 0.65, summary_text, transform=axes[2].transAxes,
fontsize=11, fontfamily='monospace',
verticalalignment='top', horizontalalignment='center')
# Add severity color bar
severity_bar = plt.Rectangle((0.15, 0.25), 0.7, 0.1,
facecolor=severity_color,
edgecolor='black',
transform=axes[2].transAxes)
axes[2].add_patch(severity_bar)
axes[2].text(0.5, 0.30, severity_label.upper(),
transform=axes[2].transAxes, ha='center', va='center',
fontsize=12, fontweight='bold', color='white')
# Add affected area progress bar
bar_width = 0.7 * (affected_percent / 100)
bg_bar = plt.Rectangle((0.15, 0.12), 0.7, 0.06,
facecolor='#EEEEEE', edgecolor='black',
transform=axes[2].transAxes)
progress_bar = plt.Rectangle((0.15, 0.12), max(0.01, bar_width), 0.06,
facecolor='#E74C3C',
transform=axes[2].transAxes)
axes[2].add_patch(bg_bar)
axes[2].add_patch(progress_bar)
axes[2].text(0.5, 0.08, f'Affected Area: {affected_percent:.1f}%',
transform=axes[2].transAxes, ha='center', fontsize=10)
plt.tight_layout()
return fig
def create_mask_overlay(
image: np.ndarray,
masks: np.ndarray,
alpha: float = 0.5,
colors: Optional[List[str]] = None
) -> np.ndarray:
"""
Create an overlay of segmentation masks on an image.
Args:
image: RGB image array (H, W, 3)
masks: Binary masks (N, H, W)
alpha: Transparency of overlay
colors: Optional list of colors for masks
Returns:
Image array with mask overlay
"""
if colors is None:
colors = SYMPTOM_COLORS
# Start with the original image
overlay = image.copy().astype(np.float32)
for i, mask in enumerate(masks):
if mask.any():
# Get color for this mask
color_hex = colors[i % len(colors)]
color_rgb = hex_to_rgb(color_hex)
# Create colored mask
colored_mask = np.zeros_like(overlay)
colored_mask[mask] = color_rgb
# Blend with overlay
mask_3d = np.stack([mask] * 3, axis=-1)
overlay = np.where(
mask_3d,
overlay * (1 - alpha) + colored_mask * alpha,
overlay
)
return overlay.astype(np.uint8)
def create_severity_heatmap(
image: Union[str, Path, Image.Image, np.ndarray],
severity_map: np.ndarray,
figsize: Tuple[int, int] = (12, 5)
) -> plt.Figure:
"""
Create a heatmap showing severity distribution across the image.
Args:
image: Input image
severity_map: Array of severity values (H, W) with values 0-3
figsize: Figure size
Returns:
matplotlib Figure object
"""
# Load image
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_array = np.array(image)
# Create custom colormap
colors = ['#2ECC71', '#F1C40F', '#E67E22', '#E74C3C'] # Green to Red
cmap = LinearSegmentedColormap.from_list('severity', colors, N=256)
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Original image
axes[0].imshow(img_array)
axes[0].set_title('Original Image')
axes[0].axis('off')
# Heatmap overlay
axes[1].imshow(img_array)
heatmap = axes[1].imshow(severity_map, cmap=cmap, alpha=0.6, vmin=0, vmax=3)
axes[1].set_title('Severity Heatmap')
axes[1].axis('off')
# Add colorbar
cbar = plt.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04)
cbar.set_ticks([0, 1, 2, 3])
cbar.set_ticklabels(['Healthy', 'Mild', 'Moderate', 'Severe'])
plt.tight_layout()
return fig
def create_comparison_view(
images: List[Union[str, Path, Image.Image]],
results: List[Dict],
cols: int = 4,
figsize_per_image: Tuple[float, float] = (4, 5)
) -> plt.Figure:
"""
Create a grid comparison view of multiple diagnoses.
Args:
images: List of images
results: List of diagnostic results (dicts with 'severity_label', 'disease_name', etc.)
cols: Number of columns in grid
figsize_per_image: Size per image in the grid
Returns:
matplotlib Figure object
"""
n_images = len(images)
rows = (n_images + cols - 1) // cols
fig, axes = plt.subplots(
rows, cols,
figsize=(figsize_per_image[0] * cols, figsize_per_image[1] * rows)
)
if rows == 1:
axes = [axes]
if cols == 1:
axes = [[ax] for ax in axes]
for i, (img, result) in enumerate(zip(images, results)):
row = i // cols
col = i % cols
ax = axes[row][col] if rows > 1 else axes[col]
# Load image
if isinstance(img, (str, Path)):
img = Image.open(img).convert("RGB")
ax.imshow(img)
ax.axis('off')
# Add colored border based on severity
severity = result.get('severity_label', 'unknown')
color = SEVERITY_COLORS.get(severity.lower(), '#95A5A6')
for spine in ax.spines.values():
spine.set_edgecolor(color)
spine.set_linewidth(4)
spine.set_visible(True)
# Add label
ax.set_title(
f"{result.get('disease_name', 'Unknown')}\n{severity.upper()}",
fontsize=10,
color=color
)
# Hide empty subplots
for i in range(n_images, rows * cols):
row = i // cols
col = i % cols
ax = axes[row][col] if rows > 1 else axes[col]
ax.axis('off')
ax.set_visible(False)
plt.tight_layout()
return fig
def create_treatment_card(
result: Dict,
figsize: Tuple[int, int] = (8, 10)
) -> plt.Figure:
"""
Create a treatment recommendation card.
Args:
result: Diagnostic result dictionary
figsize: Figure size
Returns:
matplotlib Figure object
"""
fig, ax = plt.subplots(figsize=figsize)
ax.axis('off')
severity_color = SEVERITY_COLORS.get(
result.get('severity_label', 'unknown').lower(),
'#95A5A6'
)
# Title
ax.text(0.5, 0.95, '🌿 TREATMENT CARD',
ha='center', va='top', fontsize=16, fontweight='bold',
transform=ax.transAxes)
# Disease info
disease_text = f"""
╔═══════════════════════════════════════════╗
β•‘ Disease: {result.get('disease_name', 'Unknown'):<32}β•‘
β•‘ Type: {result.get('disease_type', 'unknown'):<35}β•‘
β•‘ Severity: {result.get('severity_label', 'unknown').upper():<31}β•‘
β•‘ Affected Area: {result.get('affected_area_percent', 0):.1f}%{' ' * 25}β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
ax.text(0.5, 0.85, disease_text,
ha='center', va='top', fontfamily='monospace', fontsize=10,
transform=ax.transAxes)
# Treatments
y_pos = 0.60
# Organic treatments
ax.text(0.1, y_pos, '🌱 ORGANIC TREATMENTS', fontweight='bold', fontsize=11,
transform=ax.transAxes)
y_pos -= 0.03
for treatment in result.get('organic_treatments', [])[:4]:
ax.text(0.12, y_pos, f'β€’ {treatment[:50]}', fontsize=9,
transform=ax.transAxes)
y_pos -= 0.03
y_pos -= 0.02
# Chemical treatments
if result.get('chemical_treatments'):
ax.text(0.1, y_pos, 'πŸ§ͺ CHEMICAL TREATMENTS', fontweight='bold', fontsize=11,
transform=ax.transAxes)
y_pos -= 0.03
for treatment in result.get('chemical_treatments', [])[:3]:
ax.text(0.12, y_pos, f'β€’ {treatment[:50]}', fontsize=9,
transform=ax.transAxes)
y_pos -= 0.03
y_pos -= 0.02
# Prevention
ax.text(0.1, y_pos, 'πŸ›‘οΈ PREVENTION', fontweight='bold', fontsize=11,
transform=ax.transAxes)
y_pos -= 0.03
for measure in result.get('preventive_measures', [])[:4]:
ax.text(0.12, y_pos, f'β€’ {measure[:50]}', fontsize=9,
transform=ax.transAxes)
y_pos -= 0.03
# Timing
y_pos -= 0.02
ax.text(0.1, y_pos, f"⏰ TIMING: {result.get('treatment_timing', 'Consult expert')[:60]}",
fontsize=9, transform=ax.transAxes)
# Add border
rect = plt.Rectangle((0.05, 0.05), 0.9, 0.92,
fill=False, edgecolor=severity_color, linewidth=3,
transform=ax.transAxes)
ax.add_patch(rect)
return fig
def hex_to_rgb(hex_color: str) -> List[int]:
"""Convert hex color to RGB."""
hex_color = hex_color.lstrip('#')
return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
def save_visualization(
fig: plt.Figure,
output_path: Union[str, Path],
dpi: int = 150
):
"""Save figure to file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
plt.close(fig)
logger.info(f"Visualization saved to {output_path}")
if __name__ == "__main__":
# Test visualizations
import numpy as np
# Create test image
test_img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
test_img[:, :, 1] = 139 # Greenish tint
# Create test masks
test_masks = np.zeros((2, 480, 640), dtype=bool)
test_masks[0, 100:200, 100:200] = True # Square mask
test_masks[1, 300:400, 400:500] = True # Another square
# Test diagnostic visualization
fig = create_diagnostic_visualization(
test_img,
test_masks,
severity_label="moderate",
disease_name="Leaf Spot Disease",
affected_percent=15.5,
prompt_labels=["brown spots", "yellowing"]
)
save_visualization(fig, "/tmp/test_diagnostic.png")
print("Test visualization saved to /tmp/test_diagnostic.png")