spMetaTME-Atlas / src /backend /flux_utils.py
Surajv's picture
initial commit
31d5c57
import numpy as np
import pandas as pd
import logging
from typing import Optional, Dict, List
logger = logging.getLogger(__name__)
def aggregate_flux_by_pathway(adata, pathway_col: str = "subsystems", aggregation: str = "mean") -> pd.DataFrame:
"""Aggregate reaction fluxes by metabolic pathway."""
if pathway_col not in adata.var.columns:
return pd.DataFrame()
pathways = adata.var[pathway_col].unique()
pathway_fluxes = []
for pathway in pathways:
if pd.isna(pathway): continue
mask = adata.var[pathway_col] == pathway
pathway_flux = adata.X[:, mask]
if aggregation == "mean":
aggregated = np.mean(pathway_flux, axis=1)
elif aggregation == "sum":
aggregated = np.sum(pathway_flux, axis=1)
else:
aggregated = np.mean(pathway_flux, axis=1)
pathway_fluxes.append(aggregated)
result = pd.DataFrame(
np.array(pathway_fluxes).T,
index=adata.obs_names,
columns=[p for p in pathways if pd.notna(p)]
)
return result
def compute_flux_statistics(adata, groupby: Optional[str] = None) -> Dict:
"""Compute basic flux statistics."""
flux_data = adata.X
stats = {
'mean': np.asarray(flux_data.mean(axis=0)).flatten(),
'std': np.asarray(flux_data.std(axis=0)).flatten(),
'variance': np.asarray(flux_data.var(axis=0)).flatten()
}
if groupby and groupby in adata.obs.columns:
groups = adata.obs[groupby].unique()
group_stats = {}
for group in groups:
mask = adata.obs[groupby] == group
group_stats[group] = {
'mean': np.asarray(flux_data[mask].mean(axis=0)).flatten(),
'count': int(mask.sum())
}
stats['groups'] = group_stats
return stats