Spaces:
Sleeping
Sleeping
Pranavathiyani G
improve: HTML result card with confidence colors, italic organisms, no tab scroll, bold headers
346815f | # EmbedAMR | |
| # Pranavathiyani G, SASTRA Deemed University | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import joblib | |
| import plotly.graph_objects as go | |
| import os | |
| import time | |
| from transformers import AutoTokenizer, AutoModel | |
| # ββ Assets ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ASSET_DIR = "app_assets" | |
| print("Loading assets...") | |
| df_meta = pd.read_csv(os.path.join(ASSET_DIR, "embedamr_metadata.csv")) | |
| pca50_emb = np.load(os.path.join(ASSET_DIR, "esm2_pca50.npy")) | |
| scaler = joblib.load(os.path.join(ASSET_DIR, "esm2_scaler.pkl")) | |
| pca = joblib.load(os.path.join(ASSET_DIR, "esm2_pca50.pkl")) | |
| knn = joblib.load(os.path.join(ASSET_DIR, "knn_index.pkl")) | |
| print(f"Ready. {len(df_meta)} sequences loaded.") | |
| # ββ ESM2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading ESM2-650M...") | |
| DEVICE = torch.device("cpu") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| esm2 = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(DEVICE) | |
| esm2.eval() | |
| print("ESM2-650M ready.") | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| VALID_AA = set("ACDEFGHIKLMNPQRSTVWY") | |
| MIN_LEN = 30 | |
| MECH_COLORS = { | |
| "antibiotic inactivation" : "#2196F3", | |
| "antibiotic efflux" : "#FF9800", | |
| "antibiotic target alteration" : "#4CAF50", | |
| "antibiotic target protection" : "#9C27B0", | |
| "antibiotic target replacement" : "#F44336", | |
| "reduced permeability to antibiotic" : "#795548", | |
| "resistance by host-dependent nutrient acquisition": "#607D8B", | |
| } | |
| MECH_LABEL = { | |
| "antibiotic inactivation" : "Inactivation", | |
| "antibiotic efflux" : "Efflux", | |
| "antibiotic target alteration" : "Target alteration", | |
| "antibiotic target protection" : "Target protection", | |
| "antibiotic target replacement" : "Target replacement", | |
| "reduced permeability to antibiotic" : "Reduced permeability", | |
| "resistance by host-dependent nutrient acquisition": "Host nutrient", | |
| } | |
| BLA_COLORS = { | |
| "Class_A": "#E53935", | |
| "Class_B": "#1E88E5", | |
| "Class_C": "#43A047", | |
| "Class_D": "#FB8C00", | |
| } | |
| MECH_EXPLAIN = { | |
| "antibiotic inactivation" : "The protein chemically modifies or destroys the antibiotic.", | |
| "antibiotic efflux" : "The protein pumps antibiotics out of the bacterial cell.", | |
| "antibiotic target alteration" : "The antibiotic target is modified so the drug cannot bind.", | |
| "antibiotic target protection" : "The protein shields the target from the antibiotic.", | |
| "antibiotic target replacement": "The bacterium switches to an alternative target the drug cannot affect.", | |
| "reduced permeability to antibiotic": "The cell wall blocks the antibiotic from entering.", | |
| } | |
| # ββ Sequence validation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def validate_sequence(seq: str): | |
| """ | |
| Returns (cleaned_seq, error_message). | |
| error_message is None if valid. | |
| """ | |
| seq = seq.strip().upper() | |
| if seq.startswith(">"): | |
| lines = seq.split("\n") | |
| seq = "".join(l.strip() for l in lines[1:] | |
| if not l.startswith(">")) | |
| seq = "".join(c for c in seq if c.isalpha()) | |
| if len(seq) == 0: | |
| return None, "No sequence found. Please paste a protein sequence." | |
| if len(seq) < MIN_LEN: | |
| return None, ( | |
| f"Sequence is too short ({len(seq)} amino acids). " | |
| f"Minimum length is {MIN_LEN} amino acids. " | |
| f"Please check your input." | |
| ) | |
| invalid = set(seq) - VALID_AA | |
| if invalid: | |
| return None, ( | |
| f"Sequence contains non-standard characters: {', '.join(sorted(invalid))}. " | |
| f"Please use standard one-letter amino acid codes." | |
| ) | |
| # Composition check -- catch names and gibberish | |
| from collections import Counter | |
| counts = Counter(seq) | |
| top_aa = counts.most_common(1)[0] | |
| uniq_aa = len(counts) | |
| if uniq_aa < 5: | |
| return None, ( | |
| f"Sequence uses only {uniq_aa} different amino acids. " | |
| f"A real protein typically uses 15 or more. " | |
| f"Please check your input." | |
| ) | |
| if top_aa[1] / len(seq) > 0.5: | |
| return None, ( | |
| f"Single amino acid '{top_aa[0]}' makes up " | |
| f"{top_aa[1]/len(seq)*100:.0f}% of the sequence. " | |
| f"This does not look like a real protein. " | |
| f"Please check your input." | |
| ) | |
| return seq, None | |
| # ββ Embed βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def embed_sequence(seq: str): | |
| if len(seq) <= 1022: | |
| inputs = tokenizer(seq, return_tensors="pt", | |
| add_special_tokens=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = esm2(**inputs) | |
| return out.last_hidden_state[0, 1:-1, :].mean(dim=0).cpu().numpy() | |
| window, stride = 1022, 256 | |
| chunks, lengths = [], [] | |
| start = 0 | |
| while start < len(seq): | |
| end = min(start + window, len(seq)) | |
| chunks.append(seq[start:end]) | |
| lengths.append(end - start) | |
| if end == len(seq): | |
| break | |
| start += stride | |
| embs = [] | |
| for chunk in chunks: | |
| inp = tokenizer(chunk, return_tensors="pt", | |
| add_special_tokens=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = esm2(**inp) | |
| embs.append(out.last_hidden_state[0, 1:-1, :].mean(0).cpu().numpy()) | |
| total = sum(lengths) | |
| return sum((l / total) * e for l, e in zip(lengths, embs)) | |
| # ββ Project to UMAP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def project_to_umap(embedding: np.ndarray): | |
| scaled = scaler.transform(embedding.reshape(1, -1)) | |
| pca50_q = pca.transform(scaled) | |
| dists, idxs = knn.kneighbors(pca50_q, n_neighbors=5) | |
| dists = dists[0] | |
| idxs = idxs[0] | |
| weights = 1.0 / (dists + 1e-9) | |
| weights /= weights.sum() | |
| nbrs = df_meta.iloc[idxs] | |
| # Use 3D coords for query projection too | |
| x3 = (nbrs["esm2_x3"].values * weights).sum() | |
| y3 = (nbrs["esm2_y3"].values * weights).sum() | |
| z3 = (nbrs["esm2_z3"].values * weights).sum() | |
| return pca50_q, np.array([[x3, y3, z3]]) | |
| # ββ Build 3D landscape ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_3d_landscape(color_by="mechanism", query_point=None): | |
| card = df_meta[df_meta["source"] == "CARD"] | |
| wt = df_meta[df_meta["source"] == "WT_PBP"] | |
| wt_upper = wt[wt["esm2_y"] >= -2.64] | |
| wt_lower = wt[wt["esm2_y"] < -2.64] | |
| fig = go.Figure() | |
| if color_by == "mechanism": | |
| # Background -- sequences not in main mechanism list | |
| other = card[~card["primary_mechanism"].isin(MECH_LABEL.keys())] | |
| if len(other) > 0: | |
| fig.add_trace(go.Scatter3d( | |
| x=other["esm2_x3"], y=other["esm2_y3"], z=other["esm2_z3"], | |
| mode="markers", name="Other", | |
| marker=dict(size=1.5, color="#cccccc", opacity=0.15), | |
| hoverinfo="skip" | |
| )) | |
| for mech, label in MECH_LABEL.items(): | |
| s = card[card["primary_mechanism"] == mech] | |
| if len(s) == 0: | |
| continue | |
| hover = s.apply(lambda r: | |
| f"<b>{r.get('gene_name','?')}</b><br>" | |
| f"Organism: {str(r.get('organism','?'))[:45]}<br>" | |
| f"Mechanism: {r.get('primary_mechanism','?')}<br>" | |
| f"Drug class: {r.get('primary_drug_class','?')}<br>" | |
| f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}", | |
| axis=1) | |
| fig.add_trace(go.Scatter3d( | |
| x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"], | |
| mode="markers", | |
| name=f"{label} (n={len(s)})", | |
| marker=dict( | |
| size=2.5, | |
| color=MECH_COLORS.get(mech, "#999999"), | |
| opacity=0.65 | |
| ), | |
| text=hover, | |
| hovertemplate="%{text}<extra></extra>" | |
| )) | |
| # WT PBP sub-zones | |
| fig.add_trace(go.Scatter3d( | |
| x=wt_upper["esm2_x3"], y=wt_upper["esm2_y3"], z=wt_upper["esm2_z3"], | |
| mode="markers", | |
| name=f"WT PBP soluble (n={len(wt_upper)})", | |
| marker=dict(size=4, color="black", opacity=0.85, symbol="diamond"), | |
| hoverinfo="skip" | |
| )) | |
| fig.add_trace(go.Scatter3d( | |
| x=wt_lower["esm2_x3"], y=wt_lower["esm2_y3"], z=wt_lower["esm2_z3"], | |
| mode="markers", | |
| name=f"WT PBP membrane (n={len(wt_lower)})", | |
| marker=dict(size=4, color="#777777", opacity=0.85, | |
| symbol="diamond-open"), | |
| hoverinfo="skip" | |
| )) | |
| elif color_by == "bla_class": | |
| bg = card[card["bla_class"].isna()] | |
| fig.add_trace(go.Scatter3d( | |
| x=bg["esm2_x3"], y=bg["esm2_y3"], z=bg["esm2_z3"], | |
| mode="markers", name="Other AMR", | |
| marker=dict(size=1.5, color="#dddddd", opacity=0.1), | |
| hoverinfo="skip" | |
| )) | |
| bla_labels = { | |
| "Class_A": "Class A serine -- TEM/SHV/CTX-M/KPC", | |
| "Class_B": "Class B metallo -- NDM/VIM/IMP", | |
| "Class_C": "Class C AmpC -- CMY/PDC/ADC", | |
| "Class_D": "Class D OXA", | |
| } | |
| for cls, label in bla_labels.items(): | |
| s = card[card["bla_class"] == cls] | |
| hover = s.apply(lambda r: | |
| f"<b>{r.get('gene_name','?')}</b><br>" | |
| f"Class: {cls}<br>" | |
| f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}<br>" | |
| f"Organism: {str(r.get('organism','?'))[:45]}", | |
| axis=1) | |
| fig.add_trace(go.Scatter3d( | |
| x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"], | |
| mode="markers", name=f"{label} (n={len(s)})", | |
| marker=dict(size=3, color=BLA_COLORS[cls], opacity=0.75), | |
| text=hover, | |
| hovertemplate="%{text}<extra></extra>" | |
| )) | |
| fig.add_trace(go.Scatter3d( | |
| x=wt_upper["esm2_x3"], y=wt_upper["esm2_y3"], z=wt_upper["esm2_z3"], | |
| mode="markers", | |
| name=f"WT PBP soluble (n={len(wt_upper)})", | |
| marker=dict(size=5, color="black", opacity=0.9, symbol="diamond"), | |
| hoverinfo="skip" | |
| )) | |
| fig.add_trace(go.Scatter3d( | |
| x=wt_lower["esm2_x3"], y=wt_lower["esm2_y3"], z=wt_lower["esm2_z3"], | |
| mode="markers", | |
| name=f"WT PBP membrane (n={len(wt_lower)})", | |
| marker=dict(size=5, color="#777777", opacity=0.9, | |
| symbol="diamond-open"), | |
| hoverinfo="skip" | |
| )) | |
| # Query point | |
| if query_point is not None: | |
| fig.add_trace(go.Scatter3d( | |
| x=[query_point[0]], | |
| y=[query_point[1]], | |
| z=[query_point[2]], | |
| mode="markers", | |
| name="Your sequence", | |
| marker=dict( | |
| size=10, | |
| color="#FF1744", | |
| symbol="diamond", | |
| opacity=1.0, | |
| line=dict(color="white", width=2) | |
| ), | |
| hovertemplate="Your query sequence<extra></extra>" | |
| )) | |
| fig.update_layout( | |
| scene=dict( | |
| xaxis_title="UMAP 1", | |
| yaxis_title="UMAP 2", | |
| zaxis_title="UMAP 3", | |
| bgcolor="white", | |
| xaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", | |
| showbackground=True), | |
| yaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", | |
| showbackground=True), | |
| zaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", | |
| showbackground=True), | |
| ), | |
| height=650, | |
| legend=dict( | |
| font=dict(size=10, family="Arial"), | |
| borderwidth=1, | |
| bgcolor="rgba(255,255,255,0.95)", | |
| bordercolor="#dddddd" | |
| ), | |
| paper_bgcolor="white", | |
| margin=dict(l=0, r=0, t=10, b=0) | |
| ) | |
| return fig | |
| # ββ Query function ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def format_result_html(tier_label, tier_color, mean_dist, summary, | |
| pred_mech, pred_drug, confidence, mech_explain, | |
| embed_time, seq_len): | |
| """Build styled HTML result card.""" | |
| color_map = { | |
| "green" : "#2e7d32", | |
| "orange": "#e65100", | |
| "red" : "#c62828" | |
| } | |
| bg_map = { | |
| "green" : "#f1f8e9", | |
| "orange": "#fff3e0", | |
| "red" : "#ffebee" | |
| } | |
| hex_color = color_map.get(tier_color, "#333333") | |
| hex_bg = bg_map.get(tier_color, "#f5f5f5") | |
| mech_note = mech_explain.get(pred_mech, "") | |
| if tier_color == "red": | |
| pred_section = f""" | |
| <p style="margin:4px 0"><b>Closest mechanism found</b><br> | |
| <span style="color:#555">{pred_mech}</span></p> | |
| <p style="margin:4px 0"><b>Closest drug class</b><br> | |
| <span style="color:#555">{pred_drug}</span></p> | |
| <p style="margin:8px 0 4px 0; color:#888; font-size:0.9em"> | |
| These values reflect proximity in embedding space only and should not | |
| be interpreted as a resistance classification.</p> | |
| """ | |
| else: | |
| pred_section = f""" | |
| <p style="margin:6px 0"><b>Predicted resistance mechanism</b><br> | |
| <span style="font-size:1.05em; color:#1a1a1a">{pred_mech}</span><br> | |
| <span style="color:#555; font-size:0.92em">{mech_note}</span></p> | |
| <p style="margin:6px 0"><b>Predicted drug class</b><br> | |
| <span style="font-size:1.05em; color:#1a1a1a">{pred_drug}</span></p> | |
| <p style="margin:6px 0"><b>Confidence</b><br> | |
| <span style="font-size:1.05em">{int(confidence*5)}/5 neighbours agree | |
| ({confidence*100:.0f}%)</span></p> | |
| """ | |
| html = f""" | |
| <div style="font-family: Arial, sans-serif; max-width: 100%; | |
| overflow-y: auto; padding: 4px;"> | |
| <div style="background:{hex_bg}; border-left: 5px solid {hex_color}; | |
| border-radius:6px; padding:14px 16px; margin-bottom:14px;"> | |
| <p style="margin:0 0 6px 0; font-size:1.15em; font-weight:700; | |
| color:{hex_color}; letter-spacing:0.3px"> | |
| {tier_label} | |
| </p> | |
| <p style="margin:4px 0; font-size:0.95em; color:#333"> | |
| Average distance to 5 nearest neighbours: <b>{mean_dist:.3f}</b> | |
| | threshold for reliable match: below 0.35 | |
| </p> | |
| <p style="margin:6px 0 0 0; color:#444">{summary}</p> | |
| </div> | |
| <div style="border:1px solid #e0e0e0; border-radius:6px; | |
| padding:14px 16px; margin-bottom:14px;"> | |
| <p style="margin:0 0 10px 0; font-size:1.05em; font-weight:700; | |
| color:#1a1a1a; border-bottom:1px solid #eee; padding-bottom:6px"> | |
| Prediction | |
| </p> | |
| {pred_section} | |
| </div> | |
| <div style="background:#f9f9f9; border-radius:6px; | |
| padding:12px 16px; margin-bottom:10px;"> | |
| <p style="margin:0 0 6px 0; font-weight:700; color:#1a1a1a"> | |
| How this works | |
| </p> | |
| <p style="margin:0; color:#555; font-size:0.92em; line-height:1.6"> | |
| ESM2-650M converts the sequence into a 1,280-dimensional vector learned | |
| from 250 million protein sequences. That vector is compared against | |
| 5,029 curated AMR proteins from CARD using cosine distance | |
| (0 = identical, 1 = completely unrelated). The red diamond on the | |
| 3D map shows where your sequence lands. | |
| </p> | |
| </div> | |
| <p style="margin:6px 0 2px 0; color:#888; font-size:0.85em"> | |
| <i>Disclaimer: predictions are based on similarity to known AMR proteins | |
| in CARD. Novel resistance mechanisms or proteins outside this database | |
| may not be identified correctly.</i> | |
| </p> | |
| <p style="margin:2px 0; color:#aaa; font-size:0.82em"> | |
| Embedding time: {embed_time:.1f}s | | |
| Sequence length: {seq_len} aa | |
| </p> | |
| </div> | |
| """ | |
| return html | |
| def run_query(sequence_input, color_by): | |
| empty_fig = build_3d_landscape(color_by) | |
| empty_html = "<p style='color:#888; font-family:Arial'>Results will appear here.</p>" | |
| if not sequence_input or len(sequence_input.strip()) < 2: | |
| yield empty_fig, empty_html, None | |
| return | |
| seq, err = validate_sequence(sequence_input) | |
| if err: | |
| err_html = f""" | |
| <div style="background:#ffebee; border-left:5px solid #c62828; | |
| border-radius:6px; padding:12px 16px; font-family:Arial"> | |
| <b style="color:#c62828">Input error</b><br> | |
| <span style="color:#333">{err}</span> | |
| </div> | |
| """ | |
| yield empty_fig, err_html, None | |
| return | |
| loading_html = """ | |
| <div style="padding:16px; font-family:Arial; color:#555"> | |
| <b>Embedding with ESM2-650M...</b><br> | |
| <span style="font-size:0.9em">About 20-30 seconds on CPU. Please wait.</span> | |
| </div> | |
| """ | |
| yield empty_fig, loading_html, None | |
| start = time.time() | |
| embedding = embed_sequence(seq) | |
| embed_time = time.time() - start | |
| pca50_q, coords3d = project_to_umap(embedding) | |
| qx = float(coords3d[0, 0]) | |
| qy = float(coords3d[0, 1]) | |
| qz = float(coords3d[0, 2]) | |
| dists, idxs = knn.kneighbors(pca50_q) | |
| dists = dists[0] | |
| idxs = idxs[0] | |
| neighbors = df_meta.iloc[idxs].copy() | |
| neighbors["distance"] = dists.round(3) | |
| top5_mechs = neighbors.head(5)["primary_mechanism"].tolist() | |
| pred_mech = max(set(top5_mechs), key=top5_mechs.count) | |
| pred_drug = neighbors.head(5)["primary_drug_class"].mode()[0] | |
| confidence = top5_mechs.count(pred_mech) / 5 | |
| mean_dist = float(dists[:5].mean()) | |
| if mean_dist > 0.35: | |
| tier_label = "LOW CONFIDENCE" | |
| tier_color = "red" | |
| summary = ( | |
| "This sequence does not closely resemble any known AMR protein in CARD. " | |
| "The results below show the closest matches found but should not be " | |
| "interpreted as a resistance prediction. This sequence is likely " | |
| "not a resistance protein." | |
| ) | |
| elif mean_dist > 0.20: | |
| tier_label = "MODERATE CONFIDENCE" | |
| tier_color = "orange" | |
| summary = ( | |
| "This sequence resembles known resistance proteins but is not a " | |
| "close match. Treat as a hypothesis to follow up experimentally." | |
| ) | |
| else: | |
| tier_label = "HIGH CONFIDENCE" | |
| tier_color = "green" | |
| summary = ( | |
| "This sequence closely resembles known resistance proteins. " | |
| "Confirm experimentally before drawing conclusions." | |
| ) | |
| result_html = format_result_html( | |
| tier_label, tier_color, mean_dist, summary, | |
| pred_mech, pred_drug, confidence, MECH_EXPLAIN, | |
| embed_time, len(seq) | |
| ) | |
| # Table with italic organism names | |
| col_map = { | |
| "gene_name" : "Gene", | |
| "organism" : "Organism", | |
| "primary_mechanism" : "Mechanism", | |
| "primary_drug_class" : "Drug class", | |
| "AMR Gene Family" : "Gene family", | |
| "distance" : "Distance", | |
| "source" : "Source" | |
| } | |
| keep = [c for c in col_map if c in neighbors.columns] | |
| tbl = neighbors[keep].copy().reset_index(drop=True) | |
| tbl = tbl.rename(columns=col_map) | |
| tbl["Gene family"] = tbl["Gene family"].str[:40] | |
| # Italic organism names | |
| tbl["Organism"] = tbl["Organism"].apply( | |
| lambda x: f"<i>{str(x)[:40]}</i>") | |
| tbl.index = tbl.index + 1 | |
| fig = build_3d_landscape(color_by, query_point=(qx, qy, qz)) | |
| yield fig, result_html, tbl | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| body, .gradio-container { | |
| font-family: Arial, sans-serif !important; | |
| } | |
| .gradio-tabitem { | |
| overflow-y: hidden !important; | |
| } | |
| .gr-prose p { | |
| font-size: 0.95rem !important; | |
| line-height: 1.6 !important; | |
| } | |
| """ | |
| with gr.Blocks(title="EmbedAMR", theme=gr.themes.Default(), css=CSS) as app: | |
| gr.Markdown("# EmbedAMR") | |
| gr.Markdown( | |
| "Exploring the AMR protein embedding landscape with protein language models. " | |
| "5,029 AMR proteins from CARD v3.3.0 embedded with ESM2-650M and Ankh-base." | |
| "\n\n*Pranavathiyani G, SASTRA Deemed University, Thanjavur, India*" | |
| ) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Explore ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Explore AMR Embedding Landscape"): | |
| gr.Markdown( | |
| "Each point is one AMR protein. Points close together have " | |
| "similar sequence embeddings. Hover to inspect. " | |
| "Rotate, zoom, and pan freely." | |
| ) | |
| color_radio = gr.Radio( | |
| choices=[ | |
| ("Resistance mechanism", "mechanism"), | |
| ("Beta-lactamase class (A/B/C/D) + WT PBP zones", "bla_class") | |
| ], | |
| value="mechanism", | |
| label="Color by" | |
| ) | |
| loading_msg = gr.Markdown("*Loading AMR Embedding Landscape...*") | |
| explore_plot = gr.Plot(label="") | |
| def load_explore(color_by): | |
| return gr.Markdown(visible=False), build_3d_landscape(color_by) | |
| color_radio.change( | |
| fn=lambda c: build_3d_landscape(c), | |
| inputs=color_radio, | |
| outputs=explore_plot | |
| ) | |
| # ββ Tab 2: Query ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Query Sequence"): | |
| gr.Markdown( | |
| "Paste a protein sequence to find its nearest neighbours " | |
| "in the AMR landscape. ESM2-650M embeds it on CPU -- " | |
| "about 20-30 seconds." | |
| ) | |
| seq_input = gr.Textbox( | |
| label="Protein sequence (FASTA or plain amino acids, min 30 aa)", | |
| placeholder=">MyProtein\nMSIQHFRVALIPFFAAFCLPVFA...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| color_radio_q = gr.Radio( | |
| choices=[ | |
| ("Resistance mechanism", "mechanism"), | |
| ("Beta-lactamase class", "bla_class") | |
| ], | |
| value="mechanism", | |
| label="Map color by" | |
| ) | |
| query_btn = gr.Button("Find nearest neighbours", | |
| variant="primary", scale=1) | |
| query_plot = gr.Plot(label="Your sequence in the AMR landscape") | |
| result_box = gr.HTML(label="Result and interpretation") | |
| result_table = gr.Dataframe( | |
| label="10 nearest AMR neighbours", | |
| wrap=True, | |
| datatype=["str","html","str","str","str","number","str"], | |
| column_widths=["10%","18%","16%","13%","25%","8%","10%"] | |
| ) | |
| query_btn.click( | |
| fn=run_query, | |
| inputs=[seq_input, color_radio_q], | |
| outputs=[query_plot, result_box, result_table] | |
| ) | |
| # ββ Tab 3: About ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("About"): | |
| gr.Markdown(""" | |
| ### What is EmbedAMR? | |
| EmbedAMR maps 5,029 AMR proteins from CARD into a shared embedding space | |
| using ESM2-650M and Ankh-base. Query any protein sequence to find its | |
| nearest neighbours using embedding similarity rather than sequence alignment. | |
| --- | |
| ### Are penicillin-binding proteins the same as AMR proteins? | |
| **No -- and our embeddings confirm this.** | |
| Wild-type PBPs are the *targets* of beta-lactam antibiotics, not resistance | |
| proteins. EmbedAMR correctly places them far from the AMR landscape | |
| (average distance 0.44, above the 0.35 reliability threshold). | |
| Three related groups ARE resistance proteins and score well: | |
| - **Resistance-conferring mutant PBPs** in CARD variant model | |
| - **mecA/mecB/mecC/mecD (PBP2a)**: PBP-derived proteins conferring methicillin resistance | |
| - **Class A/C/D beta-lactamases**: distantly PBP-related enzymes | |
| --- | |
| ### Methods | |
| **Data**: CARD v3.3.0, 5,029 sequences (homolog + variant models). | |
| 187 wild-type PBP controls from UniProt Swiss-Prot (KW-0573, bacteria, reviewed). | |
| 12 sequences removed after sequence-level deduplication against CARD. | |
| **Embeddings**: Mean pooling of ESM2-650M last hidden state (1,280 dim). | |
| Sequences longer than 1,022 aa use sliding window (window 1,022, stride 256). | |
| **Projection**: StandardScaler, PCA 50 components (89.7% variance), | |
| UMAP (n_neighbors=30, cosine metric, random_state=42). | |
| **Distance guide**: | |
| - Below 0.20: close match, prediction likely reliable | |
| - 0.20 to 0.35: moderate, treat as hypothesis | |
| - Above 0.35: sequence likely not an AMR protein | |
| --- | |
| ### Source code | |
| github.com/pranavathiyani/EmbedAMR | |
| ### Contact | |
| pranavathiyani@scbt.sastra.edu | |
| SASTRA Deemed University, Thanjavur, Tamil Nadu, India | |
| """) | |
| # Auto-load landscape on startup | |
| app.load( | |
| fn=lambda: build_3d_landscape("mechanism"), | |
| outputs=explore_plot | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |