File size: 2,816 Bytes
13144d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import re
import torch
import gradio as gr
from transformers import AutoModel, AutoTokenizer

MODEL_REPO = os.getenv("MODEL_REPO", "mineself2016/GeneMamba")
DEFAULT_MAX_LEN = int(os.getenv("MAX_LEN", "2048"))


def _load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True)
    model = AutoModel.from_pretrained(MODEL_REPO, trust_remote_code=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    model.eval()
    return tokenizer, model, device


tokenizer, model, device = _load_model()
vocab = tokenizer.get_vocab()
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 1
unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 0


def parse_gene_sequence(raw_text: str):
    tokens = [t.strip() for t in re.split(r"[\s,;\n\t]+", raw_text) if t.strip()]
    return tokens


def embed_gene_sequence(raw_text: str, max_len: int = DEFAULT_MAX_LEN, normalize: bool = False):
    genes = parse_gene_sequence(raw_text)
    if len(genes) == 0:
        raise gr.Error("Please provide at least one gene token (e.g., ENSG00000000003).")

    ids = []
    unknown_genes = []
    for g in genes:
        if g in vocab:
            ids.append(vocab[g])
        else:
            ids.append(unk_id)
            unknown_genes.append(g)

    ids = ids[:max_len]
    if len(ids) < max_len:
        ids = ids + [pad_id] * (max_len - len(ids))

    input_ids = torch.tensor([ids], dtype=torch.long, device=device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids)
        emb = outputs.pooled_embedding[0]
        if normalize:
            emb = torch.nn.functional.normalize(emb, p=2, dim=0)
        emb = emb.detach().cpu().tolist()

    return {
        "model_repo": MODEL_REPO,
        "embedding_dim": len(emb),
        "input_gene_count": len(genes),
        "used_tokens": min(len(genes), max_len),
        "unknown_gene_count": len(unknown_genes),
        "unknown_genes_preview": unknown_genes[:20],
        "embedding": emb,
    }


DESCRIPTION = """
Input a gene sequence (Ensembl IDs separated by space/comma/newline), then get the GeneMamba pooled embedding.

Examples:
ENSG00000000003 ENSG00000000419 ENSG00000001036
"""


demo = gr.Interface(
    fn=embed_gene_sequence,
    inputs=[
        gr.Textbox(lines=8, label="Gene sequence (ENSG IDs)", placeholder="ENSG00000000003 ENSG00000000419 ..."),
        gr.Slider(64, DEFAULT_MAX_LEN, value=DEFAULT_MAX_LEN, step=64, label="Max sequence length"),
        gr.Checkbox(value=False, label="L2 normalize embedding"),
    ],
    outputs=gr.JSON(label="Embedding Result"),
    title="GeneMamba Embedding API",
    description=DESCRIPTION,
    allow_flagging="never",
)


if __name__ == "__main__":
    demo.launch()