Update README.md
Browse files
README.md
CHANGED
|
@@ -79,24 +79,56 @@ To see the real speedup, **compilation is mandatory** (otherwise PyTorch Python
|
|
| 79 |
```python
|
| 80 |
import torch
|
| 81 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
# 1. Load Model
|
|
|
|
| 84 |
model_id = "ykae/monarch-bert-base-mnli"
|
|
|
|
| 85 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 86 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 87 |
model_id,
|
| 88 |
-
trust_remote_code=True
|
| 89 |
-
).to(
|
| 90 |
|
| 91 |
-
# 2.
|
| 92 |
torch.set_float32_matmul_precision('high')
|
| 93 |
model = torch.compile(model, mode="max-autotune")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
inputs = tokenizer("Monarch matrices are efficiently sparse.", return_tensors="pt").to("cuda")
|
| 97 |
with torch.no_grad():
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
```
|
| 101 |
|
| 102 |
## 🧠 The "Memory Paradox" (Read this!)
|
|
|
|
| 79 |
```python
|
| 80 |
import torch
|
| 81 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 82 |
+
from datasets import load_dataset
|
| 83 |
+
from torch.utils.data import DataLoader
|
| 84 |
+
from tqdm import tqdm
|
| 85 |
|
| 86 |
+
# 1. Setup & Load Model
|
| 87 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 88 |
model_id = "ykae/monarch-bert-base-mnli"
|
| 89 |
+
|
| 90 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 91 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 92 |
model_id,
|
| 93 |
+
trust_remote_code=True
|
| 94 |
+
).to(device)
|
| 95 |
|
| 96 |
+
# 2. Performance Optimization (Mandatory for Monarch Speed)
|
| 97 |
torch.set_float32_matmul_precision('high')
|
| 98 |
model = torch.compile(model, mode="max-autotune")
|
| 99 |
+
model.eval()
|
| 100 |
+
|
| 101 |
+
# 3. Load MNLI Validation Set
|
| 102 |
+
print("📊 Loading MNLI Validation set...")
|
| 103 |
+
dataset = load_dataset("glue", "mnli", split="validation_matched")
|
| 104 |
+
|
| 105 |
+
def tokenize_fn(ex):
|
| 106 |
+
return tokenizer(ex['premise'], ex['hypothesis'],
|
| 107 |
+
padding="max_length", truncation=True, max_length=128)
|
| 108 |
+
|
| 109 |
+
tokenized_ds = dataset.map(tokenize_fn, batched=True)
|
| 110 |
+
tokenized_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
| 111 |
+
loader = DataLoader(tokenized_ds, batch_size=32)
|
| 112 |
+
|
| 113 |
+
# 4. Scientific Evaluation
|
| 114 |
+
correct = 0
|
| 115 |
+
total = 0
|
| 116 |
|
| 117 |
+
print(f"🚀 Starting evaluation on {len(tokenized_ds)} samples...")
|
|
|
|
| 118 |
with torch.no_grad():
|
| 119 |
+
for batch in tqdm(loader):
|
| 120 |
+
ids = batch['input_ids'].to(device)
|
| 121 |
+
mask = batch['attention_mask'].to(device)
|
| 122 |
+
labels = batch['label'].to(device)
|
| 123 |
+
|
| 124 |
+
outputs = model(ids, attention_mask=mask)
|
| 125 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 126 |
+
|
| 127 |
+
correct += (preds == labels).sum().item()
|
| 128 |
+
total += labels.size(0)
|
| 129 |
+
|
| 130 |
+
print(f"\n✅ Evaluation Finished!")
|
| 131 |
+
print(f"📈 Accuracy: {100 * correct / total:.2f}%")
|
| 132 |
```
|
| 133 |
|
| 134 |
## 🧠 The "Memory Paradox" (Read this!)
|