YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
SimpleStories Bilinear Attention Models
Collection of ToyTransformer models with CausalBilinearSelfAttention trained on SimpleStories dataset.
Models
| Model | Layers | Type | Final Val Loss |
|---|---|---|---|
attention_only_1L_bilinear_attn |
1 | Attention only | ~2.91 |
attention_only_2L_bilinear_attn |
2 | Attention only | ~2.87 |
Architecture
These models use CausalBilinearSelfAttention which computes squared attention patterns:
- Two QK pairs:
(Q1, K1)and(Q2, K2) - Attention scores:
(Q1 @ K1.T) * (Q2 @ K2.T)(element-wise product) - Causal masking applied
- RoPE (Rotary Positional Embeddings)
Training Details
- Dataset: SimpleStories (4096 vocab)
- d_model: 256
- n_head: 4
- n_ctx: 512
- Optimizer: Muon
- Batches: 100,000
- Batch size: 128
Usage
from huggingface_hub import hf_hub_download
import torch
import json
# Download a model
model_name = "attention_only_1L_bilinear_attn"
config_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/config.json")
weights_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/pytorch_model.bin")
# Load config
with open(config_path) as f:
config_dict = json.load(f)
# Use with ToyTransformer from this repo
from utils import ToyTransformer, ModelConfig
for key in ['num_batches', 'final_train_loss', 'final_val_loss', 'learning_rate', 'batch_size']:
config_dict.pop(key, None)
config = ModelConfig(**config_dict)
model = ToyTransformer(config)
model.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)
Repository
Source code: toy_models_of_tensor_networks
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support