OpenMTgene / app.py
yashm's picture
Update app.py
3a08bc5 verified
import streamlit as st
import os
import io
import logging
from Bio import SeqIO, Phylo
from Bio.SeqRecord import SeqRecord
import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as sch
from scipy.stats import chi2
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import plotly.graph_objects as go
st.set_page_config(page_title="GeneBank Genie", layout="wide")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ---------------------------
# Core Classes and Functions
# ---------------------------
class GenBankParser:
def __init__(self, text_io):
self.text_io = text_io
self.records = []
def load_records(self):
# Always seek to start (for re-reads)
self.text_io.seek(0)
self.records = list(SeqIO.parse(self.text_io, "genbank"))
return self.records
class AnalysisEngine:
@staticmethod
def perform_pca(features, n_components=2):
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(features_scaled)
return pca_result, pca.explained_variance_ratio_
@staticmethod
def duplicate_feature(X):
return np.hstack([X, X])
@staticmethod
def compute_mahalanobis(points):
mean_vec = np.mean(points, axis=0)
cov_matrix = np.cov(points, rowvar=False)
inv_cov_matrix = np.linalg.inv(cov_matrix)
diff = points - mean_vec
md_squared = np.sum(diff.dot(inv_cov_matrix) * diff, axis=1)
return np.sqrt(md_squared)
def extract_gene_sequences(records, selected_gene):
gene_seqs = []
for rec in records:
for feature in rec.features:
if feature.type.lower() in ["gene", "cds"]:
if "gene" in feature.qualifiers:
gene_name = feature.qualifiers["gene"][0]
if gene_name.upper() == selected_gene.upper():
seq = feature.extract(rec.seq)
new_rec = SeqRecord(seq, id=rec.id, description=rec.annotations.get("organism", ""))
gene_seqs.append(new_rec)
break
return gene_seqs
def extract_full_sequences(records):
full_seqs = []
for rec in records:
new_rec = SeqRecord(rec.seq, id=rec.id, description=rec.description)
full_seqs.append(new_rec)
return full_seqs
def save_fasta_download(sequences, file_label):
fasta_io = io.StringIO()
SeqIO.write(sequences, fasta_io, "fasta")
fasta_io.seek(0)
st.download_button(
label=f"Download {file_label} as FASTA",
data=fasta_io.getvalue(),
file_name=f"{file_label}.fasta",
mime="text/plain"
)
# ---------------------------
# Streamlit UI
# ---------------------------
st.title("GeneBank Genie (Streamlit Version)")
st.markdown("""
A versatile tool for the analysis of GenBank records.
**Version 1.0 – © Dr. Yash Munnalal Gupta**
""")
with st.expander("ℹ️ About GeneBank Genie"):
st.markdown("""
**GeneBank Genie** is a comprehensive tool for analyzing GenBank files.
It includes modules for general analysis, gene analysis,
taxonomic visualization, additional visualizations, dendrogram analysis, and sequence extraction.
""")
# Sidebar: GenBank File Upload
st.sidebar.header("Step 1. Upload GenBank File")
gb_file = st.sidebar.file_uploader("Upload a GenBank file (.gb, .gbk)", type=["gb", "gbk"])
if gb_file is not None:
# Always wrap the file as text for Biopython
gb_file.seek(0)
gb_bytes = gb_file.read()
if isinstance(gb_bytes, bytes):
gb_text = gb_bytes.decode("utf-8")
else:
gb_text = gb_bytes
gb_text_io = io.StringIO(gb_text)
parser = GenBankParser(gb_text_io)
try:
records = parser.load_records()
except Exception as e:
st.error(f"Error parsing GenBank file: {e}")
records = []
else:
st.sidebar.success(f"Loaded {len(records)} GenBank records.")
else:
records = []
st.warning("Upload a GenBank file to get started.")
if records:
# Sidebar: Taxonomy Level Selection
st.sidebar.header("Step 2. Set Analysis Options")
tax_levels = []
for rec in records:
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) > len(tax_levels):
tax_levels = taxonomy
if tax_levels:
max_level = len(tax_levels) - 1
tax_level_index = st.sidebar.number_input("Taxonomy Level Index", min_value=0, max_value=max_level, value=max_level)
else:
tax_level_index = 0
color_palette = st.sidebar.selectbox("Color Palette", ["tab20", "viridis", "plasma", "inferno", "magma", "cividis"])
# Main Tabs
tabs = st.tabs([
"General Analysis", "Gene Analysis", "Sankey Diagram",
"Additional Visualizations", "Dendrogram Analysis", "Sequence Extraction"
])
# --- General Analysis ---
with tabs[0]:
st.header("General Analysis")
# Populate tax groups for selection
tax_groups = set()
for rec in records:
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) > tax_level_index:
tax_groups.add(taxonomy[tax_level_index])
elif taxonomy:
tax_groups.add(taxonomy[-1])
tax_groups = sorted(list(tax_groups))
selected_tax_group = st.selectbox("Selected Taxonomic Group", tax_groups)
default_marker = st.text_input("Default Marker", "o")
outlier_marker = st.text_input("Outlier Marker", "D")
run_general = st.button("Run General Analysis")
summary_general = st.empty()
if run_general:
# Compute features
data = []
for rec in records:
seq = str(rec.seq).upper()
total_length = len(seq)
if total_length == 0:
continue
countA = seq.count("A")
countC = seq.count("C")
countG = seq.count("G")
countT = seq.count("T")
propA = countA / total_length
propC = countC / total_length
propG = countG / total_length
propT = countT / total_length
gc_content = (countG + countC) / total_length
gene_count = sum(1 for f in rec.features if f.type.lower() == 'gene')
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) > tax_level_index:
chosen_tax = taxonomy[tax_level_index]
elif taxonomy:
chosen_tax = taxonomy[-1]
else:
chosen_tax = "Unknown"
organism = rec.annotations.get("organism", "Unknown")
data.append([propA, propC, propG, propT, total_length, gc_content, gene_count,
organism, " | ".join(taxonomy), chosen_tax])
df = pd.DataFrame(data, columns=['A','C','G','T','SeqLength','GC','GeneCount',
'Organism','Full_Taxonomy','TaxLevel'])
# PCA
features_nuc = df[['A','C','G','T']].values
features_add = df[['SeqLength','GC','GeneCount']].values
features_combined = df[['A','C','G','T','SeqLength','GC','GeneCount']].values
features_seq = df[['SeqLength']].values
features_gc = df[['GC']].values
features_gene = df[['GeneCount']].values
pca_nuc, var_nuc = AnalysisEngine.perform_pca(features_nuc)
pca_add, var_add = AnalysisEngine.perform_pca(features_add)
pca_comb, var_comb = AnalysisEngine.perform_pca(features_combined)
pca_seq, var_seq = AnalysisEngine.perform_pca(AnalysisEngine.duplicate_feature(features_seq))
pca_gc, var_gc = AnalysisEngine.perform_pca(AnalysisEngine.duplicate_feature(features_gc))
pca_gene, var_gene = AnalysisEngine.perform_pca(AnalysisEngine.duplicate_feature(features_gene))
df['PC1_nuc'] = pca_nuc[:, 0]; df['PC2_nuc'] = pca_nuc[:, 1]
df['PC1_add'] = pca_add[:, 0]; df['PC2_add'] = pca_add[:, 1]
df['PC1_comb'] = pca_comb[:, 0]; df['PC2_comb'] = pca_comb[:, 1]
df['PC1_seq'] = pca_seq[:, 0]; df['PC2_seq'] = pca_seq[:, 1]
df['PC1_gc'] = pca_gc[:, 0]; df['PC2_gc'] = pca_gc[:, 1]
df['PC1_gene'] = pca_gene[:, 0]; df['PC2_gene'] = pca_gene[:, 1]
# Outlier detection
group_mask = df['TaxLevel'] == selected_tax_group
df.loc[:, 'Outlier'] = False
group_df = df[group_mask]
outlier_txt = ""
if not group_df.empty:
group_points = group_df[['PC1_comb', 'PC2_comb']].values
distances = AnalysisEngine.compute_mahalanobis(group_points)
threshold = np.sqrt(chi2.ppf(0.95, df=2))
outlier_flags = distances > threshold
df.loc[group_mask, 'Outlier'] = outlier_flags
outlier_species = df[(df['TaxLevel'] == selected_tax_group) & (df['Outlier'] == True)]['Organism']
if not outlier_species.empty:
unique_species = outlier_species.unique()
outlier_txt = f"Outlier species in {selected_tax_group}: {', '.join(unique_species)}"
else:
outlier_txt = f"No outlier species detected in {selected_tax_group}."
else:
outlier_txt = f"No records found for {selected_tax_group}."
summary_general.info(outlier_txt)
st.subheader("PCA Plots")
fig, axes = plt.subplots(2, 3, figsize=(20, 10))
axes = axes.flatten()
unique_tax = df['TaxLevel'].unique()
cmap = plt.cm.get_cmap(color_palette, len(unique_tax))
color_dict = {tax: cmap(i) for i, tax in enumerate(unique_tax)}
for i, (col1, col2, title, var, pca_data) in enumerate([
('PC1_nuc', 'PC2_nuc', f"Nucleotide PCA (var: {var_nuc[0]:.2f}, {var_nuc[1]:.2f})", var_nuc, pca_nuc),
('PC1_seq', 'PC2_seq', f"SeqLength PCA (var: {var_seq[0]:.2f}, {var_seq[1]:.2f})", var_seq, pca_seq),
('PC1_gc', 'PC2_gc', f"GC PCA (var: {var_gc[0]:.2f}, {var_gc[1]:.2f})", var_gc, pca_gc),
('PC1_gene', 'PC2_gene', f"Gene Count PCA (var: {var_gene[0]:.2f}, {var_gene[1]:.2f})", var_gene, pca_gene),
('PC1_add', 'PC2_add', f"Additional PCA (var: {var_add[0]:.2f}, {var_add[1]:.2f})", var_add, pca_add),
('PC1_comb', 'PC2_comb', f"Combined PCA (var: {var_comb[0]:.2f}, {var_comb[1]:.2f})", var_comb, pca_comb)
]):
for tax in unique_tax:
subset = df[df['TaxLevel'] == tax]
marker = outlier_marker if (col1 == 'PC1_comb' and tax == selected_tax_group and subset['Outlier'].any()) else default_marker
axes[i].scatter(subset[col1], subset[col2], label=tax, alpha=0.7, marker=marker)
axes[i].set_title(title)
axes[i].set_xlabel("PC1")
axes[i].set_ylabel("PC2")
axes[i].legend(fontsize=8)
st.pyplot(fig)
st.dataframe(df.head(), use_container_width=True)
csv = df.to_csv(index=False).encode('utf-8')
st.download_button("Download PCA Features CSV", data=csv, file_name="pca_features.csv", mime='text/csv')
# --- Gene Analysis ---
with tabs[1]:
st.header("Gene Analysis")
# Get all gene names
gene_set = set()
for rec in records:
for feature in rec.features:
if feature.type.lower() in ["gene", "cds"]:
if "gene" in feature.qualifiers:
gene_set.add(feature.qualifiers["gene"][0].upper())
gene_list = sorted(list(gene_set))
selected_gene = st.selectbox("Select Gene", gene_list)
selected_tax_group_gene = st.text_input("Selected Taxonomic Group (for outlier detection)", tax_groups[0] if tax_groups else "")
run_gene = st.button("Run Gene Analysis")
summary_gene = st.empty()
if run_gene:
gene_data = []
for rec in records:
gene_seq = None
for feature in rec.features:
if feature.type.lower() in ["gene", "cds"]:
if "gene" in feature.qualifiers:
gene_name = feature.qualifiers["gene"][0]
if gene_name.upper() == selected_gene.upper():
gene_seq = str(feature.extract(rec.seq)).upper()
break
if gene_seq is None or len(gene_seq) == 0:
continue
gene_length = len(gene_seq)
countA = gene_seq.count("A")
countC = gene_seq.count("C")
countG = gene_seq.count("G")
countT = gene_seq.count("T")
propA = countA / gene_length
propC = countC / gene_length
propG = countG / gene_length
propT = countT / gene_length
gc_content = (countG + countC) / gene_length
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) > tax_level_index:
chosen_tax = taxonomy[tax_level_index]
elif taxonomy:
chosen_tax = taxonomy[-1]
else:
chosen_tax = "Unknown"
organism = rec.annotations.get("organism", "Unknown")
gene_data.append([propA, propC, propG, propT, gene_length, gc_content,
organism, " | ".join(taxonomy), chosen_tax])
df_gene = pd.DataFrame(gene_data, columns=["A", "C", "G", "T", "GeneLength", "GC",
"Organism", "Full_Taxonomy", "TaxLevel"])
features = df_gene[["A", "C", "G", "T", "GeneLength", "GC"]].values
pca_result, var_gene = AnalysisEngine.perform_pca(features)
df_gene["PC1"] = pca_result[:, 0]
df_gene["PC2"] = pca_result[:, 1]
group_mask = df_gene["TaxLevel"] == selected_tax_group_gene
df_gene["Outlier"] = False
group_df = df_gene[group_mask]
outlier_txt = ""
if not group_df.empty:
group_points = group_df[["PC1", "PC2"]].values
distances = AnalysisEngine.compute_mahalanobis(group_points)
threshold = np.sqrt(chi2.ppf(0.95, df=2))
outlier_flags = distances > threshold
df_gene.loc[group_mask, "Outlier"] = outlier_flags
outlier_species = df_gene[(df_gene["TaxLevel"] == selected_tax_group_gene) & (df_gene["Outlier"] == True)]["Organism"]
if not outlier_species.empty:
unique_species = outlier_species.unique()
outlier_txt = f"Outlier species in {selected_tax_group_gene}: {', '.join(unique_species)}"
else:
outlier_txt = f"No outlier species detected in {selected_tax_group_gene}."
else:
outlier_txt = f"No records found for {selected_tax_group_gene}."
summary_gene.info(outlier_txt)
st.subheader("Gene PCA Scatter Plot")
fig, ax = plt.subplots(figsize=(8, 6))
unique_tax = df_gene["TaxLevel"].unique()
cmap = plt.cm.get_cmap(color_palette, len(unique_tax))
color_dict = {tax: cmap(i) for i, tax in enumerate(unique_tax)}
for tax in unique_tax:
subset = df_gene[df_gene["TaxLevel"] == tax]
marker = outlier_marker if (tax == selected_tax_group_gene and subset["Outlier"].any()) else default_marker
ax.scatter(subset["PC1"], subset["PC2"], label=tax, alpha=0.7, marker=marker)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend(fontsize=8)
st.pyplot(fig)
st.dataframe(df_gene.head(), use_container_width=True)
csv_gene = df_gene.to_csv(index=False).encode('utf-8')
st.download_button("Download Gene PCA Features CSV", data=csv_gene, file_name="selected_gene_pca_features.csv", mime='text/csv')
# --- Sankey Diagram ---
with tabs[2]:
st.header("Sankey Diagram")
start_level = st.number_input("Start Level (for taxonomy)", min_value=0, value=0)
run_sankey = st.button("Run Sankey Diagram")
if run_sankey:
import collections
link_counts = collections.defaultdict(int)
nodes_set = set()
for rec in records:
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) <= start_level:
continue
for i in range(start_level, len(taxonomy) - 1):
source_node = f"L{i}-{taxonomy[i]}"
target_node = f"L{i+1}-{taxonomy[i+1]}"
nodes_set.add(source_node)
nodes_set.add(target_node)
link_counts[(source_node, target_node)] += 1
def sort_key(label):
try:
level = int(label.split("-")[0][1:])
except:
level = 9999
return (level, label)
nodes_list = sorted(list(nodes_set), key=sort_key)
node_to_index = {node: i for i, node in enumerate(nodes_list)}
sources = []
targets = []
values = []
for (src, tgt), count in link_counts.items():
sources.append(node_to_index[src])
targets.append(node_to_index[tgt])
values.append(count)
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=nodes_list,
color="blue"
),
link=dict(
source=sources,
target=targets,
value=values
))])
fig.update_layout(title_text=f"Sankey Diagram of Taxonomy Flow (Starting at Level {start_level})", font_size=10)
st.plotly_chart(fig, use_container_width=True)
# --- Additional Visualizations ---
with tabs[3]:
st.header("Additional Visualizations")
if st.button("Show Correlation Heatmap (A, C, G, T, SeqLength, GC, GeneCount)"):
data = []
for rec in records:
seq = str(rec.seq).upper()
total_length = len(seq)
if total_length == 0:
continue
countA = seq.count("A")
countC = seq.count("C")
countG = seq.count("G")
countT = seq.count("T")
propA = countA / total_length
propC = countC / total_length
propG = countG / total_length
propT = countT / total_length
gc_content = (countG + countC) / total_length
gene_count = sum(1 for f in rec.features if f.type.lower() == 'gene')
taxonomy = rec.annotations.get("taxonomy", [])
organism = rec.annotations.get("organism", "Unknown")
data.append([propA, propC, propG, propT, total_length, gc_content, gene_count, organism])
df = pd.DataFrame(data, columns=['A','C','G','T','SeqLength','GC','GeneCount','Organism'])
plt.figure(figsize=(8, 6))
corr_matrix = df[['A','C','G','T','SeqLength','GC','GeneCount']].corr()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm')
st.pyplot(plt.gcf())
n_clusters = st.number_input("Number of Clusters for KMeans", min_value=2, value=2)
if st.button("Run KMeans Clustering"):
data = []
for rec in records:
seq = str(rec.seq).upper()
total_length = len(seq)
if total_length == 0:
continue
countA = seq.count("A")
countC = seq.count("C")
countG = seq.count("G")
countT = seq.count("T")
propA = countA / total_length
propC = countC / total_length
propG = countG / total_length
propT = countT / total_length
gc_content = (countG + countC) / total_length
gene_count = sum(1 for f in rec.features if f.type.lower() == 'gene')
taxonomy = rec.annotations.get("taxonomy", [])
organism = rec.annotations.get("organism", "Unknown")
data.append([propA, propC, propG, propT, total_length, gc_content, gene_count, organism])
df = pd.DataFrame(data, columns=['A','C','G','T','SeqLength','GC','GeneCount','Organism'])
features_combined = df[['A','C','G','T','SeqLength','GC','GeneCount']].values
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(features_combined)
df['KMeans_Cluster'] = clusters
sil_score = silhouette_score(features_combined, clusters)
st.info(f"Silhouette score: {sil_score:.2f}")
st.write("KMeans Cluster Assignment (first 10 rows):")
st.dataframe(df[['Organism','KMeans_Cluster']].head(10))
# PCA
pca_comb, _ = AnalysisEngine.perform_pca(features_combined)
df['PC1_comb'] = pca_comb[:, 0]
df['PC2_comb'] = pca_comb[:, 1]
plt.figure(figsize=(8, 6))
sns.scatterplot(x='PC1_comb', y='PC2_comb', hue='KMeans_Cluster', data=df, palette='tab10', s=80)
plt.title("K-Means Clustering (Combined PCA Projection)")
st.pyplot(plt.gcf())
if st.button("Show Pair Plot (A, C, G, T)"):
data = []
for rec in records:
seq = str(rec.seq).upper()
total_length = len(seq)
if total_length == 0:
continue
countA = seq.count("A")
countC = seq.count("C")
countG = seq.count("G")
countT = seq.count("T")
propA = countA / total_length
propC = countC / total_length
propG = countG / total_length
propT = countT / total_length
data.append([propA, propC, propG, propT])
df = pd.DataFrame(data, columns=['A','C','G','T'])
sns.pairplot(df, diag_kind='kde')
st.pyplot(plt.gcf())
# --- Dendrogram Analysis ---
with tabs[4]:
st.header("Dendrogram Analysis")
feature_option = st.selectbox(
"Select Feature Set for Dendrogram",
["Nucleotide Composition", "Additional Features", "Combined Features"]
)
if st.button("Run Dendrogram"):
data = []
for rec in records:
seq = str(rec.seq).upper()
total_length = len(seq)
if total_length == 0:
continue
countA = seq.count("A")
countC = seq.count("C")
countG = seq.count("G")
countT = seq.count("T")
propA = countA / total_length
propC = countC / total_length
propG = countG / total_length
propT = countT / total_length
gc_content = (countG + countC) / total_length
gene_count = sum(1 for f in rec.features if f.type.lower() == 'gene')
taxonomy = rec.annotations.get("taxonomy", [])
if len(taxonomy) > tax_level_index:
chosen_tax = taxonomy[tax_level_index]
elif taxonomy:
chosen_tax = taxonomy[-1]
else:
chosen_tax = "Unknown"
data.append([propA, propC, propG, propT, total_length, gc_content, gene_count, chosen_tax])
df = pd.DataFrame(data, columns=['A','C','G','T','SeqLength','GC','GeneCount','TaxLevel'])
if feature_option == "Nucleotide Composition":
features = df[['A', 'C', 'G', 'T']].values
elif feature_option == "Additional Features":
features = df[['SeqLength', 'GC', 'GeneCount']].values
else:
features = df[['A', 'C', 'G', 'T', 'SeqLength', 'GC', 'GeneCount']].values
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
linkage_matrix = sch.linkage(features_scaled, method='ward')
plt.figure(figsize=(12, 8))
sch.dendrogram(linkage_matrix, labels=df['TaxLevel'].values, leaf_rotation=90)
plt.title("Dendrogram Analysis")
plt.xlabel("Taxonomic Level")
plt.ylabel("Euclidean Distance")
st.pyplot(plt.gcf())
# --- Sequence Extraction ---
with tabs[5]:
st.header("Sequence Extraction")
st.subheader("Full Sequences")
if st.button("Download All Full Sequences as FASTA"):
full_seqs = extract_full_sequences(records)
save_fasta_download(full_seqs, "full_sequences")
st.subheader("Gene Sequences")
gene_list = sorted(list({feature.qualifiers["gene"][0].upper()
for rec in records for feature in rec.features
if feature.type.lower() in ["gene", "cds"] and "gene" in feature.qualifiers}))
selected_gene_extract = st.selectbox("Select Gene for Extraction", gene_list)
if st.button("Download Selected Gene Sequences as FASTA"):
gene_seqs = extract_gene_sequences(records, selected_gene_extract)
if gene_seqs:
save_fasta_download(gene_seqs, f"{selected_gene_extract}_extracted")
else:
st.warning(f"No sequences found for gene {selected_gene_extract}.")