import streamlit as st import pandas as pd from io import StringIO from predictor import load_model, predict_from_df from Bio import SeqIO import torch import os from huggingface_hub import hf_hub_download # ============================== # 页面配置 # ============================== st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide") st.title("🧠 Peptide–HLA Binding Predictor") st.markdown(""" Upload a **CSV** file with columns `Peptide` and `HLA`, or a **FASTA** file containing peptide sequences (headers optionally include HLA type). """) # ============================== # 全局路径设置 # ============================== CACHE_DIR = "/data/phla_cache" MODEL_DIR = "/app/src" UPLOAD_DIR = "/data/uploads" for d in [CACHE_DIR, MODEL_DIR, UPLOAD_DIR]: os.makedirs(d, exist_ok=True) # 环境变量(确保所有模型和 ESM 缓存写入 /data) os.environ["HF_HOME"] = "/data/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface" os.environ["TORCH_HOME"] = "/data/huggingface" os.environ["ESM_CACHE_DIR"] = CACHE_DIR # ============================== # 模型加载函数(延迟加载 + 缓存) # ============================== @st.cache_resource def load_model_cached(): device = "cuda:0" if torch.cuda.is_available() else "cpu" local_path = os.path.join(MODEL_DIR, "model.pt") if not os.path.exists(local_path): st.warning("🔄 Model not found locally. Downloading from Hugging Face model repo...") # ⚠️ 使用 Model Repo,而不是 Space Repo local_path = hf_hub_download( repo_id="caokai1073/StriMap-model", # 建议单独创建模型仓库 filename="model.pt", cache_dir=MODEL_DIR ) model, device = load_model(local_path, device=device) return model, device # ============================== # 上传文件(安全写入 /data/uploads) # ============================== uploaded_file = st.file_uploader("📤 Upload CSV or FASTA", type=["csv", "fasta"]) if uploaded_file: safe_name = uploaded_file.name.replace(" ", "_") temp_path = os.path.join(UPLOAD_DIR, safe_name) with open(temp_path, "wb") as f: f.write(uploaded_file.getbuffer()) # ============================== # 文件解析 # ============================== if safe_name.endswith(".csv"): df = pd.read_csv(temp_path) else: seqs = [] for rec in SeqIO.parse(temp_path, "fasta"): header = rec.id seq = str(rec.seq) if "|" in header: hla, _ = header.split("|", 1) else: hla = "HLA-Unknown" seqs.append([seq, hla]) df = pd.DataFrame(seqs, columns=["Peptide", "HLA"]) st.write("✅ Uploaded data preview:") st.dataframe(df.head()) # ============================== # 模型预测(延迟加载) # ============================== if st.button("🚀 Run Prediction"): with st.spinner("🔄 Loading model (this may take ~1 min first time)..."): model, device = load_model_cached() with st.spinner("Running inference..."): result_df = predict_from_df(df, model) st.success("✅ Prediction complete!") st.dataframe(result_df.head(10)) # ============================== # 下载结果 # ============================== csv = result_df.to_csv(index=False).encode("utf-8") st.download_button( "⬇️ Download results as CSV", data=csv, file_name="hla_binding_predictions.csv", mime="text/csv", ) # ============================== # Debug / data check (optional) # ============================== if st.sidebar.button("📁 List /data files"): files = [] for root, _, filenames in os.walk("/data"): for f in filenames: files.append(os.path.join(root, f)) st.sidebar.write(files)