Spaces:
Running
Running
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from typing import Optional, Tuple, Dict, Any, Union | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| class SpatialImageExtractor: | |
| """Extract spatial background images from AnnData objects""" | |
| def get_spatial_image( | |
| adata, | |
| library_id: Optional[str] = None, | |
| prefer_lowres: bool = True, | |
| ) -> Optional[Tuple[np.ndarray, Dict[str, Any], str]]: | |
| """ | |
| Extract spatial background image from AnnData object | |
| Spatial images are typically stored in: | |
| - adata.uns['spatial'][library_id]['images']['hires'] or 'lowres' | |
| - adata.uns['spatial'][library_id]['scalefactors'] | |
| Args: | |
| adata: AnnData object | |
| library_id: Library/sample ID. If None, uses first available. | |
| prefer_lowres: If True, prefer lowres image for faster rendering | |
| Returns: | |
| Tuple of (image_array, scalefactors, image_key) or None if not found | |
| """ | |
| try: | |
| # Check if spatial data exists | |
| if 'spatial' not in adata.uns: | |
| return None | |
| spatial_data = adata.uns['spatial'] | |
| # Get library_id | |
| if library_id is None: | |
| # Use first available library | |
| if isinstance(spatial_data, dict) and len(spatial_data) > 0: | |
| library_id = list(spatial_data.keys())[0] | |
| else: | |
| return None | |
| if library_id not in spatial_data: | |
| return None | |
| library_data = spatial_data[library_id] | |
| # Get images | |
| if 'images' not in library_data: | |
| return None | |
| images = library_data['images'] | |
| # Select image based on preference (lowres is faster) | |
| image_key = None | |
| if prefer_lowres and 'lowres' in images: | |
| img_array = images['lowres'] | |
| image_key = 'lowres' | |
| elif 'hires' in images: | |
| img_array = images['hires'] | |
| image_key = 'hires' | |
| elif 'lowres' in images: | |
| img_array = images['lowres'] | |
| image_key = 'lowres' | |
| else: | |
| return None | |
| # Get scalefactors | |
| scalefactors = library_data.get('scalefactors', {}) | |
| return img_array, scalefactors, image_key | |
| except Exception as e: | |
| print(f"Warning: Could not extract spatial image: {e}") | |
| return None | |
| def get_available_libraries(adata) -> list: | |
| """Get list of available library IDs with spatial images""" | |
| try: | |
| if 'spatial' not in adata.uns: | |
| return [] | |
| return list(adata.uns['spatial'].keys()) | |
| except: | |
| return [] | |
| def has_spatial_image(adata) -> bool: | |
| """Check if AnnData has spatial background image""" | |
| try: | |
| if 'spatial' not in adata.uns: | |
| return False | |
| spatial_data = adata.uns['spatial'] | |
| if not isinstance(spatial_data, dict) or len(spatial_data) == 0: | |
| return False | |
| # Check first library | |
| first_lib = list(spatial_data.keys())[0] | |
| lib_data = spatial_data[first_lib] | |
| if 'images' not in lib_data: | |
| return False | |
| images = lib_data['images'] | |
| return 'hires' in images or 'lowres' in images | |
| except: | |
| return False | |
| class SpatialPlotter: | |
| """Create spatial visualizations for gene expression""" | |
| def plot_spatial_gene( | |
| spatial_coords: np.ndarray, | |
| expression: np.ndarray, | |
| gene_name: str, | |
| point_size: int = 5, | |
| use_log: bool = False, | |
| colorscale: str = "Viridis", | |
| width: int = 800, | |
| height: int = 800, | |
| background_image: Optional[Union[np.ndarray, str]] = None, | |
| scalefactors: Optional[Dict[str, float]] = None, | |
| background_opacity: float = 0.5, | |
| ) -> go.Figure: | |
| """ | |
| Create spatial scatter plot of gene expression | |
| Args: | |
| spatial_coords: Nx2 array of spatial coordinates | |
| expression: N-length array of gene expression values | |
| gene_name: Name of the gene | |
| point_size: Size of scatter points | |
| use_log: Whether to apply log1p transformation to expression | |
| colorscale: Plotly colorscale name | |
| width: Figure width in pixels | |
| height: Figure height in pixels | |
| background_image: Background image as numpy array or file path | |
| scalefactors: Scale factors from h5ad for coordinate mapping | |
| background_opacity: Opacity of background image (0.0-1.0) | |
| Returns: | |
| Plotly Figure object | |
| """ | |
| # Prepare expression values | |
| expr_values = expression.copy() | |
| # Apply log transformation if requested | |
| if use_log: | |
| expr_values = np.log1p(expr_values) | |
| expr_label = f"log1p({gene_name})" | |
| else: | |
| expr_label = gene_name | |
| # Extract coordinates | |
| x = spatial_coords[:, 0] | |
| y = spatial_coords[:, 1] | |
| # Create figure | |
| fig = go.Figure() | |
| # Add background image if provided | |
| if background_image is not None: | |
| try: | |
| # Handle different input types | |
| if isinstance(background_image, str): | |
| # File path | |
| img = Image.open(background_image) | |
| img_array = np.array(img) | |
| elif isinstance(background_image, np.ndarray): | |
| img_array = background_image | |
| else: | |
| img_array = None | |
| if img_array is not None: | |
| # Convert numpy array to PIL Image for Plotly | |
| if img_array.dtype == np.float64 or img_array.dtype == np.float32: | |
| # Normalize float images to 0-255 | |
| img_array = (img_array * 255).astype(np.uint8) | |
| img = Image.fromarray(img_array) | |
| # Calculate image bounds in spatial coordinate system | |
| # The spatial coordinates in adata.obsm['spatial'] are in full-resolution pixel space | |
| # The stored image is scaled down by scalefactors | |
| img_height, img_width = img_array.shape[:2] | |
| # Determine the scale factor used for this image | |
| if scalefactors: | |
| # Get scale factor based on image_key (passed via scalefactors dict) | |
| image_key = scalefactors.get('_image_key', 'hires') | |
| if image_key == 'lowres': | |
| scale = scalefactors.get('tissue_lowres_scalef', 1.0) | |
| else: | |
| scale = scalefactors.get('tissue_hires_scalef', 1.0) | |
| # Image spans from (0,0) to (img_width/scale, img_height/scale) in spatial coords | |
| img_x_min = 0 | |
| img_y_min = 0 | |
| img_x_max = img_width / scale | |
| img_y_max = img_height / scale | |
| else: | |
| # No scalefactors: fit image to coordinate bounds with padding | |
| padding = 0.05 # 5% padding | |
| x_range = x.max() - x.min() | |
| y_range = y.max() - y.min() | |
| img_x_min = x.min() - x_range * padding | |
| img_y_min = y.min() - y_range * padding | |
| img_x_max = x.max() + x_range * padding | |
| img_y_max = y.max() + y_range * padding | |
| # Convert to base64 for Plotly (use JPEG for faster encoding) | |
| buffered = BytesIO() | |
| # Convert RGBA to RGB if needed for JPEG | |
| if img.mode == 'RGBA': | |
| img_rgb = Image.new('RGB', img.size, (255, 255, 255)) | |
| img_rgb.paste(img, mask=img.split()[3]) | |
| img = img_rgb | |
| img.save(buffered, format="JPEG", quality=85) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| img_src = f"data:image/jpeg;base64,{img_base64}" | |
| # With Y-axis reversed (autorange="reversed"), smaller Y is at top | |
| # Image anchor point is top-left, so y should be img_y_min (top of image) | |
| fig.add_layout_image( | |
| dict( | |
| source=img_src, | |
| xref="x", | |
| yref="y", | |
| x=img_x_min, | |
| y=img_y_min, # Top of image (smallest Y value) | |
| sizex=img_x_max - img_x_min, | |
| sizey=img_y_max - img_y_min, | |
| sizing="stretch", | |
| opacity=background_opacity, | |
| layer="below", | |
| yanchor="top", | |
| ) | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not load background image: {e}") | |
| # Add scatter plot | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="markers", | |
| marker=dict( | |
| size=point_size, | |
| color=expr_values, | |
| colorscale=colorscale, | |
| showscale=True, | |
| colorbar=dict(title=expr_label), | |
| line=dict(width=0), | |
| ), | |
| text=[f"Expression: {val:.2f}" for val in expr_values], | |
| hovertemplate="<b>%{text}</b><br>" | |
| + "X: %{x:.1f}<br>" | |
| + "Y: %{y:.1f}<br>" | |
| + "<extra></extra>", | |
| ) | |
| ) | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Spatial Expression: {gene_name}", | |
| x=0.5, | |
| xanchor="center", | |
| font=dict(size=18), | |
| ), | |
| xaxis=dict( | |
| title="Spatial X", | |
| showgrid=False, | |
| zeroline=False, | |
| ), | |
| yaxis=dict( | |
| title="Spatial Y", | |
| showgrid=False, | |
| zeroline=False, | |
| scaleanchor="x", | |
| scaleratio=1, | |
| autorange="reversed", # Flip Y-axis to match image coordinate system | |
| ), | |
| width=width, | |
| height=height, | |
| hovermode="closest", | |
| plot_bgcolor="white", | |
| ) | |
| return fig | |
| def create_overview_plot( | |
| spatial_coords: np.ndarray, | |
| width: int = 600, | |
| height: int = 600, | |
| ) -> go.Figure: | |
| """ | |
| Create overview plot of spatial coordinates (without gene expression) | |
| Args: | |
| spatial_coords: Nx2 array of spatial coordinates | |
| width: Figure width in pixels | |
| height: Figure height in pixels | |
| Returns: | |
| Plotly Figure object | |
| """ | |
| x = spatial_coords[:, 0] | |
| y = spatial_coords[:, 1] | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="markers", | |
| marker=dict( | |
| size=3, | |
| color="lightblue", | |
| line=dict(width=0), | |
| ), | |
| hovertemplate="X: %{x:.1f}<br>Y: %{y:.1f}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text="Spatial Overview", | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| xaxis=dict( | |
| title="Spatial X", | |
| showgrid=False, | |
| zeroline=False, | |
| ), | |
| yaxis=dict( | |
| title="Spatial Y", | |
| showgrid=False, | |
| zeroline=False, | |
| scaleanchor="x", | |
| scaleratio=1, | |
| ), | |
| width=width, | |
| height=height, | |
| plot_bgcolor="white", | |
| ) | |
| return fig | |
| def create_overview_plot_with_background( | |
| spatial_coords: np.ndarray, | |
| background_image: Optional[np.ndarray] = None, | |
| scalefactors: Optional[Dict[str, Any]] = None, | |
| width: int = 600, | |
| height: int = 600, | |
| background_opacity: float = 0.6, | |
| ) -> go.Figure: | |
| """ | |
| Create overview plot of spatial coordinates with optional tissue background | |
| Args: | |
| spatial_coords: Nx2 array of spatial coordinates | |
| background_image: Optional background image as numpy array | |
| scalefactors: Scale factors for coordinate mapping | |
| width: Figure width in pixels | |
| height: Figure height in pixels | |
| background_opacity: Opacity of background image | |
| Returns: | |
| Plotly Figure object | |
| """ | |
| x = spatial_coords[:, 0] | |
| y = spatial_coords[:, 1] | |
| fig = go.Figure() | |
| # Add background image if provided | |
| if background_image is not None: | |
| try: | |
| img_array = background_image | |
| if img_array.dtype == np.float64 or img_array.dtype == np.float32: | |
| img_array = (img_array * 255).astype(np.uint8) | |
| img = Image.fromarray(img_array) | |
| img_height, img_width = img_array.shape[:2] | |
| # Calculate image bounds | |
| if scalefactors: | |
| image_key = scalefactors.get('_image_key', 'hires') | |
| if image_key == 'lowres': | |
| scale = scalefactors.get('tissue_lowres_scalef', 1.0) | |
| else: | |
| scale = scalefactors.get('tissue_hires_scalef', 1.0) | |
| img_x_min = 0 | |
| img_y_min = 0 | |
| img_x_max = img_width / scale | |
| img_y_max = img_height / scale | |
| else: | |
| padding = 0.05 | |
| x_range = x.max() - x.min() | |
| y_range = y.max() - y.min() | |
| img_x_min = x.min() - x_range * padding | |
| img_y_min = y.min() - y_range * padding | |
| img_x_max = x.max() + x_range * padding | |
| img_y_max = y.max() + y_range * padding | |
| # Convert to base64 | |
| buffered = BytesIO() | |
| if img.mode == 'RGBA': | |
| img_rgb = Image.new('RGB', img.size, (255, 255, 255)) | |
| img_rgb.paste(img, mask=img.split()[3]) | |
| img = img_rgb | |
| img.save(buffered, format="JPEG", quality=85) | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| img_src = f"data:image/jpeg;base64,{img_base64}" | |
| fig.add_layout_image( | |
| dict( | |
| source=img_src, | |
| xref="x", | |
| yref="y", | |
| x=img_x_min, | |
| y=img_y_min, | |
| sizex=img_x_max - img_x_min, | |
| sizey=img_y_max - img_y_min, | |
| sizing="stretch", | |
| opacity=background_opacity, | |
| layer="below", | |
| yanchor="top", | |
| ) | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not add background image: {e}") | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="markers", | |
| marker=dict( | |
| size=3, | |
| color="rgba(65, 105, 225, 0.7)", # Royal blue with transparency | |
| line=dict(width=0), | |
| ), | |
| hovertemplate="X: %{x:.1f}<br>Y: %{y:.1f}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text="Spatial Overview", | |
| x=0.5, | |
| xanchor="center", | |
| ), | |
| xaxis=dict( | |
| title="Spatial X", | |
| showgrid=False, | |
| zeroline=False, | |
| ), | |
| yaxis=dict( | |
| title="Spatial Y", | |
| showgrid=False, | |
| zeroline=False, | |
| scaleanchor="x", | |
| scaleratio=1, | |
| autorange="reversed", # Match image coordinate system | |
| ), | |
| width=width, | |
| height=height, | |
| plot_bgcolor="white", | |
| ) | |
| return fig | |
| def get_expression_stats(expression: np.ndarray) -> dict: | |
| """ | |
| Calculate basic statistics for expression values | |
| Args: | |
| expression: Expression array | |
| Returns: | |
| Dictionary with statistics | |
| """ | |
| return { | |
| "min": float(np.min(expression)), | |
| "max": float(np.max(expression)), | |
| "mean": float(np.mean(expression)), | |
| "median": float(np.median(expression)), | |
| "std": float(np.std(expression)), | |
| "non_zero_count": int(np.sum(expression > 0)), | |
| "non_zero_percent": float(100 * np.sum(expression > 0) / len(expression)), | |
| } | |