maom's picture
Update app.py
e3c2f3a verified
import numpy as np
import pandas as pd
import datasets
import streamlit as st
from streamlit_cytoscapejs import st_cytoscapejs
import networkx as nx
st.set_page_config(layout='wide')
# parse out gene_ids from URL query args to it's possible to link to this page
query_params = st.query_params
if "gene_ids" in query_params.keys():
input_gene_ids = query_params["gene_ids"]
else:
input_gene_ids = "B9J08_000884,B9J08_004112"
# use "\n" as the separator so it shows correctly in the text area
input_gene_ids = input_gene_ids.replace(",", "\n")
if "coexp_score_threshold" in query_params.keys():
coexp_score_threshold = query_params["coexp_score_threshold"]
else:
coexp_score_threshold = "0.85"
if "max_per_gene" in query_params.keys():
max_per_gene = query_params["max_per_gene"]
else:
max_per_gene = "25"
st.markdown("""
# CaurisCEN Network
**CaurisCEN** is a co-expression network for *Candida auris* built on 577 RNA-seq runs across 2 96-well plates formats in 3 biological replicas.
A pair of genes are said to be co-expressed when their expression is correlated across different conditions and
is often a marker for genes to be involved in similar processes.
To Cite:
Rapala JR, MJ O'Meara, TR O'Meara
CaurisCEN: A Co-Expression Network for Candida auris
* Code available at https://github.com/maomlab/CalCEN/tree/master/vignettes/CaurisCEN
* Full network and dataset: https://huggingface.co/datasets/maomlab/CaurisCEN
## Plot a network for a set of genes
Put a ``B9J08_######`` gene_id, one each row to seed the network
""")
gene_metadata = datasets.load_dataset(
path = "maomlab/CaurisCEN",
name = "gene_metadata",
data_dir = "gene_metadata/data")['train'].to_pandas()
top_coexp_hits = datasets.load_dataset(
path = "maomlab/CaurisCEN",
name = "top_coexp_hits_general",
data_dir = "top_coexp_hits_general/data")['train'].to_pandas()
col1, col2, col3, padding = st.columns(spec = [0.2, 0.2, 0.2, 0.4])
with col1:
input_gene_ids = st.text_area(
label = "Gene IDs",
value = f"{input_gene_ids}",
height = 130,
help = "B9J08 Gene IDs e.g. B9J08_000884")
with col2:
coexp_score_threshold = st.text_input(
label = "Co-expression threshold [0-1]",
value = f"{coexp_score_threshold}",
help = "Default: 0.85")
try:
coexp_score_threshold = float(coexp_score_threshold)
except:
st.error(f"Co-expression threshold should be a number between 0 and 1, instead it is '{coexp_score_threshold}'")
if coexp_score_threshold < 0 or 1 < coexp_score_threshold:
st.error(f"Co-expression threshold should be a number between 0 and 1, instead it is '{coexp_score_threshold}'")
max_per_gene = st.text_input(
label = "Max per gene",
value = f"{max_per_gene}",
help = "Default: 25")
try:
max_per_gene = int(max_per_gene)
except:
st.error(f"Max per gene should be a number greater than 0, instead it is '{max_per_gene}'")
if max_per_gene <= 0:
st.error(f"Max per gene should be a number greater than 0, instead it is '{max_per_gene}'")
##################################
# Parse and check the user input #
##################################
seed_gene_ids = []
for input_gene_id in input_gene_ids.split("\n"):
gene_id = input_gene_id.strip()
if gene_id == "":
continue
else:
seed_gene_ids.append(gene_id)
neighbors = []
for seed_gene_id in seed_gene_ids:
hits = top_coexp_hits[
(top_coexp_hits.feature_name_1 == seed_gene_id) & (top_coexp_hits.score > coexp_score_threshold)]
if len(hits.index) > max_per_gene:
hits = hits[0:max_per_gene]
neighbors.append(hits)
neighbors = pd.concat(neighbors)
neighbor_gene_ids = list(set(neighbors.feature_name_2))
gene_ids = seed_gene_ids + neighbor_gene_ids
gene_types = ['seed'] * len(seed_gene_ids) + ['neighbor'] * len(neighbor_gene_ids)
old_locus_tags = []
gene_names = []
sacch_orthologs = []
descriptions = []
for gene_id in gene_ids:
try:
locus_tag_old = gene_metadata.loc[gene_metadata["locus_tag_old"] == gene_id]["locus_tag_old"].values[0]
gene_name = gene_metadata.loc[gene_metadata["locus_tag_old"] == gene_id]["gene_name"].values[0]
sacch_ortholog = gene_metadata.loc[gene_metadata["locus_tag_old"] == gene_id]["sacch_ortholog"].values[0]
description = gene_metadata.loc[gene_metadata["locus_tag_old"] == gene_id]["description"].values[0]
except:
st.error(f"Unable to locate locus_tag_new for Gene ID: {gene_id}, it should be of the form 'B9J08_#######'")
gene_id = None
gene_names = None
sacch_ortholog = None
description = None
old_locus_tags.append(locus_tag_old)
gene_names.append(gene_name)
sacch_orthologs.append(sacch_ortholog)
descriptions.append(description)
print(f"""
Constructing node_info
seed_gene_ids: {len(seed_gene_ids)},
neighbor_gene_ids: {len(neighbor_gene_ids)},
gene_index: {len(gene_ids)},
locus_tag_old: {len(old_locus_tags)},
gene_types: {len(gene_types)},
gene_name: {len(gene_names)},
sacc_ortholog: {len(sacch_orthologs)},
descriptions: {len(descriptions)}
""")
node_info = pd.DataFrame({
"gene_index": range(len(gene_ids)),
"locus_tag_old" : old_locus_tags,
"gene_type" : gene_types,
"gene_name" : gene_names,
"sacch_ortholog": sacch_orthologs,
"description": descriptions
})
neighbors = neighbors.merge(
right = node_info,
left_on = "feature_name_1",
right_on = "locus_tag_old")
neighbors = neighbors.merge(
right = node_info,
left_on = "feature_name_2",
right_on = "locus_tag_old",
suffixes = ("_a", "_b"))
################################
# Use NetworkX to layout graph #
################################
# note I think CytoscapeJS can layout graphs
# but I'm unsure how to do it through the streamlit-cytoscapejs interface :(
st.write(neighbors)
G = nx.Graph()
for i in range(len(neighbors.index)):
edge = neighbors.iloc[i]
G.add_edge(
edge["gene_index_a"],
edge["gene_index_b"],
weight = edge["score"])
layout = nx.spring_layout(G)
node_color_lut = {
"seed" : "#4866F0", # blue
"neighbor" : "#F0C547" # gold
}
elements = []
singleton_index = 0
for i in range(len(node_info.index)):
node = node_info.iloc[i]
if node["gene_index"] in layout.keys():
layout_x = layout[node["gene_index"]][0] * 600 + 1500/2
layout_y = layout[node["gene_index"]][1] * 600 + 1500/2
else:
layout_x = (singleton_index % 8) * 150 + 100
layout_y = np.floor(singleton_index / 8) * 50 + 30
singleton_index += 1
elements.append({
"data": {
"id": node["locus_tag_old"],
"label": node["gene_name"] if node["gene_name"] is not None else node["locus_tag_old"],
"color": node_color_lut[node["gene_type"]]},
"position": {
"x" : layout_x,
"y" : layout_y}})
for i in range(len(neighbors.index)):
edge = neighbors.iloc[i]
elements.append({
"data" : {
"source" : edge["feature_name_1"],
"target" : edge["feature_name_2"],
"width" :
20 if edge["score"] > 0.98 else
15 if edge["score"] > 0.93 else
10 if edge["score"] > 0.90 else
8 if edge["score"] > 0.88 else
5}})
with col3:
st.text('') # help alignment with input box
st.download_button(
label = "Download as as TSV",
data = neighbors.to_csv(sep ='\t').encode('utf-8'),
file_name = f"CaurisCEN_network.tsv",
mime = "text/csv")
##########################################################
stylesheet = [
{"selector": "node", "style": {
"width": 140,
"height": 30,
"shape": "rectangle",
"label" : "data(label)",
"labelFontSize": 100,
'background-color': 'data(color)',
"text-halign": "center",
"text-valign": "center",
}},
{"selector": "edge", "style": {
"width": "data(width)"
}}
]
st.title("CaurisCEN Network")
clicked_elements = st_cytoscapejs(
elements = elements,
stylesheet = stylesheet,
width = 1000,
height= 1000,
key = "1")