|
|
--- |
|
|
tags: |
|
|
- biology |
|
|
- genomics |
|
|
- dna-compression |
|
|
- causal-language-modeling |
|
|
- gpt2 |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- dnabert-2 |
|
|
library_name: transformers |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
# DNAGPT2: Genomic Large Language Model for Compression and Analysis |
|
|
|
|
|
**DNAGPT2** is a family of autoregressive (decoder-only) transformer models trained on genomic DNA sequences. |
|
|
|
|
|
The models follow the GPT-2 architecture and are trained from scratch on a multi-species genome dataset. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type:** Causal Language Model (Decoder-only Transformer) |
|
|
- **Architecture:** GPT-2 Small |
|
|
- **Parameters:** ~86 Million |
|
|
- **Layers:** 12 |
|
|
- **Heads:** 12 |
|
|
- **Embedding Dimensions:** 768 |
|
|
- **Context Window:** 1,024 tokens |
|
|
- **Vocabulary Sizes:** Models are available with vocabulary sizes of 16, 32, **64**, 128, 256, 512, 1024, 2048, 4096, and 8192. |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
These models are designed for: |
|
|
1. **DNA Compression:** Used in conjunction with Arithmetic Encoding (AE) to compress genomic sequences. |
|
|
3. **Sequence Modeling:** Next-token prediction for DNA sequences. |
|
|
|
|
|
**Input:** Raw DNA sequences containing the characters `A`, `C`, `G`, `T`. |
|
|
**Output:** Logits/Probabilities for the next token in the sequence. |
|
|
|
|
|
## Training Data |
|
|
|
|
|
The models were pretrained on the dataset provided by the authors of **DNABERT-2**. |
|
|
- **Composition:** 135 genomes covering Vertebrata, Fungi, Protozoa, Invertebrata, and Bacteria. |
|
|
- **Size:** Approximately 32.5 billion nucleotides. |
|
|
- **Preprocessing:** The alphabet was restricted to **A, C, G, T**. The letter **N** (unknown/ambiguous nucleotide) was omitted from the training data. |
|
|
|
|
|
## Training Procedure |
|
|
|
|
|
The models were trained using the PyTorch framework and the `nanoGPT` recipe. |
|
|
|
|
|
- **Tokenizer:** Byte-Pair Encoding (BPE) trained via SentencePiece. |
|
|
- **Epochs:** 1 |
|
|
- **Optimization:** AdamW (Betas: 0.9, 0.95; Weight decay: 0.1) |
|
|
- **Learning Rate:** Cosine decay (Max: 8e-4, Min: 8e-5) with linear warmup. |
|
|
- **Batch Size:** $2^{19}$ tokens per step. |
|
|
- **Hardware:** Single NVIDIA A40 GPU. |
|
|
|
|
|
## Performance |
|
|
|
|
|
The models were evaluated on their ability to compress DNA sequences (measured in **bits per symbol** or **bps**) using Arithmetic Encoding. Lower is better. |
|
|
|
|
|
| Dataset | Metric | DNAGPT2_32 | Benchmark (gzip -9) | Benchmark (Jarvis3) | |
|
|
| :--- | :--- | :--- | :--- | :--- | |
|
|
| **Homo sapiens** (T2T-CHM13v2.0) | bits/symbol | **1.470** | 2.022 | 1.384 | |
|
|
| **M. llanfair...** (Bacteria) | bits/symbol | **1.783** | 2.142 | 1.713 | |
|
|
| **A. thaliana** (Plant - Chr1) | bits/symbol | **1.876** | 2.161 | 1.702 | |
|
|
|
|
|
The `DNAGPT2_32` model outperforms general-purpose compressors (gzip) and competitive deep learning models like `hyenaDNA` and `megaDNA` on the evaluated datasets. |
|
|
|
|
|
## How to Use |
|
|
|
|
|
The model is compatible with the Hugging Face `transformers` library. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
# Select the model variant (e.g., vocab size 128 or 32) |
|
|
# Replace with the specific repository path if hosted on HF Hub |
|
|
hf_model_repository = "vojtam/DNAGPT2_128" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
hf_model_repository, |
|
|
trust_remote_code=True |
|
|
).to(device) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
hf_model_repository, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
# Inference Example |
|
|
dna_sequence = "ACGTTGCAAACG" |
|
|
token_ids = tokenizer.encode(dna_sequence, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(token_ids).logits |
|
|
|
|
|
print(f"Input: {dna_sequence}") |
|
|
print(f"Logits shape: {logits.shape}") |