Upload python/train_cpu.py with huggingface_hub
Browse files- python/train_cpu.py +354 -0
python/train_cpu.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
H4 Polytopic Attention — CPU autoresearch training script.
|
| 3 |
+
This is the ONLY file the agent modifies during autonomous research.
|
| 4 |
+
|
| 5 |
+
Follows the autoresearch pattern: modify → run (2 min budget) → measure → keep/discard.
|
| 6 |
+
|
| 7 |
+
The frozen H4 geometry is off-limits. Only the trainable adapters, hyperparameters,
|
| 8 |
+
training loop details, and architecture of trainable layers may be changed.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import math
|
| 13 |
+
import time
|
| 14 |
+
import json
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
|
| 23 |
+
from h4_polytopic_attention import generate_600_cell_vertices, build_coxeter_chambers
|
| 24 |
+
from h4_language_model import H4LanguageModel
|
| 25 |
+
from bitlinear import BitLinear
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Constants (DO NOT MODIFY the frozen geometry section)
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
PHI = (1 + math.sqrt(5)) / 2
|
| 32 |
+
|
| 33 |
+
# Frozen geometric constants — loaded from existing code
|
| 34 |
+
VERTICES = torch.tensor(generate_600_cell_vertices(), dtype=torch.float32)
|
| 35 |
+
CHAMBERS = build_coxeter_chambers(VERTICES.numpy())
|
| 36 |
+
SIMPLE_ROOTS = torch.tensor(CHAMBERS['simple_roots'], dtype=torch.float32)
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Hyperparameters (AGENT MAY MODIFY THESE)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
# Time budget: 2 minutes on CPU
|
| 43 |
+
TIME_BUDGET = 120 # seconds
|
| 44 |
+
|
| 45 |
+
# Dataset: 'synthetic', 'shakespeare', or 'tinystories'
|
| 46 |
+
DATASET = 'synthetic'
|
| 47 |
+
|
| 48 |
+
# Data
|
| 49 |
+
MAX_SEQ_LEN = 128
|
| 50 |
+
BATCH_SIZE = 8
|
| 51 |
+
|
| 52 |
+
# Model
|
| 53 |
+
D_MODEL = 256
|
| 54 |
+
N_HEADS = 8
|
| 55 |
+
N_LAYERS = 4
|
| 56 |
+
D_VALUE = 16
|
| 57 |
+
D_FFN = 512
|
| 58 |
+
TOP_K = 16
|
| 59 |
+
DROPOUT = 0.0
|
| 60 |
+
USE_BITLINEAR = True # Set True for ternary {-1,0,+1} weights
|
| 61 |
+
|
| 62 |
+
# Optimizer
|
| 63 |
+
LR = 5e-3
|
| 64 |
+
WEIGHT_DECAY = 0.01
|
| 65 |
+
WARMUP_STEPS = 50
|
| 66 |
+
GRAD_CLIP = 1.0
|
| 67 |
+
|
| 68 |
+
# Eval
|
| 69 |
+
EVAL_INTERVAL = 25
|
| 70 |
+
EVAL_BATCHES = 5
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Data: Character-level Shakespeare (or synthetic if not available)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def load_text_data():
|
| 77 |
+
"""Load training text. Falls back to synthetic data if no file available."""
|
| 78 |
+
# Try to load Shakespeare or other text
|
| 79 |
+
data_paths = [
|
| 80 |
+
os.path.join(os.path.dirname(__file__), '..', 'data', 'shakespeare.txt'),
|
| 81 |
+
os.path.join(os.path.dirname(__file__), '..', 'data', 'input.txt'),
|
| 82 |
+
os.path.join(os.path.dirname(__file__), 'data', 'input.txt'),
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
text = None
|
| 86 |
+
for path in data_paths:
|
| 87 |
+
if os.path.exists(path):
|
| 88 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 89 |
+
text = f.read()
|
| 90 |
+
print(f"Loaded data from {path} ({len(text)} chars)")
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
if text is None:
|
| 94 |
+
# Generate synthetic text with mathematical structure
|
| 95 |
+
# Fibonacci-structured repetitions to test geometric inductive bias
|
| 96 |
+
print("No data file found, generating synthetic text...")
|
| 97 |
+
base_phrases = [
|
| 98 |
+
"the golden ratio appears in nature ",
|
| 99 |
+
"fibonacci numbers grow exponentially ",
|
| 100 |
+
"symmetry underlies all of physics ",
|
| 101 |
+
"the icosahedron has twenty faces ",
|
| 102 |
+
"phi equals one plus one over phi ",
|
| 103 |
+
"geometry is the language of space ",
|
| 104 |
+
"five fold symmetry cannot tile a plane ",
|
| 105 |
+
"the dodecahedron has twelve faces ",
|
| 106 |
+
]
|
| 107 |
+
# Build text with Fibonacci-structured repetitions
|
| 108 |
+
text = ""
|
| 109 |
+
a, b = 1, 1
|
| 110 |
+
for _ in range(200):
|
| 111 |
+
phrase = base_phrases[a % len(base_phrases)]
|
| 112 |
+
text += phrase * (b % 3 + 1)
|
| 113 |
+
a, b = b, a + b
|
| 114 |
+
|
| 115 |
+
return text
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def prepare_char_dataset(text: str):
|
| 119 |
+
"""Prepare character-level dataset from text."""
|
| 120 |
+
chars = sorted(list(set(text)))
|
| 121 |
+
vocab_size = len(chars)
|
| 122 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 123 |
+
itos = {i: ch for ch, i in stoi.items()}
|
| 124 |
+
|
| 125 |
+
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
|
| 126 |
+
|
| 127 |
+
# Split 90/10
|
| 128 |
+
n = int(0.9 * len(data))
|
| 129 |
+
train_data = data[:n]
|
| 130 |
+
val_data = data[n:]
|
| 131 |
+
|
| 132 |
+
return train_data, val_data, vocab_size, stoi, itos
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_batch(data: torch.Tensor, batch_size: int, seq_len: int):
|
| 136 |
+
"""Sample a random batch of sequences."""
|
| 137 |
+
max_start = len(data) - seq_len - 1
|
| 138 |
+
if max_start <= 0:
|
| 139 |
+
max_start = 1
|
| 140 |
+
ix = torch.randint(0, max_start, (batch_size,))
|
| 141 |
+
x = torch.stack([data[i:i + seq_len] for i in ix])
|
| 142 |
+
y = torch.stack([data[i + 1:i + seq_len + 1] for i in ix])
|
| 143 |
+
return x, y
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# Training loop (follows autoresearch pattern)
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def main():
|
| 151 |
+
t_start = time.time()
|
| 152 |
+
torch.manual_seed(42)
|
| 153 |
+
np.random.seed(42)
|
| 154 |
+
|
| 155 |
+
# Load data
|
| 156 |
+
if DATASET != 'synthetic':
|
| 157 |
+
from prepare_data import load_and_prepare
|
| 158 |
+
train_data, val_data, vocab_size, stoi, itos = load_and_prepare(DATASET)
|
| 159 |
+
else:
|
| 160 |
+
text = load_text_data()
|
| 161 |
+
train_data, val_data, vocab_size, stoi, itos = prepare_char_dataset(text)
|
| 162 |
+
print(f"Vocab size: {vocab_size}, Train: {len(train_data)}, Val: {len(val_data)}")
|
| 163 |
+
|
| 164 |
+
# Create model
|
| 165 |
+
model = H4LanguageModel(
|
| 166 |
+
vocab_size=vocab_size,
|
| 167 |
+
d_model=D_MODEL,
|
| 168 |
+
n_heads=N_HEADS,
|
| 169 |
+
n_layers=N_LAYERS,
|
| 170 |
+
d_value=D_VALUE,
|
| 171 |
+
d_ffn=D_FFN,
|
| 172 |
+
top_k=TOP_K,
|
| 173 |
+
max_seq_len=MAX_SEQ_LEN * 2,
|
| 174 |
+
dropout=DROPOUT,
|
| 175 |
+
use_bitlinear=USE_BITLINEAR,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
param_info = model.count_params()
|
| 179 |
+
print(f"Model params: {param_info['trainable']:,} trainable, {param_info['buffers']:,} buffer elements")
|
| 180 |
+
|
| 181 |
+
# Optimizer: AdamW with cosine schedule
|
| 182 |
+
optimizer = torch.optim.AdamW(
|
| 183 |
+
model.parameters(),
|
| 184 |
+
lr=LR,
|
| 185 |
+
weight_decay=WEIGHT_DECAY,
|
| 186 |
+
betas=(0.9, 0.95),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Cosine LR schedule with warmup
|
| 190 |
+
def lr_schedule(step):
|
| 191 |
+
if step < WARMUP_STEPS:
|
| 192 |
+
return step / max(WARMUP_STEPS, 1)
|
| 193 |
+
# Cosine decay to 10% of peak
|
| 194 |
+
progress = (step - WARMUP_STEPS) / max(1, 500 - WARMUP_STEPS)
|
| 195 |
+
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0)))
|
| 196 |
+
|
| 197 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
|
| 198 |
+
|
| 199 |
+
# Training state
|
| 200 |
+
step = 0
|
| 201 |
+
total_training_time = 0.0
|
| 202 |
+
best_val_loss = float('inf')
|
| 203 |
+
train_losses = []
|
| 204 |
+
val_losses = []
|
| 205 |
+
|
| 206 |
+
# Use full attention (no tree) for short sequences during training
|
| 207 |
+
# Tree is beneficial for long sequences; for seq_len=128, full attention is faster
|
| 208 |
+
use_tree = MAX_SEQ_LEN > 256
|
| 209 |
+
|
| 210 |
+
print(f"\nTraining for {TIME_BUDGET}s budget, seq_len={MAX_SEQ_LEN}, use_tree={use_tree}")
|
| 211 |
+
print(f"{'step':>6} {'loss':>8} {'val_loss':>8} {'lr':>10} {'dt':>6} {'progress':>8}")
|
| 212 |
+
print("-" * 56)
|
| 213 |
+
|
| 214 |
+
model.train()
|
| 215 |
+
|
| 216 |
+
while True:
|
| 217 |
+
t0 = time.time()
|
| 218 |
+
|
| 219 |
+
# Get batch
|
| 220 |
+
x, y = get_batch(train_data, BATCH_SIZE, MAX_SEQ_LEN)
|
| 221 |
+
|
| 222 |
+
# Forward
|
| 223 |
+
logits = model(x, use_tree=use_tree)
|
| 224 |
+
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
|
| 225 |
+
|
| 226 |
+
# Backward
|
| 227 |
+
optimizer.zero_grad()
|
| 228 |
+
loss.backward()
|
| 229 |
+
|
| 230 |
+
# Gradient clipping
|
| 231 |
+
if GRAD_CLIP > 0:
|
| 232 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
|
| 233 |
+
|
| 234 |
+
optimizer.step()
|
| 235 |
+
scheduler.step()
|
| 236 |
+
|
| 237 |
+
dt = time.time() - t0
|
| 238 |
+
if step > 2: # skip warmup steps for timing
|
| 239 |
+
total_training_time += dt
|
| 240 |
+
|
| 241 |
+
train_losses.append(loss.item())
|
| 242 |
+
|
| 243 |
+
# Eval
|
| 244 |
+
val_loss = None
|
| 245 |
+
if step % EVAL_INTERVAL == 0:
|
| 246 |
+
model.eval()
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
vl = []
|
| 249 |
+
for _ in range(EVAL_BATCHES):
|
| 250 |
+
xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
|
| 251 |
+
vlogits = model(xv, use_tree=False)
|
| 252 |
+
vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
|
| 253 |
+
val_loss = sum(vl) / len(vl)
|
| 254 |
+
val_losses.append(val_loss)
|
| 255 |
+
|
| 256 |
+
if val_loss < best_val_loss:
|
| 257 |
+
best_val_loss = val_loss
|
| 258 |
+
|
| 259 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 260 |
+
progress = min(total_training_time / TIME_BUDGET, 1.0)
|
| 261 |
+
print(f"{step:6d} {loss.item():8.4f} {val_loss:8.4f} {current_lr:10.6f} {dt:6.3f} {progress:7.1%}")
|
| 262 |
+
model.train()
|
| 263 |
+
|
| 264 |
+
step += 1
|
| 265 |
+
if step > 2 and total_training_time >= TIME_BUDGET:
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
# Final evaluation
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
|
| 272 |
+
model.eval()
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
# Final val loss
|
| 275 |
+
vl = []
|
| 276 |
+
for _ in range(EVAL_BATCHES * 4):
|
| 277 |
+
xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
|
| 278 |
+
vlogits = model(xv, use_tree=False)
|
| 279 |
+
vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
|
| 280 |
+
final_val_loss = sum(vl) / len(vl)
|
| 281 |
+
|
| 282 |
+
# Bits per byte (for character-level: loss_nats / ln(2))
|
| 283 |
+
val_bpb = final_val_loss / math.log(2)
|
| 284 |
+
|
| 285 |
+
# Geometric diagnostics on a sample batch
|
| 286 |
+
xd, _ = get_batch(val_data, 1, MAX_SEQ_LEN)
|
| 287 |
+
_, diag_list = model(xd, use_tree=False, return_diagnostics=True)
|
| 288 |
+
|
| 289 |
+
# Aggregate diagnostics across layers
|
| 290 |
+
avg_chamber_entropy = np.mean([d['chamber_entropy'] for d in diag_list])
|
| 291 |
+
nudge_ranks = []
|
| 292 |
+
geo_aligns = []
|
| 293 |
+
for d in diag_list:
|
| 294 |
+
nudge_ranks.extend(d['nudge_rank'])
|
| 295 |
+
geo_aligns.extend(d['geo_alignment'])
|
| 296 |
+
avg_nudge_rank = np.mean([r for r in nudge_ranks if r != float('inf')] or [0])
|
| 297 |
+
avg_geo_alignment = np.mean(geo_aligns)
|
| 298 |
+
|
| 299 |
+
# Generate sample text
|
| 300 |
+
seed_text = list(stoi.keys())[:4] # first 4 chars
|
| 301 |
+
seed_ids = torch.tensor([[stoi[c] for c in seed_text]], dtype=torch.long)
|
| 302 |
+
generated = model.generate(seed_ids, max_new_tokens=80, temperature=0.8, top_k_sample=10)
|
| 303 |
+
gen_text = ''.join([itos.get(i.item(), '?') for i in generated[0]])
|
| 304 |
+
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
# Summary (autoresearch-parseable format)
|
| 307 |
+
# ---------------------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
# Ternary diagnostics (if using BitLinear)
|
| 310 |
+
has_bitlinear = any(isinstance(m, BitLinear) for m in model.modules())
|
| 311 |
+
ternary_info = {}
|
| 312 |
+
if has_bitlinear:
|
| 313 |
+
from ternary_diagnostics import chamber_preservation, bitlinear_layer_stats, size_comparison
|
| 314 |
+
cp = chamber_preservation(model)
|
| 315 |
+
mean_cp = sum(cp.values()) / len(cp) if cp else 0.0
|
| 316 |
+
bl_stats = bitlinear_layer_stats(model)
|
| 317 |
+
mean_zero_pct = np.mean([s['zero'] for s in bl_stats.values()]) if bl_stats else 0.0
|
| 318 |
+
sz = size_comparison(model)
|
| 319 |
+
ternary_info = {
|
| 320 |
+
'chamber_preserve': mean_cp,
|
| 321 |
+
'mean_zero_pct': mean_zero_pct,
|
| 322 |
+
'compression': sz['compression'],
|
| 323 |
+
'mixed_kb': sz['mixed_kb'],
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
print("\n" + "=" * 60)
|
| 327 |
+
print("GENERATED SAMPLE:")
|
| 328 |
+
print(gen_text[:200])
|
| 329 |
+
print("=" * 60)
|
| 330 |
+
|
| 331 |
+
print("\n---")
|
| 332 |
+
print(f"val_bpb: {val_bpb:.6f}")
|
| 333 |
+
print(f"val_loss: {final_val_loss:.6f}")
|
| 334 |
+
print(f"best_val_loss: {best_val_loss:.6f}")
|
| 335 |
+
print(f"chamber_entropy: {avg_chamber_entropy:.4f}")
|
| 336 |
+
print(f"avg_nudge_rank: {avg_nudge_rank:.4f}")
|
| 337 |
+
print(f"avg_geo_alignment: {avg_geo_alignment:.4f}")
|
| 338 |
+
print(f"training_seconds: {total_training_time:.1f}")
|
| 339 |
+
print(f"total_seconds: {time.time() - t_start:.1f}")
|
| 340 |
+
print(f"peak_memory_mb: {0:.1f}")
|
| 341 |
+
print(f"num_steps: {step}")
|
| 342 |
+
print(f"num_params: {param_info['trainable']}")
|
| 343 |
+
print(f"vocab_size: {vocab_size}")
|
| 344 |
+
print(f"seq_len: {MAX_SEQ_LEN}")
|
| 345 |
+
print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}")
|
| 346 |
+
if ternary_info:
|
| 347 |
+
print(f"chamber_preserve: {ternary_info['chamber_preserve']:.4f}")
|
| 348 |
+
print(f"mean_zero_pct: {ternary_info['mean_zero_pct']:.4f}")
|
| 349 |
+
print(f"compression: {ternary_info['compression']:.1f}x")
|
| 350 |
+
print(f"model_size_kb: {ternary_info['mixed_kb']:.1f}")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == '__main__':
|
| 354 |
+
main()
|