File size: 9,889 Bytes
21f308b 2a35beb 21f308b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | #fmt: off
import streamlit as st
import pandas as pd
import os
import tempfile
import subprocess
import requests
import csv
from models.polybert import polymer2psmiles
import py3Dmol
# Fix for permission error - disable usage stats
if 'STREAMLIT_CONFIG_DIR' not in os.environ:
os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit'
# Create streamlit config directory if it doesn't exist
streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit')
os.makedirs(streamlit_dir, exist_ok=True)
# Create config.toml to disable usage stats
config_path = os.path.join(streamlit_dir, 'config.toml')
if not os.path.exists(config_path):
with open(config_path, 'w') as f:
f.write("""[browser]
gatherUsageStats = false
[server]
headless = true
enableCORS = false
enableXsrfProtection = false
""")
# fmt: on
aa2resn = {
'A': 'ALA',
'C': 'CYS',
'D': 'ASP',
'E': 'GLU',
'F': 'PHE',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'K': 'LYS',
'L': 'LEU',
'M': 'MET',
'N': 'ASN',
'P': 'PRO',
'Q': 'GLN',
'R': 'ARG',
'S': 'SER',
'T': 'THR',
'V': 'VAL',
'W': 'TRP',
'Y': 'TYR'
}
# Fancy header
st.markdown("""
<div style='text-align: center;'>
<h1 style='color:#377EB9;font-size:2.5em;'>🧬 Plastic Degradation Predictor</h1>
<h3 style='color:#4DAE48;'>Predict the degradability of plastics using protein sequences and polymer SMILES</h3>
</div>
<hr style='border:1px solid #974F9F;'>
""", unsafe_allow_html=True)
st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.")
# Load polymer names and SMILES
# Only show polymers with SMILES in the dropdown
polymer_csv = os.path.join(os.path.dirname(
__file__), 'data/polymer2tok.csv')
polymer_options = []
with open(polymer_csv, newline='') as f:
reader = csv.DictReader(f)
for row in reader:
name = row['polymer']
smiles = polymer2psmiles.get(name, '')
if smiles: # Only include polymers with SMILES
polymer_options.append(f"{name} | {smiles}")
input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"])
if input_type == "UniProt ID":
uniprot_id = st.text_input("Enter UniProt ID", "P69905")
sequence = ""
if uniprot_id:
# Fetch sequence from UniProt
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
resp = requests.get(url)
if resp.status_code == 200:
fasta = resp.text
sequence = "".join(fasta.split("\n")[1:])
st.success(f"Fetched sequence for {uniprot_id}")
st.code(sequence)
else:
st.error("Failed to fetch sequence from UniProt.")
else:
sequence = st.text_area("Paste protein sequence",
"MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG")
polymer = st.selectbox("Select polymer", polymer_options)
selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer
ckpt = "src/checkpoints/weights.ckpt"
plm = "esm2_t33_650M_UR50D"
if st.button("Predict degradation", type="primary"):
if not sequence or not selected_polymer:
st.error("Please provide both sequence and polymer.")
else:
# Create temp CSV
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
tmp.write("sequence,polymer\n")
tmp.write(f"{sequence},{selected_polymer}\n")
tmp_path = tmp.name
output_path = os.path.join(tempfile.gettempdir(), "predictions.csv")
st.write("Running prediction...")
result = subprocess.run([
"python", "src/predict.py",
"--ckpt", ckpt,
"--plm", plm,
"--csv", tmp_path,
"--output", output_path,
"--attn"
], capture_output=True, text=True)
if result.returncode == 0 and os.path.exists(output_path):
df = pd.read_csv(output_path)
if 'time' in df.columns:
df = df.rename(columns={'time': 'running time'})
st.markdown(f"""
<div style='background: linear-gradient(90deg, #377EB9 0%, #4DAE48 100%); padding: 1.5em; border-radius: 12px; color: white; margin-bottom: 1em;'>
<h2 style='margin:0;'><span style='font-size:18pt'>✅</span> Prediction Complete!</h2>
<p style='font-size:12pt;'>Your input has been processed. See the results below:</p>
<p style='font-size:12pt;'>Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})</p>
</div>
""", unsafe_allow_html=True)
st.download_button("⬇️ Download Results", data=df.to_csv(
index=False), file_name="predictions.csv", type="primary")
# Show top-N attention residues if attention file exists
attn_dir = os.path.join(os.path.dirname(
output_path), "predictions.attn")
attn_path = os.path.join(attn_dir, "0.pt")
if os.path.exists(attn_path):
import torch
attn = torch.load(attn_path)
# attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
# Average over heads if needed
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
# For each residue, sum attention weights
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
st.markdown(f"**Top {topN} high-attention residues:**")
st.write(pd.DataFrame({
"Amino Acid": [sequence[i] for i in top_idx],
"Residue Index": top_idx+1,
"Attention Score": residue_scores[top_idx]
}))
else:
st.info("No attention file found for visualization.")
else:
st.error("Prediction failed. See details below:")
st.text(result.stderr)
# If UniProt ID, try to download AlphaFold structure
structure_path = None
if input_type == "UniProt ID" and uniprot_id:
af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-2-F1-model_v6.cif"
# af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif"
# If attention available, highlight top residues
highlight_residues = None
attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn")
attn_path = os.path.join(attn_dir, "0.pt")
if os.path.exists(attn_path):
import torch
attn = torch.load(attn_path)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
# Molstar selection: list of residue numbers (1-based)
highlight_residues = [int(i+1) for i in top_idx]
structure_path = os.path.join(
tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif")
try:
r = requests.get(af_url)
if r.status_code == 200:
with open(structure_path, "wb") as f:
f.write(r.content)
st.success(
f"AlphaFold structure downloaded: {structure_path}")
else:
st.warning(
"AlphaFoldDB structure not found for this UniProt ID.")
except Exception as e:
st.warning(f"AlphaFoldDB download error: {e}")
if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path):
st.markdown("### 3D Structure Visualization (stmol)")
import torch
from stmol import showmol
attn = torch.load(attn_path)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
labels = [
f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx]
with open(structure_path, "r") as cif_file:
cif_data = cif_file.read()
view = py3Dmol.view(width=600, height=400)
view.addModel(cif_data, "cif")
view.setStyle({"cartoon": {"color": "lightgray"}})
for i, idx in enumerate(top_idx):
resi_num = int(idx+1)
view.setStyle(
{"resi": resi_num}, {
"cartoon": {"color": "red"}})
view.addResLabels(
{"resi": resi_num},
{
"font": 'Arial', "fontColor": 'black',
"showBackground": False, "screenOffset": {"x": 0, "y": 0}})
view.zoomTo()
showmol(view, height=600, width='100%')
# --- Footer: License and References ---
st.markdown("""
---
<h4>License</h4>
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)<br>
<a href='https://creativecommons.org/licenses/by-nc-sa/4.0/' target='_blank'>View full license details</a><br>
""", unsafe_allow_html=True)
|