| | |
| | import streamlit as st |
| | import pandas as pd |
| | import os |
| | import tempfile |
| | import subprocess |
| | import requests |
| | import csv |
| | from models.polybert import polymer2psmiles |
| | import py3Dmol |
| |
|
| | |
| | if 'STREAMLIT_CONFIG_DIR' not in os.environ: |
| | os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit' |
| |
|
| | |
| | streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit') |
| | os.makedirs(streamlit_dir, exist_ok=True) |
| |
|
| | |
| | config_path = os.path.join(streamlit_dir, 'config.toml') |
| | if not os.path.exists(config_path): |
| | with open(config_path, 'w') as f: |
| | f.write("""[browser] |
| | gatherUsageStats = false |
| | |
| | [server] |
| | headless = true |
| | enableCORS = false |
| | enableXsrfProtection = false |
| | """) |
| | |
| |
|
| | aa2resn = { |
| | 'A': 'ALA', |
| | 'C': 'CYS', |
| | 'D': 'ASP', |
| | 'E': 'GLU', |
| | 'F': 'PHE', |
| | 'G': 'GLY', |
| | 'H': 'HIS', |
| | 'I': 'ILE', |
| | 'K': 'LYS', |
| | 'L': 'LEU', |
| | 'M': 'MET', |
| | 'N': 'ASN', |
| | 'P': 'PRO', |
| | 'Q': 'GLN', |
| | 'R': 'ARG', |
| | 'S': 'SER', |
| | 'T': 'THR', |
| | 'V': 'VAL', |
| | 'W': 'TRP', |
| | 'Y': 'TYR' |
| | } |
| |
|
| | |
| | st.markdown(""" |
| | <div style='text-align: center;'> |
| | <h1 style='color:#377EB9;font-size:2.5em;'>🧬 Plastic Degradation Predictor</h1> |
| | <h3 style='color:#4DAE48;'>Predict the degradability of plastics using protein sequences and polymer SMILES</h3> |
| | </div> |
| | <hr style='border:1px solid #974F9F;'> |
| | """, unsafe_allow_html=True) |
| |
|
| | st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.") |
| |
|
| |
|
| | |
| |
|
| | |
| | polymer_csv = os.path.join(os.path.dirname( |
| | __file__), 'data/polymer2tok.csv') |
| | polymer_options = [] |
| | with open(polymer_csv, newline='') as f: |
| | reader = csv.DictReader(f) |
| | for row in reader: |
| | name = row['polymer'] |
| | smiles = polymer2psmiles.get(name, '') |
| | if smiles: |
| | polymer_options.append(f"{name} | {smiles}") |
| |
|
| | input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"]) |
| |
|
| | if input_type == "UniProt ID": |
| | uniprot_id = st.text_input("Enter UniProt ID", "P69905") |
| | sequence = "" |
| | if uniprot_id: |
| | |
| | url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" |
| | resp = requests.get(url) |
| | if resp.status_code == 200: |
| | fasta = resp.text |
| | sequence = "".join(fasta.split("\n")[1:]) |
| | st.success(f"Fetched sequence for {uniprot_id}") |
| | st.code(sequence) |
| | else: |
| | st.error("Failed to fetch sequence from UniProt.") |
| | else: |
| | sequence = st.text_area("Paste protein sequence", |
| | "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG") |
| |
|
| | polymer = st.selectbox("Select polymer", polymer_options) |
| | selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer |
| |
|
| |
|
| | ckpt = "src/checkpoints/weights.ckpt" |
| | plm = "esm2_t33_650M_UR50D" |
| |
|
| | if st.button("Predict degradation", type="primary"): |
| | if not sequence or not selected_polymer: |
| | st.error("Please provide both sequence and polymer.") |
| | else: |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp: |
| | tmp.write("sequence,polymer\n") |
| | tmp.write(f"{sequence},{selected_polymer}\n") |
| | tmp_path = tmp.name |
| | output_path = os.path.join(tempfile.gettempdir(), "predictions.csv") |
| | st.write("Running prediction...") |
| | result = subprocess.run([ |
| | "python", "src/predict.py", |
| | "--ckpt", ckpt, |
| | "--plm", plm, |
| | "--csv", tmp_path, |
| | "--output", output_path, |
| | "--attn" |
| | ], capture_output=True, text=True) |
| | if result.returncode == 0 and os.path.exists(output_path): |
| | df = pd.read_csv(output_path) |
| | if 'time' in df.columns: |
| | df = df.rename(columns={'time': 'running time'}) |
| | st.markdown(f""" |
| | <div style='background: linear-gradient(90deg, #377EB9 0%, #4DAE48 100%); padding: 1.5em; border-radius: 12px; color: white; margin-bottom: 1em;'> |
| | <h2 style='margin:0;'><span style='font-size:18pt'>✅</span> Prediction Complete!</h2> |
| | <p style='font-size:12pt;'>Your input has been processed. See the results below:</p> |
| | <p style='font-size:12pt;'>Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})</p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| | st.download_button("⬇️ Download Results", data=df.to_csv( |
| | index=False), file_name="predictions.csv", type="primary") |
| |
|
| | |
| | attn_dir = os.path.join(os.path.dirname( |
| | output_path), "predictions.attn") |
| | attn_path = os.path.join(attn_dir, "0.pt") |
| | if os.path.exists(attn_path): |
| | import torch |
| | attn = torch.load(attn_path) |
| | |
| | attn_matrix = attn[0][0] if isinstance( |
| | attn[0], (list, tuple)) else attn[0] |
| | |
| | if attn_matrix.ndim == 3: |
| | attn_matrix = attn_matrix.mean(0) |
| | |
| | residue_scores = attn_matrix.sum(0).cpu().numpy() |
| | topN = min(10, len(residue_scores)) |
| | top_idx = residue_scores.argsort()[::-1][:topN] |
| | st.markdown(f"**Top {topN} high-attention residues:**") |
| | st.write(pd.DataFrame({ |
| | "Amino Acid": [sequence[i] for i in top_idx], |
| | "Residue Index": top_idx+1, |
| | "Attention Score": residue_scores[top_idx] |
| | })) |
| | else: |
| | st.info("No attention file found for visualization.") |
| | else: |
| | st.error("Prediction failed. See details below:") |
| | st.text(result.stderr) |
| |
|
| | |
| | structure_path = None |
| |
|
| | if input_type == "UniProt ID" and uniprot_id: |
| | af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-2-F1-model_v6.cif" |
| | |
| | |
| | highlight_residues = None |
| | attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn") |
| | attn_path = os.path.join(attn_dir, "0.pt") |
| | if os.path.exists(attn_path): |
| | import torch |
| | attn = torch.load(attn_path) |
| | attn_matrix = attn[0][0] if isinstance( |
| | attn[0], (list, tuple)) else attn[0] |
| | if attn_matrix.ndim == 3: |
| | attn_matrix = attn_matrix.mean(0) |
| | residue_scores = attn_matrix.sum(0).cpu().numpy() |
| | topN = min(10, len(residue_scores)) |
| | top_idx = residue_scores.argsort()[::-1][:topN] |
| | |
| | highlight_residues = [int(i+1) for i in top_idx] |
| |
|
| | structure_path = os.path.join( |
| | tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif") |
| | try: |
| | r = requests.get(af_url) |
| | if r.status_code == 200: |
| | with open(structure_path, "wb") as f: |
| | f.write(r.content) |
| | st.success( |
| | f"AlphaFold structure downloaded: {structure_path}") |
| | else: |
| | st.warning( |
| | "AlphaFoldDB structure not found for this UniProt ID.") |
| | except Exception as e: |
| | st.warning(f"AlphaFoldDB download error: {e}") |
| |
|
| | if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path): |
| | st.markdown("### 3D Structure Visualization (stmol)") |
| | import torch |
| | from stmol import showmol |
| | attn = torch.load(attn_path) |
| | attn_matrix = attn[0][0] if isinstance( |
| | attn[0], (list, tuple)) else attn[0] |
| | if attn_matrix.ndim == 3: |
| | attn_matrix = attn_matrix.mean(0) |
| | residue_scores = attn_matrix.sum(0).cpu().numpy() |
| | topN = min(10, len(residue_scores)) |
| | top_idx = residue_scores.argsort()[::-1][:topN] |
| | labels = [ |
| | f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx] |
| | with open(structure_path, "r") as cif_file: |
| | cif_data = cif_file.read() |
| | view = py3Dmol.view(width=600, height=400) |
| | view.addModel(cif_data, "cif") |
| | view.setStyle({"cartoon": {"color": "lightgray"}}) |
| | for i, idx in enumerate(top_idx): |
| | resi_num = int(idx+1) |
| | view.setStyle( |
| | {"resi": resi_num}, { |
| | "cartoon": {"color": "red"}}) |
| | view.addResLabels( |
| | {"resi": resi_num}, |
| | { |
| | "font": 'Arial', "fontColor": 'black', |
| | "showBackground": False, "screenOffset": {"x": 0, "y": 0}}) |
| | view.zoomTo() |
| | showmol(view, height=600, width='100%') |
| |
|
| |
|
| | |
| | st.markdown(""" |
| | --- |
| | <h4>License</h4> |
| | Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)<br> |
| | <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/' target='_blank'>View full license details</a><br> |
| | """, unsafe_allow_html=True) |
| |
|