import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from typing import Optional, Tuple, Dict, Any, Union
from PIL import Image
import base64
from io import BytesIO
class SpatialImageExtractor:
"""Extract spatial background images from AnnData objects"""
@staticmethod
def get_spatial_image(
adata,
library_id: Optional[str] = None,
prefer_lowres: bool = True,
) -> Optional[Tuple[np.ndarray, Dict[str, Any], str]]:
"""
Extract spatial background image from AnnData object
Spatial images are typically stored in:
- adata.uns['spatial'][library_id]['images']['hires'] or 'lowres'
- adata.uns['spatial'][library_id]['scalefactors']
Args:
adata: AnnData object
library_id: Library/sample ID. If None, uses first available.
prefer_lowres: If True, prefer lowres image for faster rendering
Returns:
Tuple of (image_array, scalefactors, image_key) or None if not found
"""
try:
# Check if spatial data exists
if 'spatial' not in adata.uns:
return None
spatial_data = adata.uns['spatial']
# Get library_id
if library_id is None:
# Use first available library
if isinstance(spatial_data, dict) and len(spatial_data) > 0:
library_id = list(spatial_data.keys())[0]
else:
return None
if library_id not in spatial_data:
return None
library_data = spatial_data[library_id]
# Get images
if 'images' not in library_data:
return None
images = library_data['images']
# Select image based on preference (lowres is faster)
image_key = None
if prefer_lowres and 'lowres' in images:
img_array = images['lowres']
image_key = 'lowres'
elif 'hires' in images:
img_array = images['hires']
image_key = 'hires'
elif 'lowres' in images:
img_array = images['lowres']
image_key = 'lowres'
else:
return None
# Get scalefactors
scalefactors = library_data.get('scalefactors', {})
return img_array, scalefactors, image_key
except Exception as e:
print(f"Warning: Could not extract spatial image: {e}")
return None
@staticmethod
def get_available_libraries(adata) -> list:
"""Get list of available library IDs with spatial images"""
try:
if 'spatial' not in adata.uns:
return []
return list(adata.uns['spatial'].keys())
except:
return []
@staticmethod
def has_spatial_image(adata) -> bool:
"""Check if AnnData has spatial background image"""
try:
if 'spatial' not in adata.uns:
return False
spatial_data = adata.uns['spatial']
if not isinstance(spatial_data, dict) or len(spatial_data) == 0:
return False
# Check first library
first_lib = list(spatial_data.keys())[0]
lib_data = spatial_data[first_lib]
if 'images' not in lib_data:
return False
images = lib_data['images']
return 'hires' in images or 'lowres' in images
except:
return False
class SpatialPlotter:
"""Create spatial visualizations for gene expression"""
@staticmethod
def plot_spatial_gene(
spatial_coords: np.ndarray,
expression: np.ndarray,
gene_name: str,
point_size: int = 5,
use_log: bool = False,
colorscale: str = "Viridis",
width: int = 800,
height: int = 800,
background_image: Optional[Union[np.ndarray, str]] = None,
scalefactors: Optional[Dict[str, float]] = None,
background_opacity: float = 0.5,
) -> go.Figure:
"""
Create spatial scatter plot of gene expression
Args:
spatial_coords: Nx2 array of spatial coordinates
expression: N-length array of gene expression values
gene_name: Name of the gene
point_size: Size of scatter points
use_log: Whether to apply log1p transformation to expression
colorscale: Plotly colorscale name
width: Figure width in pixels
height: Figure height in pixels
background_image: Background image as numpy array or file path
scalefactors: Scale factors from h5ad for coordinate mapping
background_opacity: Opacity of background image (0.0-1.0)
Returns:
Plotly Figure object
"""
# Prepare expression values
expr_values = expression.copy()
# Apply log transformation if requested
if use_log:
expr_values = np.log1p(expr_values)
expr_label = f"log1p({gene_name})"
else:
expr_label = gene_name
# Extract coordinates
x = spatial_coords[:, 0]
y = spatial_coords[:, 1]
# Create figure
fig = go.Figure()
# Add background image if provided
if background_image is not None:
try:
# Handle different input types
if isinstance(background_image, str):
# File path
img = Image.open(background_image)
img_array = np.array(img)
elif isinstance(background_image, np.ndarray):
img_array = background_image
else:
img_array = None
if img_array is not None:
# Convert numpy array to PIL Image for Plotly
if img_array.dtype == np.float64 or img_array.dtype == np.float32:
# Normalize float images to 0-255
img_array = (img_array * 255).astype(np.uint8)
img = Image.fromarray(img_array)
# Calculate image bounds in spatial coordinate system
# The spatial coordinates in adata.obsm['spatial'] are in full-resolution pixel space
# The stored image is scaled down by scalefactors
img_height, img_width = img_array.shape[:2]
# Determine the scale factor used for this image
if scalefactors:
# Get scale factor based on image_key (passed via scalefactors dict)
image_key = scalefactors.get('_image_key', 'hires')
if image_key == 'lowres':
scale = scalefactors.get('tissue_lowres_scalef', 1.0)
else:
scale = scalefactors.get('tissue_hires_scalef', 1.0)
# Image spans from (0,0) to (img_width/scale, img_height/scale) in spatial coords
img_x_min = 0
img_y_min = 0
img_x_max = img_width / scale
img_y_max = img_height / scale
else:
# No scalefactors: fit image to coordinate bounds with padding
padding = 0.05 # 5% padding
x_range = x.max() - x.min()
y_range = y.max() - y.min()
img_x_min = x.min() - x_range * padding
img_y_min = y.min() - y_range * padding
img_x_max = x.max() + x_range * padding
img_y_max = y.max() + y_range * padding
# Convert to base64 for Plotly (use JPEG for faster encoding)
buffered = BytesIO()
# Convert RGBA to RGB if needed for JPEG
if img.mode == 'RGBA':
img_rgb = Image.new('RGB', img.size, (255, 255, 255))
img_rgb.paste(img, mask=img.split()[3])
img = img_rgb
img.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
img_src = f"data:image/jpeg;base64,{img_base64}"
# With Y-axis reversed (autorange="reversed"), smaller Y is at top
# Image anchor point is top-left, so y should be img_y_min (top of image)
fig.add_layout_image(
dict(
source=img_src,
xref="x",
yref="y",
x=img_x_min,
y=img_y_min, # Top of image (smallest Y value)
sizex=img_x_max - img_x_min,
sizey=img_y_max - img_y_min,
sizing="stretch",
opacity=background_opacity,
layer="below",
yanchor="top",
)
)
except Exception as e:
print(f"Warning: Could not load background image: {e}")
# Add scatter plot
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=point_size,
color=expr_values,
colorscale=colorscale,
showscale=True,
colorbar=dict(title=expr_label),
line=dict(width=0),
),
text=[f"Expression: {val:.2f}" for val in expr_values],
hovertemplate="%{text}
"
+ "X: %{x:.1f}
"
+ "Y: %{y:.1f}
"
+ "",
)
)
# Update layout
fig.update_layout(
title=dict(
text=f"Spatial Expression: {gene_name}",
x=0.5,
xanchor="center",
font=dict(size=18),
),
xaxis=dict(
title="Spatial X",
showgrid=False,
zeroline=False,
),
yaxis=dict(
title="Spatial Y",
showgrid=False,
zeroline=False,
scaleanchor="x",
scaleratio=1,
autorange="reversed", # Flip Y-axis to match image coordinate system
),
width=width,
height=height,
hovermode="closest",
plot_bgcolor="white",
)
return fig
@staticmethod
def create_overview_plot(
spatial_coords: np.ndarray,
width: int = 600,
height: int = 600,
) -> go.Figure:
"""
Create overview plot of spatial coordinates (without gene expression)
Args:
spatial_coords: Nx2 array of spatial coordinates
width: Figure width in pixels
height: Figure height in pixels
Returns:
Plotly Figure object
"""
x = spatial_coords[:, 0]
y = spatial_coords[:, 1]
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=3,
color="lightblue",
line=dict(width=0),
),
hovertemplate="X: %{x:.1f}
Y: %{y:.1f}",
)
)
fig.update_layout(
title=dict(
text="Spatial Overview",
x=0.5,
xanchor="center",
),
xaxis=dict(
title="Spatial X",
showgrid=False,
zeroline=False,
),
yaxis=dict(
title="Spatial Y",
showgrid=False,
zeroline=False,
scaleanchor="x",
scaleratio=1,
),
width=width,
height=height,
plot_bgcolor="white",
)
return fig
@staticmethod
def create_overview_plot_with_background(
spatial_coords: np.ndarray,
background_image: Optional[np.ndarray] = None,
scalefactors: Optional[Dict[str, Any]] = None,
width: int = 600,
height: int = 600,
background_opacity: float = 0.6,
) -> go.Figure:
"""
Create overview plot of spatial coordinates with optional tissue background
Args:
spatial_coords: Nx2 array of spatial coordinates
background_image: Optional background image as numpy array
scalefactors: Scale factors for coordinate mapping
width: Figure width in pixels
height: Figure height in pixels
background_opacity: Opacity of background image
Returns:
Plotly Figure object
"""
x = spatial_coords[:, 0]
y = spatial_coords[:, 1]
fig = go.Figure()
# Add background image if provided
if background_image is not None:
try:
img_array = background_image
if img_array.dtype == np.float64 or img_array.dtype == np.float32:
img_array = (img_array * 255).astype(np.uint8)
img = Image.fromarray(img_array)
img_height, img_width = img_array.shape[:2]
# Calculate image bounds
if scalefactors:
image_key = scalefactors.get('_image_key', 'hires')
if image_key == 'lowres':
scale = scalefactors.get('tissue_lowres_scalef', 1.0)
else:
scale = scalefactors.get('tissue_hires_scalef', 1.0)
img_x_min = 0
img_y_min = 0
img_x_max = img_width / scale
img_y_max = img_height / scale
else:
padding = 0.05
x_range = x.max() - x.min()
y_range = y.max() - y.min()
img_x_min = x.min() - x_range * padding
img_y_min = y.min() - y_range * padding
img_x_max = x.max() + x_range * padding
img_y_max = y.max() + y_range * padding
# Convert to base64
buffered = BytesIO()
if img.mode == 'RGBA':
img_rgb = Image.new('RGB', img.size, (255, 255, 255))
img_rgb.paste(img, mask=img.split()[3])
img = img_rgb
img.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
img_src = f"data:image/jpeg;base64,{img_base64}"
fig.add_layout_image(
dict(
source=img_src,
xref="x",
yref="y",
x=img_x_min,
y=img_y_min,
sizex=img_x_max - img_x_min,
sizey=img_y_max - img_y_min,
sizing="stretch",
opacity=background_opacity,
layer="below",
yanchor="top",
)
)
except Exception as e:
print(f"Warning: Could not add background image: {e}")
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
marker=dict(
size=3,
color="rgba(65, 105, 225, 0.7)", # Royal blue with transparency
line=dict(width=0),
),
hovertemplate="X: %{x:.1f}
Y: %{y:.1f}",
)
)
fig.update_layout(
title=dict(
text="Spatial Overview",
x=0.5,
xanchor="center",
),
xaxis=dict(
title="Spatial X",
showgrid=False,
zeroline=False,
),
yaxis=dict(
title="Spatial Y",
showgrid=False,
zeroline=False,
scaleanchor="x",
scaleratio=1,
autorange="reversed", # Match image coordinate system
),
width=width,
height=height,
plot_bgcolor="white",
)
return fig
@staticmethod
def get_expression_stats(expression: np.ndarray) -> dict:
"""
Calculate basic statistics for expression values
Args:
expression: Expression array
Returns:
Dictionary with statistics
"""
return {
"min": float(np.min(expression)),
"max": float(np.max(expression)),
"mean": float(np.mean(expression)),
"median": float(np.median(expression)),
"std": float(np.std(expression)),
"non_zero_count": int(np.sum(expression > 0)),
"non_zero_percent": float(100 * np.sum(expression > 0) / len(expression)),
}