""" visualization.py - FastAPI-Compatible Visualization Functions (WITH GATE SUPPORT) ================================================================================== Pure visualization functions that return image bytes or JSON data for FastAPI endpoints. No Streamlit dependencies. NOW INCLUDES: 5-channel support with GATE predictions visualization! """ import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Non-interactive backend for server-side rendering from matplotlib.patches import Rectangle import plotly.graph_objects as go import io from PIL import Image from typing import Dict, List, Tuple, Optional import json import base64 # ============================================================================== # 2D PATCH VISUALIZATION # ============================================================================== def generate_2d_patch_preview( patches: Dict[str, np.ndarray], aoi_name: str, patch_id: str, layer_order: List[str] = None ) -> bytes: """ Generate 2D visualization of patch layers Args: patches: Dictionary of patch data {layer_name: array} aoi_name: Name of AOI patch_id: Patch identifier layer_order: Order of layers to display Returns: PNG image as bytes """ if layer_order is None: layer_order = ['dtm', 'slope', 'ndvi', 'ndwi', 'flow_acc'] show_layers = [l for l in layer_order if l in patches][:3] if not show_layers: # Return empty image fig, ax = plt.subplots(figsize=(5, 5)) ax.text(0.5, 0.5, 'No data available', ha='center', va='center') ax.axis('off') else: fig, axes = plt.subplots(1, len(show_layers), figsize=(5*len(show_layers), 5)) if len(show_layers) == 1: axes = [axes] cmaps = { 'dtm': 'terrain', 'slope': 'YlOrRd', 'ndvi': 'RdYlGn', 'ndwi': 'Blues', 'flow_acc': 'viridis', 'roughness': 'hot' } for idx, layer_name in enumerate(show_layers): data = patches[layer_name] im = axes[idx].imshow(data, cmap=cmaps.get(layer_name, 'viridis'), interpolation='bilinear') valid = data[~np.isnan(data)] if len(valid) > 0: axes[idx].set_title( f'{layer_name.upper()}\nμ={np.mean(valid):.2f}, σ={np.std(valid):.2f}', fontweight='bold', fontsize=11 ) else: axes[idx].set_title(f'{layer_name.upper()}', fontweight='bold', fontsize=11) axes[idx].axis('off') plt.colorbar(im, ax=axes[idx], fraction=0.046) plt.suptitle(f'{aoi_name} - {patch_id}', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() # ============================================================================== # 3D TERRAIN VISUALIZATION # ============================================================================== def generate_3d_terrain_json( patches: Dict[str, np.ndarray], aoi_name: str, patch_id: str, lat: float, lon: float ) -> Dict: """ Generate 3D terrain visualization data for Plotly Returns: Dictionary containing Plotly figure JSON """ if 'dtm' not in patches: return { 'error': 'No DTM data available', 'figure': None } dtm = patches['dtm'] # Handle NaN values dtm_clean = dtm.copy() valid_mask = ~np.isnan(dtm) if not valid_mask.any(): return { 'error': 'No valid DTM data', 'figure': None } # Fill NaNs with median dtm_clean[~valid_mask] = np.nanmedian(dtm) # Create coordinate grids rows, cols = dtm.shape x = np.arange(cols) y = np.arange(rows) X, Y = np.meshgrid(x, y) # Color by slope if available if 'slope' in patches: slope = patches['slope'] slope_clean = slope.copy() slope_clean[np.isnan(slope)] = np.nanmedian(slope) colorscale = 'YlOrRd' surfacecolor = slope_clean colorbar_title = 'Slope (°)' else: colorscale = 'earth' surfacecolor = dtm_clean colorbar_title = 'Elevation (m)' # Create 3D surface fig = go.Figure(data=[go.Surface( z=dtm_clean, x=X, y=Y, surfacecolor=surfacecolor, colorscale=colorscale, colorbar=dict(title=colorbar_title), lighting=dict( ambient=0.4, diffuse=0.8, fresnel=0.2, specular=0.3, roughness=0.5 ), contours=dict( z=dict( show=True, usecolormap=True, highlightcolor="limegreen", project=dict(z=True) ) ) )]) # Update layout fig.update_layout( title=f'3D Terrain View - {patch_id}
Lat: {lat:.6f}, Lon: {lon:.6f}', scene=dict( xaxis_title='X (pixels)', yaxis_title='Y (pixels)', zaxis_title='Elevation (m)', camera=dict( eye=dict(x=1.5, y=1.5, z=1.3) ), aspectmode='manual', aspectratio=dict(x=1, y=1, z=0.5) ), width=800, height=600, margin=dict(l=0, r=0, t=40, b=0) ) # Statistics stats = { 'min_elevation': float(np.nanmin(dtm)), 'max_elevation': float(np.nanmax(dtm)), 'relief': float(np.nanmax(dtm) - np.nanmin(dtm)), 'mean_elevation': float(np.nanmean(dtm)), 'std_elevation': float(np.nanstd(dtm)) } return { 'figure': fig.to_json(), 'stats': stats } # ============================================================================== # AI MODEL ANALYSIS VISUALIZATION # ============================================================================== def generate_ai_analysis_visualization( patches: Dict[str, np.ndarray], results: Dict, patch_id: str ) -> bytes: """ Generate AI model analysis visualization Args: patches: Patch data results: Results from analyze_patch_multi_model() patch_id: Patch identifier Returns: PNG image as bytes """ # Determine number of plots n_plots = 3 if results['cluster_id'] is not None else 2 fig, axes = plt.subplots(1, n_plots, figsize=(6*n_plots, 5)) if n_plots == 2: axes = list(axes) # Plot 1: Reconstruction Error ax = axes[0] dtm_orig = patches['dtm'] dtm_recon = results['reconstruction'][0] diff = np.abs(dtm_orig - dtm_recon) im = ax.imshow(diff, cmap='Reds', interpolation='bilinear') ax.set_title('Reconstruction Error\n(Model 1: Autoencoder)', fontweight='bold') ax.axis('off') plt.colorbar(im, ax=ax, fraction=0.046) # Plot 2: IForest decision ax = axes[1] status = "ANOMALY" if results['iforest_is_anomaly'] else "NORMAL" color = 'red' if results['iforest_is_anomaly'] else 'green' ax.text(0.5, 0.5, f"{status}\nScore: {results['iforest_score']:.4f}", ha='center', va='center', fontsize=20, fontweight='bold', color=color) ax.set_title('Isolation Forest\n(Model 2)', fontweight='bold') ax.axis('off') # Plot 3: Cluster similarities (if available) if results['cluster_id'] is not None: ax = axes[2] clusters = list(range(len(results['all_cluster_similarities']))) sims = results['all_cluster_similarities'] colors = ['red' if i == results['cluster_id'] else 'gray' for i in clusters] ax.bar(clusters, sims, color=colors) ax.set_xlabel('Cluster ID') ax.set_ylabel('Similarity') ax.set_title('Cluster Similarities\n(Model 3: K-Means)', fontweight='bold') ax.set_ylim([0, 1]) plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() def generate_ai_analysis_json(results: Dict) -> Dict: """ Generate JSON-serializable AI analysis results Args: results: Results from analyze_patch_multi_model() Returns: Dictionary with analysis data """ return { 'reconstruction_error': float(results['reconstruction_error']), 'iforest_score': float(results['iforest_score']), 'iforest_is_anomaly': bool(results['iforest_is_anomaly']), 'cluster_id': int(results['cluster_id']) if results['cluster_id'] is not None else None, 'cluster_similarity': float(results['cluster_similarity']) if results['cluster_similarity'] is not None else None, 'combined_anomaly_score': float(results['combined_anomaly_score']), 'verdict': get_anomaly_verdict(results['combined_anomaly_score']) } def get_anomaly_verdict(combined_score: float) -> str: """Get human-readable verdict""" if combined_score > 0.7: return "HIGH ANOMALY LIKELIHOOD - Investigate!" elif combined_score > 0.5: return "MODERATE ANOMALY - Worth checking" else: return "LIKELY NORMAL TERRAIN" # ============================================================================== # UNIFIED PROBABILITY MATRIX VISUALIZATION (NOW WITH 5 CHANNELS!) # ============================================================================== def generate_probability_matrix_visualization( unified_matrix: np.ndarray, patch_idx: int, channel_names: List[str] ) -> bytes: """ Visualize all probability channels for a single patch (supports 4 or 5 channels) Args: unified_matrix: Shape (num_patches, 64, 64, N) where N=4 or 5 patch_idx: Index of patch to visualize channel_names: Names of the channels Returns: PNG image as bytes """ patch_data = unified_matrix[patch_idx] # Shape: (64, 64, N) n_channels = patch_data.shape[-1] # Determine grid layout if n_channels == 4: nrows, ncols = 2, 2 elif n_channels == 5: nrows, ncols = 2, 3 else: nrows = (n_channels + 2) // 3 ncols = 3 fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 6*nrows)) axes = axes.flatten() if n_channels > 1 else [axes] for i in range(n_channels): channel_data = patch_data[:, :, i] im = axes[i].imshow(channel_data, cmap='hot', vmin=0, vmax=1, interpolation='bilinear') axes[i].set_title(f'{channel_names[i]}\n(Range: {channel_data.min():.3f} - {channel_data.max():.3f})', fontweight='bold') axes[i].axis('off') plt.colorbar(im, ax=axes[i], fraction=0.046) # Hide unused subplots for i in range(n_channels, len(axes)): axes[i].axis('off') plt.suptitle(f'Probability Matrix - Patch {patch_idx}', fontsize=14, fontweight='bold') plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() def generate_full_aoi_heatmap( unified_matrix: np.ndarray, metadata: List[Dict], aoi_shape: Tuple[int, int], channel_idx: int = 0, patch_size: int = 64 ) -> bytes: """ Generate full AOI heatmap for a specific probability channel Args: unified_matrix: Shape (num_patches, 64, 64, N) where N=4 or 5 metadata: List of patch metadata with 'row' and 'col' aoi_shape: (height, width) of full AOI channel_idx: Which channel to visualize (0 to N-1) patch_size: Size of patches Returns: PNG image as bytes """ # Reconstruct full heatmap heatmap = np.zeros(aoi_shape, dtype=np.float32) count_map = np.zeros(aoi_shape, dtype=np.float32) for i, meta in enumerate(metadata): row = meta['row'] col = meta['col'] row_end = min(row + patch_size, aoi_shape[0]) col_end = min(col + patch_size, aoi_shape[1]) patch_h = row_end - row patch_w = col_end - col if patch_h > 0 and patch_w > 0: patch_prob = unified_matrix[i, :patch_h, :patch_w, channel_idx] heatmap[row:row_end, col:col_end] += patch_prob count_map[row:row_end, col:col_end] += 1 # Average overlapping areas count_map = np.maximum(count_map, 1) heatmap = heatmap / count_map # Create visualization fig, ax = plt.subplots(figsize=(15, 10)) im = ax.imshow(heatmap, cmap='hot', interpolation='bilinear') ax.set_title(f'Full AOI Probability Heatmap\n{aoi_shape[0]}Ɨ{aoi_shape[1]} pixels', fontsize=14, fontweight='bold') ax.axis('off') plt.colorbar(im, ax=ax, fraction=0.046, label='Probability') plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() # ============================================================================== # šŸ†• GATE PREDICTIONS VISUALIZATION (NEW!) # ============================================================================== def generate_gate_prediction_heatmap( unified_matrix: np.ndarray, metadata: List[Dict], aoi_shape: Tuple[int, int], aoi_name: str, threshold: float = 0.5, patch_size: int = 64 ) -> bytes: """ Generate GATE prediction heatmap showing archaeological site candidates Args: unified_matrix: Shape (num_patches, 64, 64, 5) - Channel 4 is GATE predictions metadata: List of patch metadata aoi_shape: (height, width) of full AOI aoi_name: Name of AOI for title threshold: Classification threshold (default 0.5) patch_size: Patch size in pixels Returns: PNG image as bytes """ # Extract GATE channel (channel 4) gate_channel_idx = 4 # Reconstruct full heatmap heatmap = np.zeros(aoi_shape, dtype=np.float32) count_map = np.zeros(aoi_shape, dtype=np.float32) for i, meta in enumerate(metadata): row = meta['row'] col = meta['col'] row_end = min(row + patch_size, aoi_shape[0]) col_end = min(col + patch_size, aoi_shape[1]) patch_h = row_end - row patch_w = col_end - col if patch_h > 0 and patch_w > 0: gate_prob = unified_matrix[i, :patch_h, :patch_w, gate_channel_idx] heatmap[row:row_end, col:col_end] += gate_prob count_map[row:row_end, col:col_end] += 1 # Average overlapping areas count_map = np.maximum(count_map, 1) heatmap = heatmap / count_map # Create binary mask for positive predictions positive_mask = heatmap >= threshold num_positive_pixels = positive_mask.sum() percent_positive = (num_positive_pixels / heatmap.size) * 100 # Create visualization fig, ax = plt.subplots(figsize=(16, 12)) # Show heatmap im = ax.imshow(heatmap, cmap='RdYlGn', vmin=0, vmax=1, interpolation='bilinear') # Overlay contours for high-probability areas contour_levels = [threshold, 0.7, 0.9] contours = ax.contour(heatmap, levels=contour_levels, colors=['yellow', 'orange', 'red'], linewidths=2, alpha=0.7) ax.clabel(contours, inline=True, fontsize=10) # Add title with statistics ax.set_title( f'GATE Archaeological Site Predictions - {aoi_name}\n' f'Positive Area: {percent_positive:.2f}% (threshold={threshold})', fontsize=16, fontweight='bold', pad=20 ) ax.axis('off') # Colorbar cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.set_label('GATE Prediction Probability', fontsize=12, fontweight='bold') plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() def generate_gate_positive_patches_visualization( unified_matrix: np.ndarray, metadata: List[Dict], patches_data: Dict[str, np.ndarray], threshold: float = 0.5, top_n: int = 16 ) -> bytes: """ Visualize top N patches with highest GATE predictions (archaeological candidates) Args: unified_matrix: Shape (num_patches, 64, 64, 5) metadata: Patch metadata patches_data: Original patch data dict with 'dtm', 'slope', etc. threshold: Minimum GATE score to consider top_n: Number of top patches to show Returns: PNG image as bytes showing DTM of top patches """ # Extract GATE predictions (channel 4) gate_predictions = unified_matrix[:, :, :, 4] # Shape: (num_patches, 64, 64) # Get mean GATE score per patch mean_scores = gate_predictions.mean(axis=(1, 2)) # Shape: (num_patches,) # Filter patches above threshold positive_mask = mean_scores >= threshold positive_indices = np.where(positive_mask)[0] if len(positive_indices) == 0: # No positive predictions fig, ax = plt.subplots(figsize=(10, 8)) ax.text(0.5, 0.5, f'No patches above threshold {threshold}', ha='center', va='center', fontsize=16, fontweight='bold') ax.axis('off') else: # Sort by score sorted_indices = positive_indices[np.argsort(mean_scores[positive_indices])[::-1]] top_indices = sorted_indices[:top_n] # Create grid n_show = len(top_indices) ncols = 4 nrows = (n_show + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows)) axes = axes.flatten() if n_show > 1 else [axes] for idx, patch_idx in enumerate(top_indices): ax = axes[idx] # Get DTM for this patch dtm_patch = patches_data['dtm'][patch_idx] if 'dtm' in patches_data else None if dtm_patch is not None: im = ax.imshow(dtm_patch, cmap='terrain', interpolation='bilinear') plt.colorbar(im, ax=ax, fraction=0.046) else: # Show GATE prediction if no DTM im = ax.imshow(gate_predictions[patch_idx], cmap='hot', vmin=0, vmax=1) plt.colorbar(im, ax=ax, fraction=0.046) score = mean_scores[patch_idx] row = metadata[patch_idx]['row'] col = metadata[patch_idx]['col'] ax.set_title(f'Patch {patch_idx}\nGATE: {score:.3f}\n(r={row}, c={col})', fontsize=10, fontweight='bold') ax.axis('off') # Hide unused subplots for idx in range(n_show, len(axes)): axes[idx].axis('off') plt.suptitle(f'Top {n_show} Archaeological Site Candidates (GATE > {threshold})', fontsize=16, fontweight='bold', y=1.0) plt.tight_layout() # Convert to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') plt.close(fig) buf.seek(0) return buf.getvalue() def generate_gate_statistics_json( unified_matrix: np.ndarray, threshold: float = 0.5 ) -> Dict: """ Generate statistics about GATE predictions for JSON response Args: unified_matrix: Shape (num_patches, 64, 64, 5) threshold: Classification threshold Returns: Dictionary with GATE statistics """ gate_predictions = unified_matrix[:, :, :, 4] mean_scores = gate_predictions.mean(axis=(1, 2)) positive_patches = (mean_scores >= threshold).sum() total_patches = len(mean_scores) return { 'total_patches': int(total_patches), 'positive_patches': int(positive_patches), 'positive_percentage': float(positive_patches / total_patches * 100), 'mean_gate_score': float(mean_scores.mean()), 'max_gate_score': float(mean_scores.max()), 'min_gate_score': float(mean_scores.min()), 'std_gate_score': float(mean_scores.std()), 'threshold': float(threshold) } # ============================================================================== # HELPER: Convert bytes to base64 (for JSON responses) # ============================================================================== def image_bytes_to_base64(image_bytes: bytes) -> str: """Convert image bytes to base64 string for JSON embedding""" return base64.b64encode(image_bytes).decode('utf-8') def base64_to_image_bytes(base64_str: str) -> bytes: """Convert base64 string back to image bytes""" return base64.b64decode(base64_str) # ============================================================================== # šŸ†• EXPORT FUNCTIONS FOR FASTAPI (WITH GATE SUPPORT) # ============================================================================== def get_patch_visualizations( patches: Dict[str, np.ndarray], aoi_name: str, patch_id: str, lat: float, lon: float, ai_results: Optional[Dict] = None ) -> Dict[str, str]: """ Generate all visualizations for a patch and return as base64 Returns: Dictionary with base64-encoded images: { '2d_preview': 'base64...', '3d_terrain': {...}, # Plotly JSON 'ai_analysis': 'base64...', # Only if ai_results provided 'ai_data': {...} # Only if ai_results provided } """ result = {} # 2D preview img_2d = generate_2d_patch_preview(patches, aoi_name, patch_id) result['2d_preview'] = image_bytes_to_base64(img_2d) # 3D terrain terrain_3d = generate_3d_terrain_json(patches, aoi_name, patch_id, lat, lon) result['3d_terrain'] = terrain_3d # AI analysis (if provided) if ai_results is not None: img_ai = generate_ai_analysis_visualization(patches, ai_results, patch_id) result['ai_analysis'] = image_bytes_to_base64(img_ai) result['ai_data'] = generate_ai_analysis_json(ai_results) return result def get_gate_visualizations( unified_matrix: np.ndarray, metadata: List[Dict], aoi_shape: Tuple[int, int], aoi_name: str, patches_data: Optional[Dict[str, np.ndarray]] = None, threshold: float = 0.5 ) -> Dict: """ Generate all GATE-related visualizations for FastAPI Args: unified_matrix: Shape (num_patches, 64, 64, 5) with GATE channel metadata: Patch metadata aoi_shape: AOI dimensions aoi_name: AOI identifier patches_data: Original patch data (optional, for showing DTM) threshold: GATE classification threshold Returns: Dictionary with base64 images and statistics: { 'gate_heatmap': 'base64...', 'positive_patches': 'base64...', 'statistics': {...} } """ result = {} # 1. Full AOI GATE heatmap img_heatmap = generate_gate_prediction_heatmap( unified_matrix, metadata, aoi_shape, aoi_name, threshold ) result['gate_heatmap'] = image_bytes_to_base64(img_heatmap) # 2. Top positive patches if patches_data is not None: img_positives = generate_gate_positive_patches_visualization( unified_matrix, metadata, patches_data, threshold ) result['positive_patches'] = image_bytes_to_base64(img_positives) # 3. Statistics result['statistics'] = generate_gate_statistics_json(unified_matrix, threshold) return result