import numpy as np import plotly.graph_objects as go import plotly.express as px from typing import Optional, Tuple, Dict, Any, Union from PIL import Image import base64 from io import BytesIO class SpatialImageExtractor: """Extract spatial background images from AnnData objects""" @staticmethod def get_spatial_image( adata, library_id: Optional[str] = None, prefer_lowres: bool = True, ) -> Optional[Tuple[np.ndarray, Dict[str, Any], str]]: """ Extract spatial background image from AnnData object Spatial images are typically stored in: - adata.uns['spatial'][library_id]['images']['hires'] or 'lowres' - adata.uns['spatial'][library_id]['scalefactors'] Args: adata: AnnData object library_id: Library/sample ID. If None, uses first available. prefer_lowres: If True, prefer lowres image for faster rendering Returns: Tuple of (image_array, scalefactors, image_key) or None if not found """ try: # Check if spatial data exists if 'spatial' not in adata.uns: return None spatial_data = adata.uns['spatial'] # Get library_id if library_id is None: # Use first available library if isinstance(spatial_data, dict) and len(spatial_data) > 0: library_id = list(spatial_data.keys())[0] else: return None if library_id not in spatial_data: return None library_data = spatial_data[library_id] # Get images if 'images' not in library_data: return None images = library_data['images'] # Select image based on preference (lowres is faster) image_key = None if prefer_lowres and 'lowres' in images: img_array = images['lowres'] image_key = 'lowres' elif 'hires' in images: img_array = images['hires'] image_key = 'hires' elif 'lowres' in images: img_array = images['lowres'] image_key = 'lowres' else: return None # Get scalefactors scalefactors = library_data.get('scalefactors', {}) return img_array, scalefactors, image_key except Exception as e: print(f"Warning: Could not extract spatial image: {e}") return None @staticmethod def get_available_libraries(adata) -> list: """Get list of available library IDs with spatial images""" try: if 'spatial' not in adata.uns: return [] return list(adata.uns['spatial'].keys()) except: return [] @staticmethod def has_spatial_image(adata) -> bool: """Check if AnnData has spatial background image""" try: if 'spatial' not in adata.uns: return False spatial_data = adata.uns['spatial'] if not isinstance(spatial_data, dict) or len(spatial_data) == 0: return False # Check first library first_lib = list(spatial_data.keys())[0] lib_data = spatial_data[first_lib] if 'images' not in lib_data: return False images = lib_data['images'] return 'hires' in images or 'lowres' in images except: return False class SpatialPlotter: """Create spatial visualizations for gene expression""" @staticmethod def plot_spatial_gene( spatial_coords: np.ndarray, expression: np.ndarray, gene_name: str, point_size: int = 5, use_log: bool = False, colorscale: str = "Viridis", width: int = 800, height: int = 800, background_image: Optional[Union[np.ndarray, str]] = None, scalefactors: Optional[Dict[str, float]] = None, background_opacity: float = 0.5, ) -> go.Figure: """ Create spatial scatter plot of gene expression Args: spatial_coords: Nx2 array of spatial coordinates expression: N-length array of gene expression values gene_name: Name of the gene point_size: Size of scatter points use_log: Whether to apply log1p transformation to expression colorscale: Plotly colorscale name width: Figure width in pixels height: Figure height in pixels background_image: Background image as numpy array or file path scalefactors: Scale factors from h5ad for coordinate mapping background_opacity: Opacity of background image (0.0-1.0) Returns: Plotly Figure object """ # Prepare expression values expr_values = expression.copy() # Apply log transformation if requested if use_log: expr_values = np.log1p(expr_values) expr_label = f"log1p({gene_name})" else: expr_label = gene_name # Extract coordinates x = spatial_coords[:, 0] y = spatial_coords[:, 1] # Create figure fig = go.Figure() # Add background image if provided if background_image is not None: try: # Handle different input types if isinstance(background_image, str): # File path img = Image.open(background_image) img_array = np.array(img) elif isinstance(background_image, np.ndarray): img_array = background_image else: img_array = None if img_array is not None: # Convert numpy array to PIL Image for Plotly if img_array.dtype == np.float64 or img_array.dtype == np.float32: # Normalize float images to 0-255 img_array = (img_array * 255).astype(np.uint8) img = Image.fromarray(img_array) # Calculate image bounds in spatial coordinate system # The spatial coordinates in adata.obsm['spatial'] are in full-resolution pixel space # The stored image is scaled down by scalefactors img_height, img_width = img_array.shape[:2] # Determine the scale factor used for this image if scalefactors: # Get scale factor based on image_key (passed via scalefactors dict) image_key = scalefactors.get('_image_key', 'hires') if image_key == 'lowres': scale = scalefactors.get('tissue_lowres_scalef', 1.0) else: scale = scalefactors.get('tissue_hires_scalef', 1.0) # Image spans from (0,0) to (img_width/scale, img_height/scale) in spatial coords img_x_min = 0 img_y_min = 0 img_x_max = img_width / scale img_y_max = img_height / scale else: # No scalefactors: fit image to coordinate bounds with padding padding = 0.05 # 5% padding x_range = x.max() - x.min() y_range = y.max() - y.min() img_x_min = x.min() - x_range * padding img_y_min = y.min() - y_range * padding img_x_max = x.max() + x_range * padding img_y_max = y.max() + y_range * padding # Convert to base64 for Plotly (use JPEG for faster encoding) buffered = BytesIO() # Convert RGBA to RGB if needed for JPEG if img.mode == 'RGBA': img_rgb = Image.new('RGB', img.size, (255, 255, 255)) img_rgb.paste(img, mask=img.split()[3]) img = img_rgb img.save(buffered, format="JPEG", quality=85) img_base64 = base64.b64encode(buffered.getvalue()).decode() img_src = f"data:image/jpeg;base64,{img_base64}" # With Y-axis reversed (autorange="reversed"), smaller Y is at top # Image anchor point is top-left, so y should be img_y_min (top of image) fig.add_layout_image( dict( source=img_src, xref="x", yref="y", x=img_x_min, y=img_y_min, # Top of image (smallest Y value) sizex=img_x_max - img_x_min, sizey=img_y_max - img_y_min, sizing="stretch", opacity=background_opacity, layer="below", yanchor="top", ) ) except Exception as e: print(f"Warning: Could not load background image: {e}") # Add scatter plot fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict( size=point_size, color=expr_values, colorscale=colorscale, showscale=True, colorbar=dict(title=expr_label), line=dict(width=0), ), text=[f"Expression: {val:.2f}" for val in expr_values], hovertemplate="%{text}
" + "X: %{x:.1f}
" + "Y: %{y:.1f}
" + "", ) ) # Update layout fig.update_layout( title=dict( text=f"Spatial Expression: {gene_name}", x=0.5, xanchor="center", font=dict(size=18), ), xaxis=dict( title="Spatial X", showgrid=False, zeroline=False, ), yaxis=dict( title="Spatial Y", showgrid=False, zeroline=False, scaleanchor="x", scaleratio=1, autorange="reversed", # Flip Y-axis to match image coordinate system ), width=width, height=height, hovermode="closest", plot_bgcolor="white", ) return fig @staticmethod def create_overview_plot( spatial_coords: np.ndarray, width: int = 600, height: int = 600, ) -> go.Figure: """ Create overview plot of spatial coordinates (without gene expression) Args: spatial_coords: Nx2 array of spatial coordinates width: Figure width in pixels height: Figure height in pixels Returns: Plotly Figure object """ x = spatial_coords[:, 0] y = spatial_coords[:, 1] fig = go.Figure() fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict( size=3, color="lightblue", line=dict(width=0), ), hovertemplate="X: %{x:.1f}
Y: %{y:.1f}", ) ) fig.update_layout( title=dict( text="Spatial Overview", x=0.5, xanchor="center", ), xaxis=dict( title="Spatial X", showgrid=False, zeroline=False, ), yaxis=dict( title="Spatial Y", showgrid=False, zeroline=False, scaleanchor="x", scaleratio=1, ), width=width, height=height, plot_bgcolor="white", ) return fig @staticmethod def create_overview_plot_with_background( spatial_coords: np.ndarray, background_image: Optional[np.ndarray] = None, scalefactors: Optional[Dict[str, Any]] = None, width: int = 600, height: int = 600, background_opacity: float = 0.6, ) -> go.Figure: """ Create overview plot of spatial coordinates with optional tissue background Args: spatial_coords: Nx2 array of spatial coordinates background_image: Optional background image as numpy array scalefactors: Scale factors for coordinate mapping width: Figure width in pixels height: Figure height in pixels background_opacity: Opacity of background image Returns: Plotly Figure object """ x = spatial_coords[:, 0] y = spatial_coords[:, 1] fig = go.Figure() # Add background image if provided if background_image is not None: try: img_array = background_image if img_array.dtype == np.float64 or img_array.dtype == np.float32: img_array = (img_array * 255).astype(np.uint8) img = Image.fromarray(img_array) img_height, img_width = img_array.shape[:2] # Calculate image bounds if scalefactors: image_key = scalefactors.get('_image_key', 'hires') if image_key == 'lowres': scale = scalefactors.get('tissue_lowres_scalef', 1.0) else: scale = scalefactors.get('tissue_hires_scalef', 1.0) img_x_min = 0 img_y_min = 0 img_x_max = img_width / scale img_y_max = img_height / scale else: padding = 0.05 x_range = x.max() - x.min() y_range = y.max() - y.min() img_x_min = x.min() - x_range * padding img_y_min = y.min() - y_range * padding img_x_max = x.max() + x_range * padding img_y_max = y.max() + y_range * padding # Convert to base64 buffered = BytesIO() if img.mode == 'RGBA': img_rgb = Image.new('RGB', img.size, (255, 255, 255)) img_rgb.paste(img, mask=img.split()[3]) img = img_rgb img.save(buffered, format="JPEG", quality=85) img_base64 = base64.b64encode(buffered.getvalue()).decode() img_src = f"data:image/jpeg;base64,{img_base64}" fig.add_layout_image( dict( source=img_src, xref="x", yref="y", x=img_x_min, y=img_y_min, sizex=img_x_max - img_x_min, sizey=img_y_max - img_y_min, sizing="stretch", opacity=background_opacity, layer="below", yanchor="top", ) ) except Exception as e: print(f"Warning: Could not add background image: {e}") fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict( size=3, color="rgba(65, 105, 225, 0.7)", # Royal blue with transparency line=dict(width=0), ), hovertemplate="X: %{x:.1f}
Y: %{y:.1f}", ) ) fig.update_layout( title=dict( text="Spatial Overview", x=0.5, xanchor="center", ), xaxis=dict( title="Spatial X", showgrid=False, zeroline=False, ), yaxis=dict( title="Spatial Y", showgrid=False, zeroline=False, scaleanchor="x", scaleratio=1, autorange="reversed", # Match image coordinate system ), width=width, height=height, plot_bgcolor="white", ) return fig @staticmethod def get_expression_stats(expression: np.ndarray) -> dict: """ Calculate basic statistics for expression values Args: expression: Expression array Returns: Dictionary with statistics """ return { "min": float(np.min(expression)), "max": float(np.max(expression)), "mean": float(np.mean(expression)), "median": float(np.median(expression)), "std": float(np.std(expression)), "non_zero_count": int(np.sum(expression > 0)), "non_zero_percent": float(100 * np.sum(expression > 0) / len(expression)), }