molcrawl-rna-celltype-gpt2-medium

Model Description

GPT-2 medium (345M parameters) fine-tuned on cell-type specific RNA sequences, starting from the molcrawl-rna-gpt2-medium pre-trained model.

  • Model Type: gpt2
  • Data Type: RNA
  • Training Date: 2026-04-24

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("kojima-lab/molcrawl-rna-celltype-gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("kojima-lab/molcrawl-rna-celltype-gpt2-medium")

# Generate next gene-id tokens (RNA gene-list model)
prompt = "ENSG00000000003 ENSG00000000005 ENSG00000000419"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.8,
        eos_token_id=None,  # HF config.json has legacy eos_token_id=0; disable early stop
        pad_token_id=0,
    )
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

Source Code

Training pipeline, configuration files, and data preparation scripts are available in the MolCrawl GitHub repository: https://github.com/mmai-framework-lab/MolCrawl

License

This model is released under the APACHE-2.0 license.

Citation

If you use this model, please cite:

@misc{molcrawl_rna_celltype_gpt2_medium,
  title={molcrawl-rna-celltype-gpt2-medium},
  author={{RIKEN}},
  year={2026},
  publisher={{Hugging Face}},
  url={{https://huggingface.co/kojima-lab/molcrawl-rna-celltype-gpt2-medium}}
}

Example Output

Inference test (CPU). This GPT-2 model is trained on a WordLevel gene-ID vocabulary — each input position is one ENSEMBL gene ID. Use convert_tokens_to_ids to encode a prefix, then ask the model for the next gene.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

REPO_ID = "kojima-lab/molcrawl-rna-celltype-gpt2-medium"
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = AutoModelForCausalLM.from_pretrained(REPO_ID)
model.eval()

# Prefix sequence of ENSEMBL gene IDs (Geneformer-style ranked input)
genes = [
    "ENSG00000000003",
    "ENSG00000000005",
    "ENSG00000001167",
    "ENSG00000002586",
]
ids = tokenizer.convert_tokens_to_ids(genes)
input_ids = torch.tensor([ids])

with torch.no_grad():
    outputs = model(input_ids=input_ids)

# Next-token (next-gene) prediction
next_id = outputs.logits[0, -1].argmax(dim=-1).item()
next_gene = tokenizer.convert_ids_to_tokens([next_id])[0]
print(f"Predicted next gene: {next_gene}")
# => Predicted next gene: PLB1
Downloads last month
2,770
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including kojima-lab/molcrawl-rna-celltype-gpt2-medium