ayyuce's picture
Update app.py
aa88038 verified
Raw
History Blame Contribute Delete
14.7 kB
import os
import urllib.request
import tarfile
import shutil
import pandas as pd
from flask import Flask, render_template, request, jsonify, url_for
from werkzeug.utils import secure_filename
import matplotlib
matplotlib.use('Agg')
import scAnalysis.sc_io as io
import scAnalysis.preprocessing as pp
import scAnalysis.quality_control as qc
import scAnalysis.cell_cycle as cc
import scAnalysis.batch_correction as bc
import scAnalysis.dimensionality as dim
import scAnalysis.clustering as cl
import scAnalysis.trajectory as traj
import scAnalysis.differential as diff
import scAnalysis.enrichment as enrich
import scAnalysis.visualization as vis
import scAnalysis.interactive_viz as iviz
import scAnalysis.imputation as imp
import scAnalysis.grn_inference as grn
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = './static/uploads'
app.config['RESULTS_FOLDER'] = './static/results'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
DATASETS = {
"pbmc3k": "https://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/pbmc3k_filtered_gene_bc_matrices.tar.gz",
"pbmc5k": "https://cf.10xgenomics.com/samples/cell-exp/3.0.2/5k_pbmc_v3/5k_pbmc_v3_filtered_feature_bc_matrix.tar.gz",
"heart_atlas": "https://cf.10xgenomics.com/samples/cell-exp/3.0.0/heart_10k_v3/heart_10k_v3_filtered_feature_bc_matrix.tar.gz",
"mouse_brain": "https://cf.10xgenomics.com/samples/cell-exp/3.0.0/neuron_10k_v3/neuron_10k_v3_filtered_feature_bc_matrix.tar.gz",
"lung_tumor": "https://cf.10xgenomics.com/samples/cell-exp/3.0.0/nsclc_10k_v3/nsclc_10k_v3_filtered_feature_bc_matrix.tar.gz"
}
def download_and_extract(dataset_name):
url = DATASETS[dataset_name]
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{dataset_name}.tar.gz")
extract_path = os.path.join(app.config['UPLOAD_FOLDER'], dataset_name)
if not os.path.exists(extract_path):
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urllib.request.urlopen(req) as response, open(filepath, "wb") as out_file:
shutil.copyfileobj(response, out_file)
with tarfile.open(filepath, "r:gz") as tar:
tar.extractall(path=extract_path)
for root, dirs, files in os.walk(extract_path):
if 'matrix.mtx' in files or 'matrix.mtx.gz' in files:
return root
return extract_path
@app.route('/')
def index():
return render_template('index.html', datasets=DATASETS.keys())
@app.route('/run_pipeline', methods=['POST'])
def run_pipeline():
try:
dataset_choice = request.form.get('dataset')
if dataset_choice == 'custom':
file = request.files['custom_file']
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
data = io.read_h5ad(file_path) if filename.endswith('.h5ad') else io.read_csv(file_path)
else:
data_path = download_and_extract(dataset_choice)
data = io.read_10x_mtx(data_path)
data.var.index = io._make_unique(data.var.index.values)
res_dir = app.config['RESULTS_FOLDER']
outputs = []
pp.calculate_qc_metrics(data, qc_vars=["MT-"])
try:
qc_path = os.path.join(res_dir, 'qc_violin.png')
vis.plot_qc_violin(data, save=qc_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/qc_violin.png'), 'title': 'QC Metrics'})
except Exception as e: print("QC Plot failed:", e)
if request.form.get('run_scrublet') == 'true':
db_rate = float(request.form.get('doublet_rate', 0.06))
qc.scrublet(data, expected_doublet_rate=db_rate)
data = data[~data.obs['predicted_doublet'].astype(bool), :]
data = pp.filter_cells(
data,
min_genes=int(request.form.get('min_genes', 200)),
max_genes=int(request.form.get('max_genes', 2500)),
max_pct_mito=float(request.form.get('max_mito', 5.0))
)
data = pp.filter_genes(data, min_cells=int(request.form.get('min_cells', 3)))
norm_method = request.form.get('norm_method', 'total')
if norm_method == 'total':
pp.normalize_total(data, target_sum=1e4)
pp.log1p(data)
elif norm_method == 'scran':
pp.normalize_scran_pooling(data, target_sum=1e4)
pp.log1p(data)
elif norm_method == 'sctransform':
pp.normalize_sctransform(data)
# <-- EXPANDED IMPUTATION EXECUTION BLOCK -->
if request.form.get('run_imputation') == 'true':
imp_method = request.form.get('imp_method', 'wnid')
imp_pcs = int(request.form.get('imp_pcs', 30))
imp_k = int(request.form.get('imp_k', 7))
if imp_method == 'wnid':
imp_thresh = float(request.form.get('imp_thresh', 0.72))
imp.impute_wnid(data, k=imp_k, dropout_thresh=imp_thresh, n_pcs=imp_pcs)
elif imp_method == 'knn':
imp.impute_knn_smooth(data, k=imp_k, n_pcs=imp_pcs)
elif imp_method == 'diffusion':
imp.impute_diffusion(data, t=3, n_pcs=imp_pcs, use_prebuilt_graph=False)
organism = request.form.get('organism', 'human')
cc.score_cell_cycle(data, organism=organism)
n_hvg = int(request.form.get('n_hvg', 2000))
pp.highly_variable_genes(data, n_top_genes=n_hvg)
try:
hvg_path = os.path.join(res_dir, 'highest_expr_genes.png')
vis.plot_highest_expr_genes(data, save=hvg_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/highest_expr_genes.png'), 'title': 'Highest Expressed'})
except: pass
data.raw = data.copy()
pp.scale(data, max_value=10)
n_pcs = int(request.form.get('n_pcs', 50))
n_neighbors = int(request.form.get('n_neighbors', 15))
dim.run_pca(data, n_components=n_pcs)
dim.neighbors(data, n_neighbors=n_neighbors, n_pcs=min(40, n_pcs))
# <-- BATCH CORRECTION -->
batch_key = request.form.get('batch_key', '')
if request.form.get('run_batch') == 'true' and batch_key in data.obs.columns:
b_algo = request.form.get('batch_algo', 'harmony')
if b_algo == 'harmony':
bc.harmony_integrate(data, batch_key=batch_key)
elif b_algo == 'combat':
bc.combat(data, batch_key=batch_key)
elif b_algo == 'mnn':
batches = data.obs[batch_key].unique()
dataset_list = [data[data.obs[batch_key] == b].copy() for b in batches]
data = bc.mnn_correct(dataset_list, batch_key=batch_key)
# <-- GRN INFERENCE -->
if request.form.get('run_grn') == 'true':
tf_input = request.form.get('tf_list', '')
tf_list = [tf.strip() for tf in tf_input.split(',') if tf.strip()]
if tf_list:
try:
df_grn = grn.infer_grn_ridge(data, tf_list=tf_list, top_n_edges=5000)
grn_path = os.path.join(res_dir, 'grn_edges.csv')
df_grn.to_csv(grn_path, index=False)
outputs.append({'type': 'file', 'url': url_for('static', filename='results/grn_edges.csv'), 'title': 'Download GRN Edges', 'icon': 'fa-project-diagram'})
except Exception as e:
print("GRN Inference failed:", e)
# <-- DIMENSIONALITY REDUCTION -->
if request.form.get('run_umap') == 'true':
dim.run_umap(data, min_dist=float(request.form.get('umap_min_dist', 0.5)))
if request.form.get('run_tsne') == 'true':
dim.run_tsne(data, perplexity=float(request.form.get('tsne_perplex', 30.0)))
if request.form.get('run_phate') == 'true':
try: dim.run_phate(data)
except: pass
# <-- CLUSTERING & TRAJECTORY -->
clust_algo = request.form.get('clustering', 'leiden')
res_k = float(request.form.get('resolution', 1.0))
if clust_algo == 'leiden': cl.cluster_leiden(data, resolution=res_k, key_added="cluster")
elif clust_algo == 'louvain': cl.cluster_louvain(data, resolution=res_k, key_added="cluster")
elif clust_algo == 'kmeans': cl.cluster_kmeans(data, n_clusters=int(res_k*10), key_added="cluster")
elif clust_algo == 'hierarchical': cl.cluster_hierarchical(data, n_clusters=int(res_k*10), key_added="cluster")
elif clust_algo == 'spectral': cl.cluster_spectral(data, n_clusters=int(res_k*10), key_added="cluster")
dim.run_diffmap(data)
first_cluster = data.obs['cluster'].unique()[0]
root_strat = request.form.get('root_strategy', 'extreme')
root_idx = traj.select_root_cell(data, cluster_key='cluster', root_cluster=first_cluster, strategy=root_strat)
traj.diffusion_pseudotime(data, root_cell=root_idx, n_branchings=int(request.form.get('branches', 0)))
diff_method = request.form.get('diff_method', 't-test')
diff.rank_genes_groups(data, groupby='cluster', method=diff_method, use_raw=True)
# <-- VISUALIZATION -->
if 'X_umap' in data.obsm:
umap_path = os.path.join(res_dir, 'umap_clusters.png')
vis.plot_umap(data, color="cluster", title="UMAP (Clusters)", save=umap_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/umap_clusters.png'), 'title': 'UMAP Clustering'})
phase_path = os.path.join(res_dir, 'umap_phase.png')
vis.plot_umap(data, color="phase", title="Cell Cycle Phase", save=phase_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/umap_phase.png'), 'title': 'UMAP Cell Cycle'})
if 'X_tsne' in data.obsm:
tsne_path = os.path.join(res_dir, 'tsne_clusters.png')
vis.plot_tsne(data, color="cluster", title="t-SNE", save=tsne_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/tsne_clusters.png'), 'title': 't-SNE Clustering'})
try:
volcano_path = os.path.join(res_dir, 'volcano_plot.png')
vis.volcano_plot(data, group=first_cluster, save=volcano_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/volcano_plot.png'), 'title': f'Volcano (Cluster {first_cluster})'})
except: pass
canonical_markers = ["CD3D", "CD14", "CD19", "MS4A1", "GNLY", "LYZ", "FCER1A", "CST3", "CD8A"]
valid_markers = [g for g in canonical_markers if g in data.var.index]
if valid_markers:
try:
dot_path = os.path.join(res_dir, 'dotplot.png')
vis.plot_dotplot(data, var_names=valid_markers, groupby='cluster', save=dot_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/dotplot.png'), 'title': 'Marker Dotplot'})
heat_path = os.path.join(res_dir, 'heatmap.png')
vis.plot_heatmap(data, var_names=valid_markers, groupby='cluster', save=heat_path)
outputs.append({'type': 'image', 'url': url_for('static', filename='results/heatmap.png'), 'title': 'Marker Heatmap'})
except: pass
hover_data_cols = ['phase', 'total_counts', 'dpt_pseudotime']
if request.form.get('int_umap') == 'true' and 'X_umap' in data.obsm:
try:
html_umap = os.path.join(res_dir, 'interactive_umap.html')
iviz.interactive_embedding(data, basis='X_umap', color='cluster', hover_data=hover_data_cols, title="Interactive UMAP", save_html=html_umap)
outputs.append({'type': 'html', 'url': url_for('static', filename='results/interactive_umap.html'), 'title': 'Interactive UMAP'})
except Exception as e: print("Interactive UMAP failed:", e)
if request.form.get('int_3d_pca') == 'true' and 'X_pca' in data.obsm:
try:
html_pca = os.path.join(res_dir, 'interactive_3d_pca.html')
iviz.interactive_3d_embedding(data, basis='X_pca', color='cluster', dimensions=[0, 1, 2], save_html=html_pca)
outputs.append({'type': 'html', 'url': url_for('static', filename='results/interactive_3d_pca.html'), 'title': '3D PCA'})
except Exception as e: print("Interactive PCA failed:", e)
if request.form.get('int_tsne') == 'true' and 'X_tsne' in data.obsm:
try:
html_tsne = os.path.join(res_dir, 'interactive_tsne.html')
iviz.interactive_embedding(data, basis='X_tsne', color='cluster', hover_data=hover_data_cols, title="Interactive t-SNE", save_html=html_tsne)
outputs.append({'type': 'html', 'url': url_for('static', filename='results/interactive_tsne.html'), 'title': 'Interactive t-SNE'})
except Exception as e: print("Interactive t-SNE failed:", e)
if request.form.get('int_violin') == 'true':
try:
html_violin = os.path.join(res_dir, 'interactive_violin.html')
iviz.interactive_violin(data, keys=['n_genes_by_counts', 'total_counts'], groupby='cluster', save_html=html_violin)
outputs.append({'type': 'html', 'url': url_for('static', filename='results/interactive_violin.html'), 'title': 'Interactive QC Violins'})
except Exception as e: print("Interactive Violin failed:", e)
if request.form.get('int_heatmap') == 'true':
try:
if valid_markers:
html_heat = os.path.join(res_dir, 'interactive_heatmap.html')
iviz.interactive_heatmap(data, var_names=valid_markers, groupby='cluster', save_html=html_heat)
outputs.append({'type': 'html', 'url': url_for('static', filename='results/interactive_heatmap.html'), 'title': 'Interactive Heatmap'})
except Exception as e: print("Interactive Heatmap failed:", e)
output_h5ad = os.path.join(res_dir, 'processed_data.h5ad')
io.write_h5ad(data, output_h5ad)
return jsonify({"status": "success", "outputs": outputs, "download": url_for('static', filename='results/processed_data.h5ad')})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({"status": "error", "message": str(e)})
if __name__ == '__main__':
app.run(debug=True, port=5000)