π BexaMask-v2 (β800M Parameters)
BexaMask-v2 is a pretrained base (foundation) decoder-only Transformer model trained on large-scale permissively licensed and uncopyrighted text data using the MaxText framework on TPU v4-16.
β οΈ This is a base model β it is not instruction-tuned and may not follow prompts like ChatGPT without further fine-tuning.
π§ Model Overview
- Type: Pretrained Base Model (Foundation Model)
- Architecture: Decoder-only Transformer
- Parameters: ~800M
- Layers: 16
- Embedding Dimension: 2048
- MLP Dimension: 5120
- Attention Heads:
- Query Heads: 16
- KV Heads: 4 (Grouped Query Attention)
- Head Dimension: 128
- Activation: SiLU + Linear
- Max Context Length: 4096 tokens
- Vocabulary Size: 32,000 (SentencePiece)
βοΈ Training Details
- Framework: MaxText
- Hardware: TPU v4-16 (8 chips, 256GB HBM)
π¦ Dataset
- Subset of The Pile (uncopyrighted / permissive sources only)
- Filtered to remove restricted or copyrighted data
π§ Training Config
- Steps: 100,000
- Epochs: 2
- Batch Size: 16 per device
- Learning Rate: 3e-4
- Warmup Steps: 2,000
- Scheduler: Cosine decay
β‘ Optimization Techniques
- Flash Attention
- Full Rematerialization (Remat)
- Asynchronous Checkpointing
- Distributed GCS checkpointing
- IOTA embeddings
π§ͺ Inference
Run inference using MaxText:
python3 -m MaxText.decode \
maxtext/configs/pretrain.yml \
run_name=inference \
load_parameters_path=/home/pynatic079/bexamask_v2_inference_local/items \
tokenizer_path=/path/to/llama/tokenizer.model \
max_target_length=512 \
'prompt="<Your prompt>"' \
decode_sampling_strategy="topk" \
decode_sampling_top_k=4 \
decode_sampling_temperature=1.9 \
attention=dot_product