GLUE-Agent-MCP / src /tools /training.py
dmannk's picture
Upload folder using huggingface_hub
d048db9 verified
"""
GLUE model training workflow for multi-omics data integration.
This MCP Server provides 4 tools:
1. glue_configure_datasets: Configure RNA-seq and ATAC-seq datasets for GLUE model training
2. glue_train_model: Train GLUE model for multi-omics integration
3. glue_check_integration_consistency: Evaluate integration quality with consistency scores
4. glue_generate_embeddings: Generate cell and feature embeddings from trained GLUE model
All tools extracted from `gao-lab/GLUE/docs/training.ipynb`.
"""
import os
from datetime import datetime
from itertools import chain
from pathlib import Path
# Standard imports
from typing import Annotated, Any, Literal
# Domain-specific imports
import anndata as ad
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
import scglue
import seaborn as sns
from fastmcp import FastMCP
from matplotlib import rcParams
# Project structure
PROJECT_ROOT = Path(__file__).parent.parent.parent.resolve()
DEFAULT_INPUT_DIR = PROJECT_ROOT / "tmp" / "inputs"
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "tmp" / "outputs"
INPUT_DIR = Path(os.environ.get("TRAINING_INPUT_DIR", DEFAULT_INPUT_DIR))
OUTPUT_DIR = Path(os.environ.get("TRAINING_OUTPUT_DIR", DEFAULT_OUTPUT_DIR))
# Ensure directories exist
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Timestamp for unique outputs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# MCP server instance
training_mcp = FastMCP(name="training")
# Set plot parameters
plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300
scglue.plot.set_publication_params()
rcParams["figure.figsize"] = (4, 4)
@training_mcp.tool
def glue_configure_datasets(
# Primary data inputs
rna_path: Annotated[
str | None, "Path to preprocessed RNA-seq data file with extension .h5ad"
] = None,
atac_path: Annotated[
str | None, "Path to preprocessed ATAC-seq data file with extension .h5ad"
] = None,
guidance_path: Annotated[
str | None, "Path to guidance graph file with extension .graphml.gz"
] = None,
# Configuration parameters with tutorial defaults
prob_model: Annotated[
Literal["NB", "ZINB", "ZIP"], "Probabilistic generative model"
] = "NB",
use_highly_variable: Annotated[bool, "Use only highly variable features"] = True,
rna_use_layer: Annotated[
str | None, "RNA data layer to use (None uses .X)"
] = "counts",
rna_use_rep: Annotated[str, "RNA preprocessing embedding to use"] = "X_pca",
atac_use_rep: Annotated[str, "ATAC preprocessing embedding to use"] = "X_lsi",
out_prefix: Annotated[str | None, "Output file prefix"] = None,
) -> dict:
"""
Configure RNA-seq and ATAC-seq datasets for GLUE model training.
Input is preprocessed RNA/ATAC h5ad files and guidance graph, output is configured h5ad files and HVF-filtered guidance graph.
"""
# Input file validation
if rna_path is None:
raise ValueError("Path to RNA-seq data file must be provided")
if atac_path is None:
raise ValueError("Path to ATAC-seq data file must be provided")
if guidance_path is None:
raise ValueError("Path to guidance graph file must be provided")
# File existence validation
rna_file = Path(rna_path)
if not rna_file.exists():
raise FileNotFoundError(f"RNA-seq file not found: {rna_path}")
atac_file = Path(atac_path)
if not atac_file.exists():
raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}")
guidance_file = Path(guidance_path)
if not guidance_file.exists():
raise FileNotFoundError(f"Guidance graph file not found: {guidance_path}")
# Load data
rna = ad.read_h5ad(rna_path)
atac = ad.read_h5ad(atac_path)
guidance = nx.read_graphml(guidance_path)
# Configure datasets
scglue.models.configure_dataset(
rna,
prob_model,
use_highly_variable=use_highly_variable,
use_layer=rna_use_layer,
use_rep=rna_use_rep,
)
scglue.models.configure_dataset(
atac, prob_model, use_highly_variable=use_highly_variable, use_rep=atac_use_rep
)
# Extract subgraph with highly variable features
guidance_hvf = guidance.subgraph(
chain(
rna.var.query("highly_variable").index,
atac.var.query("highly_variable").index,
)
).copy()
# Note: anndata drops None values during save/load, but scglue's configure_dataset
# creates these fields. We preserve them by converting None to a special string marker.
for adata in [rna, atac]:
if "__scglue__" in adata.uns:
config = adata.uns["__scglue__"]
# Convert None values to string markers that will survive serialization
for key in [
"batches",
"use_batch",
"use_cell_type",
"cell_types",
"use_dsc_weight",
"use_layer",
]:
if key in config and config[key] is None:
config[key] = "__none__"
# Save configured datasets and HVF guidance graph
if out_prefix is None:
out_prefix = f"glue_configured_{timestamp}"
rna_output = OUTPUT_DIR / f"{out_prefix}_rna_configured.h5ad"
atac_output = OUTPUT_DIR / f"{out_prefix}_atac_configured.h5ad"
guidance_hvf_output = OUTPUT_DIR / f"{out_prefix}_guidance_hvf.graphml.gz"
rna.write(str(rna_output), compression="gzip")
atac.write(str(atac_output), compression="gzip")
nx.write_graphml(guidance_hvf, str(guidance_hvf_output))
# Return standardized format
return {
"message": f"Configured datasets with {len(rna.var.query('highly_variable'))} RNA and {len(atac.var.query('highly_variable'))} ATAC HVFs",
"reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb",
"artifacts": [
{
"description": "Configured RNA-seq data",
"path": str(rna_output.resolve()),
},
{
"description": "Configured ATAC-seq data",
"path": str(atac_output.resolve()),
},
{
"description": "HVF-filtered guidance graph",
"path": str(guidance_hvf_output.resolve()),
},
],
}
@training_mcp.tool
def glue_train_model(
# Primary data inputs
rna_path: Annotated[
str | None, "Path to configured RNA-seq data file with extension .h5ad"
] = None,
atac_path: Annotated[
str | None, "Path to configured ATAC-seq data file with extension .h5ad"
] = None,
guidance_hvf_path: Annotated[
str | None,
"Path to HVF-filtered guidance graph file with extension .graphml.gz",
] = None,
# Training parameters
training_dir: Annotated[
str | None, "Directory to store model snapshots and training logs"
] = None,
out_prefix: Annotated[str | None, "Output file prefix"] = None,
) -> dict:
"""
Train GLUE model for multi-omics integration.
Input is configured RNA/ATAC h5ad files and HVF guidance graph, output is trained GLUE model.
"""
# Input file validation
if rna_path is None:
raise ValueError("Path to configured RNA-seq data file must be provided")
if atac_path is None:
raise ValueError("Path to configured ATAC-seq data file must be provided")
if guidance_hvf_path is None:
raise ValueError("Path to HVF-filtered guidance graph file must be provided")
# File existence validation
rna_file = Path(rna_path)
if not rna_file.exists():
raise FileNotFoundError(f"RNA-seq file not found: {rna_path}")
atac_file = Path(atac_path)
if not atac_file.exists():
raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}")
guidance_hvf_file = Path(guidance_hvf_path)
if not guidance_hvf_file.exists():
raise FileNotFoundError(
f"Guidance HVF graph file not found: {guidance_hvf_path}"
)
# Load data
rna = ad.read_h5ad(rna_path)
atac = ad.read_h5ad(atac_path)
guidance_hvf = nx.read_graphml(guidance_hvf_path)
# Convert string markers back to None for scglue compatibility
for adata in [rna, atac]:
if "__scglue__" in adata.uns:
config = adata.uns["__scglue__"]
for key in [
"batches",
"use_batch",
"use_cell_type",
"cell_types",
"use_dsc_weight",
"use_layer",
]:
if key in config and config[key] == "__none__":
config[key] = None
# Set training directory
if training_dir is None:
if out_prefix is None:
out_prefix = f"glue_model_{timestamp}"
training_dir = str(OUTPUT_DIR / f"{out_prefix}_training")
# Create training directory
Path(training_dir).mkdir(parents=True, exist_ok=True)
# Train GLUE model
glue = scglue.models.fit_SCGLUE(
{"rna": rna, "atac": atac}, guidance_hvf, fit_kws={"directory": training_dir}
)
# Save trained model
if out_prefix is None:
out_prefix = f"glue_model_{timestamp}"
model_output = OUTPUT_DIR / f"{out_prefix}.dill"
glue.save(str(model_output))
# Return standardized format
return {
"message": "GLUE model training completed successfully",
"reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb",
"artifacts": [
{"description": "Trained GLUE model", "path": str(model_output.resolve())},
{
"description": "Training logs directory",
"path": str(Path(training_dir).resolve()),
},
],
}
@training_mcp.tool
def glue_check_integration_consistency(
# Primary data inputs
model_path: Annotated[
str | None, "Path to trained GLUE model file with extension .dill"
] = None,
rna_path: Annotated[
str | None, "Path to configured RNA-seq data file with extension .h5ad"
] = None,
atac_path: Annotated[
str | None, "Path to configured ATAC-seq data file with extension .h5ad"
] = None,
guidance_hvf_path: Annotated[
str | None,
"Path to HVF-filtered guidance graph file with extension .graphml.gz",
] = None,
out_prefix: Annotated[str | None, "Output file prefix"] = None,
) -> dict:
"""
Evaluate integration quality with consistency scores across metacell granularities.
Input is trained model, RNA/ATAC data, and HVF guidance graph, output is consistency scores table and plot.
"""
# Input file validation
if model_path is None:
raise ValueError("Path to trained GLUE model file must be provided")
if rna_path is None:
raise ValueError("Path to configured RNA-seq data file must be provided")
if atac_path is None:
raise ValueError("Path to configured ATAC-seq data file must be provided")
if guidance_hvf_path is None:
raise ValueError("Path to HVF-filtered guidance graph file must be provided")
# File existence validation
model_file = Path(model_path)
if not model_file.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
rna_file = Path(rna_path)
if not rna_file.exists():
raise FileNotFoundError(f"RNA-seq file not found: {rna_path}")
atac_file = Path(atac_path)
if not atac_file.exists():
raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}")
guidance_hvf_file = Path(guidance_hvf_path)
if not guidance_hvf_file.exists():
raise FileNotFoundError(
f"Guidance HVF graph file not found: {guidance_hvf_path}"
)
# Load data
glue = scglue.models.load_model(model_path)
rna = ad.read_h5ad(rna_path)
atac = ad.read_h5ad(atac_path)
guidance_hvf = nx.read_graphml(guidance_hvf_path)
# Convert string markers back to None for scglue compatibility
for adata in [rna, atac]:
if "__scglue__" in adata.uns:
config = adata.uns["__scglue__"]
for key in [
"batches",
"use_batch",
"use_cell_type",
"cell_types",
"use_dsc_weight",
"use_layer",
]:
if key in config and config[key] == "__none__":
config[key] = None
# Compute integration consistency
dx = scglue.models.integration_consistency(
glue, {"rna": rna, "atac": atac}, guidance_hvf
)
# Save consistency scores
if out_prefix is None:
out_prefix = f"glue_consistency_{timestamp}"
consistency_table = OUTPUT_DIR / f"{out_prefix}_scores.csv"
dx.to_csv(str(consistency_table), index=False)
# Generate consistency plot
plt.figure(figsize=(4, 4))
ax = sns.lineplot(x="n_meta", y="consistency", data=dx)
ax.axhline(y=0.05, c="darkred", ls="--")
plt.xlabel("Number of metacells")
plt.ylabel("Consistency score")
plt.tight_layout()
consistency_plot = OUTPUT_DIR / f"{out_prefix}_plot.png"
plt.savefig(str(consistency_plot), dpi=300, bbox_inches="tight")
plt.close()
# Return standardized format
return {
"message": f"Integration consistency computed (range: {dx['consistency'].min():.3f}-{dx['consistency'].max():.3f})",
"reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb",
"artifacts": [
{
"description": "Consistency scores table",
"path": str(consistency_table.resolve()),
},
{
"description": "Consistency plot",
"path": str(consistency_plot.resolve()),
},
],
}
@training_mcp.tool
def glue_generate_embeddings(
# Primary data inputs
model_path: Annotated[
str | None, "Path to trained GLUE model file with extension .dill"
] = None,
rna_path: Annotated[
str | None, "Path to configured RNA-seq data file with extension .h5ad"
] = None,
atac_path: Annotated[
str | None, "Path to configured ATAC-seq data file with extension .h5ad"
] = None,
guidance_hvf_path: Annotated[
str | None,
"Path to HVF-filtered guidance graph file with extension .graphml.gz",
] = None,
# Visualization parameters with tutorial defaults
color_vars: Annotated[list, "Variables to color UMAP by"] = ["cell_type", "domain"],
out_prefix: Annotated[str | None, "Output file prefix"] = None,
) -> dict:
"""
Generate cell and feature embeddings from trained GLUE model and visualize alignment.
Input is trained model and RNA/ATAC data, output is h5ad files with embeddings and UMAP visualization.
"""
# Input file validation
if model_path is None:
raise ValueError("Path to trained GLUE model file must be provided")
if rna_path is None:
raise ValueError("Path to configured RNA-seq data file must be provided")
if atac_path is None:
raise ValueError("Path to configured ATAC-seq data file must be provided")
if guidance_hvf_path is None:
raise ValueError("Path to HVF-filtered guidance graph file must be provided")
# File existence validation
model_file = Path(model_path)
if not model_file.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
rna_file = Path(rna_path)
if not rna_file.exists():
raise FileNotFoundError(f"RNA-seq file not found: {rna_path}")
atac_file = Path(atac_path)
if not atac_file.exists():
raise FileNotFoundError(f"ATAC-seq file not found: {atac_path}")
guidance_hvf_file = Path(guidance_hvf_path)
if not guidance_hvf_file.exists():
raise FileNotFoundError(
f"Guidance HVF graph file not found: {guidance_hvf_path}"
)
# Load data
glue = scglue.models.load_model(model_path)
rna = ad.read_h5ad(rna_path)
atac = ad.read_h5ad(atac_path)
guidance_hvf = nx.read_graphml(guidance_hvf_path)
# Convert string markers back to None for scglue compatibility
for adata in [rna, atac]:
if "__scglue__" in adata.uns:
config = adata.uns["__scglue__"]
for key in [
"batches",
"use_batch",
"use_cell_type",
"cell_types",
"use_dsc_weight",
"use_layer",
]:
if key in config and config[key] == "__none__":
config[key] = None
# Generate cell embeddings
rna.obsm["X_glue"] = glue.encode_data("rna", rna)
atac.obsm["X_glue"] = glue.encode_data("atac", atac)
# Generate feature embeddings
feature_embeddings = glue.encode_graph(guidance_hvf)
feature_embeddings = pd.DataFrame(feature_embeddings, index=glue.vertices)
rna.varm["X_glue"] = feature_embeddings.reindex(rna.var_names).to_numpy()
atac.varm["X_glue"] = feature_embeddings.reindex(atac.var_names).to_numpy()
# Create combined dataset for visualization
combined = ad.concat([rna, atac])
# Generate UMAP visualization
sc.pp.neighbors(combined, use_rep="X_glue", metric="cosine")
sc.tl.umap(combined)
sc.pl.umap(combined, color=color_vars, wspace=0.65)
# Save UMAP plot
if out_prefix is None:
out_prefix = f"glue_embeddings_{timestamp}"
umap_plot = OUTPUT_DIR / f"{out_prefix}_umap.png"
plt.savefig(str(umap_plot), dpi=300, bbox_inches="tight")
plt.close()
# Save h5ad files with embeddings
rna_output = OUTPUT_DIR / f"{out_prefix}_rna_emb.h5ad"
atac_output = OUTPUT_DIR / f"{out_prefix}_atac_emb.h5ad"
guidance_hvf_output = OUTPUT_DIR / f"{out_prefix}_guidance_hvf.graphml.gz"
rna.write(str(rna_output), compression="gzip")
atac.write(str(atac_output), compression="gzip")
nx.write_graphml(guidance_hvf, str(guidance_hvf_output))
# Return standardized format
return {
"message": f"Generated embeddings for {rna.n_obs} RNA and {atac.n_obs} ATAC cells",
"reference": "https://github.com/gao-lab/GLUE/blob/master/docs/training.ipynb",
"artifacts": [
{
"description": "RNA data with embeddings",
"path": str(rna_output.resolve()),
},
{
"description": "ATAC data with embeddings",
"path": str(atac_output.resolve()),
},
{
"description": "HVF guidance graph",
"path": str(guidance_hvf_output.resolve()),
},
{"description": "UMAP visualization", "path": str(umap_plot.resolve())},
],
}