Spaces:
Sleeping
Sleeping
| """ | |
| GLUE model training workflow for multi-omics data integration. | |
| This MCP Server provides 4 tools: | |
| 1. glue_configure_datasets: Configure RNA-seq and ATAC-seq datasets for GLUE model training | |
| 2. glue_train_model: Train GLUE model for multi-omics integration | |
| 3. glue_check_integration_consistency: Evaluate integration quality with consistency scores | |
| 4. glue_generate_embeddings: Generate cell and feature embeddings from trained GLUE model | |
| All tools extracted from `gao-lab/GLUE/docs/training.ipynb`. | |
| """ | |
| import os | |
| from datetime import datetime | |
| from itertools import chain | |
| from pathlib import Path | |
| # Standard imports | |
| from typing import Annotated, Any, Literal | |
| # Domain-specific imports | |
| import anndata as ad | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import numpy as np | |
| import pandas as pd | |
| import scanpy as sc | |
| import scglue | |
| import seaborn as sns | |
| from fastmcp import FastMCP | |
| from matplotlib import rcParams | |
| # Project structure | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent.resolve() | |
| DEFAULT_INPUT_DIR = PROJECT_ROOT / "tmp" / "inputs" | |
| DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "tmp" / "outputs" | |
| INPUT_DIR = Path(os.environ.get("TRAINING_INPUT_DIR", DEFAULT_INPUT_DIR)) | |
| OUTPUT_DIR = Path(os.environ.get("TRAINING_OUTPUT_DIR", DEFAULT_OUTPUT_DIR)) | |
| # Ensure directories exist | |
| INPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Timestamp for unique outputs | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # MCP server instance | |
| training_mcp = FastMCP(name="training") | |
| # Set plot parameters | |
| plt.rcParams["figure.dpi"] = 300 | |
| plt.rcParams["savefig.dpi"] = 300 | |
| scglue.plot.set_publication_params() | |
| rcParams["figure.figsize"] = (4, 4) | |
| def glue_configure_datasets( | |
| # Primary data inputs | |
| rna_path: Annotated[ | |
| str | None, "Path to preprocessed RNA-seq data file with extension .h5ad" | |
| ] = None, | |
| atac_path: Annotated[ | |
| str | None, "Path to preprocessed ATAC-seq data file with extension .h5ad" | |
| ] = None, | |
| guidance_path: Annotated[ | |
| str | None, "Path to guidance graph file with extension .graphml.gz" | |
| ] = None, | |
| # Configuration parameters with tutorial defaults | |
| prob_model: Annotated[ | |
| Literal["NB", "ZINB", "ZIP"], "Probabilistic generative model" | |
| ] = "NB", | |
| use_highly_variable: Annotated[bool, "Use only highly variable features"] = True, | |
| rna_use_layer: Annotated[ | |
| str | None, "RNA data layer to use (None uses .X)" | |
| ] = "counts", | |
| rna_use_rep: Annotated[str, "RNA preprocessing embedding to use"] = "X_pca", | |
| atac_use_rep: Annotated[str, "ATAC preprocessing embedding to use"] = "X_lsi", | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Configure RNA-seq and ATAC-seq datasets for GLUE model training. | |
| Input is preprocessed RNA/ATAC h5ad files and guidance graph, output is configured h5ad files and HVF-filtered guidance graph. | |
| """ | |
| # Input file validation | |
| if rna_path is None: | |
| raise ValueError("Path to RNA-seq data file must be provided") | |
| if atac_path is None: | |
| raise ValueError("Path to ATAC-seq data file must be provided") | |
| if guidance_path is None: | |
| raise ValueError("Path to guidance graph file must be provided") | |
| # File existence validation | |
| rna_file = Path(rna_path) | |
| if not rna_file.exists(): | |
| raise FileNotFoundError(f"RNA-seq file not found: {rna_path}") | |
| atac_file = Path(atac_path) | |
| if not atac_file.exists(): | |
| raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}") | |
| guidance_file = Path(guidance_path) | |
| if not guidance_file.exists(): | |
| raise FileNotFoundError(f"Guidance graph file not found: {guidance_path}") | |
| # Load data | |
| rna = ad.read_h5ad(rna_path) | |
| atac = ad.read_h5ad(atac_path) | |
| guidance = nx.read_graphml(guidance_path) | |
| # Configure datasets | |
| scglue.models.configure_dataset( | |
| rna, | |
| prob_model, | |
| use_highly_variable=use_highly_variable, | |
| use_layer=rna_use_layer, | |
| use_rep=rna_use_rep, | |
| ) | |
| scglue.models.configure_dataset( | |
| atac, prob_model, use_highly_variable=use_highly_variable, use_rep=atac_use_rep | |
| ) | |
| # Extract subgraph with highly variable features | |
| guidance_hvf = guidance.subgraph( | |
| chain( | |
| rna.var.query("highly_variable").index, | |
| atac.var.query("highly_variable").index, | |
| ) | |
| ).copy() | |
| # Note: anndata drops None values during save/load, but scglue's configure_dataset | |
| # creates these fields. We preserve them by converting None to a special string marker. | |
| for adata in [rna, atac]: | |
| if "__scglue__" in adata.uns: | |
| config = adata.uns["__scglue__"] | |
| # Convert None values to string markers that will survive serialization | |
| for key in [ | |
| "batches", | |
| "use_batch", | |
| "use_cell_type", | |
| "cell_types", | |
| "use_dsc_weight", | |
| "use_layer", | |
| ]: | |
| if key in config and config[key] is None: | |
| config[key] = "__none__" | |
| # Save configured datasets and HVF guidance graph | |
| if out_prefix is None: | |
| out_prefix = f"glue_configured_{timestamp}" | |
| rna_output = OUTPUT_DIR / f"{out_prefix}_rna_configured.h5ad" | |
| atac_output = OUTPUT_DIR / f"{out_prefix}_atac_configured.h5ad" | |
| guidance_hvf_output = OUTPUT_DIR / f"{out_prefix}_guidance_hvf.graphml.gz" | |
| rna.write(str(rna_output), compression="gzip") | |
| atac.write(str(atac_output), compression="gzip") | |
| nx.write_graphml(guidance_hvf, str(guidance_hvf_output)) | |
| # Return standardized format | |
| return { | |
| "message": f"Configured datasets with {len(rna.var.query('highly_variable'))} RNA and {len(atac.var.query('highly_variable'))} ATAC HVFs", | |
| "reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb", | |
| "artifacts": [ | |
| { | |
| "description": "Configured RNA-seq data", | |
| "path": str(rna_output.resolve()), | |
| }, | |
| { | |
| "description": "Configured ATAC-seq data", | |
| "path": str(atac_output.resolve()), | |
| }, | |
| { | |
| "description": "HVF-filtered guidance graph", | |
| "path": str(guidance_hvf_output.resolve()), | |
| }, | |
| ], | |
| } | |
| def glue_train_model( | |
| # Primary data inputs | |
| rna_path: Annotated[ | |
| str | None, "Path to configured RNA-seq data file with extension .h5ad" | |
| ] = None, | |
| atac_path: Annotated[ | |
| str | None, "Path to configured ATAC-seq data file with extension .h5ad" | |
| ] = None, | |
| guidance_hvf_path: Annotated[ | |
| str | None, | |
| "Path to HVF-filtered guidance graph file with extension .graphml.gz", | |
| ] = None, | |
| # Training parameters | |
| training_dir: Annotated[ | |
| str | None, "Directory to store model snapshots and training logs" | |
| ] = None, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Train GLUE model for multi-omics integration. | |
| Input is configured RNA/ATAC h5ad files and HVF guidance graph, output is trained GLUE model. | |
| """ | |
| # Input file validation | |
| if rna_path is None: | |
| raise ValueError("Path to configured RNA-seq data file must be provided") | |
| if atac_path is None: | |
| raise ValueError("Path to configured ATAC-seq data file must be provided") | |
| if guidance_hvf_path is None: | |
| raise ValueError("Path to HVF-filtered guidance graph file must be provided") | |
| # File existence validation | |
| rna_file = Path(rna_path) | |
| if not rna_file.exists(): | |
| raise FileNotFoundError(f"RNA-seq file not found: {rna_path}") | |
| atac_file = Path(atac_path) | |
| if not atac_file.exists(): | |
| raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}") | |
| guidance_hvf_file = Path(guidance_hvf_path) | |
| if not guidance_hvf_file.exists(): | |
| raise FileNotFoundError( | |
| f"Guidance HVF graph file not found: {guidance_hvf_path}" | |
| ) | |
| # Load data | |
| rna = ad.read_h5ad(rna_path) | |
| atac = ad.read_h5ad(atac_path) | |
| guidance_hvf = nx.read_graphml(guidance_hvf_path) | |
| # Convert string markers back to None for scglue compatibility | |
| for adata in [rna, atac]: | |
| if "__scglue__" in adata.uns: | |
| config = adata.uns["__scglue__"] | |
| for key in [ | |
| "batches", | |
| "use_batch", | |
| "use_cell_type", | |
| "cell_types", | |
| "use_dsc_weight", | |
| "use_layer", | |
| ]: | |
| if key in config and config[key] == "__none__": | |
| config[key] = None | |
| # Set training directory | |
| if training_dir is None: | |
| if out_prefix is None: | |
| out_prefix = f"glue_model_{timestamp}" | |
| training_dir = str(OUTPUT_DIR / f"{out_prefix}_training") | |
| # Create training directory | |
| Path(training_dir).mkdir(parents=True, exist_ok=True) | |
| # Train GLUE model | |
| glue = scglue.models.fit_SCGLUE( | |
| {"rna": rna, "atac": atac}, guidance_hvf, fit_kws={"directory": training_dir} | |
| ) | |
| # Save trained model | |
| if out_prefix is None: | |
| out_prefix = f"glue_model_{timestamp}" | |
| model_output = OUTPUT_DIR / f"{out_prefix}.dill" | |
| glue.save(str(model_output)) | |
| # Return standardized format | |
| return { | |
| "message": "GLUE model training completed successfully", | |
| "reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb", | |
| "artifacts": [ | |
| {"description": "Trained GLUE model", "path": str(model_output.resolve())}, | |
| { | |
| "description": "Training logs directory", | |
| "path": str(Path(training_dir).resolve()), | |
| }, | |
| ], | |
| } | |
| def glue_check_integration_consistency( | |
| # Primary data inputs | |
| model_path: Annotated[ | |
| str | None, "Path to trained GLUE model file with extension .dill" | |
| ] = None, | |
| rna_path: Annotated[ | |
| str | None, "Path to configured RNA-seq data file with extension .h5ad" | |
| ] = None, | |
| atac_path: Annotated[ | |
| str | None, "Path to configured ATAC-seq data file with extension .h5ad" | |
| ] = None, | |
| guidance_hvf_path: Annotated[ | |
| str | None, | |
| "Path to HVF-filtered guidance graph file with extension .graphml.gz", | |
| ] = None, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Evaluate integration quality with consistency scores across metacell granularities. | |
| Input is trained model, RNA/ATAC data, and HVF guidance graph, output is consistency scores table and plot. | |
| """ | |
| # Input file validation | |
| if model_path is None: | |
| raise ValueError("Path to trained GLUE model file must be provided") | |
| if rna_path is None: | |
| raise ValueError("Path to configured RNA-seq data file must be provided") | |
| if atac_path is None: | |
| raise ValueError("Path to configured ATAC-seq data file must be provided") | |
| if guidance_hvf_path is None: | |
| raise ValueError("Path to HVF-filtered guidance graph file must be provided") | |
| # File existence validation | |
| model_file = Path(model_path) | |
| if not model_file.exists(): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| rna_file = Path(rna_path) | |
| if not rna_file.exists(): | |
| raise FileNotFoundError(f"RNA-seq file not found: {rna_path}") | |
| atac_file = Path(atac_path) | |
| if not atac_file.exists(): | |
| raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}") | |
| guidance_hvf_file = Path(guidance_hvf_path) | |
| if not guidance_hvf_file.exists(): | |
| raise FileNotFoundError( | |
| f"Guidance HVF graph file not found: {guidance_hvf_path}" | |
| ) | |
| # Load data | |
| glue = scglue.models.load_model(model_path) | |
| rna = ad.read_h5ad(rna_path) | |
| atac = ad.read_h5ad(atac_path) | |
| guidance_hvf = nx.read_graphml(guidance_hvf_path) | |
| # Convert string markers back to None for scglue compatibility | |
| for adata in [rna, atac]: | |
| if "__scglue__" in adata.uns: | |
| config = adata.uns["__scglue__"] | |
| for key in [ | |
| "batches", | |
| "use_batch", | |
| "use_cell_type", | |
| "cell_types", | |
| "use_dsc_weight", | |
| "use_layer", | |
| ]: | |
| if key in config and config[key] == "__none__": | |
| config[key] = None | |
| # Compute integration consistency | |
| dx = scglue.models.integration_consistency( | |
| glue, {"rna": rna, "atac": atac}, guidance_hvf | |
| ) | |
| # Save consistency scores | |
| if out_prefix is None: | |
| out_prefix = f"glue_consistency_{timestamp}" | |
| consistency_table = OUTPUT_DIR / f"{out_prefix}_scores.csv" | |
| dx.to_csv(str(consistency_table), index=False) | |
| # Generate consistency plot | |
| plt.figure(figsize=(4, 4)) | |
| ax = sns.lineplot(x="n_meta", y="consistency", data=dx) | |
| ax.axhline(y=0.05, c="darkred", ls="--") | |
| plt.xlabel("Number of metacells") | |
| plt.ylabel("Consistency score") | |
| plt.tight_layout() | |
| consistency_plot = OUTPUT_DIR / f"{out_prefix}_plot.png" | |
| plt.savefig(str(consistency_plot), dpi=300, bbox_inches="tight") | |
| plt.close() | |
| # Return standardized format | |
| return { | |
| "message": f"Integration consistency computed (range: {dx['consistency'].min():.3f}-{dx['consistency'].max():.3f})", | |
| "reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb", | |
| "artifacts": [ | |
| { | |
| "description": "Consistency scores table", | |
| "path": str(consistency_table.resolve()), | |
| }, | |
| { | |
| "description": "Consistency plot", | |
| "path": str(consistency_plot.resolve()), | |
| }, | |
| ], | |
| } | |
| def glue_generate_embeddings( | |
| # Primary data inputs | |
| model_path: Annotated[ | |
| str | None, "Path to trained GLUE model file with extension .dill" | |
| ] = None, | |
| rna_path: Annotated[ | |
| str | None, "Path to configured RNA-seq data file with extension .h5ad" | |
| ] = None, | |
| atac_path: Annotated[ | |
| str | None, "Path to configured ATAC-seq data file with extension .h5ad" | |
| ] = None, | |
| guidance_hvf_path: Annotated[ | |
| str | None, | |
| "Path to HVF-filtered guidance graph file with extension .graphml.gz", | |
| ] = None, | |
| # Visualization parameters with tutorial defaults | |
| color_vars: Annotated[list, "Variables to color UMAP by"] = ["cell_type", "domain"], | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Generate cell and feature embeddings from trained GLUE model and visualize alignment. | |
| Input is trained model and RNA/ATAC data, output is h5ad files with embeddings and UMAP visualization. | |
| """ | |
| # Input file validation | |
| if model_path is None: | |
| raise ValueError("Path to trained GLUE model file must be provided") | |
| if rna_path is None: | |
| raise ValueError("Path to configured RNA-seq data file must be provided") | |
| if atac_path is None: | |
| raise ValueError("Path to configured ATAC-seq data file must be provided") | |
| if guidance_hvf_path is None: | |
| raise ValueError("Path to HVF-filtered guidance graph file must be provided") | |
| # File existence validation | |
| model_file = Path(model_path) | |
| if not model_file.exists(): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| rna_file = Path(rna_path) | |
| if not rna_file.exists(): | |
| raise FileNotFoundError(f"RNA-seq file not found: {rna_path}") | |
| atac_file = Path(atac_path) | |
| if not atac_file.exists(): | |
| raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}") | |
| guidance_hvf_file = Path(guidance_hvf_path) | |
| if not guidance_hvf_file.exists(): | |
| raise FileNotFoundError( | |
| f"Guidance HVF graph file not found: {guidance_hvf_path}" | |
| ) | |
| # Load data | |
| glue = scglue.models.load_model(model_path) | |
| rna = ad.read_h5ad(rna_path) | |
| atac = ad.read_h5ad(atac_path) | |
| guidance_hvf = nx.read_graphml(guidance_hvf_path) | |
| # Convert string markers back to None for scglue compatibility | |
| for adata in [rna, atac]: | |
| if "__scglue__" in adata.uns: | |
| config = adata.uns["__scglue__"] | |
| for key in [ | |
| "batches", | |
| "use_batch", | |
| "use_cell_type", | |
| "cell_types", | |
| "use_dsc_weight", | |
| "use_layer", | |
| ]: | |
| if key in config and config[key] == "__none__": | |
| config[key] = None | |
| # Generate cell embeddings | |
| rna.obsm["X_glue"] = glue.encode_data("rna", rna) | |
| atac.obsm["X_glue"] = glue.encode_data("atac", atac) | |
| # Generate feature embeddings | |
| feature_embeddings = glue.encode_graph(guidance_hvf) | |
| feature_embeddings = pd.DataFrame(feature_embeddings, index=glue.vertices) | |
| rna.varm["X_glue"] = feature_embeddings.reindex(rna.var_names).to_numpy() | |
| atac.varm["X_glue"] = feature_embeddings.reindex(atac.var_names).to_numpy() | |
| # Create combined dataset for visualization | |
| combined = ad.concat([rna, atac]) | |
| # Generate UMAP visualization | |
| sc.pp.neighbors(combined, use_rep="X_glue", metric="cosine") | |
| sc.tl.umap(combined) | |
| sc.pl.umap(combined, color=color_vars, wspace=0.65) | |
| # Save UMAP plot | |
| if out_prefix is None: | |
| out_prefix = f"glue_embeddings_{timestamp}" | |
| umap_plot = OUTPUT_DIR / f"{out_prefix}_umap.png" | |
| plt.savefig(str(umap_plot), dpi=300, bbox_inches="tight") | |
| plt.close() | |
| # Save h5ad files with embeddings | |
| rna_output = OUTPUT_DIR / f"{out_prefix}_rna_emb.h5ad" | |
| atac_output = OUTPUT_DIR / f"{out_prefix}_atac_emb.h5ad" | |
| guidance_hvf_output = OUTPUT_DIR / f"{out_prefix}_guidance_hvf.graphml.gz" | |
| rna.write(str(rna_output), compression="gzip") | |
| atac.write(str(atac_output), compression="gzip") | |
| nx.write_graphml(guidance_hvf, str(guidance_hvf_output)) | |
| # Return standardized format | |
| return { | |
| "message": f"Generated embeddings for {rna.n_obs} RNA and {atac.n_obs} ATAC cells", | |
| "reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb", | |
| "artifacts": [ | |
| { | |
| "description": "RNA data with embeddings", | |
| "path": str(rna_output.resolve()), | |
| }, | |
| { | |
| "description": "ATAC data with embeddings", | |
| "path": str(atac_output.resolve()), | |
| }, | |
| { | |
| "description": "HVF guidance graph", | |
| "path": str(guidance_hvf_output.resolve()), | |
| }, | |
| {"description": "UMAP visualization", "path": str(umap_plot.resolve())}, | |
| ], | |
| } | |