# EmbedAMR
# Pranavathiyani G, SASTRA Deemed University
import gradio as gr
import numpy as np
import pandas as pd
import torch
import joblib
import plotly.graph_objects as go
import os
import time
from transformers import AutoTokenizer, AutoModel
# ── Assets ────────────────────────────────────────────────────────────────────
ASSET_DIR = "app_assets"
print("Loading assets...")
df_meta = pd.read_csv(os.path.join(ASSET_DIR, "embedamr_metadata.csv"))
pca50_emb = np.load(os.path.join(ASSET_DIR, "esm2_pca50.npy"))
scaler = joblib.load(os.path.join(ASSET_DIR, "esm2_scaler.pkl"))
pca = joblib.load(os.path.join(ASSET_DIR, "esm2_pca50.pkl"))
knn = joblib.load(os.path.join(ASSET_DIR, "knn_index.pkl"))
print(f"Ready. {len(df_meta)} sequences loaded.")
# ── ESM2 ──────────────────────────────────────────────────────────────────────
print("Loading ESM2-650M...")
DEVICE = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm2 = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(DEVICE)
esm2.eval()
print("ESM2-650M ready.")
# ── Constants ─────────────────────────────────────────────────────────────────
VALID_AA = set("ACDEFGHIKLMNPQRSTVWY")
MIN_LEN = 30
MECH_COLORS = {
"antibiotic inactivation" : "#2196F3",
"antibiotic efflux" : "#FF9800",
"antibiotic target alteration" : "#4CAF50",
"antibiotic target protection" : "#9C27B0",
"antibiotic target replacement" : "#F44336",
"reduced permeability to antibiotic" : "#795548",
"resistance by host-dependent nutrient acquisition": "#607D8B",
}
MECH_LABEL = {
"antibiotic inactivation" : "Inactivation",
"antibiotic efflux" : "Efflux",
"antibiotic target alteration" : "Target alteration",
"antibiotic target protection" : "Target protection",
"antibiotic target replacement" : "Target replacement",
"reduced permeability to antibiotic" : "Reduced permeability",
"resistance by host-dependent nutrient acquisition": "Host nutrient",
}
BLA_COLORS = {
"Class_A": "#E53935",
"Class_B": "#1E88E5",
"Class_C": "#43A047",
"Class_D": "#FB8C00",
}
MECH_EXPLAIN = {
"antibiotic inactivation" : "The protein chemically modifies or destroys the antibiotic.",
"antibiotic efflux" : "The protein pumps antibiotics out of the bacterial cell.",
"antibiotic target alteration" : "The antibiotic target is modified so the drug cannot bind.",
"antibiotic target protection" : "The protein shields the target from the antibiotic.",
"antibiotic target replacement": "The bacterium switches to an alternative target the drug cannot affect.",
"reduced permeability to antibiotic": "The cell wall blocks the antibiotic from entering.",
}
# ── Sequence validation ───────────────────────────────────────────────────────
def validate_sequence(seq: str):
"""
Returns (cleaned_seq, error_message).
error_message is None if valid.
"""
seq = seq.strip().upper()
if seq.startswith(">"):
lines = seq.split("\n")
seq = "".join(l.strip() for l in lines[1:]
if not l.startswith(">"))
seq = "".join(c for c in seq if c.isalpha())
if len(seq) == 0:
return None, "No sequence found. Please paste a protein sequence."
if len(seq) < MIN_LEN:
return None, (
f"Sequence is too short ({len(seq)} amino acids). "
f"Minimum length is {MIN_LEN} amino acids. "
f"Please check your input."
)
invalid = set(seq) - VALID_AA
if invalid:
return None, (
f"Sequence contains non-standard characters: {', '.join(sorted(invalid))}. "
f"Please use standard one-letter amino acid codes."
)
# Composition check -- catch names and gibberish
from collections import Counter
counts = Counter(seq)
top_aa = counts.most_common(1)[0]
uniq_aa = len(counts)
if uniq_aa < 5:
return None, (
f"Sequence uses only {uniq_aa} different amino acids. "
f"A real protein typically uses 15 or more. "
f"Please check your input."
)
if top_aa[1] / len(seq) > 0.5:
return None, (
f"Single amino acid '{top_aa[0]}' makes up "
f"{top_aa[1]/len(seq)*100:.0f}% of the sequence. "
f"This does not look like a real protein. "
f"Please check your input."
)
return seq, None
# ── Embed ─────────────────────────────────────────────────────────────────────
def embed_sequence(seq: str):
if len(seq) <= 1022:
inputs = tokenizer(seq, return_tensors="pt",
add_special_tokens=True).to(DEVICE)
with torch.no_grad():
out = esm2(**inputs)
return out.last_hidden_state[0, 1:-1, :].mean(dim=0).cpu().numpy()
window, stride = 1022, 256
chunks, lengths = [], []
start = 0
while start < len(seq):
end = min(start + window, len(seq))
chunks.append(seq[start:end])
lengths.append(end - start)
if end == len(seq):
break
start += stride
embs = []
for chunk in chunks:
inp = tokenizer(chunk, return_tensors="pt",
add_special_tokens=True).to(DEVICE)
with torch.no_grad():
out = esm2(**inp)
embs.append(out.last_hidden_state[0, 1:-1, :].mean(0).cpu().numpy())
total = sum(lengths)
return sum((l / total) * e for l, e in zip(lengths, embs))
# ── Project to UMAP ───────────────────────────────────────────────────────────
def project_to_umap(embedding: np.ndarray):
scaled = scaler.transform(embedding.reshape(1, -1))
pca50_q = pca.transform(scaled)
dists, idxs = knn.kneighbors(pca50_q, n_neighbors=5)
dists = dists[0]
idxs = idxs[0]
weights = 1.0 / (dists + 1e-9)
weights /= weights.sum()
nbrs = df_meta.iloc[idxs]
# Use 3D coords for query projection too
x3 = (nbrs["esm2_x3"].values * weights).sum()
y3 = (nbrs["esm2_y3"].values * weights).sum()
z3 = (nbrs["esm2_z3"].values * weights).sum()
return pca50_q, np.array([[x3, y3, z3]])
# ── Build 3D landscape ────────────────────────────────────────────────────────
def build_3d_landscape(color_by="mechanism", query_point=None):
card = df_meta[df_meta["source"] == "CARD"]
wt = df_meta[df_meta["source"] == "WT_PBP"]
wt_upper = wt[wt["esm2_y"] >= -2.64]
wt_lower = wt[wt["esm2_y"] < -2.64]
fig = go.Figure()
if color_by == "mechanism":
# Background -- sequences not in main mechanism list
other = card[~card["primary_mechanism"].isin(MECH_LABEL.keys())]
if len(other) > 0:
fig.add_trace(go.Scatter3d(
x=other["esm2_x3"], y=other["esm2_y3"], z=other["esm2_z3"],
mode="markers", name="Other",
marker=dict(size=1.5, color="#cccccc", opacity=0.15),
hoverinfo="skip"
))
for mech, label in MECH_LABEL.items():
s = card[card["primary_mechanism"] == mech]
if len(s) == 0:
continue
hover = s.apply(lambda r:
f"{r.get('gene_name','?')}
"
f"Organism: {str(r.get('organism','?'))[:45]}
"
f"Mechanism: {r.get('primary_mechanism','?')}
"
f"Drug class: {r.get('primary_drug_class','?')}
"
f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}",
axis=1)
fig.add_trace(go.Scatter3d(
x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"],
mode="markers",
name=f"{label} (n={len(s)})",
marker=dict(
size=2.5,
color=MECH_COLORS.get(mech, "#999999"),
opacity=0.65
),
text=hover,
hovertemplate="%{text}
"
f"Class: {cls}
"
f"Gene family: {str(r.get('AMR Gene Family','?'))[:45]}
"
f"Organism: {str(r.get('organism','?'))[:45]}",
axis=1)
fig.add_trace(go.Scatter3d(
x=s["esm2_x3"], y=s["esm2_y3"], z=s["esm2_z3"],
mode="markers", name=f"{label} (n={len(s)})",
marker=dict(size=3, color=BLA_COLORS[cls], opacity=0.75),
text=hover,
hovertemplate="%{text}
Closest mechanism found
{pred_mech}
Closest drug class
{pred_drug}
These values reflect proximity in embedding space only and should not be interpreted as a resistance classification.
""" else: pred_section = f"""Predicted resistance mechanism
{pred_mech}
{mech_note}
Predicted drug class
{pred_drug}
Confidence
{int(confidence*5)}/5 neighbours agree
({confidence*100:.0f}%)
{tier_label}
Average distance to 5 nearest neighbours: {mean_dist:.3f} | threshold for reliable match: below 0.35
{summary}
Prediction
{pred_section}How this works
ESM2-650M converts the sequence into a 1,280-dimensional vector learned from 250 million protein sequences. That vector is compared against 5,029 curated AMR proteins from CARD using cosine distance (0 = identical, 1 = completely unrelated). The red diamond on the 3D map shows where your sequence lands.
Disclaimer: predictions are based on similarity to known AMR proteins in CARD. Novel resistance mechanisms or proteins outside this database may not be identified correctly.
Embedding time: {embed_time:.1f}s | Sequence length: {seq_len} aa
Results will appear here.
" if not sequence_input or len(sequence_input.strip()) < 2: yield empty_fig, empty_html, None return seq, err = validate_sequence(sequence_input) if err: err_html = f"""