import gradio as gr import os import io import zipfile import tempfile import csv import datetime from pathlib import Path from typing import Optional, Tuple, List, Dict import numpy as np import plotly.graph_objects as go from utils.loader import H5adLoader from utils.validator import AnnDataValidator from utils.plot import SpatialPlotter, SpatialImageExtractor from utils.data_source_manager import DataSourceManager class SpatialViewer: """Main application class for spatial transcriptomics viewer""" # Default demo dataset to load on startup DEFAULT_DEMO = "Cerebellum-MALDI-MSI.h5ad" def __init__(self): self.data_manager = DataSourceManager() self.current_source = None def load_default_demo(self) -> Tuple[str, Optional[gr.Plot], gr.update, gr.update, str]: """ Load default demo dataset on app startup Returns: Tuple of (status, overview_plot, selector_update, row_visibility, dataset_info) """ demo_path = Path("data") / self.DEFAULT_DEMO if not demo_path.exists(): return ( "Demo dataset not found. Please load data manually.", None, gr.update(), gr.update(visible=False), "No dataset loaded" ) try: adata = H5adLoader.load_from_source(str(demo_path)) # Validate data is_valid, errors = AnnDataValidator.validate(adata) if not is_valid: return ( "Demo dataset validation failed: " + "; ".join(errors), None, gr.update(), gr.update(visible=False), "No dataset loaded" ) # Add to data manager source_id = self.data_manager.add_source( name=self.DEFAULT_DEMO, source_type="demo", source_path=str(demo_path), adata=adata ) # Create overview plot spatial_coords = adata.obsm["spatial"] overview_fig = SpatialPlotter.create_overview_plot(spatial_coords) status = ( f"β Auto-loaded demo dataset!\n" f"- Dataset: {self.DEFAULT_DEMO}\n" f"- Observations (spots/cells): {adata.n_obs:,}\n" f"- Variables (genes): {adata.n_vars:,}\n" f"- Spatial coordinates: {spatial_coords.shape}\n" f"\nReady to visualize gene expression. Switch to 'Visualize Gene' tab." ) # Dataset selector update choices = self.data_manager.get_source_choices() selector_update = gr.update( choices=choices, value=self.data_manager.current_id, visible=True ) # Dataset info for Visualize tab current_source = self.data_manager.get_current_source() dataset_info = f"π Current: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)" return status, overview_fig, selector_update, gr.update(visible=True), dataset_info except Exception as e: return ( f"Failed to load demo dataset: {str(e)}", None, gr.update(), gr.update(visible=False), "No dataset loaded" ) def load_data( self, source_type: str, demo_dataset: Optional[str] = None, url: Optional[str] = None, file_path: Optional[str] = None ) -> Tuple[str, Optional[gr.Plot], gr.update]: """ Load h5ad data from various sources Now supports ZIP files containing multiple h5ad files Args: source_type: Type of source ('demo', 'url', 'upload') demo_dataset: Selected demo dataset name (if source_type is 'demo') url: URL to h5ad file (if source_type is 'url') file_path: Path to uploaded file (if source_type is 'upload') Returns: Tuple of (status_message, overview_plot, dataset_selector_update) """ try: # Determine source if source_type == "demo": if not demo_dataset: return "Please select a demo dataset.", None, gr.update() demo_path = Path("data") / demo_dataset if not demo_path.exists(): return f"Demo dataset not found: {demo_dataset}", None, gr.update() source = str(demo_path) display_name = demo_dataset elif source_type == "url": if not url or url.strip() == "": return "Please provide a valid URL.", None, gr.update() source = url.strip() display_name = source.split("/")[-1] or "URL Dataset" elif source_type == "upload": if not file_path: return "Please upload a file.", None, gr.update() source = file_path display_name = Path(file_path).name else: return f"Unknown source type: {source_type}", None, gr.update() # Load data loaded_data = H5adLoader.load_from_source(source) # Handle multiple datasets (from ZIP file) if isinstance(loaded_data, list): # Multiple h5ad files loaded from ZIP status_messages = [] loaded_count = 0 for idx, adata in enumerate(loaded_data): # Validate each dataset is_valid, errors = AnnDataValidator.validate(adata) if not is_valid: status_messages.append( f"Dataset {idx + 1} validation failed:\n" + "\n".join(f" - {e}" for e in errors) ) continue # Add to data manager file_name = f"{display_name} - Part {idx + 1}" source_id = self.data_manager.add_source( name=file_name, source_type=source_type, source_path=source, adata=adata ) loaded_count += 1 if loaded_count == 0: return "No valid datasets found in ZIP file.\n" + "\n".join(status_messages), None, gr.update() # Get current (latest loaded) dataset current_source = self.data_manager.get_current_source() spatial_coords = current_source.adata.obsm["spatial"] overview_fig = SpatialPlotter.create_overview_plot(spatial_coords) status = ( f"Successfully loaded {loaded_count} dataset(s) from ZIP file!\n\n" f"Current dataset: {current_source.name}\n" f"- Observations (spots/cells): {current_source.n_obs:,}\n" f"- Variables (genes): {current_source.n_vars:,}\n" f"- Spatial coordinates: {spatial_coords.shape}\n" f"\nUse the dataset selector above to switch between datasets.\n" f"Ready to visualize gene expression." ) else: # Single h5ad file adata = loaded_data # Validate data is_valid, errors = AnnDataValidator.validate(adata) if not is_valid: error_msg = "Validation errors:\n" + "\n".join(f"- {e}" for e in errors) return error_msg, None, gr.update() # Add to data manager source_id = self.data_manager.add_source( name=display_name, source_type=source_type, source_path=source, adata=adata ) # Create overview plot spatial_coords = adata.obsm["spatial"] overview_fig = SpatialPlotter.create_overview_plot(spatial_coords) status = ( f"Successfully loaded data!\n" f"- Dataset: {display_name}\n" f"- Observations (spots/cells): {adata.n_obs:,}\n" f"- Variables (genes): {adata.n_vars:,}\n" f"- Spatial coordinates: {spatial_coords.shape}\n" f"\nReady to visualize gene expression." ) # Update dataset selector choices = self.data_manager.get_source_choices() selector_update = gr.update( choices=choices, value=self.data_manager.current_id, visible=True ) return status, overview_fig, selector_update except Exception as e: return f"Error loading data: {str(e)}", None, gr.update() def switch_dataset(self, source_id: str) -> Tuple[str, Optional[gr.Plot]]: """ Switch to a different loaded dataset Args: source_id: ID of the dataset to switch to Returns: Tuple of (info_message, overview_plot) """ if not source_id: return "No dataset selected.", None success = self.data_manager.set_current(source_id) if not success: return f"Dataset not found: {source_id}", None current_source = self.data_manager.get_current_source() spatial_coords = current_source.adata.obsm["spatial"] overview_fig = SpatialPlotter.create_overview_plot(spatial_coords) info = current_source.get_info() return info, overview_fig def visualize_gene( self, gene_name: str, point_size: int = 5, use_log: bool = True, colorscale: str = "Viridis", show_background: bool = False, background_opacity: float = 0.5, ) -> Tuple[str, Optional[gr.Plot], str, str]: """ Visualize gene expression in spatial context """ current_source = self.data_manager.get_current_source() if current_source is None: return "β Please load data first.", None, "", "" if current_source.adata is None: return "β Dataset registered but not yet loaded. Please select it in the 'Select Dataset' tab first.", None, "", "" if not gene_name or gene_name.strip() == "": return "β Please enter a gene name.", None, "", "" gene_name = gene_name.strip() try: adata = current_source.adata # Get gene expression expression = AnnDataValidator.get_gene_expression(adata, gene_name) # Get spatial coordinates spatial_coords = adata.obsm["spatial"] # Extract background image from h5ad if requested background_image = None scalefactors = None bg_status = "" if show_background: result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True) if result is not None: background_image, scalefactors, image_key = result # Pass image_key to scalefactors so plot knows which scale to use scalefactors = dict(scalefactors) # Make a copy scalefactors['_image_key'] = image_key bg_status = f" (with {image_key} tissue background)" else: bg_status = " (no background image in h5ad)" # Create plot fig = SpatialPlotter.plot_spatial_gene( spatial_coords=spatial_coords, expression=expression, gene_name=gene_name, point_size=point_size, use_log=use_log, colorscale=colorscale, background_image=background_image, scalefactors=scalefactors, background_opacity=background_opacity, ) # Get statistics stats = SpatialPlotter.get_expression_stats(expression) stats_text = ( f"Expression Statistics for {gene_name}:\n" f"- Min: {stats['min']:.4f}\n" f"- Max: {stats['max']:.4f}\n" f"- Mean: {stats['mean']:.4f}\n" f"- Median: {stats['median']:.4f}\n" f"- Std Dev: {stats['std']:.4f}\n" f"- Non-zero: {stats['non_zero_count']:,} ({stats['non_zero_percent']:.1f}%)" ) # Current dataset info dataset_info = f"Current dataset: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)" return f"Successfully visualized gene: {gene_name}{bg_status}", fig, stats_text, dataset_info except ValueError as e: return str(e), None, "", "" except Exception as e: return f"Error visualizing gene: {str(e)}", None, "", "" def check_spatial_image_available(self) -> bool: """Check if current dataset has spatial background image""" current_source = self.data_manager.get_current_source() if current_source is None or current_source.adata is None: return False return SpatialImageExtractor.has_spatial_image(current_source.adata) def get_gene_suggestions(self, limit: int = 100) -> list: """Get list of available genes for autocomplete""" current_source = self.data_manager.get_current_source() if current_source is None or current_source.adata is None: return [] return AnnDataValidator.get_gene_list(current_source.adata, limit=limit) def get_current_dataset_info(self) -> str: """Get formatted info string for current dataset""" current_source = self.data_manager.get_current_source() if current_source is None: return "No dataset loaded. Please load data first." if current_source.adata is None: return f"π Current: {current_source.name}\n(Not yet loaded)" return f"π Current: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)" def get_all_genes(self) -> List[str]: """Get full list of genes for autocomplete dropdown""" current_source = self.data_manager.get_current_source() if current_source is None or current_source.adata is None: return [] return list(current_source.adata.var_names) def search_genes(self, query: str, limit: int = 50) -> List[str]: """ Search genes by prefix or substring match """ current_source = self.data_manager.get_current_source() if current_source is None or current_source.adata is None: return [] if not query or query.strip() == "": # Return first N genes if no query return list(current_source.adata.var_names[:limit]) query = query.strip().upper() all_genes = list(current_source.adata.var_names) # First: exact prefix matches (prioritized) prefix_matches = [g for g in all_genes if g.upper().startswith(query)] # Second: substring matches (lower priority) substring_matches = [g for g in all_genes if query in g.upper() and g not in prefix_matches] # Combine and limit results = prefix_matches + substring_matches return results[:limit] def get_adata_summary(self) -> str: """ Get detailed summary of current AnnData object Returns: Formatted string with h5ad file details """ current_source = self.data_manager.get_current_source() if current_source is None: return "No dataset loaded" if current_source.adata is None: return f"π **{current_source.name}**\n\n*Dataset registered but not yet loaded. Select it in the list to load.*" adata = current_source.adata lines = [] lines.append(f"π **{current_source.name}**") lines.append("") # Basic info lines.append("### π Dimensions") lines.append(f"- Observations (cells/spots): **{adata.n_obs:,}**") lines.append(f"- Variables (features): **{adata.n_vars:,}**") # Spatial coordinates if "spatial" in adata.obsm: spatial_shape = adata.obsm["spatial"].shape lines.append(f"- Spatial coordinates: **{spatial_shape}**") lines.append("") # Variables info (first 5) lines.append("### 𧬠Variables (first 5)") var_names = list(adata.var_names[:5]) lines.append(f"`{', '.join(var_names)}`") if adata.n_vars > 5: lines.append(f"... and {adata.n_vars - 5:,} more") lines.append("") # obsm keys if len(adata.obsm.keys()) > 0: lines.append("### π obsm (embeddings)") for key in list(adata.obsm.keys())[:5]: shape = adata.obsm[key].shape lines.append(f"- `{key}`: {shape}") # obsp keys if hasattr(adata, 'obsp') and len(adata.obsp.keys()) > 0: lines.append("") lines.append("### π obsp (pairwise)") for key in list(adata.obsp.keys())[:3]: lines.append(f"- `{key}`") # uns keys if len(adata.uns.keys()) > 0: lines.append("") lines.append("### π¦ uns (unstructured)") uns_keys = list(adata.uns.keys())[:6] lines.append(f"`{', '.join(uns_keys)}`") if len(adata.uns.keys()) > 6: lines.append(f"... and {len(adata.uns.keys()) - 6} more") # Check for spatial image lines.append("") lines.append("### πΌοΈ Spatial Image") if SpatialImageExtractor.has_spatial_image(adata): libs = SpatialImageExtractor.get_available_libraries(adata) lines.append(f"β Available (libraries: {', '.join(libs)})") else: lines.append("β Not available") return "\n".join(lines) def get_local_h5ad_files(self) -> List[str]: """Get list of h5ad files in the data folder""" data_dir = Path("data") if not data_dir.exists(): return [] return [f.name for f in data_dir.glob("*.h5ad")] def create_overview_with_background(self) -> Optional[go.Figure]: """Create spatial overview plot with tissue background if available""" current_source = self.data_manager.get_current_source() if current_source is None or current_source.adata is None: return None adata = current_source.adata spatial_coords = adata.obsm["spatial"] # Try to get background image background_image = None scalefactors = None result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True) if result is not None: background_image, scalefactors, image_key = result scalefactors = dict(scalefactors) scalefactors['_image_key'] = image_key # Create overview plot with background return SpatialPlotter.create_overview_plot_with_background( spatial_coords=spatial_coords, background_image=background_image, scalefactors=scalefactors, ) def parse_variables_list(self, input_text: str) -> Tuple[List[str], List[str], List[str]]: """ Parse comma/space/newline separated variables list Args: input_text: Raw input text with variable names Returns: Tuple of (found_features, not_found_features, all_parsed) """ current_source = self.data_manager.get_current_source() if current_source is None: return [], [], [] if not input_text or input_text.strip() == "": return [], [], [] # Parse: split by comma, space, newline, tab import re raw_items = re.split(r'[,\s\n\t]+', input_text.strip()) all_parsed = [item.strip() for item in raw_items if item.strip()] # Check which features exist in dataset available_genes = set(current_source.adata.var_names) found_features = [g for g in all_parsed if g in available_genes] not_found_features = [g for g in all_parsed if g not in available_genes] return found_features, not_found_features, all_parsed def batch_visualize( self, variables_text: str, point_size: int = 5, use_log: bool = True, colorscale: str = "Viridis", show_background: bool = False, background_opacity: float = 0.5, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, Optional[str], str, str]: """ Perform batch visualization for multiple features Args: variables_text: Comma/space/newline separated feature names point_size, use_log, colorscale, show_background, background_opacity: Plot settings progress: Gradio progress tracker Returns: Tuple of (status, zip_file_path, summary_report, stats_csv) """ current_source = self.data_manager.get_current_source() if current_source is None: return "β No dataset loaded. Please load data first.", None, "", "" found_features, not_found_features, all_parsed = self.parse_variables_list(variables_text) if not found_features: return f"β No valid features found in dataset.\nParsed: {', '.join(all_parsed)}", None, "", "" # Prepare output adata = current_source.adata spatial_coords = adata.obsm["spatial"] # Get background image if needed background_image = None scalefactors = None if show_background: result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True) if result is not None: background_image, scalefactors, image_key = result scalefactors = dict(scalefactors) scalefactors['_image_key'] = image_key # Create temp directory for outputs temp_dir = tempfile.mkdtemp(prefix="batch_viz_") # Track results stats_records = [] successful_plots = [] failed_features = [] # Generate plots total = len(found_features) for idx, gene_name in enumerate(found_features): progress((idx + 1) / total, desc=f"Processing {gene_name} ({idx + 1}/{total})") try: # Get expression expression = AnnDataValidator.get_gene_expression(adata, gene_name) # Create plot fig = SpatialPlotter.plot_spatial_gene( spatial_coords=spatial_coords, expression=expression, gene_name=gene_name, point_size=point_size, use_log=use_log, colorscale=colorscale, background_image=background_image, scalefactors=scalefactors, background_opacity=background_opacity, ) # Save as PNG png_path = os.path.join(temp_dir, f"{gene_name}.png") fig.write_image(png_path, scale=2) successful_plots.append((gene_name, png_path)) # Get statistics stats = SpatialPlotter.get_expression_stats(expression) stats['feature'] = gene_name stats_records.append(stats) except Exception as e: failed_features.append((gene_name, str(e))) # Generate summary report report_lines = [ "# Batch Visualization Report", f"Dataset: {current_source.name}", f"Total cells/spots: {current_source.n_obs:,}", f"Total features: {current_source.n_vars:,}", "", "## Settings", f"- Point Size: {point_size}", f"- Log Transform: {use_log}", f"- Color Scale: {colorscale}", f"- Background: {show_background}", "", "## Results Summary", f"- Total requested: {len(all_parsed)}", f"- Found in dataset: {len(found_features)}", f"- Successfully visualized: {len(successful_plots)}", f"- Failed: {len(failed_features)}", "", ] if not_found_features: report_lines.append("## Not Found Features") for feat in not_found_features: report_lines.append(f"- {feat}") report_lines.append("") if failed_features: report_lines.append("## Failed Features") for feat, err in failed_features: report_lines.append(f"- {feat}: {err}") report_lines.append("") report_lines.append("## Successfully Visualized Features") for feat, _ in successful_plots: report_lines.append(f"- {feat}") report_text = "\n".join(report_lines) # Save report report_path = os.path.join(temp_dir, "report.md") with open(report_path, "w") as f: f.write(report_text) # Save statistics CSV stats_csv_path = os.path.join(temp_dir, "expression_statistics.csv") if stats_records: with open(stats_csv_path, "w", newline="") as f: fieldnames = ['feature', 'min', 'max', 'mean', 'median', 'std', 'non_zero_count', 'non_zero_percent'] writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(stats_records) # Create ZIP file zip_path = os.path.join(temp_dir, "batch_visualization.zip") with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: # Add images for gene_name, png_path in successful_plots: zf.write(png_path, f"images/{gene_name}.png") # Add report zf.write(report_path, "report.md") # Add stats CSV if stats_records: zf.write(stats_csv_path, "expression_statistics.csv") # Format stats for display stats_display = "Feature | Min | Max | Mean | Non-zero %\n" stats_display += "--- | --- | --- | --- | ---\n" for rec in stats_records: stats_display += f"{rec['feature']} | {rec['min']:.4f} | {rec['max']:.4f} | {rec['mean']:.4f} | {rec['non_zero_percent']:.1f}%\n" status = f"β Batch visualization complete!\n- Generated: {len(successful_plots)} plots\n- Failed: {len(failed_features)}" return status, zip_path, report_text, stats_display def create_interface(): """Create Gradio interface""" viewer = SpatialViewer() # Custom CSS custom_css = """ .duplicate-notice { background: linear-gradient(135deg, #fff8e1 0%, #ffecb3 100%); color: #3e2723; border: 1px solid #ffc107; border-radius: 8px; padding: 12px 16px; margin: 12px 0; font-size: 0.95rem; line-height: 1.5; } .duplicate-notice b { color: #e65100; } @media (prefers-color-scheme: dark) { .duplicate-notice { background: linear-gradient(135deg, rgba(50,40,20,0.9) 0%, rgba(40,30,10,0.9) 100%); color: #ffffff; border-color: #ffc107; } .duplicate-notice b { color: #ffd54f; } } .file-browser { background: linear-gradient(180deg, #f8f9fa 0%, #e9ecef 100%); border: 1px solid #dee2e6; border-radius: 8px; padding: 12px; } @media (prefers-color-scheme: dark) { .file-browser { background: linear-gradient(180deg, #2d2d2d 0%, #1a1a1a 100%); border-color: #444; } } .data-info-panel { background: linear-gradient(180deg, #e3f2fd 0%, #bbdefb 100%); border: 1px solid #90caf9; border-radius: 8px; padding: 12px; } @media (prefers-color-scheme: dark) { .data-info-panel { background: linear-gradient(180deg, rgba(33,150,243,0.15) 0%, rgba(33,150,243,0.05) 100%); border-color: #1976d2; } } .control-panel { background: linear-gradient(180deg, #f5f5f5 0%, #eeeeee 100%); border: 1px solid #e0e0e0; border-radius: 8px; padding: 16px; } @media (prefers-color-scheme: dark) { .control-panel { background: linear-gradient(180deg, #2a2a2a 0%, #1f1f1f 100%); border-color: #444; } } """ with gr.Blocks( title="Spatial Omics Viewer", theme=gr.themes.Soft(), css=custom_css, ) as app: gr.Markdown( """ # π¬ Spatial Omics Viewer Visualize spatial expression from .h5ad files (AnnData format)