Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from io import StringIO | |
| from predictor import load_model, predict_from_df | |
| from Bio import SeqIO | |
| import torch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # ============================== | |
| # 页面配置 | |
| # ============================== | |
| st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide") | |
| st.title("🧠 Peptide–HLA Binding Predictor") | |
| st.markdown(""" | |
| Upload a **CSV** file with columns `Peptide` and `HLA`, | |
| or a **FASTA** file containing peptide sequences (headers optionally include HLA type). | |
| """) | |
| # ============================== | |
| # 全局路径设置 | |
| # ============================== | |
| CACHE_DIR = "/data/phla_cache" | |
| MODEL_DIR = "/app/src" | |
| UPLOAD_DIR = "/data/uploads" | |
| for d in [CACHE_DIR, MODEL_DIR, UPLOAD_DIR]: | |
| os.makedirs(d, exist_ok=True) | |
| # 环境变量(确保所有模型和 ESM 缓存写入 /data) | |
| os.environ["HF_HOME"] = "/data/huggingface" | |
| os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface" | |
| os.environ["TORCH_HOME"] = "/data/huggingface" | |
| os.environ["ESM_CACHE_DIR"] = CACHE_DIR | |
| # ============================== | |
| # 模型加载函数(延迟加载 + 缓存) | |
| # ============================== | |
| def load_model_cached(): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| local_path = os.path.join(MODEL_DIR, "model.pt") | |
| if not os.path.exists(local_path): | |
| st.warning("🔄 Model not found locally. Downloading from Hugging Face model repo...") | |
| # ⚠️ 使用 Model Repo,而不是 Space Repo | |
| local_path = hf_hub_download( | |
| repo_id="caokai1073/StriMap-model", # 建议单独创建模型仓库 | |
| filename="model.pt", | |
| cache_dir=MODEL_DIR | |
| ) | |
| model, device = load_model(local_path, device=device) | |
| return model, device | |
| # ============================== | |
| # 上传文件(安全写入 /data/uploads) | |
| # ============================== | |
| uploaded_file = st.file_uploader("📤 Upload CSV or FASTA", type=["csv", "fasta"]) | |
| if uploaded_file: | |
| safe_name = uploaded_file.name.replace(" ", "_") | |
| temp_path = os.path.join(UPLOAD_DIR, safe_name) | |
| with open(temp_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # ============================== | |
| # 文件解析 | |
| # ============================== | |
| if safe_name.endswith(".csv"): | |
| df = pd.read_csv(temp_path) | |
| else: | |
| seqs = [] | |
| for rec in SeqIO.parse(temp_path, "fasta"): | |
| header = rec.id | |
| seq = str(rec.seq) | |
| if "|" in header: | |
| hla, _ = header.split("|", 1) | |
| else: | |
| hla = "HLA-Unknown" | |
| seqs.append([seq, hla]) | |
| df = pd.DataFrame(seqs, columns=["Peptide", "HLA"]) | |
| st.write("✅ Uploaded data preview:") | |
| st.dataframe(df.head()) | |
| # ============================== | |
| # 模型预测(延迟加载) | |
| # ============================== | |
| if st.button("🚀 Run Prediction"): | |
| with st.spinner("🔄 Loading model (this may take ~1 min first time)..."): | |
| model, device = load_model_cached() | |
| with st.spinner("Running inference..."): | |
| result_df = predict_from_df(df, model) | |
| st.success("✅ Prediction complete!") | |
| st.dataframe(result_df.head(10)) | |
| # ============================== | |
| # 下载结果 | |
| # ============================== | |
| csv = result_df.to_csv(index=False).encode("utf-8") | |
| st.download_button( | |
| "⬇️ Download results as CSV", | |
| data=csv, | |
| file_name="hla_binding_predictions.csv", | |
| mime="text/csv", | |
| ) | |
| # ============================== | |
| # Debug / data check (optional) | |
| # ============================== | |
| if st.sidebar.button("📁 List /data files"): | |
| files = [] | |
| for root, _, filenames in os.walk("/data"): | |
| for f in filenames: | |
| files.append(os.path.join(root, f)) | |
| st.sidebar.write(files) |