File size: 9,889 Bytes
21f308b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a35beb
 
21f308b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#fmt: off
import streamlit as st
import pandas as pd
import os
import tempfile
import subprocess
import requests
import csv
from models.polybert import polymer2psmiles
import py3Dmol

# Fix for permission error - disable usage stats
if 'STREAMLIT_CONFIG_DIR' not in os.environ:
    os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit'

# Create streamlit config directory if it doesn't exist
streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit')
os.makedirs(streamlit_dir, exist_ok=True)

# Create config.toml to disable usage stats
config_path = os.path.join(streamlit_dir, 'config.toml')
if not os.path.exists(config_path):
    with open(config_path, 'w') as f:
        f.write("""[browser]
gatherUsageStats = false

[server]
headless = true
enableCORS = false
enableXsrfProtection = false
""")
# fmt: on

aa2resn = {
    'A': 'ALA',
    'C': 'CYS',
    'D': 'ASP',
    'E': 'GLU',
    'F': 'PHE',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'K': 'LYS',
    'L': 'LEU',
    'M': 'MET',
    'N': 'ASN',
    'P': 'PRO',
    'Q': 'GLN',
    'R': 'ARG',
    'S': 'SER',
    'T': 'THR',
    'V': 'VAL',
    'W': 'TRP',
    'Y': 'TYR'
}

# Fancy header
st.markdown("""
<div style='text-align: center;'>
    <h1 style='color:#377EB9;font-size:2.5em;'>🧬 Plastic Degradation Predictor</h1>
    <h3 style='color:#4DAE48;'>Predict the degradability of plastics using protein sequences and polymer SMILES</h3>
</div>
<hr style='border:1px solid #974F9F;'>
""", unsafe_allow_html=True)

st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.")


# Load polymer names and SMILES

# Only show polymers with SMILES in the dropdown
polymer_csv = os.path.join(os.path.dirname(
    __file__), 'data/polymer2tok.csv')
polymer_options = []
with open(polymer_csv, newline='') as f:
    reader = csv.DictReader(f)
    for row in reader:
        name = row['polymer']
        smiles = polymer2psmiles.get(name, '')
        if smiles:  # Only include polymers with SMILES
            polymer_options.append(f"{name} | {smiles}")

input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"])

if input_type == "UniProt ID":
    uniprot_id = st.text_input("Enter UniProt ID", "P69905")
    sequence = ""
    if uniprot_id:
        # Fetch sequence from UniProt
        url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
        resp = requests.get(url)
        if resp.status_code == 200:
            fasta = resp.text
            sequence = "".join(fasta.split("\n")[1:])
            st.success(f"Fetched sequence for {uniprot_id}")
            st.code(sequence)
        else:
            st.error("Failed to fetch sequence from UniProt.")
else:
    sequence = st.text_area("Paste protein sequence",
                            "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG")

polymer = st.selectbox("Select polymer", polymer_options)
selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer


ckpt = "src/checkpoints/weights.ckpt"
plm = "esm2_t33_650M_UR50D"

if st.button("Predict degradation", type="primary"):
    if not sequence or not selected_polymer:
        st.error("Please provide both sequence and polymer.")
    else:
        # Create temp CSV
        with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
            tmp.write("sequence,polymer\n")
            tmp.write(f"{sequence},{selected_polymer}\n")
            tmp_path = tmp.name
        output_path = os.path.join(tempfile.gettempdir(), "predictions.csv")
        st.write("Running prediction...")
        result = subprocess.run([
            "python", "src/predict.py",
            "--ckpt", ckpt,
            "--plm", plm,
            "--csv", tmp_path,
            "--output", output_path,
            "--attn"
        ], capture_output=True, text=True)
        if result.returncode == 0 and os.path.exists(output_path):
            df = pd.read_csv(output_path)
            if 'time' in df.columns:
                df = df.rename(columns={'time': 'running time'})
            st.markdown(f"""
<div style='background: linear-gradient(90deg, #377EB9 0%, #4DAE48 100%); padding: 1.5em; border-radius: 12px; color: white; margin-bottom: 1em;'>
    <h2 style='margin:0;'><span style='font-size:18pt'>✅</span> Prediction Complete!</h2>
    <p style='font-size:12pt;'>Your input has been processed. See the results below:</p>
    <p style='font-size:12pt;'>Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})</p>
</div>
""", unsafe_allow_html=True)
            st.download_button("⬇️ Download Results", data=df.to_csv(
                index=False), file_name="predictions.csv", type="primary")

            # Show top-N attention residues if attention file exists
            attn_dir = os.path.join(os.path.dirname(
                output_path), "predictions.attn")
            attn_path = os.path.join(attn_dir, "0.pt")
            if os.path.exists(attn_path):
                import torch
                attn = torch.load(attn_path)
                # attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len)
                attn_matrix = attn[0][0] if isinstance(
                    attn[0], (list, tuple)) else attn[0]
                # Average over heads if needed
                if attn_matrix.ndim == 3:
                    attn_matrix = attn_matrix.mean(0)
                # For each residue, sum attention weights
                residue_scores = attn_matrix.sum(0).cpu().numpy()
                topN = min(10, len(residue_scores))
                top_idx = residue_scores.argsort()[::-1][:topN]
                st.markdown(f"**Top {topN} high-attention residues:**")
                st.write(pd.DataFrame({
                    "Amino Acid": [sequence[i] for i in top_idx],
                    "Residue Index": top_idx+1,
                    "Attention Score": residue_scores[top_idx]
                }))
            else:
                st.info("No attention file found for visualization.")
        else:
            st.error("Prediction failed. See details below:")
            st.text(result.stderr)

        # If UniProt ID, try to download AlphaFold structure
        structure_path = None

        if input_type == "UniProt ID" and uniprot_id:
            af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-2-F1-model_v6.cif"
            # af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif"
            # If attention available, highlight top residues
            highlight_residues = None
            attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn")
            attn_path = os.path.join(attn_dir, "0.pt")
            if os.path.exists(attn_path):
                import torch
                attn = torch.load(attn_path)
                attn_matrix = attn[0][0] if isinstance(
                    attn[0], (list, tuple)) else attn[0]
                if attn_matrix.ndim == 3:
                    attn_matrix = attn_matrix.mean(0)
                residue_scores = attn_matrix.sum(0).cpu().numpy()
                topN = min(10, len(residue_scores))
                top_idx = residue_scores.argsort()[::-1][:topN]
                # Molstar selection: list of residue numbers (1-based)
                highlight_residues = [int(i+1) for i in top_idx]

            structure_path = os.path.join(
                tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif")
            try:
                r = requests.get(af_url)
                if r.status_code == 200:
                    with open(structure_path, "wb") as f:
                        f.write(r.content)
                    st.success(
                        f"AlphaFold structure downloaded: {structure_path}")
                else:
                    st.warning(
                        "AlphaFoldDB structure not found for this UniProt ID.")
            except Exception as e:
                st.warning(f"AlphaFoldDB download error: {e}")

        if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path):
            st.markdown("### 3D Structure Visualization (stmol)")
            import torch
            from stmol import showmol
            attn = torch.load(attn_path)
            attn_matrix = attn[0][0] if isinstance(
                attn[0], (list, tuple)) else attn[0]
            if attn_matrix.ndim == 3:
                attn_matrix = attn_matrix.mean(0)
            residue_scores = attn_matrix.sum(0).cpu().numpy()
            topN = min(10, len(residue_scores))
            top_idx = residue_scores.argsort()[::-1][:topN]
            labels = [
                f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx]
            with open(structure_path, "r") as cif_file:
                cif_data = cif_file.read()
            view = py3Dmol.view(width=600, height=400)
            view.addModel(cif_data, "cif")
            view.setStyle({"cartoon": {"color": "lightgray"}})
            for i, idx in enumerate(top_idx):
                resi_num = int(idx+1)
                view.setStyle(
                    {"resi": resi_num}, {
                        "cartoon": {"color": "red"}})
                view.addResLabels(
                    {"resi": resi_num},
                    {
                        "font": 'Arial', "fontColor": 'black',
                        "showBackground": False, "screenOffset": {"x": 0, "y": 0}})
            view.zoomTo()
            showmol(view, height=600, width='100%')


# --- Footer: License and References ---
st.markdown("""
---
<h4>License</h4>
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)<br>
<a href='https://creativecommons.org/licenses/by-nc-sa/4.0/' target='_blank'>View full license details</a><br>
""", unsafe_allow_html=True)