| import os |
| import re |
| import torch |
| import gradio as gr |
| from transformers import AutoModel, AutoTokenizer |
|
|
| MODEL_REPO = os.getenv("MODEL_REPO", "mineself2016/GeneMamba") |
| DEFAULT_MAX_LEN = int(os.getenv("MAX_LEN", "2048")) |
|
|
|
|
| def _load_model(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True) |
| model = AutoModel.from_pretrained(MODEL_REPO, trust_remote_code=True) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| model.eval() |
| return tokenizer, model, device |
|
|
|
|
| tokenizer, model, device = _load_model() |
| vocab = tokenizer.get_vocab() |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 1 |
| unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 0 |
|
|
|
|
| def parse_gene_sequence(raw_text: str): |
| tokens = [t.strip() for t in re.split(r"[\s,;\n\t]+", raw_text) if t.strip()] |
| return tokens |
|
|
|
|
| def embed_gene_sequence(raw_text: str, max_len: int = DEFAULT_MAX_LEN, normalize: bool = False): |
| genes = parse_gene_sequence(raw_text) |
| if len(genes) == 0: |
| raise gr.Error("Please provide at least one gene token (e.g., ENSG00000000003).") |
|
|
| ids = [] |
| unknown_genes = [] |
| for g in genes: |
| if g in vocab: |
| ids.append(vocab[g]) |
| else: |
| ids.append(unk_id) |
| unknown_genes.append(g) |
|
|
| ids = ids[:max_len] |
| if len(ids) < max_len: |
| ids = ids + [pad_id] * (max_len - len(ids)) |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids=input_ids) |
| emb = outputs.pooled_embedding[0] |
| if normalize: |
| emb = torch.nn.functional.normalize(emb, p=2, dim=0) |
| emb = emb.detach().cpu().tolist() |
|
|
| return { |
| "model_repo": MODEL_REPO, |
| "embedding_dim": len(emb), |
| "input_gene_count": len(genes), |
| "used_tokens": min(len(genes), max_len), |
| "unknown_gene_count": len(unknown_genes), |
| "unknown_genes_preview": unknown_genes[:20], |
| "embedding": emb, |
| } |
|
|
|
|
| DESCRIPTION = """ |
| Input a gene sequence (Ensembl IDs separated by space/comma/newline), then get the GeneMamba pooled embedding. |
| |
| Examples: |
| ENSG00000000003 ENSG00000000419 ENSG00000001036 |
| """ |
|
|
|
|
| demo = gr.Interface( |
| fn=embed_gene_sequence, |
| inputs=[ |
| gr.Textbox(lines=8, label="Gene sequence (ENSG IDs)", placeholder="ENSG00000000003 ENSG00000000419 ..."), |
| gr.Slider(64, DEFAULT_MAX_LEN, value=DEFAULT_MAX_LEN, step=64, label="Max sequence length"), |
| gr.Checkbox(value=False, label="L2 normalize embedding"), |
| ], |
| outputs=gr.JSON(label="Embedding Result"), |
| title="GeneMamba Embedding API", |
| description=DESCRIPTION, |
| allow_flagging="never", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|