YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

SimpleStories Bilinear Attention Models

Collection of ToyTransformer models with CausalBilinearSelfAttention trained on SimpleStories dataset.

Models

Model Layers Type Final Val Loss
attention_only_1L_bilinear_attn 1 Attention only ~2.91
attention_only_2L_bilinear_attn 2 Attention only ~2.87

Architecture

These models use CausalBilinearSelfAttention which computes squared attention patterns:

  • Two QK pairs: (Q1, K1) and (Q2, K2)
  • Attention scores: (Q1 @ K1.T) * (Q2 @ K2.T) (element-wise product)
  • Causal masking applied
  • RoPE (Rotary Positional Embeddings)

Training Details

  • Dataset: SimpleStories (4096 vocab)
  • d_model: 256
  • n_head: 4
  • n_ctx: 512
  • Optimizer: Muon
  • Batches: 100,000
  • Batch size: 128

Usage

from huggingface_hub import hf_hub_download
import torch
import json

# Download a model
model_name = "attention_only_1L_bilinear_attn"
config_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/config.json")
weights_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/pytorch_model.bin")

# Load config
with open(config_path) as f:
    config_dict = json.load(f)

# Use with ToyTransformer from this repo
from utils import ToyTransformer, ModelConfig
for key in ['num_batches', 'final_train_loss', 'final_val_loss', 'learning_rate', 'batch_size']:
    config_dict.pop(key, None)
config = ModelConfig(**config_dict)
model = ToyTransformer(config)
model.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)

Repository

Source code: toy_models_of_tensor_networks

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support