StriMap / src /streamlit_app.py
caokai1073's picture
Update src/streamlit_app.py
374d0bb verified
import streamlit as st
import pandas as pd
from io import StringIO
from predictor import load_model, predict_from_df
from Bio import SeqIO
import torch
import os
from huggingface_hub import hf_hub_download
# ==============================
# 页面配置
# ==============================
st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide")
st.title("🧠 Peptide–HLA Binding Predictor")
st.markdown("""
Upload a **CSV** file with columns `Peptide` and `HLA`,
or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
""")
# ==============================
# 全局路径设置
# ==============================
CACHE_DIR = "/data/phla_cache"
MODEL_DIR = "/app/src"
UPLOAD_DIR = "/data/uploads"
for d in [CACHE_DIR, MODEL_DIR, UPLOAD_DIR]:
os.makedirs(d, exist_ok=True)
# 环境变量(确保所有模型和 ESM 缓存写入 /data)
os.environ["HF_HOME"] = "/data/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface"
os.environ["TORCH_HOME"] = "/data/huggingface"
os.environ["ESM_CACHE_DIR"] = CACHE_DIR
# ==============================
# 模型加载函数(延迟加载 + 缓存)
# ==============================
@st.cache_resource
def load_model_cached():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
local_path = os.path.join(MODEL_DIR, "model.pt")
if not os.path.exists(local_path):
st.warning("🔄 Model not found locally. Downloading from Hugging Face model repo...")
# ⚠️ 使用 Model Repo,而不是 Space Repo
local_path = hf_hub_download(
repo_id="caokai1073/StriMap-model", # 建议单独创建模型仓库
filename="model.pt",
cache_dir=MODEL_DIR
)
model, device = load_model(local_path, device=device)
return model, device
# ==============================
# 上传文件(安全写入 /data/uploads)
# ==============================
uploaded_file = st.file_uploader("📤 Upload CSV or FASTA", type=["csv", "fasta"])
if uploaded_file:
safe_name = uploaded_file.name.replace(" ", "_")
temp_path = os.path.join(UPLOAD_DIR, safe_name)
with open(temp_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# ==============================
# 文件解析
# ==============================
if safe_name.endswith(".csv"):
df = pd.read_csv(temp_path)
else:
seqs = []
for rec in SeqIO.parse(temp_path, "fasta"):
header = rec.id
seq = str(rec.seq)
if "|" in header:
hla, _ = header.split("|", 1)
else:
hla = "HLA-Unknown"
seqs.append([seq, hla])
df = pd.DataFrame(seqs, columns=["Peptide", "HLA"])
st.write("✅ Uploaded data preview:")
st.dataframe(df.head())
# ==============================
# 模型预测(延迟加载)
# ==============================
if st.button("🚀 Run Prediction"):
with st.spinner("🔄 Loading model (this may take ~1 min first time)..."):
model, device = load_model_cached()
with st.spinner("Running inference..."):
result_df = predict_from_df(df, model)
st.success("✅ Prediction complete!")
st.dataframe(result_df.head(10))
# ==============================
# 下载结果
# ==============================
csv = result_df.to_csv(index=False).encode("utf-8")
st.download_button(
"⬇️ Download results as CSV",
data=csv,
file_name="hla_binding_predictions.csv",
mime="text/csv",
)
# ==============================
# Debug / data check (optional)
# ==============================
if st.sidebar.button("📁 List /data files"):
files = []
for root, _, filenames in os.walk("/data"):
for f in filenames:
files.append(os.path.join(root, f))
st.sidebar.write(files)