YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

SimpleStories Bilinear Attention Models

Collection of ToyTransformer models with CausalBilinearSelfAttention trained on SimpleStories dataset.

Models

Model Layers Type Final Val Loss
attention_only_1L_bilinear_attn 1 Attention only ~2.91
attention_only_2L_bilinear_attn 2 Attention only ~2.87

Architecture

These models use CausalBilinearSelfAttention which computes squared attention patterns:

  • Two QK pairs: (Q1, K1) and (Q2, K2)
  • Attention scores: (Q1 @ K1.T) * (Q2 @ K2.T) (element-wise product)
  • Causal masking applied
  • RoPE (Rotary Positional Embeddings)

Training Details

  • Dataset: SimpleStories (4096 vocab)
  • d_model: 256
  • n_head: 4
  • n_ctx: 512
  • Optimizer: Muon
  • Batches: 100,000
  • Batch size: 128

Usage

from huggingface_hub import hf_hub_download
import torch
import json

# Download a model
model_name = "attention_only_1L_bilinear_attn"
config_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/config.json")
weights_path = hf_hub_download(repo_id="Elriggs/simplestories-bilinear-attn", filename=f"{model_name}/pytorch_model.bin")

# Load config
with open(config_path) as f:
    config_dict = json.load(f)

# Use with ToyTransformer from this repo
from utils import ToyTransformer, ModelConfig
for key in ['num_batches', 'final_train_loss', 'final_val_loss', 'learning_rate', 'batch_size']:
    config_dict.pop(key, None)
config = ModelConfig(**config_dict)
model = ToyTransformer(config)
model.load_state_dict(torch.load(weights_path, map_location='cpu'), strict=False)

Repository

Source code: toy_models_of_tensor_networks

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support