|
|
--- |
|
|
language: en |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- bert |
|
|
- fast |
|
|
- monarch-matrices |
|
|
- mnli |
|
|
- efficiency |
|
|
- triton |
|
|
- hardware-efficient |
|
|
- sub-quadratic |
|
|
- fast-inference |
|
|
- h100-optimized |
|
|
datasets: |
|
|
- glue |
|
|
- wikipedia |
|
|
metrics: |
|
|
- accuracy |
|
|
- throughput |
|
|
- latency |
|
|
pipeline_tag: text-classification |
|
|
model-index: |
|
|
- name: Monarch-BERT-Base-MNLI-Full |
|
|
results: |
|
|
- task: |
|
|
type: text-classification |
|
|
name: Natural Language Inference |
|
|
dataset: |
|
|
name: GLUE MNLI |
|
|
type: glue |
|
|
config: mnli |
|
|
split: validation_matched |
|
|
metrics: |
|
|
- name: Accuracy |
|
|
type: accuracy |
|
|
value: 78.34% |
|
|
description: "Approx. 5% accuracy trade-off for maximum speed compared to dense BERT." |
|
|
- name: Throughput (TPS on H100) |
|
|
type: throughput |
|
|
value: 9029.4 |
|
|
description: "Measured with torch.compile(mode='max-autotune') on NVIDIA H100. Represents a 24.3% increase over the optimized Triton baseline." |
|
|
- name: Latency (ms) |
|
|
type: latency |
|
|
value: 3.54 |
|
|
description: "Batch size 32, sequence length 128. Achieves a ~20% faster inference time." |
|
|
--- |
|
|
|
|
|
# Monarch-BERT-MNLI (Full) |
|
|
|
|
|
**Breaking the Efficiency Barrier: -51.5% Parameters, +24% Speed.** |
|
|
|
|
|
> **tl;dr:** Achieving **extreme resource efficiency** on **MNLI**. We replaced *every* dense FFN layer in BERT-Base with structured Monarch Matrices. Distilled in **just 3 hours on one H100** using only **500k Wiki tokens** and **MNLI** data, this model slashes parameters by 51.5% and boosts throughput by **+24%** (vs optimized Baseline). |
|
|
|
|
|
## High Performance, Low Cost |
|
|
|
|
|
Training models from scratch typically requires billions of tokens. We took a different path to shock the efficiency curve: |
|
|
|
|
|
* **Training Time:** A few hours on **1x NVIDIA H100**. |
|
|
* **Data:** Only **MNLI** + **500k Wikipedia Samples**. |
|
|
* **Trade-off:** This extreme compression comes with a moderate accuracy drop (~5%). *Need higher accuracy? Check out our [Hybrid Version](https://huggingface.co/ykae/monarch-bert-base-mnli-hybrid) (<1% loss).* |
|
|
|
|
|
## Key Benchmarks |
|
|
|
|
|
Measured on a single NVIDIA H100 using `torch.compile(mode="max-autotune")`. |
|
|
|
|
|
| Metric | BERT-Base (Baseline) | **Monarch-Full (This)** | Delta | |
|
|
| :--- | :--- | :--- | :--- | |
|
|
| **Parameters** | 110.0M | **53.3M** | π **-51.5%** | |
|
|
| **Compute (GFLOPs)** | 696.5 | **232.6** | π **-66.6%** | |
|
|
| **Throughput (TPS)** | 7261 | **9029** | π **+24.3%** | |
|
|
| **Latency (Batch 32)** | 4.41 ms | **3.54 ms** | β‘ **19.7% Faster** | |
|
|
| **Accuracy (MNLI)** | 83.62% | **78.34%** | π **-5.28%** | |
|
|
|
|
|
## Usage |
|
|
|
|
|
This model uses a **custom architecture**. You must enable `trust_remote_code=True` to load the Monarch layers (`MonarchUp`, `MonarchDown`, `MonarchFFN`). |
|
|
|
|
|
To see the real speedup, **compilation is mandatory** (otherwise PyTorch Python overhead masks the hardware gains). |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from datasets import load_dataset |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model_id = "ykae/monarch-bert-base-mnli" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
).to(device) |
|
|
|
|
|
# torch.set_float32_matmul_precision('high') |
|
|
# model = torch.compile(model, mode="max-autotune") |
|
|
model.eval() |
|
|
|
|
|
print("π Loading MNLI Validation set...") |
|
|
dataset = load_dataset("glue", "mnli", split="validation_matched") |
|
|
|
|
|
def tokenize_fn(ex): |
|
|
return tokenizer(ex['premise'], ex['hypothesis'], |
|
|
padding="max_length", truncation=True, max_length=128) |
|
|
|
|
|
tokenized_ds = dataset.map(tokenize_fn, batched=True) |
|
|
tokenized_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) |
|
|
loader = DataLoader(tokenized_ds, batch_size=32) |
|
|
|
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
print(f"π Starting evaluation on {len(tokenized_ds)} samples...") |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(loader): |
|
|
ids = batch['input_ids'].to(device) |
|
|
mask = batch['attention_mask'].to(device) |
|
|
labels = batch['label'].to(device) |
|
|
|
|
|
outputs = model(ids, attention_mask=mask) |
|
|
preds = torch.argmax(outputs.logits, dim=1) |
|
|
|
|
|
correct += (preds == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
|
|
|
print(f"\nβ
Evaluation Finished!") |
|
|
print(f"π Accuracy: {100 * correct / total:.2f}%") |
|
|
``` |
|
|
|
|
|
## The "Memory Paradox" (Read this!) |
|
|
|
|
|
You might notice that while the parameter count is lower (82M vs 110M), the peak VRAM usage during inference can be slightly higher than the baseline. |
|
|
|
|
|
**Why?** |
|
|
This is a **software artifact**, not a hardware limitation. |
|
|
* **Solution:** A custom **Fused Triton Kernel** (planned) would fuse the steps of our Monarch, keeping intermediate activations in the GPU's SRAM. This would drop dynamic VRAM usage significantly below the baseline, matching the FLOPs reduction. |
|
|
|
|
|
## Citation |
|
|
```bibtex |
|
|
@misc{ykae-monarch-bert-mnli-2026, |
|
|
author = {Yusuf Kalyoncuoglu, YKAE-Vision}, |
|
|
title = {Monarch-BERT-MNLI: Extreme Compression via Monarch FFNs}, |
|
|
year = {2026}, |
|
|
publisher = {Hugging Face}, |
|
|
journal = {Hugging Face Model Hub}, |
|
|
howpublished = {\url{https://huggingface.co/ykae/monarch-bert-base-mnli}} |
|
|
} |
|
|
``` |