# EmbedAMR # Pranavathiyani G, SASTRA Deemed University import gradio as gr import numpy as np import pandas as pd import torch import joblib import plotly.graph_objects as go import os import time from transformers import AutoTokenizer, AutoModel # ── Assets ──────────────────────────────────────────────────────────────────── ASSET_DIR = "app_assets" print("Loading assets...") df_meta = pd.read_csv(os.path.join(ASSET_DIR, "embedamr_metadata.csv")) pca50_emb = np.load(os.path.join(ASSET_DIR, "esm2_pca50.npy")) scaler = joblib.load(os.path.join(ASSET_DIR, "esm2_scaler.pkl")) pca = joblib.load(os.path.join(ASSET_DIR, "esm2_pca50.pkl")) knn = joblib.load(os.path.join(ASSET_DIR, "knn_index.pkl")) print(f"Ready. {len(df_meta)} sequences loaded.") # ── ESM2 ────────────────────────────────────────────────────────────────────── print("Loading ESM2-650M...") DEVICE = torch.device("cpu") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") esm2 = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(DEVICE) esm2.eval() print("ESM2-650M ready.") # ── Constants ───────────────────────────────────────────────────────────────── VALID_AA = set("ACDEFGHIKLMNPQRSTVWY") MIN_LEN = 30 MECH_COLORS = { "antibiotic inactivation" : "#2196F3", "antibiotic efflux" : "#FF9800", "antibiotic target alteration" : "#4CAF50", "antibiotic target protection" : "#9C27B0", "antibiotic target replacement" : "#F44336", "reduced permeability to antibiotic" : "#795548", "resistance by host-dependent nutrient acquisition": "#607D8B", } MECH_LABEL = { "antibiotic inactivation" : "Inactivation", "antibiotic efflux" : "Efflux", "antibiotic target alteration" : "Target alteration", "antibiotic target protection" : "Target protection", "antibiotic target replacement" : "Target replacement", "reduced permeability to antibiotic" : "Reduced permeability", "resistance by host-dependent nutrient acquisition": "Host nutrient", } BLA_COLORS = { "Class_A": "#E53935", "Class_B": "#1E88E5", "Class_C": "#43A047", "Class_D": "#FB8C00", } MECH_EXPLAIN = { "antibiotic inactivation" : "The protein chemically modifies or destroys the antibiotic.", "antibiotic efflux" : "The protein pumps antibiotics out of the bacterial cell.", "antibiotic target alteration" : "The antibiotic target is modified so the drug cannot bind.", "antibiotic target protection" : "The protein shields the target from the antibiotic.", "antibiotic target replacement": "The bacterium switches to an alternative target the drug cannot affect.", "reduced permeability to antibiotic": "The cell wall blocks the antibiotic from entering.", } # ── Sequence validation ─────────────────────────────────────────────────────── def validate_sequence(seq: str): """ Returns (cleaned_seq, error_message). error_message is None if valid. """ seq = seq.strip().upper() if seq.startswith(">"): lines = seq.split("\n") seq = "".join(l.strip() for l in lines[1:] if not l.startswith(">")) seq = "".join(c for c in seq if c.isalpha()) if len(seq) == 0: return None, "No sequence found. Please paste a protein sequence." if len(seq) < MIN_LEN: return None, ( f"Sequence is too short ({len(seq)} amino acids). " f"Minimum length is {MIN_LEN} amino acids. " f"Please check your input." ) invalid = set(seq) - VALID_AA if invalid: return None, ( f"Sequence contains non-standard characters: {', '.join(sorted(invalid))}. " f"Please use standard one-letter amino acid codes." ) # Composition check -- catch names and gibberish from collections import Counter counts = Counter(seq) top_aa = counts.most_common(1)[0] uniq_aa = len(counts) if uniq_aa < 5: return None, ( f"Sequence uses only {uniq_aa} different amino acids. " f"A real protein typically uses 15 or more. " f"Please check your input." ) if top_aa[1] / len(seq) > 0.5: return None, ( f"Single amino acid '{top_aa[0]}' makes up " f"{top_aa[1]/len(seq)*100:.0f}% of the sequence. " f"This does not look like a real protein. " f"Please check your input." ) return seq, None # ── Embed ───────────────────────────────────────────────────────────────────── def embed_sequence(seq: str): if len(seq) <= 1022: inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True).to(DEVICE) with torch.no_grad(): out = esm2(**inputs) return out.last_hidden_state[0, 1:-1, :].mean(dim=0).cpu().numpy() window, stride = 1022, 256 chunks, lengths = [], [] start = 0 while start < len(seq): end = min(start + window, len(seq)) chunks.append(seq[start:end]) lengths.append(end - start) if end == len(seq): break start += stride embs = [] for chunk in chunks: inp = tokenizer(chunk, return_tensors="pt", add_special_tokens=True).to(DEVICE) with torch.no_grad(): out = esm2(**inp) embs.append(out.last_hidden_state[0, 1:-1, :].mean(0).cpu().numpy()) total = sum(lengths) return sum((l / total) * e for l, e in zip(lengths, embs)) # ── Project to UMAP ─────────────────────────────────────────────────────────── def project_to_umap(embedding: np.ndarray): scaled = scaler.transform(embedding.reshape(1, -1)) pca50_q = pca.transform(scaled) dists, idxs = knn.kneighbors(pca50_q, n_neighbors=5) dists = dists[0] idxs = idxs[0] weights = 1.0 / (dists + 1e-9) weights /= weights.sum() nbrs = df_meta.iloc[idxs] # Use 3D coords for query projection too x3 = (nbrs["esm2_x3"].values * weights).sum() y3 = (nbrs["esm2_y3"].values * weights).sum() z3 = (nbrs["esm2_z3"].values * weights).sum() return pca50_q, np.array([[x3, y3, z3]]) # ── Build 3D landscape ──────────────────────────────────────────────────────── def build_3d_landscape(color_by="mechanism", query_point=None): card = df_meta[df_meta["source"] == "CARD"] wt = df_meta[df_meta["source"] == "WT_PBP"] wt_upper = wt[wt["esm2_y"] >= -2.64] wt_lower = wt[wt["esm2_y"] < -2.64] fig = go.Figure() if color_by == "mechanism": # Background -- sequences not in main mechanism list other = card[~card["primary_mechanism"].isin(MECH_LABEL.keys())] if len(other) > 0: fig.add_trace(go.Scatter3d( x=other["esm2_x3"], y=other["esm2_y3"], z=other["esm2_z3"], mode="markers", name="Other", marker=dict(size=1.5, color="#cccccc", opacity=0.15), hoverinfo="skip" )) for mech, label in MECH_LABEL.items(): s = card[card["primary_mechanism"] == mech] if len(s) == 0: continue hover = s.apply(lambda r: f"{r.get('gene_name','?')}
" f"Organism: {str(r.get('organism','?'))[:45]}
" f"Mechanism: {r.get('primary_mechanism','?')}
" f"Drug class: {r.get('primary_drug_class','?')}
" f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}", axis=1) fig.add_trace(go.Scatter3d( x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"], mode="markers", name=f"{label} (n={len(s)})", marker=dict( size=2.5, color=MECH_COLORS.get(mech, "#999999"), opacity=0.65 ), text=hover, hovertemplate="%{text}" )) # WT PBP sub-zones fig.add_trace(go.Scatter3d( x=wt_upper["esm2_x3"], y=wt_upper["esm2_y3"], z=wt_upper["esm2_z3"], mode="markers", name=f"WT PBP soluble (n={len(wt_upper)})", marker=dict(size=4, color="black", opacity=0.85, symbol="diamond"), hoverinfo="skip" )) fig.add_trace(go.Scatter3d( x=wt_lower["esm2_x3"], y=wt_lower["esm2_y3"], z=wt_lower["esm2_z3"], mode="markers", name=f"WT PBP membrane (n={len(wt_lower)})", marker=dict(size=4, color="#777777", opacity=0.85, symbol="diamond-open"), hoverinfo="skip" )) elif color_by == "bla_class": bg = card[card["bla_class"].isna()] fig.add_trace(go.Scatter3d( x=bg["esm2_x3"], y=bg["esm2_y3"], z=bg["esm2_z3"], mode="markers", name="Other AMR", marker=dict(size=1.5, color="#dddddd", opacity=0.1), hoverinfo="skip" )) bla_labels = { "Class_A": "Class A serine -- TEM/SHV/CTX-M/KPC", "Class_B": "Class B metallo -- NDM/VIM/IMP", "Class_C": "Class C AmpC -- CMY/PDC/ADC", "Class_D": "Class D OXA", } for cls, label in bla_labels.items(): s = card[card["bla_class"] == cls] hover = s.apply(lambda r: f"{r.get('gene_name','?')}
" f"Class: {cls}
" f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}
" f"Organism: {str(r.get('organism','?'))[:45]}", axis=1) fig.add_trace(go.Scatter3d( x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"], mode="markers", name=f"{label} (n={len(s)})", marker=dict(size=3, color=BLA_COLORS[cls], opacity=0.75), text=hover, hovertemplate="%{text}" )) fig.add_trace(go.Scatter3d( x=wt_upper["esm2_x3"], y=wt_upper["esm2_y3"], z=wt_upper["esm2_z3"], mode="markers", name=f"WT PBP soluble (n={len(wt_upper)})", marker=dict(size=5, color="black", opacity=0.9, symbol="diamond"), hoverinfo="skip" )) fig.add_trace(go.Scatter3d( x=wt_lower["esm2_x3"], y=wt_lower["esm2_y3"], z=wt_lower["esm2_z3"], mode="markers", name=f"WT PBP membrane (n={len(wt_lower)})", marker=dict(size=5, color="#777777", opacity=0.9, symbol="diamond-open"), hoverinfo="skip" )) # Query point if query_point is not None: fig.add_trace(go.Scatter3d( x=[query_point[0]], y=[query_point[1]], z=[query_point[2]], mode="markers", name="Your sequence", marker=dict( size=10, color="#FF1744", symbol="diamond", opacity=1.0, line=dict(color="white", width=2) ), hovertemplate="Your query sequence" )) fig.update_layout( scene=dict( xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3", bgcolor="white", xaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", showbackground=True), yaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", showbackground=True), zaxis=dict(backgroundcolor="white", gridcolor="#eeeeee", showbackground=True), ), height=650, legend=dict( font=dict(size=10, family="Arial"), borderwidth=1, bgcolor="rgba(255,255,255,0.95)", bordercolor="#dddddd" ), paper_bgcolor="white", margin=dict(l=0, r=0, t=10, b=0) ) return fig # ── Query function ──────────────────────────────────────────────────────────── def format_result_html(tier_label, tier_color, mean_dist, summary, pred_mech, pred_drug, confidence, mech_explain, embed_time, seq_len): """Build styled HTML result card.""" color_map = { "green" : "#2e7d32", "orange": "#e65100", "red" : "#c62828" } bg_map = { "green" : "#f1f8e9", "orange": "#fff3e0", "red" : "#ffebee" } hex_color = color_map.get(tier_color, "#333333") hex_bg = bg_map.get(tier_color, "#f5f5f5") mech_note = mech_explain.get(pred_mech, "") if tier_color == "red": pred_section = f"""

Closest mechanism found
{pred_mech}

Closest drug class
{pred_drug}

These values reflect proximity in embedding space only and should not be interpreted as a resistance classification.

""" else: pred_section = f"""

Predicted resistance mechanism
{pred_mech}
{mech_note}

Predicted drug class
{pred_drug}

Confidence
{int(confidence*5)}/5 neighbours agree ({confidence*100:.0f}%)

""" html = f"""

{tier_label}

Average distance to 5 nearest neighbours: {mean_dist:.3f}  |  threshold for reliable match: below 0.35

{summary}

Prediction

{pred_section}

How this works

ESM2-650M converts the sequence into a 1,280-dimensional vector learned from 250 million protein sequences. That vector is compared against 5,029 curated AMR proteins from CARD using cosine distance (0 = identical, 1 = completely unrelated). The red diamond on the 3D map shows where your sequence lands.

Disclaimer: predictions are based on similarity to known AMR proteins in CARD. Novel resistance mechanisms or proteins outside this database may not be identified correctly.

Embedding time: {embed_time:.1f}s  |  Sequence length: {seq_len} aa

""" return html def run_query(sequence_input, color_by): empty_fig = build_3d_landscape(color_by) empty_html = "

Results will appear here.

" if not sequence_input or len(sequence_input.strip()) < 2: yield empty_fig, empty_html, None return seq, err = validate_sequence(sequence_input) if err: err_html = f"""
Input error
{err}
""" yield empty_fig, err_html, None return loading_html = """
Embedding with ESM2-650M...
About 20-30 seconds on CPU. Please wait.
""" yield empty_fig, loading_html, None start = time.time() embedding = embed_sequence(seq) embed_time = time.time() - start pca50_q, coords3d = project_to_umap(embedding) qx = float(coords3d[0, 0]) qy = float(coords3d[0, 1]) qz = float(coords3d[0, 2]) dists, idxs = knn.kneighbors(pca50_q) dists = dists[0] idxs = idxs[0] neighbors = df_meta.iloc[idxs].copy() neighbors["distance"] = dists.round(3) top5_mechs = neighbors.head(5)["primary_mechanism"].tolist() pred_mech = max(set(top5_mechs), key=top5_mechs.count) pred_drug = neighbors.head(5)["primary_drug_class"].mode()[0] confidence = top5_mechs.count(pred_mech) / 5 mean_dist = float(dists[:5].mean()) if mean_dist > 0.35: tier_label = "LOW CONFIDENCE" tier_color = "red" summary = ( "This sequence does not closely resemble any known AMR protein in CARD. " "The results below show the closest matches found but should not be " "interpreted as a resistance prediction. This sequence is likely " "not a resistance protein." ) elif mean_dist > 0.20: tier_label = "MODERATE CONFIDENCE" tier_color = "orange" summary = ( "This sequence resembles known resistance proteins but is not a " "close match. Treat as a hypothesis to follow up experimentally." ) else: tier_label = "HIGH CONFIDENCE" tier_color = "green" summary = ( "This sequence closely resembles known resistance proteins. " "Confirm experimentally before drawing conclusions." ) result_html = format_result_html( tier_label, tier_color, mean_dist, summary, pred_mech, pred_drug, confidence, MECH_EXPLAIN, embed_time, len(seq) ) # Table with italic organism names col_map = { "gene_name" : "Gene", "organism" : "Organism", "primary_mechanism" : "Mechanism", "primary_drug_class" : "Drug class", "AMR Gene Family" : "Gene family", "distance" : "Distance", "source" : "Source" } keep = [c for c in col_map if c in neighbors.columns] tbl = neighbors[keep].copy().reset_index(drop=True) tbl = tbl.rename(columns=col_map) tbl["Gene family"] = tbl["Gene family"].str[:40] # Italic organism names tbl["Organism"] = tbl["Organism"].apply( lambda x: f"{str(x)[:40]}") tbl.index = tbl.index + 1 fig = build_3d_landscape(color_by, query_point=(qx, qy, qz)) yield fig, result_html, tbl # ── App ─────────────────────────────────────────────────────────────────────── CSS = """ body, .gradio-container { font-family: Arial, sans-serif !important; } .gradio-tabitem { overflow-y: hidden !important; } .gr-prose p { font-size: 0.95rem !important; line-height: 1.6 !important; } """ with gr.Blocks(title="EmbedAMR", theme=gr.themes.Default(), css=CSS) as app: gr.Markdown("# EmbedAMR") gr.Markdown( "Exploring the AMR protein embedding landscape with protein language models. " "5,029 AMR proteins from CARD v3.3.0 embedded with ESM2-650M and Ankh-base." "\n\n*Pranavathiyani G, SASTRA Deemed University, Thanjavur, India*" ) with gr.Tabs(): # ── Tab 1: Explore ──────────────────────────────────────────────────── with gr.TabItem("Explore AMR Embedding Landscape"): gr.Markdown( "Each point is one AMR protein. Points close together have " "similar sequence embeddings. Hover to inspect. " "Rotate, zoom, and pan freely." ) color_radio = gr.Radio( choices=[ ("Resistance mechanism", "mechanism"), ("Beta-lactamase class (A/B/C/D) + WT PBP zones", "bla_class") ], value="mechanism", label="Color by" ) loading_msg = gr.Markdown("*Loading AMR Embedding Landscape...*") explore_plot = gr.Plot(label="") def load_explore(color_by): return gr.Markdown(visible=False), build_3d_landscape(color_by) color_radio.change( fn=lambda c: build_3d_landscape(c), inputs=color_radio, outputs=explore_plot ) # ── Tab 2: Query ────────────────────────────────────────────────────── with gr.TabItem("Query Sequence"): gr.Markdown( "Paste a protein sequence to find its nearest neighbours " "in the AMR landscape. ESM2-650M embeds it on CPU -- " "about 20-30 seconds." ) seq_input = gr.Textbox( label="Protein sequence (FASTA or plain amino acids, min 30 aa)", placeholder=">MyProtein\nMSIQHFRVALIPFFAAFCLPVFA...", lines=5 ) with gr.Row(): color_radio_q = gr.Radio( choices=[ ("Resistance mechanism", "mechanism"), ("Beta-lactamase class", "bla_class") ], value="mechanism", label="Map color by" ) query_btn = gr.Button("Find nearest neighbours", variant="primary", scale=1) query_plot = gr.Plot(label="Your sequence in the AMR landscape") result_box = gr.HTML(label="Result and interpretation") result_table = gr.Dataframe( label="10 nearest AMR neighbours", wrap=True, datatype=["str","html","str","str","str","number","str"], column_widths=["10%","18%","16%","13%","25%","8%","10%"] ) query_btn.click( fn=run_query, inputs=[seq_input, color_radio_q], outputs=[query_plot, result_box, result_table] ) # ── Tab 3: About ────────────────────────────────────────────────────── with gr.TabItem("About"): gr.Markdown(""" ### What is EmbedAMR? EmbedAMR maps 5,029 AMR proteins from CARD into a shared embedding space using ESM2-650M and Ankh-base. Query any protein sequence to find its nearest neighbours using embedding similarity rather than sequence alignment. --- ### Are penicillin-binding proteins the same as AMR proteins? **No -- and our embeddings confirm this.** Wild-type PBPs are the *targets* of beta-lactam antibiotics, not resistance proteins. EmbedAMR correctly places them far from the AMR landscape (average distance 0.44, above the 0.35 reliability threshold). Three related groups ARE resistance proteins and score well: - **Resistance-conferring mutant PBPs** in CARD variant model - **mecA/mecB/mecC/mecD (PBP2a)**: PBP-derived proteins conferring methicillin resistance - **Class A/C/D beta-lactamases**: distantly PBP-related enzymes --- ### Methods **Data**: CARD v3.3.0, 5,029 sequences (homolog + variant models). 187 wild-type PBP controls from UniProt Swiss-Prot (KW-0573, bacteria, reviewed). 12 sequences removed after sequence-level deduplication against CARD. **Embeddings**: Mean pooling of ESM2-650M last hidden state (1,280 dim). Sequences longer than 1,022 aa use sliding window (window 1,022, stride 256). **Projection**: StandardScaler, PCA 50 components (89.7% variance), UMAP (n_neighbors=30, cosine metric, random_state=42). **Distance guide**: - Below 0.20: close match, prediction likely reliable - 0.20 to 0.35: moderate, treat as hypothesis - Above 0.35: sequence likely not an AMR protein --- ### Source code github.com/pranavathiyani/EmbedAMR ### Contact pranavathiyani@scbt.sastra.edu SASTRA Deemed University, Thanjavur, Tamil Nadu, India """) # Auto-load landscape on startup app.load( fn=lambda: build_3d_landscape("mechanism"), outputs=explore_plot ) if __name__ == "__main__": app.launch()