ykae's picture
Update README.md
78f9257 verified
metadata
language: en
license: apache-2.0
library_name: transformers
tags:
  - bert
  - fast
  - monarch-matrices
  - mnli
  - efficiency
  - triton
  - hardware-efficient
  - sub-quadratic
  - fast-inference
  - h100-optimized
datasets:
  - glue
  - wikipedia
metrics:
  - accuracy
  - throughput
  - latency
pipeline_tag: text-classification
model-index:
  - name: Monarch-BERT-Base-MNLI-Full
    results:
      - task:
          type: text-classification
          name: Natural Language Inference
        dataset:
          name: GLUE MNLI
          type: glue
          config: mnli
          split: validation_matched
        metrics:
          - name: Accuracy
            type: accuracy
            value: 79.4
            description: >-
              Approx. 5% accuracy trade-off for maximum speed compared to dense
              BERT.
          - name: Throughput (TPS on H100)
            type: throughput
            value: 9029.4
            description: >-
              Measured with torch.compile(mode='max-autotune') on NVIDIA H100.
              Represents a 24.3% increase over the optimized Triton baseline.
          - name: Latency (ms)
            type: latency
            value: 3.54
            description: >-
              Batch size 32, sequence length 128. Achieves a ~20% faster
              inference time.

Monarch-BERT-MNLI (Full)

Breaking the Efficiency Barrier: -51.5% Parameters, +24% Speed.

tl;dr: Achieving extreme resource efficiency on MNLI. We replaced every dense FFN layer in BERT-Base with structured Monarch Matrices. Distilled in just 3 hours on one H100 using only 500k Wiki tokens and MNLI data, this model slashes parameters by 51.5% and boosts throughput by +24% (vs optimized Baseline).

High Performance, Low Cost

Training models from scratch typically requires billions of tokens. We took a different path to shock the efficiency curve:

  • Training Time: A few hours on 1x NVIDIA H100.
  • Data: Only MNLI + 500k Wikipedia Samples.
  • Trade-off: This extreme compression comes with a moderate accuracy drop (~5%). Need higher accuracy? Check out our Hybrid Version (<1% loss).

Key Benchmarks

Measured on a single NVIDIA H100 using torch.compile(mode="max-autotune").

Metric BERT-Base (Baseline) Monarch-Full (This) Delta
Parameters 110.0M 53.3M πŸ“‰ -51.5%
Compute (GFLOPs) 696.5 232.6 πŸ“‰ -66.6%
Throughput (TPS) 7261 9029 πŸš€ +24.3%
Latency (Batch 32) 4.41 ms 3.54 ms ⚑ 19.7% Faster
Accuracy (MNLI) 83.62% 78.18% πŸ“‰ -5.44%

Usage

This model uses a custom architecture. You must enable trust_remote_code=True to load the Monarch layers (MonarchUp, MonarchDown, MonarchFFN).

To see the real speedup, compilation is mandatory (otherwise PyTorch Python overhead masks the hardware gains).

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "ykae/monarch-bert-base-mnli"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
    model_id, 
    trust_remote_code=True
).to(device)

# torch.set_float32_matmul_precision('high')
# model = torch.compile(model, mode="max-autotune")
model.eval()

print("πŸ“Š Loading MNLI Validation set...")
dataset = load_dataset("glue", "mnli", split="validation_matched")

def tokenize_fn(ex):
    return tokenizer(ex['premise'], ex['hypothesis'], 
                     padding="max_length", truncation=True, max_length=128)

tokenized_ds = dataset.map(tokenize_fn, batched=True)
tokenized_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
loader = DataLoader(tokenized_ds, batch_size=32)

correct = 0
total = 0

print(f"πŸš€ Starting evaluation on {len(tokenized_ds)} samples...")
with torch.no_grad():
    for batch in tqdm(loader):
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(ids, attention_mask=mask)
        preds = torch.argmax(outputs.logits, dim=1)
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"\nβœ… Evaluation Finished!")
print(f"πŸ“ˆ Accuracy: {100 * correct / total:.2f}%")

The "Memory Paradox" (Read this!)

You might notice that while the parameter count is lower (82M vs 110M), the peak VRAM usage during inference can be slightly higher than the baseline.

Why? This is a software artifact, not a hardware limitation.

  • Solution: A custom Fused Triton Kernel (planned) would fuse the steps of our Monarch, keeping intermediate activations in the GPU's SRAM. This would drop dynamic VRAM usage significantly below the baseline, matching the FLOPs reduction.

Citation

@misc{ykae-monarch-bert-mnli-2026,
  author = {Yusuf Kalyoncuoglu, YKAE-Vision},
  title = {Monarch-BERT-MNLI: Extreme Compression via Monarch FFNs},
  year = {2026},
  publisher = {Hugging Face},
  journal = {Hugging Face Model Hub},
  howpublished = {\url{https://huggingface.co/ykae/monarch-bert-base-mnli}}
}