GeneMamba / hf_space /app.py
mineself2016's picture
Add GitHub dataset source link in README
13144d7 verified
import os
import re
import torch
import gradio as gr
from transformers import AutoModel, AutoTokenizer
MODEL_REPO = os.getenv("MODEL_REPO", "mineself2016/GeneMamba")
DEFAULT_MAX_LEN = int(os.getenv("MAX_LEN", "2048"))
def _load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_REPO, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return tokenizer, model, device
tokenizer, model, device = _load_model()
vocab = tokenizer.get_vocab()
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 1
unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 0
def parse_gene_sequence(raw_text: str):
tokens = [t.strip() for t in re.split(r"[\s,;\n\t]+", raw_text) if t.strip()]
return tokens
def embed_gene_sequence(raw_text: str, max_len: int = DEFAULT_MAX_LEN, normalize: bool = False):
genes = parse_gene_sequence(raw_text)
if len(genes) == 0:
raise gr.Error("Please provide at least one gene token (e.g., ENSG00000000003).")
ids = []
unknown_genes = []
for g in genes:
if g in vocab:
ids.append(vocab[g])
else:
ids.append(unk_id)
unknown_genes.append(g)
ids = ids[:max_len]
if len(ids) < max_len:
ids = ids + [pad_id] * (max_len - len(ids))
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model(input_ids=input_ids)
emb = outputs.pooled_embedding[0]
if normalize:
emb = torch.nn.functional.normalize(emb, p=2, dim=0)
emb = emb.detach().cpu().tolist()
return {
"model_repo": MODEL_REPO,
"embedding_dim": len(emb),
"input_gene_count": len(genes),
"used_tokens": min(len(genes), max_len),
"unknown_gene_count": len(unknown_genes),
"unknown_genes_preview": unknown_genes[:20],
"embedding": emb,
}
DESCRIPTION = """
Input a gene sequence (Ensembl IDs separated by space/comma/newline), then get the GeneMamba pooled embedding.
Examples:
ENSG00000000003 ENSG00000000419 ENSG00000001036
"""
demo = gr.Interface(
fn=embed_gene_sequence,
inputs=[
gr.Textbox(lines=8, label="Gene sequence (ENSG IDs)", placeholder="ENSG00000000003 ENSG00000000419 ..."),
gr.Slider(64, DEFAULT_MAX_LEN, value=DEFAULT_MAX_LEN, step=64, label="Max sequence length"),
gr.Checkbox(value=False, label="L2 normalize embedding"),
],
outputs=gr.JSON(label="Embedding Result"),
title="GeneMamba Embedding API",
description=DESCRIPTION,
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()