CRISPRtool / app.py
LfOreVEr's picture
Upload app.py with huggingface_hub
3773bb2 verified
import os
import cas9att
import cas9attvcf
#import cas12
import cas12lstm
import cas12lstmvcf
import pandas as pd
import streamlit as st
import plotly.graph_objs as go
from pathlib import Path
import zipfile
import io
#import gtracks
#import subprocess
import cyvcf2
# title and documentation
st.markdown(Path('crisprTool.md').read_text(), unsafe_allow_html=True)
st.divider()
CRISPR_MODELS = ['Cas9', 'Cas12']
selected_model = st.selectbox('Select CRISPR system:', CRISPR_MODELS, key='selected_model')
cas9att_path = 'cas9_model/Cas9_MultiHeadAttention_weights.h5'
cas12lstm_path = 'cas12_model/BiLSTM_Cpf1_weights.h5'
@st.cache_data
def parse_gene_annotations(file_path):
gene_dict = {}
with open(file_path, 'r', encoding='gbk', errors='ignore') as file:
headers = file.readline().strip().split('\t') # Assuming tab-delimited file
symbol_idx = headers.index('Approved symbol') # Find index of 'Approved symbol'
ensembl_idx = headers.index('Ensembl gene ID') # Find index of 'Ensembl gene ID'
for line in file:
values = line.strip().split('\t')
# Ensure we have enough values and add mapping from symbol to Ensembl ID
if len(values) > max(symbol_idx, ensembl_idx):
gene_dict[values[symbol_idx]] = values[ensembl_idx]
return gene_dict
# Replace 'your_annotation_file.txt' with the path to your actual gene annotation file
gene_annotations = parse_gene_annotations('Human_genes_HUGO_02242024_annotation.txt')
gene_symbol_list = list(gene_annotations.keys()) # List of gene symbols for the autocomplete feature
# Check if the selected model is Cas9
if selected_model == 'Cas9':
# Use a radio button to select enzymes, making sure only one can be selected at a time
target_selection = st.radio(
"Select either Normal or Mutation related to MDA-MB-231:",
('Normal', 'Mutation related to MDA-MB-231'),
key='target_selection'
)
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Define a function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
for path in [genbank_file_path, bed_file_path, csv_file_path]:
if os.path.exists(path):
os.remove(path)
if target_selection == 'Normal':
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Prediction button
predict_button = st.button('Go Cas9 on-target prediction!')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas9att.process_gene(gene_symbol, cas9att_path)
predictions_all = sorted(predictions, key=lambda x: x[8], reverse=True)
sorted_predictions = predictions_all[:10] # Get top 10 predictions
st.session_state['on_target_results_all'] = predictions_all
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("SpCas9")
# Include "Target" in the DataFrame's columns
try:
df_full = pd.DataFrame(st.session_state['on_target_results_all'],
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
st.dataframe(df_full)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results_all'])
# Initialize Plotly figure
fig = go.Figure()
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
# Plot Exons as small markers on the X-axis
for exon in st.session_state['exons']:
exon_start, exon_end = exon['start'], exon['end']
fig.add_trace(go.Bar(
x=[(exon_start + exon_end) / 2],
y=[EXON_HEIGHT],
width=[exon_end - exon_start],
base=EXON_BASE,
marker_color='rgba(128, 0, 128, 0.5)',
name='Exon'
))
VERTICAL_GAP = 0.2 # Gap between different ranks
# Define max and min Y values based on strand and rank
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
# Iterate over top 5 sorted predictions to create the plot
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
midpoint = (int(start) + int(end)) / 2
# Vertical position based on rank, modified by strand
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value],
mode='markers+text',
marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
size=12),
text=f"Rank: {i}", # Text label
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
))
# Update layout for clarity and interaction
fig.update_layout(
title='Top 5 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Strand',
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
showlegend=False,
hovermode='x unified',
)
# Display the plot
st.plotly_chart(fig)
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
#bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
#plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas9att.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
#cas9att.create_bed_file_from_df(df, bed_file_path)
cas9att.create_csv_from_df(df_full, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
#zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Important: move the cursor to the beginning of the BytesIO buffer before reading it
zip_buffer.seek(0)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank and CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif target_selection == 'Mutation related to MDA-MB-231':
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Prediction button
predict_button = st.button('Go Cas9 on-target prediction!')
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas9attvcf.process_gene(gene_symbol, vcf_reader, cas9att_path)
full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
sorted_predictions = full_predictions[:10]
st.session_state['full_results'] = full_predictions
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("SpCas9")
# Include "Target" in the DataFrame's columns
try:
# df = pd.DataFrame(st.session_state['on_target_results'],
# columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", "Exon",
# "Target",
# "gRNA", "Prediction", "Is Mutation"])
df_full = pd.DataFrame(st.session_state['full_results'],
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
"Exon", "Target",
"gRNA", "Prediction", "Is Mutation"])
st.dataframe(df_full)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results'])
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
#bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
#plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas9attvcf.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
#cas9attvcf.create_bed_file_from_df(df_full, bed_file_path)
cas9attvcf.create_csv_from_df(df_full, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
#zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank and CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif selected_model == 'Cas12':
# Use a radio button to select enzymes, making sure only one can be selected at a time
target_selection = st.radio(
"Select either Normal or Mutation related to MDA-MB-231:",
('Normal', 'Mutation related to MDA-MB-231'),
key='target_selection'
)
if 'current_gene_symbol' not in st.session_state:
st.session_state['current_gene_symbol'] = ""
# Define a function to clean up old files
def clean_up_old_files(gene_symbol):
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
for path in [genbank_file_path, bed_file_path, csv_file_path]:
if os.path.exists(path):
os.remove(path)
if target_selection == 'Normal':
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Prediction button
predict_button = st.button('Go Cas12 on-target prediction!')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas12lstm.process_gene(gene_symbol, cas12lstm_path)
predictions_all = sorted(predictions, key=lambda x: x[8], reverse=True)
sorted_predictions = predictions_all[:10] # Get top 10 predictions
st.session_state['on_target_results_all'] = predictions_all
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("AsCas12a")
# Include "Target" in the DataFrame's columns
try:
df_full = pd.DataFrame(st.session_state['on_target_results_all'],
columns=["Chr", "Start Pos", "End Pos", "Strand", "Transcript", "Exon", "Target", "gRNA", "Prediction"])
st.dataframe(df_full)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results_all'])
# Initialize Plotly figure
fig = go.Figure()
EXON_BASE = 0 # Base position for exons and CDS on the Y axis
EXON_HEIGHT = 0.02 # How 'tall' the exon markers should appear
# Plot Exons as small markers on the X-axis
for exon in st.session_state['exons']:
exon_start, exon_end = exon['start'], exon['end']
fig.add_trace(go.Bar(
x=[(exon_start + exon_end) / 2],
y=[EXON_HEIGHT],
width=[exon_end - exon_start],
base=EXON_BASE,
marker_color='rgba(128, 0, 128, 0.5)',
name='Exon'
))
VERTICAL_GAP = 0.2 # Gap between different ranks
# Define max and min Y values based on strand and rank
MAX_STRAND_Y = 0.1 # Maximum Y value for positive strand results
MIN_STRAND_Y = -0.1 # Minimum Y value for negative strand results
# Iterate over top 5 sorted predictions to create the plot
for i, prediction in enumerate(st.session_state['on_target_results'][:5], start=1): # Only top 5
chrom, start, end, strand, transcript, exon, target, gRNA, prediction_score = prediction
midpoint = (int(start) + int(end)) / 2
# Vertical position based on rank, modified by strand
y_value = (MAX_STRAND_Y - (i - 1) * VERTICAL_GAP) if strand == '1' or strand == '+' else (
MIN_STRAND_Y + (i - 1) * VERTICAL_GAP)
fig.add_trace(go.Scatter(
x=[midpoint],
y=[y_value],
mode='markers+text',
marker=dict(symbol='triangle-up' if strand == '1' or strand == '+' else 'triangle-down',
size=12),
text=f"Rank: {i}", # Text label
hoverinfo='text',
hovertext=f"Rank: {i}<br>Chromosome: {chrom}<br>Target Sequence: {target}<br>gRNA: {gRNA}<br>Start: {start}<br>End: {end}<br>Strand: {'+' if strand == '1' or strand == '+' else '-'}<br>Transcript: {transcript}<br>Prediction: {prediction_score:.4f}",
))
# Update layout for clarity and interaction
fig.update_layout(
title='Top 5 gRNA Sequences by Prediction Score',
xaxis_title='Genomic Position',
yaxis_title='Strand',
yaxis=dict(tickvals=[MAX_STRAND_Y, MIN_STRAND_Y], ticktext=['+', '-']),
showlegend=False,
hovermode='x unified',
)
# Display the plot
st.plotly_chart(fig)
# Generate and download files
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
#bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
#plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas12lstm.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
cas12lstm.create_csv_from_df(df_full, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
#zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Important: move the cursor to the beginning of the BytesIO buffer before reading it
zip_buffer.seek(0)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank and CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)
elif target_selection == 'Mutation related to MDA-MB-231':
# Gene symbol entry with autocomplete-like feature
gene_symbol = st.selectbox('Enter a Gene Symbol:', [''] + gene_symbol_list, key='gene_symbol',
format_func=lambda x: x if x else "")
# Handle gene symbol change and file cleanup
if gene_symbol != st.session_state['current_gene_symbol'] and gene_symbol:
if st.session_state['current_gene_symbol']:
# Clean up files only if a different gene symbol is entered and a previous symbol exists
clean_up_old_files(st.session_state['current_gene_symbol'])
# Update the session state with the new gene symbol
st.session_state['current_gene_symbol'] = gene_symbol
# Prediction button
predict_button = st.button('Go Cas12 on-target prediction!')
vcf_reader = cyvcf2.VCF('SRR25934512.filter.snps.indels.vcf.gz')
if 'exons' not in st.session_state:
st.session_state['exons'] = []
# Process predictions
if predict_button and gene_symbol:
with st.spinner('Predicting... Please wait'):
predictions, gene_sequence, exons = cas12lstmvcf.process_gene(gene_symbol, vcf_reader, cas12lstm_path)
full_predictions = sorted(predictions, key=lambda x: x[8], reverse=True)
sorted_predictions = full_predictions[:10]
st.session_state['full_results'] = full_predictions
st.session_state['on_target_results'] = sorted_predictions
st.session_state['gene_sequence'] = gene_sequence # Save gene sequence in session state
st.session_state['exons'] = exons # Store exon data
# Notify the user once the process is completed successfully.
st.success('Prediction completed!')
st.session_state['prediction_made'] = True
if 'on_target_results' in st.session_state and st.session_state['on_target_results']:
ensembl_id = gene_annotations.get(gene_symbol, 'Unknown') # Get Ensembl ID or default to 'Unknown'
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Genome**")
st.markdown("Homo sapiens")
with col2:
st.markdown("**Gene**")
st.markdown(f"{gene_symbol} : {ensembl_id} (primary)")
with col3:
st.markdown("**Nuclease**")
st.markdown("AsCas12a")
# Include "Target" in the DataFrame's columns
try:
# df = pd.DataFrame(st.session_state['on_target_results'],
# columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript", "Exon",
# "Target",
# "gRNA", "Prediction", "Is Mutation"])
df_full = pd.DataFrame(st.session_state['full_results'],
columns=["Gene Symbol", "Chr", "Strand", "Target Start", "Transcript",
"Exon", "Target",
"gRNA", "Prediction", "Is Mutation"])
st.dataframe(df_full)
except ValueError as e:
st.error(f"DataFrame creation error: {e}")
# Optionally print or log the problematic data for debugging:
print(st.session_state['on_target_results'])
if 'gene_sequence' in st.session_state and st.session_state['gene_sequence']:
gene_symbol = st.session_state['current_gene_symbol']
gene_sequence = st.session_state['gene_sequence']
# Define file paths
genbank_file_path = f"{gene_symbol}_crispr_targets.gb"
#bed_file_path = f"{gene_symbol}_crispr_targets.bed"
csv_file_path = f"{gene_symbol}_crispr_predictions.csv"
#plot_image_path = f"{gene_symbol}_gtracks_plot.png"
# Generate files
cas12lstmvcf.generate_genbank_file_from_df(df_full, gene_sequence, gene_symbol, genbank_file_path)
#cas9attvcf.create_bed_file_from_df(df_full, bed_file_path)
cas12lstmvcf.create_csv_from_df(df_full, csv_file_path)
# Prepare an in-memory buffer for the ZIP file
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# For each file, add it to the ZIP file
zip_file.write(genbank_file_path)
#zip_file.write(bed_file_path)
zip_file.write(csv_file_path)
# Display the download button for the ZIP file
st.download_button(
label="Download GenBank and CSV files as ZIP",
data=zip_buffer.getvalue(),
file_name=f"{gene_symbol}_files.zip",
mime="application/zip"
)