AVP-Pro / app.py
Wwwy1031's picture
Update app.py
349aef5 verified
import time
import traceback
import io
import pandas as pd
import streamlit as st
# -----------------------------------------------------------------------------
# 1. PAGE CONFIGURATION & CSS
# -----------------------------------------------------------------------------
st.set_page_config(page_title="AVP-Pro", layout="wide", page_icon="🧬")
GITHUB_URL = "https://github.com/wendy1031/AVP-Pro"
LAB_URL = "http://www.jcu-qiulab.com"
st.markdown(
"""
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.header-container {
text-align: center;
padding-bottom: 10px;
margin-bottom: 20px;
}
.main-title {
font-size: 3.5rem;
font-weight: 700;
color: #2C3E50;
margin-bottom: 0px;
line-height: 1.2;
}
.sub-title {
font-size: 1.5rem;
color: #5D6D7E;
margin-top: 5px;
font-weight: 300;
}
.copyright-info {
font-size: 1rem;
color: #2E86C1;
font-weight: 600;
margin-top: 10px;
margin-bottom: 20px;
}
.copyright-info a {
text-decoration: none;
color: #2E86C1;
}
.intro-box {
background-color: #f0f2f6;
padding: 15px;
border-radius: 10px;
margin-bottom: 15px;
text-align: center;
font-size: 1.05rem;
line-height: 1.6;
color: #31333F;
}
div.stButton > button {
border-radius: 8px;
font-weight: 600;
height: 3em;
}
div.stButton > button[kind="primary"] {
background-color: #2E86C1;
border: none;
transition: transform 0.2s;
}
div.stButton > button[kind="primary"]:hover {
transform: scale(1.02);
}
.result-card {
background-color: #f8f9fa;
padding: 25px;
border-radius: 12px;
border-left: 6px solid #2E86C1;
box-shadow: 0 4px 6px rgba(0,0,0,0.05);
margin-top: 20px;
}
</style>
""",
unsafe_allow_html=True,
)
# -----------------------------------------------------------------------------
# 2. HELPER FUNCTIONS
# -----------------------------------------------------------------------------
VALID_AA = set("ACDEFGHIKLMNPQRSTVWY")
def parse_fasta(raw_text: str):
"""
Parses FASTA format or plain text.
Returns a list of tuples: [(header, sequence), ...]
"""
if raw_text is None or not raw_text.strip():
return []
lines = [ln.strip() for ln in raw_text.splitlines() if ln.strip()]
seqs = []
# Check if FASTA format
if lines[0].startswith(">"):
current_header = None
current_seq = []
for line in lines:
if line.startswith(">"):
if current_header:
full_seq = "".join(current_seq).replace(" ", "").upper()
if full_seq:
seqs.append((current_header, full_seq))
current_header = line[1:] # Remove >
current_seq = []
else:
current_seq.append(line)
if current_header and current_seq:
full_seq = "".join(current_seq).replace(" ", "").upper()
if full_seq:
seqs.append((current_header, full_seq))
else:
full_seq = "".join(lines).replace(" ", "").upper()
if full_seq:
seqs.append(("Input_Sequence", full_seq))
return seqs
def simple_seq_report(seq: str) -> dict:
aa_weights = {
"A": 89.09, "C": 121.15, "D": 133.10, "E": 147.13, "F": 165.19,
"G": 75.07, "H": 155.16, "I": 131.17, "K": 146.19, "L": 131.17,
"M": 149.21, "N": 132.12, "P": 115.13, "Q": 146.15, "R": 174.20,
"S": 105.09, "T": 119.12, "V": 117.15, "W": 204.23, "Y": 181.19,
}
hydrophobic = set(list("AILMFWVPG"))
L = len(seq)
counts = {aa: seq.count(aa) for aa in VALID_AA}
mw = sum(counts[aa] * aa_weights[aa] for aa in VALID_AA) - (L - 1) * 18.015 if L >= 1 else 0.0
frac_hydro = sum(counts[a] for a in hydrophobic) / L if L else 0.0
pos = set(list("KRH"))
neg = set(list("DE"))
net_charge = (sum(counts[a] for a in pos) - sum(counts[a] for a in neg))
return {"length": L, "mw_est": mw, "frac_hydrophobic": frac_hydro, "net_charge_est": net_charge}
@st.cache_resource(show_spinner=False)
def get_predictor():
from avp_web.engine import AVPPredictor
return AVPPredictor()
def read_file_content(fname):
try:
with open(fname, "r", encoding="utf-8") as f:
return f.read()
except Exception:
return None
# --- Session State Initialization ---
if "seq_input" not in st.session_state:
st.session_state["seq_input"] = ""
# 新增:用于控制文件上传组件的唯一Key
if "uploader_key" not in st.session_state:
st.session_state["uploader_key"] = 0
# --- Callbacks ---
def load_example():
st.session_state["seq_input"] = ">Sample_AVP\nSISCSRGLVCLLPRLTNESGNDRFDS"
def clear_input():
# 1. 清空文本输入变量
st.session_state["seq_input"] = ""
# 2. 改变Key,强制重置文件上传组件,从而清除已上传的文件
st.session_state["uploader_key"] += 1
# -----------------------------------------------------------------------------
# 3. MAIN UI LAYOUT
# -----------------------------------------------------------------------------
# Header
st.markdown(
f"""
<div class="header-container">
<div class="main-title">AVP-Pro</div>
<div class="sub-title">Deep Learning for Antiviral Peptide Discovery</div>
<div class="copyright-info">
© 2025 AVP-Pro Team | <a href="{LAB_URL}" target="_blank">JCU Qiu Lab</a>
</div>
</div>
""",
unsafe_allow_html=True
)
# Intro
st.markdown(
"""
<div class="intro-box">
<b>AVP-Pro</b> is a deep learning framework designed to identify antiviral peptides (AVPs) with high precision.<br>
It integrates adaptive feature fusion and contrastive learning to better capture sequence dependencies.
Accurate identification of AVPs is critical for accelerating novel drug development.
</div>
""",
unsafe_allow_html=True
)
# Architecture
with st.expander("🧩 View AVP-Pro Architecture (Click to expand)"):
try:
c1, c2, c3 = st.columns([1, 8, 1])
with c2:
st.image("framework.png", caption="Figure 1: The overall architecture of AVP-Pro.", use_container_width=True)
except Exception:
st.warning("Framework image not found.")
st.write("")
# --- Input Section ---
st.markdown("### ⚡ Online Prediction")
st.markdown("Copy and paste your protein sequence here, or upload a FASTA file for batch prediction.")
col_txt, col_btn = st.columns([3, 1])
with col_txt:
st.caption("Supports FASTA format or raw peptide sequence.")
with col_btn:
st.button("📝 Click to load Sample AVP", on_click=load_example, use_container_width=True)
# Text Area
txt_val = st.text_area(
label="Sequence Input",
value=st.session_state["seq_input"],
height=150,
placeholder=">Seq1\nSISCSRGLVCLLPRLTNESGNDRFDS",
label_visibility="collapsed"
)
if txt_val != st.session_state["seq_input"]:
st.session_state["seq_input"] = txt_val
# File Uploader (注意这里绑定了 key)
uploaded_file = st.file_uploader(
"Or upload a sequence file (.txt / .fasta) to predict:",
type=["txt", "fasta"],
key=str(st.session_state["uploader_key"]) # 绑定动态Key
)
# 文件加载逻辑
if uploaded_file is not None:
stringio = uploaded_file.getvalue().decode("utf-8")
# 只有当文件内容和当前输入框不一样时才刷新,防止死循环
if stringio != st.session_state["seq_input"]:
st.session_state["seq_input"] = stringio
st.rerun()
# Buttons
b1, b2, b3 = st.columns([1, 1, 4])
with b1:
run_btn = st.button("Run Prediction", type="primary", use_container_width=True)
with b2:
st.button("Clear Input", on_click=clear_input, use_container_width=True)
# -----------------------------------------------------------------------------
# 4. PREDICTION LOGIC
# -----------------------------------------------------------------------------
if run_btn:
raw_input = st.session_state["seq_input"]
sequences = parse_fasta(raw_input)
if not sequences:
st.error("⚠️ No valid sequence found. Please input a sequence or upload a file.")
else:
try:
predictor = get_predictor()
# --- SINGLE SEQUENCE MODE ---
if len(sequences) == 1:
header, seq = sequences[0]
invalid_chars = [c for c in seq if c not in VALID_AA]
if len(seq) < 5:
st.warning("Sequence too short (must be >= 5 AA).")
elif invalid_chars:
st.error(f"Invalid characters in sequence: {sorted(set(invalid_chars))}")
else:
with st.spinner("Predicting..."):
res = predictor.predict(seq)
rep = simple_seq_report(seq)
st.markdown('<div class="result-card">', unsafe_allow_html=True)
st.markdown(f"#### 🧬 Prediction Result: {header}")
col_res1, col_res2 = st.columns(2)
with col_res1:
is_avp = res["prediction"] == 1
label_text = "AVP (Antiviral Peptide)" if is_avp else "Non-AVP"
label_color = "#28a745" if is_avp else "#6c757d"
st.markdown(f"**Predicted Label:** <span style='color:{label_color}; font-size:1.4rem; font-weight:bold'>{label_text}</span>", unsafe_allow_html=True)
st.markdown(f"**Probability:** `{res['prob_avp']:.4f}`")
st.markdown(f"**Threshold:** `{res['threshold']:.3f}`")
with col_res2:
st.markdown("**Sequence Properties:**")
st.text(f"Length : {rep['length']}")
st.text(f"Est. MW : {rep['mw_est']:.2f}")
st.text(f"Net Charge : {rep['net_charge_est']}")
st.markdown("</div>", unsafe_allow_html=True)
# --- BATCH MODE ---
else:
st.info(f"📂 Detected **{len(sequences)}** sequences. Starting batch prediction...")
results_list = []
progress_bar = st.progress(0)
for idx, (header, seq) in enumerate(sequences):
progress = (idx + 1) / len(sequences)
progress_bar.progress(progress, text=f"Processing {idx+1}/{len(sequences)}")
clean_seq = "".join([c for c in seq if c in VALID_AA])
if len(clean_seq) < 5:
results_list.append({
"Header": header, "Label": "Error", "Probability": 0, "Note": "Length < 5"
})
continue
res = predictor.predict(clean_seq)
label = "AVP" if res["prediction"] == 1 else "Non-AVP"
results_list.append({
"Header": header,
"Sequence": clean_seq,
"Label": label,
"Probability": round(res["prob_avp"], 4),
"Threshold": round(res["threshold"], 3)
})
progress_bar.empty()
st.success("✅ Batch prediction complete!")
df = pd.DataFrame(results_list)
st.dataframe(df, use_container_width=True)
csv = df.to_csv(index=False).encode('utf-8')
st.download_button(
label="📥 Download Results as CSV",
data=csv,
file_name="avp_Pro_results.csv",
mime="text/csv",
type="primary"
)
except Exception as e:
st.error("An error occurred during prediction.")
st.code(traceback.format_exc())
# -----------------------------------------------------------------------------
# 5. DATASET DOWNLOAD SECTION
# -----------------------------------------------------------------------------
st.write("")
st.write("")
st.markdown("---")
st.subheader("📂 Download Benchmark Datasets")
st.caption("Benchmark datasets used in AVP-Pro study.")
d_col1, d_col2, d_col3 = st.columns([1, 1, 2])
avp_content = read_file_content("AVP.txt")
non_avp_content = read_file_content("non_AVP.txt")
with d_col1:
if avp_content:
st.download_button("📥 Download AVP.txt", avp_content, "AVP.txt", use_container_width=True)
else:
st.warning("AVP.txt missing")
with d_col2:
if non_avp_content:
st.download_button("📥 Download non_AVP.txt", non_avp_content, "non_AVP.txt", use_container_width=True)
else:
st.warning("non_AVP.txt missing")
st.write("")
st.write("")