import os import re import torch import gradio as gr from transformers import AutoModel, AutoTokenizer MODEL_REPO = os.getenv("MODEL_REPO", "mineself2016/GeneMamba") DEFAULT_MAX_LEN = int(os.getenv("MAX_LEN", "2048")) def _load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True) model = AutoModel.from_pretrained(MODEL_REPO, trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() return tokenizer, model, device tokenizer, model, device = _load_model() vocab = tokenizer.get_vocab() pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 1 unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 0 def parse_gene_sequence(raw_text: str): tokens = [t.strip() for t in re.split(r"[\s,;\n\t]+", raw_text) if t.strip()] return tokens def embed_gene_sequence(raw_text: str, max_len: int = DEFAULT_MAX_LEN, normalize: bool = False): genes = parse_gene_sequence(raw_text) if len(genes) == 0: raise gr.Error("Please provide at least one gene token (e.g., ENSG00000000003).") ids = [] unknown_genes = [] for g in genes: if g in vocab: ids.append(vocab[g]) else: ids.append(unk_id) unknown_genes.append(g) ids = ids[:max_len] if len(ids) < max_len: ids = ids + [pad_id] * (max_len - len(ids)) input_ids = torch.tensor([ids], dtype=torch.long, device=device) with torch.no_grad(): outputs = model(input_ids=input_ids) emb = outputs.pooled_embedding[0] if normalize: emb = torch.nn.functional.normalize(emb, p=2, dim=0) emb = emb.detach().cpu().tolist() return { "model_repo": MODEL_REPO, "embedding_dim": len(emb), "input_gene_count": len(genes), "used_tokens": min(len(genes), max_len), "unknown_gene_count": len(unknown_genes), "unknown_genes_preview": unknown_genes[:20], "embedding": emb, } DESCRIPTION = """ Input a gene sequence (Ensembl IDs separated by space/comma/newline), then get the GeneMamba pooled embedding. Examples: ENSG00000000003 ENSG00000000419 ENSG00000001036 """ demo = gr.Interface( fn=embed_gene_sequence, inputs=[ gr.Textbox(lines=8, label="Gene sequence (ENSG IDs)", placeholder="ENSG00000000003 ENSG00000000419 ..."), gr.Slider(64, DEFAULT_MAX_LEN, value=DEFAULT_MAX_LEN, step=64, label="Max sequence length"), gr.Checkbox(value=False, label="L2 normalize embedding"), ], outputs=gr.JSON(label="Embedding Result"), title="GeneMamba Embedding API", description=DESCRIPTION, allow_flagging="never", ) if __name__ == "__main__": demo.launch()