File size: 3,986 Bytes
ff645fd
78f28d5
 
 
 
 
8721078
 
78f28d5
8721078
 
 
78f28d5
 
 
 
 
 
 
 
374d0bb
 
 
 
 
 
 
 
 
 
 
9ad4750
 
 
374d0bb
9ad4750
8721078
374d0bb
8721078
78f28d5
374d0bb
78f28d5
374d0bb
9370b5b
8721078
374d0bb
 
8721078
374d0bb
 
 
8721078
 
 
78f28d5
 
 
8721078
374d0bb
8721078
374d0bb
8721078
78f28d5
374d0bb
 
8721078
 
 
 
 
 
374d0bb
8721078
78f28d5
 
8721078
78f28d5
 
 
 
 
 
 
 
 
 
 
 
8721078
374d0bb
8721078
78f28d5
374d0bb
 
 
 
78f28d5
 
 
 
 
8721078
 
 
78f28d5
 
 
 
 
 
374d0bb
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import pandas as pd
from io import StringIO
from predictor import load_model, predict_from_df
from Bio import SeqIO
import torch
import os
from huggingface_hub import hf_hub_download

# ==============================
# 页面配置
# ==============================
st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide")

st.title("🧠 Peptide–HLA Binding Predictor")
st.markdown("""
Upload a **CSV** file with columns `Peptide` and `HLA`,  
or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
""")

# ==============================
# 全局路径设置
# ==============================
CACHE_DIR = "/data/phla_cache"
MODEL_DIR = "/app/src"
UPLOAD_DIR = "/data/uploads"

for d in [CACHE_DIR, MODEL_DIR, UPLOAD_DIR]:
    os.makedirs(d, exist_ok=True)

# 环境变量(确保所有模型和 ESM 缓存写入 /data)
os.environ["HF_HOME"] = "/data/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface"
os.environ["TORCH_HOME"] = "/data/huggingface"
os.environ["ESM_CACHE_DIR"] = CACHE_DIR

# ==============================
# 模型加载函数(延迟加载 + 缓存)
# ==============================
@st.cache_resource
def load_model_cached():
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    local_path = os.path.join(MODEL_DIR, "model.pt")

    if not os.path.exists(local_path):
        st.warning("🔄 Model not found locally. Downloading from Hugging Face model repo...")
        # ⚠️ 使用 Model Repo,而不是 Space Repo
        local_path = hf_hub_download(
            repo_id="caokai1073/StriMap-model",  # 建议单独创建模型仓库
            filename="model.pt",
            cache_dir=MODEL_DIR
        )

    model, device = load_model(local_path, device=device)
    return model, device


# ==============================
# 上传文件(安全写入 /data/uploads)
# ==============================
uploaded_file = st.file_uploader("📤 Upload CSV or FASTA", type=["csv", "fasta"])

if uploaded_file:
    safe_name = uploaded_file.name.replace(" ", "_")
    temp_path = os.path.join(UPLOAD_DIR, safe_name)
    with open(temp_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    # ==============================
    # 文件解析
    # ==============================
    if safe_name.endswith(".csv"):
        df = pd.read_csv(temp_path)
    else:
        seqs = []
        for rec in SeqIO.parse(temp_path, "fasta"):
            header = rec.id
            seq = str(rec.seq)
            if "|" in header:
                hla, _ = header.split("|", 1)
            else:
                hla = "HLA-Unknown"
            seqs.append([seq, hla])
        df = pd.DataFrame(seqs, columns=["Peptide", "HLA"])

    st.write("✅ Uploaded data preview:")
    st.dataframe(df.head())

    # ==============================
    # 模型预测(延迟加载)
    # ==============================
    if st.button("🚀 Run Prediction"):
        with st.spinner("🔄 Loading model (this may take ~1 min first time)..."):
            model, device = load_model_cached()

        with st.spinner("Running inference..."):
            result_df = predict_from_df(df, model)

        st.success("✅ Prediction complete!")
        st.dataframe(result_df.head(10))

        # ==============================
        # 下载结果
        # ==============================
        csv = result_df.to_csv(index=False).encode("utf-8")
        st.download_button(
            "⬇️ Download results as CSV",
            data=csv,
            file_name="hla_binding_predictions.csv",
            mime="text/csv",
        )

# ==============================
# Debug / data check (optional)
# ==============================
if st.sidebar.button("📁 List /data files"):
    files = []
    for root, _, filenames in os.walk("/data"):
        for f in filenames:
            files.append(os.path.join(root, f))
    st.sidebar.write(files)