Spaces:
Running
Running
| """ | |
| TISSUE (Transcript Imputation with Spatial Single-cell Uncertainty Estimation) tutorial implementations. | |
| This MCP Server provides 6 tools: | |
| 1. predict_spatial_gene_expression: Predict spatial gene expression using paired spatial and scRNA-seq data | |
| 2. calibrate_uncertainties_and_prediction_intervals: Use TISSUE to calibrate uncertainties and obtain prediction intervals | |
| 3. multiple_imputation_hypothesis_testing: Hypothesis testing with TISSUE multiple imputation framework | |
| 4. tissue_cell_filtering_for_supervised_learning: TISSUE cell filtering for supervised learning applications | |
| 5. tissue_cell_filtering_for_pca: TISSUE cell filtering for PCA, clustering and visualization | |
| 6. tissue_weighted_pca: TISSUE-WPCA (weighted principal component analysis) | |
| All tools extracted from TISSUE/README.md. | |
| """ | |
| # Standard imports | |
| from typing import Annotated, Literal, Any | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| import os | |
| from fastmcp import FastMCP | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| import anndata as ad | |
| # Import TISSUE modules | |
| import tissue.main | |
| import tissue.downstream | |
| # scikit-learn imports | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import accuracy_score, roc_auc_score, adjusted_rand_score | |
| from sklearn.cluster import KMeans | |
| # Base persistent directory (HF Spaces guarantees /data is writable & persistent) | |
| BASE_DIR = Path("/data") | |
| DEFAULT_INPUT_DIR = BASE_DIR / "tmp_inputs" | |
| DEFAULT_OUTPUT_DIR = BASE_DIR / "tmp_outputs" | |
| INPUT_DIR = Path(os.environ.get("TISSUE_INPUT_DIR", DEFAULT_INPUT_DIR)) | |
| OUTPUT_DIR = Path(os.environ.get("TISSUE_OUTPUT_DIR", DEFAULT_OUTPUT_DIR)) | |
| # Ensure directories exist | |
| INPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Timestamp for unique outputs | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # MCP server instance | |
| tissue_mcp = FastMCP(name="tissue_readme") | |
| def predict_spatial_gene_expression( | |
| spatial_count_path: Annotated[str, "Path to spatial count matrix file (tab-delimited text format). The header should include gene names and rows should be cells."], | |
| locations_path: Annotated[str, "Path to spatial locations file (tab-delimited text format). Should contain x and y coordinates for each cell."], | |
| scrna_count_path: Annotated[str, "Path to scRNA-seq count matrix file (tab-delimited text format). The header should include gene names and rows should be cells."], | |
| target_gene: Annotated[str, "Target gene name to predict (must be present in both datasets)"] = "plp1", | |
| prediction_method: Annotated[Literal["spage", "tangram", "harmony"], "Method for spatial gene expression prediction"] = "spage", | |
| n_folds: Annotated[int, "Number of cross-validation folds for prediction"] = 10, | |
| n_pv: Annotated[int, "Number of principal components for SpaGE method"] = 10, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Predict spatial gene expression using paired spatial and scRNA-seq data with TISSUE. | |
| Input is spatial count matrix, locations, and scRNA-seq data and output is prediction visualization and results. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_prediction_{timestamp}" | |
| # Load paired datasets | |
| adata, RNAseq_adata = tissue.main.load_paired_datasets( | |
| spatial_count_path, locations_path, scrna_count_path | |
| ) | |
| # Preprocess data | |
| adata.var_names = [x.lower() for x in adata.var_names] | |
| RNAseq_adata.var_names = [x.lower() for x in RNAseq_adata.var_names] | |
| # Preprocess RNAseq data | |
| tissue.main.preprocess_data(RNAseq_adata, standardize=False, normalize=True) | |
| # Get shared genes | |
| gene_names = np.intersect1d(adata.var_names, RNAseq_adata.var_names) | |
| adata = adata[:, gene_names].copy() | |
| # Validate target gene exists | |
| target_gene_lower = target_gene.lower() | |
| if target_gene_lower not in adata.var_names: | |
| raise ValueError(f"Target gene '{target_gene}' not found in spatial data") | |
| # Hold out target gene for validation | |
| target_expn = adata[:, target_gene_lower].X.copy() | |
| adata = adata[:, [gene for gene in gene_names if gene != target_gene_lower]].copy() | |
| # Predict gene expression | |
| tissue.main.predict_gene_expression( | |
| adata, RNAseq_adata, [target_gene_lower], | |
| method=prediction_method, n_folds=n_folds, n_pv=n_pv | |
| ) | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
| # Plot actual expression | |
| ax1.axis('off') | |
| cmap_actual = target_expn.copy() | |
| cmap_actual[cmap_actual < 0] = 0 | |
| cmap_actual = np.log1p(cmap_actual) | |
| cmap_actual[cmap_actual > np.percentile(cmap_actual, 95)] = np.percentile(cmap_actual, 95) | |
| im1 = ax1.scatter(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1], | |
| s=1, c=cmap_actual, rasterized=True) | |
| ax1.set_title('Actual', fontsize=12) | |
| cbar1 = fig.colorbar(im1, ax=ax1) | |
| cbar1.ax.get_yaxis().labelpad = 15 | |
| cbar1.ax.set_ylabel('Log Expression', rotation=270) | |
| # Plot predicted expression | |
| ax2.axis('off') | |
| pred_key = f"{prediction_method}_predicted_expression" | |
| cmap_pred = adata.obsm[pred_key][target_gene_lower].values.copy() | |
| cmap_pred[cmap_pred < 0] = 0 | |
| cmap_pred = np.log1p(cmap_pred) | |
| cmap_pred[cmap_pred > np.percentile(cmap_pred, 95)] = np.percentile(cmap_pred, 95) | |
| im2 = ax2.scatter(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1], | |
| s=1, c=cmap_pred, rasterized=True) | |
| ax2.set_title('Predicted', fontsize=12) | |
| cbar2 = fig.colorbar(im2, ax=ax2) | |
| cbar2.ax.get_yaxis().labelpad = 15 | |
| cbar2.ax.set_ylabel('Log Expression', rotation=270) | |
| plt.suptitle(f"{prediction_method.upper()} Prediction", fontsize=16) | |
| plt.tight_layout() | |
| # Save figure | |
| fig_path = OUTPUT_DIR / f"{out_prefix}_spatial_prediction.png" | |
| plt.savefig(fig_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| # Save results | |
| results_df = pd.DataFrame({ | |
| 'cell_id': range(len(adata.obs)), | |
| 'x_coord': adata.obsm['spatial'][:, 0], | |
| 'y_coord': adata.obsm['spatial'][:, 1], | |
| 'actual_expression': target_expn.flatten(), | |
| 'predicted_expression': adata.obsm[pred_key][target_gene_lower].values | |
| }) | |
| results_path = OUTPUT_DIR / f"{out_prefix}_prediction_results.csv" | |
| results_df.to_csv(results_path, index=False) | |
| # Save processed AnnData for downstream use | |
| adata_path = OUTPUT_DIR / f"{out_prefix}_processed_adata.h5ad" | |
| adata.write_h5ad(adata_path) | |
| return { | |
| "message": f"Spatial gene expression prediction completed for {target_gene}", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "Spatial prediction visualization", | |
| "path": str(fig_path.resolve()) | |
| }, | |
| { | |
| "description": "Prediction results table", | |
| "path": str(results_path.resolve()) | |
| }, | |
| { | |
| "description": "Processed AnnData object", | |
| "path": str(adata_path.resolve()) | |
| } | |
| ] | |
| } | |
| def calibrate_uncertainties_and_prediction_intervals( | |
| adata_path: Annotated[str, "Path to processed AnnData file from predict_spatial_gene_expression"], | |
| target_gene: Annotated[str, "Target gene name for visualization"] = "plp1", | |
| prediction_method: Annotated[str, "Prediction method used (spage, tangram, harmony)"] = "spage", | |
| n_neighbors: Annotated[int, "Number of neighbors for spatial graph construction"] = 15, | |
| grouping_method: Annotated[Literal["kmeans_gene_cell", "kmeans_gene", "kmeans_cell"], "Method for stratified grouping"] = "kmeans_gene_cell", | |
| k: Annotated[int, "Number of gene groups for calibration"] = 4, | |
| k2: Annotated[int, "Number of cell groups for calibration"] = 2, | |
| alpha_level: Annotated[float, "Alpha level for prediction intervals (1-alpha coverage)"] = 0.23, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Use TISSUE to calibrate uncertainties and obtain prediction intervals for spatial predictions. | |
| Input is processed AnnData with predictions and output is uncertainty calibration and interval visualization. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_calibration_{timestamp}" | |
| # Load processed data | |
| adata = ad.read_h5ad(adata_path) | |
| target_gene_lower = target_gene.lower() | |
| # Build spatial graph | |
| tissue.main.build_spatial_graph(adata, method="fixed_radius", n_neighbors=n_neighbors) | |
| # Build calibration scores | |
| pred_key = f"{prediction_method}_predicted_expression" | |
| tissue.main.conformalize_spatial_uncertainty( | |
| adata, pred_key, calib_genes=adata.var_names, | |
| grouping_method=grouping_method, k=k, k2=k2 | |
| ) | |
| # Get prediction intervals | |
| tissue.main.conformalize_prediction_interval( | |
| adata, pred_key, calib_genes=adata.var_names, alpha_level=alpha_level | |
| ) | |
| # Create visualization for prediction intervals | |
| m = prediction_method | |
| # Get target gene data for validation if available | |
| target_expn = None | |
| if hasattr(adata, 'uns') and 'target_expression' in adata.uns: | |
| target_expn = adata.uns['target_expression'] | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
| if target_expn is not None: | |
| # Plot imputation error | |
| ax1.axis('off') | |
| cmap_error = np.abs(target_expn.flatten() - adata.obsm[f"{m}_predicted_expression"][target_gene_lower].values) | |
| cmap_error[cmap_error < 0] = 0 | |
| cmap_error = np.log1p(cmap_error) | |
| cmap_error[cmap_error > np.percentile(cmap_error, 95)] = np.percentile(cmap_error, 95) | |
| im1 = ax1.scatter(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1], | |
| s=1, c=cmap_error, rasterized=True) | |
| ax1.set_title(f'Imputation Error {target_gene_lower}', fontsize=12) | |
| else: | |
| # Plot predicted expression if no ground truth | |
| ax1.axis('off') | |
| cmap_pred = adata.obsm[f"{m}_predicted_expression"][target_gene_lower].values.copy() | |
| cmap_pred[cmap_pred < 0] = 0 | |
| cmap_pred = np.log1p(cmap_pred) | |
| im1 = ax1.scatter(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1], | |
| s=1, c=cmap_pred, rasterized=True) | |
| ax1.set_title(f'Predicted Expression {target_gene_lower}', fontsize=12) | |
| cbar1 = fig.colorbar(im1, ax=ax1) | |
| cbar1.ax.get_yaxis().labelpad = 15 | |
| cbar1.ax.set_ylabel('Log Expression', rotation=270) | |
| # Plot prediction interval width | |
| ax2.axis('off') | |
| pi_width = (adata.obsm[f"{m}_predicted_expression_hi"][target_gene_lower].values - | |
| adata.obsm[f"{m}_predicted_expression_lo"][target_gene_lower].values) | |
| pi_width[pi_width < 0] = 0 | |
| pi_width = np.log1p(pi_width) | |
| pi_width[pi_width > np.percentile(pi_width, 95)] = np.percentile(pi_width, 95) | |
| im2 = ax2.scatter(adata.obsm['spatial'][:, 0], adata.obsm['spatial'][:, 1], | |
| s=1, c=pi_width, rasterized=True) | |
| ax2.set_title(f'PI Width {target_gene_lower}', fontsize=12) | |
| cbar2 = fig.colorbar(im2, ax=ax2) | |
| cbar2.ax.get_yaxis().labelpad = 15 | |
| cbar2.ax.set_ylabel('Log Expression', rotation=270) | |
| plt.suptitle(m.upper(), fontsize=16) | |
| plt.tight_layout() | |
| # Save figure | |
| fig_path = OUTPUT_DIR / f"{out_prefix}_prediction_intervals.png" | |
| plt.savefig(fig_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| # Save calibrated data | |
| calibrated_path = OUTPUT_DIR / f"{out_prefix}_calibrated_adata.h5ad" | |
| adata.write_h5ad(calibrated_path) | |
| # Save prediction intervals data | |
| intervals_df = pd.DataFrame({ | |
| 'cell_id': range(len(adata.obs)), | |
| 'x_coord': adata.obsm['spatial'][:, 0], | |
| 'y_coord': adata.obsm['spatial'][:, 1], | |
| f'{target_gene_lower}_predicted': adata.obsm[f"{m}_predicted_expression"][target_gene_lower].values, | |
| f'{target_gene_lower}_pi_lower': adata.obsm[f"{m}_predicted_expression_lo"][target_gene_lower].values, | |
| f'{target_gene_lower}_pi_upper': adata.obsm[f"{m}_predicted_expression_hi"][target_gene_lower].values, | |
| f'{target_gene_lower}_pi_width': pi_width | |
| }) | |
| intervals_path = OUTPUT_DIR / f"{out_prefix}_prediction_intervals.csv" | |
| intervals_df.to_csv(intervals_path, index=False) | |
| return { | |
| "message": f"Uncertainty calibration and prediction intervals completed (α={alpha_level})", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "Prediction intervals visualization", | |
| "path": str(fig_path.resolve()) | |
| }, | |
| { | |
| "description": "Calibrated AnnData object", | |
| "path": str(calibrated_path.resolve()) | |
| }, | |
| { | |
| "description": "Prediction intervals data", | |
| "path": str(intervals_path.resolve()) | |
| } | |
| ] | |
| } | |
| def multiple_imputation_hypothesis_testing( | |
| adata_path: Annotated[str, "Path to calibrated AnnData file from calibrate_uncertainties_and_prediction_intervals"], | |
| prediction_method: Annotated[str, "Prediction method used (spage, tangram, harmony)"] = "spage", | |
| condition_key: Annotated[str, "Key in adata.obs for condition labels"] = "condition", | |
| group1: Annotated[str, "First group label for comparison"] = "A", | |
| group2: Annotated[str, "Second group label for comparison"] = "B", | |
| n_imputations: Annotated[int, "Number of multiple imputations to use"] = 10, | |
| test_method: Annotated[Literal["ttest", "spatialde", "wilcoxon_greater", "wilcoxon_less"], "Statistical test method"] = "ttest", | |
| target_gene: Annotated[str, "Target gene for reporting results"] = "plp1", | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Perform hypothesis testing with TISSUE multiple imputation framework for differential gene expression. | |
| Input is calibrated AnnData with conditions and output is statistical test results and condition visualization. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_hypothesis_test_{timestamp}" | |
| # Load calibrated data | |
| adata = ad.read_h5ad(adata_path) | |
| target_gene_lower = target_gene.lower() | |
| # Create condition labels if they don't exist | |
| if condition_key not in adata.obs.columns: | |
| # Split into two groups based on indices (as in tutorial) | |
| adata.obs[condition_key] = [group1 if i < round(adata.shape[0]/2) else group2 | |
| for i in range(adata.shape[0])] | |
| # Plot conditions | |
| plt.figure(figsize=(8, 6)) | |
| plt.scatter(adata[adata.obs[condition_key] == group1].obsm['spatial'][:, 0], | |
| adata[adata.obs[condition_key] == group1].obsm['spatial'][:, 1], | |
| c='tab:red', s=3, label=group1) | |
| plt.scatter(adata[adata.obs[condition_key] == group2].obsm['spatial'][:, 0], | |
| adata[adata.obs[condition_key] == group2].obsm['spatial'][:, 1], | |
| c='tab:blue', s=3, label=group2) | |
| plt.legend(loc='best') | |
| plt.title('Condition Groups for Hypothesis Testing') | |
| # Save condition plot | |
| condition_fig_path = OUTPUT_DIR / f"{out_prefix}_conditions.png" | |
| plt.savefig(condition_fig_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| # Perform multiple imputation hypothesis testing | |
| pred_key = f"{prediction_method}_predicted_expression" | |
| tissue.downstream.multiple_imputation_testing( | |
| adata, pred_key, | |
| calib_genes=adata.var_names, | |
| condition=condition_key, | |
| group1=group1, | |
| group2=group2, | |
| n_imputations=n_imputations, | |
| test=test_method | |
| ) | |
| # Extract results for all genes | |
| tstat_key = f"{prediction_method}_{group1}_{group2}_tstat" | |
| pvalue_key = f"{prediction_method}_{group1}_{group2}_pvalue" | |
| results_data = [] | |
| for gene in adata.var_names: | |
| if gene in adata.uns[tstat_key]: | |
| tstat = adata.uns[tstat_key][gene].values[0] | |
| pval = adata.uns[pvalue_key][gene].values[0] | |
| results_data.append({ | |
| 'gene': gene, | |
| 't_statistic': tstat, | |
| 'p_value': pval, | |
| 'significant_05': pval < 0.05, | |
| 'significant_01': pval < 0.01 | |
| }) | |
| results_df = pd.DataFrame(results_data) | |
| results_df = results_df.sort_values('p_value') | |
| # Save results | |
| results_path = OUTPUT_DIR / f"{out_prefix}_hypothesis_test_results.csv" | |
| results_df.to_csv(results_path, index=False) | |
| # Get target gene results | |
| target_results = results_df[results_df['gene'] == target_gene_lower] | |
| if not target_results.empty: | |
| target_tstat = target_results.iloc[0]['t_statistic'] | |
| target_pval = target_results.iloc[0]['p_value'] | |
| target_message = f"Target gene {target_gene}: t-stat={target_tstat:.5f}, p={target_pval:.5f}" | |
| else: | |
| target_message = f"Target gene {target_gene} not found in results" | |
| n_significant = (results_df['p_value'] < 0.05).sum() | |
| return { | |
| "message": f"Hypothesis testing completed: {n_significant} significant genes (p<0.05). {target_message}", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "Condition groups visualization", | |
| "path": str(condition_fig_path.resolve()) | |
| }, | |
| { | |
| "description": "Hypothesis test results", | |
| "path": str(results_path.resolve()) | |
| } | |
| ] | |
| } | |
| def tissue_cell_filtering_for_supervised_learning( | |
| adata_path: Annotated[str, "Path to calibrated AnnData file from calibrate_uncertainties_and_prediction_intervals"], | |
| prediction_method: Annotated[str, "Prediction method used (spage, tangram, harmony)"] = "spage", | |
| condition_key: Annotated[str, "Key in adata.obs for condition labels"] = "condition", | |
| group1: Annotated[str, "First group label"] = "A", | |
| group2: Annotated[str, "Second group label"] = "B", | |
| filter_proportion: Annotated[str | float, "Proportion of cells to filter ('otsu' for automatic or float 0-1)"] = "otsu", | |
| train_test_split: Annotated[float, "Proportion for training set"] = 0.8, | |
| random_seed: Annotated[int, "Random seed for reproducibility"] = 444, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Apply TISSUE cell filtering for supervised learning to improve classifier performance. | |
| Input is calibrated AnnData with conditions and output is filtering results and classifier performance metrics. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_supervised_learning_{timestamp}" | |
| # Load calibrated data | |
| adata = ad.read_h5ad(adata_path) | |
| # Create condition labels if they don't exist | |
| if condition_key not in adata.obs.columns: | |
| adata.obs[condition_key] = [group1 if i < round(adata.shape[0]/2) else group2 | |
| for i in range(adata.shape[0])] | |
| # Get uncertainty (PI width) for filtering | |
| pred_key = prediction_method | |
| pi_hi_key = f"{pred_key}_predicted_expression_hi" | |
| pi_lo_key = f"{pred_key}_predicted_expression_lo" | |
| X_uncertainty = adata.obsm[pi_hi_key].values - adata.obsm[pi_lo_key].values | |
| # Uncertainty-based cell filtering | |
| keep_idxs = tissue.downstream.detect_uncertain_cells( | |
| X_uncertainty, | |
| proportion=filter_proportion, | |
| stratification=adata.obs[condition_key].values | |
| ) | |
| adata_filtered = adata[adata.obs_names[keep_idxs], :].copy() | |
| # Print filtering stats | |
| print(f"Before TISSUE cell filtering: {adata.shape}") | |
| print(f"After TISSUE cell filtering: {adata_filtered.shape}") | |
| # Check label balance | |
| balance_df = pd.DataFrame( | |
| np.unique(adata_filtered.obs[condition_key], return_counts=True), | |
| index=["Group", "Number of Cells"] | |
| ) | |
| print(f"Label balance after filtering:\n{balance_df}") | |
| # Split train and test randomly | |
| np.random.seed(random_seed) | |
| n_cells = adata_filtered.shape[0] | |
| train_size = round(n_cells * train_test_split) | |
| train_idxs = np.random.choice(np.arange(n_cells), train_size, replace=False) | |
| test_idxs = np.array([idx for idx in np.arange(n_cells) if idx not in train_idxs]) | |
| pred_expression_key = f"{pred_key}_predicted_expression" | |
| train_data = adata_filtered.obsm[pred_expression_key].values[train_idxs, :] | |
| train_labels = adata_filtered.obs[condition_key].iloc[train_idxs] | |
| test_data = adata_filtered.obsm[pred_expression_key].values[test_idxs, :] | |
| test_labels = adata_filtered.obs[condition_key].iloc[test_idxs] | |
| # Scale data and train model | |
| scaler = StandardScaler() | |
| train_data_scaled = scaler.fit_transform(train_data) | |
| # Fit logistic regression model | |
| model = LogisticRegression(penalty='l1', solver='liblinear', random_state=random_seed) | |
| model.fit(train_data_scaled, train_labels) | |
| # Make predictions on test data | |
| test_data_scaled = scaler.transform(test_data) | |
| pred_test = model.predict(test_data_scaled) | |
| pred_test_proba = model.predict_proba(test_data_scaled) | |
| # Calculate metrics | |
| test_labels_num = [0 if x == group1 else 1 for x in test_labels] | |
| accuracy = accuracy_score(test_labels, pred_test) | |
| roc_auc = roc_auc_score(test_labels_num, pred_test_proba[:, 1]) | |
| # Save results | |
| results_df = pd.DataFrame({ | |
| 'metric': ['cells_before_filtering', 'cells_after_filtering', 'cells_filtered_out', | |
| 'train_size', 'test_size', 'accuracy_score', 'roc_auc_score'], | |
| 'value': [adata.shape[0], adata_filtered.shape[0], adata.shape[0] - adata_filtered.shape[0], | |
| len(train_idxs), len(test_idxs), accuracy, roc_auc] | |
| }) | |
| results_path = OUTPUT_DIR / f"{out_prefix}_supervised_learning_results.csv" | |
| results_df.to_csv(results_path, index=False) | |
| # Save filtered data | |
| filtered_path = OUTPUT_DIR / f"{out_prefix}_filtered_adata.h5ad" | |
| adata_filtered.write_h5ad(filtered_path) | |
| # Save model predictions | |
| predictions_df = pd.DataFrame({ | |
| 'cell_id': test_idxs, | |
| 'true_label': test_labels, | |
| 'predicted_label': pred_test, | |
| 'prediction_probability': pred_test_proba[:, 1] | |
| }) | |
| predictions_path = OUTPUT_DIR / f"{out_prefix}_test_predictions.csv" | |
| predictions_df.to_csv(predictions_path, index=False) | |
| return { | |
| "message": f"TISSUE cell filtering for supervised learning completed. Accuracy: {accuracy:.3f}, ROC-AUC: {roc_auc:.3f}", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "Supervised learning results", | |
| "path": str(results_path.resolve()) | |
| }, | |
| { | |
| "description": "Filtered AnnData object", | |
| "path": str(filtered_path.resolve()) | |
| }, | |
| { | |
| "description": "Test set predictions", | |
| "path": str(predictions_path.resolve()) | |
| } | |
| ] | |
| } | |
| def tissue_cell_filtering_for_pca( | |
| adata_path: Annotated[str, "Path to calibrated AnnData file from calibrate_uncertainties_and_prediction_intervals"], | |
| prediction_method: Annotated[str, "Prediction method used (spage, tangram, harmony)"] = "spage", | |
| condition_key: Annotated[str, "Key in adata.obs for condition labels"] = "condition", | |
| group1: Annotated[str, "First group label"] = "A", | |
| group2: Annotated[str, "Second group label"] = "B", | |
| filter_proportion: Annotated[str | float, "Proportion of cells to filter ('otsu' for automatic or float 0-1)"] = "otsu", | |
| n_components: Annotated[int, "Number of principal components"] = 15, | |
| n_clusters: Annotated[int, "Number of clusters for K-means"] = 2, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Apply TISSUE cell filtering for PCA-based clustering and visualization tasks. | |
| Input is calibrated AnnData with conditions and output is PCA visualization and clustering results. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_pca_{timestamp}" | |
| # Load calibrated data | |
| adata = ad.read_h5ad(adata_path) | |
| # Create condition labels if they don't exist | |
| if condition_key not in adata.obs.columns: | |
| adata.obs[condition_key] = [group1 if i < round(adata.shape[0]/2) else group2 | |
| for i in range(adata.shape[0])] | |
| # Apply TISSUE-filtered PCA | |
| keep_idxs = tissue.downstream.filtered_PCA( | |
| adata, | |
| prediction_method, | |
| proportion=filter_proportion, | |
| stratification=adata.obs[condition_key].values, | |
| n_components=n_components, | |
| return_keep_idxs=True | |
| ) | |
| # Filter to keep track of labels | |
| adata_filtered = adata[adata.obs_names[keep_idxs], :].copy() | |
| # Retrieve filtered PCA | |
| pc_key = f"{prediction_method}_predicted_expression_PC{n_components}_filtered_" | |
| PC_reduced = adata.uns[pc_key].copy() | |
| print(f"PCA reduced data shape: {PC_reduced.shape}") | |
| # Make 2D PCA plot | |
| plt.figure(figsize=(10, 8)) | |
| plt.title("TISSUE-Filtered PCA") | |
| group1_mask = adata_filtered.obs[condition_key] == group1 | |
| group2_mask = adata_filtered.obs[condition_key] == group2 | |
| plt.scatter(PC_reduced[group1_mask, 0], PC_reduced[group1_mask, 1], | |
| c="tab:red", s=3, label=group1, alpha=0.7) | |
| plt.scatter(PC_reduced[group2_mask, 0], PC_reduced[group2_mask, 1], | |
| c="tab:blue", s=3, label=group2, alpha=0.7) | |
| plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
| plt.xlabel("PC 1") | |
| plt.ylabel("PC 2") | |
| # Save PCA plot | |
| pca_fig_path = OUTPUT_DIR / f"{out_prefix}_filtered_pca.png" | |
| plt.savefig(pca_fig_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| # Perform K-means clustering on all principal components | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42) | |
| clusters = kmeans.fit_predict(PC_reduced) | |
| # Evaluate clustering with ARI | |
| ari_score = adjusted_rand_score(adata_filtered.obs[condition_key], clusters) | |
| print(f"Adjusted Rand Index: {ari_score}") | |
| # Save PCA results | |
| pca_results_df = pd.DataFrame(PC_reduced, columns=[f'PC{i+1}' for i in range(n_components)]) | |
| pca_results_df['cell_id'] = adata_filtered.obs_names | |
| pca_results_df['condition'] = adata_filtered.obs[condition_key].values | |
| pca_results_df['kmeans_cluster'] = clusters | |
| pca_results_path = OUTPUT_DIR / f"{out_prefix}_pca_results.csv" | |
| pca_results_df.to_csv(pca_results_path, index=False) | |
| # Save clustering metrics | |
| clustering_metrics_df = pd.DataFrame({ | |
| 'metric': ['n_cells_before_filtering', 'n_cells_after_filtering', 'n_components', | |
| 'n_clusters', 'adjusted_rand_index'], | |
| 'value': [adata.shape[0], adata_filtered.shape[0], n_components, n_clusters, ari_score] | |
| }) | |
| metrics_path = OUTPUT_DIR / f"{out_prefix}_clustering_metrics.csv" | |
| clustering_metrics_df.to_csv(metrics_path, index=False) | |
| # Save filtered AnnData with PCA results | |
| adata_filtered.obsm['X_pca_tissue_filtered'] = PC_reduced | |
| adata_filtered.obs['kmeans_cluster'] = clusters | |
| filtered_path = OUTPUT_DIR / f"{out_prefix}_pca_filtered_adata.h5ad" | |
| adata_filtered.write_h5ad(filtered_path) | |
| return { | |
| "message": f"TISSUE-filtered PCA completed. ARI score: {ari_score:.3f} with {n_clusters} clusters", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "TISSUE-filtered PCA visualization", | |
| "path": str(pca_fig_path.resolve()) | |
| }, | |
| { | |
| "description": "PCA results with clustering", | |
| "path": str(pca_results_path.resolve()) | |
| }, | |
| { | |
| "description": "Clustering performance metrics", | |
| "path": str(metrics_path.resolve()) | |
| }, | |
| { | |
| "description": "PCA-filtered AnnData object", | |
| "path": str(filtered_path.resolve()) | |
| } | |
| ] | |
| } | |
| def tissue_weighted_pca( | |
| adata_path: Annotated[str, "Path to calibrated AnnData file from calibrate_uncertainties_and_prediction_intervals"], | |
| prediction_method: Annotated[str, "Prediction method used (spage, tangram, harmony)"] = "spage", | |
| condition_key: Annotated[str, "Key in adata.obs for condition labels"] = "condition", | |
| group1: Annotated[str, "First group label"] = "A", | |
| group2: Annotated[str, "Second group label"] = "B", | |
| pca_method: Annotated[Literal["wpca", "standard"], "PCA method to use"] = "wpca", | |
| weighting: Annotated[Literal["inverse_pi_width", "uniform"], "Weighting scheme for WPCA"] = "inverse_pi_width", | |
| replace_inf: Annotated[Literal["max", "zero"], "How to handle infinite weights"] = "max", | |
| binarize: Annotated[float, "Proportion for weight binarization"] = 0.2, | |
| binarize_ratio: Annotated[float, "Ratio between high and low weights"] = 10, | |
| n_components: Annotated[int, "Number of principal components"] = 15, | |
| out_prefix: Annotated[str | None, "Output file prefix"] = None, | |
| ) -> dict: | |
| """ | |
| Perform TISSUE-WPCA (weighted principal component analysis) using uncertainty-based weights. | |
| Input is calibrated AnnData with conditions and output is weighted PCA visualization and results. | |
| """ | |
| # Set output prefix | |
| if out_prefix is None: | |
| out_prefix = f"tissue_wpca_{timestamp}" | |
| # Load calibrated data | |
| adata = ad.read_h5ad(adata_path) | |
| # Create condition labels if they don't exist | |
| if condition_key not in adata.obs.columns: | |
| adata.obs[condition_key] = [group1 if i < round(adata.shape[0]/2) else group2 | |
| for i in range(adata.shape[0])] | |
| # Perform weighted PCA | |
| tissue.downstream.weighted_PCA( | |
| adata, prediction_method, | |
| pca_method=pca_method, | |
| weighting=weighting, | |
| replace_inf=replace_inf, | |
| binarize=binarize, | |
| binarize_ratio=binarize_ratio, | |
| n_components=n_components | |
| ) | |
| # Get weighted PCA results | |
| wpca_key = f"{prediction_method}_predicted_expression_PC{n_components}_" | |
| X_pc = adata.obsm[wpca_key] | |
| # Make PC plot | |
| plt.figure(figsize=(10, 8)) | |
| plt.title("TISSUE Weighted PCA") | |
| group1_mask = adata.obs[condition_key] == group1 | |
| group2_mask = adata.obs[condition_key] == group2 | |
| plt.scatter(X_pc[group1_mask, 0], X_pc[group1_mask, 1], | |
| c="tab:red", s=3, label=group1, alpha=0.7) | |
| plt.scatter(X_pc[group2_mask, 0], X_pc[group2_mask, 1], | |
| c="tab:blue", s=3, label=group2, alpha=0.7) | |
| plt.xlabel("PC 1") | |
| plt.ylabel("PC 2") | |
| plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
| # Save WPCA plot | |
| wpca_fig_path = OUTPUT_DIR / f"{out_prefix}_weighted_pca.png" | |
| plt.savefig(wpca_fig_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| # Save WPCA results | |
| wpca_results_df = pd.DataFrame(X_pc, columns=[f'WPC{i+1}' for i in range(n_components)]) | |
| wpca_results_df['cell_id'] = adata.obs_names | |
| wpca_results_df['condition'] = adata.obs[condition_key].values | |
| wpca_results_path = OUTPUT_DIR / f"{out_prefix}_wpca_results.csv" | |
| wpca_results_df.to_csv(wpca_results_path, index=False) | |
| # Save WPCA parameters | |
| params_df = pd.DataFrame({ | |
| 'parameter': ['pca_method', 'weighting', 'replace_inf', 'binarize', | |
| 'binarize_ratio', 'n_components'], | |
| 'value': [pca_method, weighting, replace_inf, binarize, binarize_ratio, n_components] | |
| }) | |
| params_path = OUTPUT_DIR / f"{out_prefix}_wpca_parameters.csv" | |
| params_df.to_csv(params_path, index=False) | |
| # Save AnnData with WPCA results | |
| adata.obsm['X_wpca_tissue'] = X_pc | |
| wpca_adata_path = OUTPUT_DIR / f"{out_prefix}_wpca_adata.h5ad" | |
| adata.write_h5ad(wpca_adata_path) | |
| return { | |
| "message": f"TISSUE weighted PCA completed with {weighting} weighting and {n_components} components", | |
| "reference": "https://github.com/sunericd/TISSUE/README.md", | |
| "artifacts": [ | |
| { | |
| "description": "TISSUE weighted PCA visualization", | |
| "path": str(wpca_fig_path.resolve()) | |
| }, | |
| { | |
| "description": "Weighted PCA results", | |
| "path": str(wpca_results_path.resolve()) | |
| }, | |
| { | |
| "description": "WPCA parameters used", | |
| "path": str(params_path.resolve()) | |
| }, | |
| { | |
| "description": "WPCA AnnData object", | |
| "path": str(wpca_adata_path.resolve()) | |
| } | |
| ] | |
| } |