WCNegentropy commited on
Commit
060d6ba
·
verified ·
1 Parent(s): 2b57524

Remove nested directory: BitTransformerLM/bit_transformer/dashboard_app.py

Browse files
BitTransformerLM/bit_transformer/dashboard_app.py DELETED
@@ -1,927 +0,0 @@
1
- import io
2
- import json
3
- import os
4
- import traceback
5
- import inspect
6
- from typing import Any, Dict, List
7
-
8
- from flask import Flask, jsonify, request, render_template, send_file
9
- import subprocess
10
- import sys
11
- import warnings
12
- import matplotlib.pyplot as plt
13
- import torch
14
- import torch.nn.functional as F
15
- import requests
16
- import gzip
17
-
18
- from .model import BitTransformerLM, infer_long_sequence
19
- from .optimization import configure_optimizer
20
- from .collapse import collapse_submodel
21
- from .dashboard import plot_telemetry
22
- from .scale import expand_model
23
- from .bit_io import text_to_bits, bits_to_text
24
- from .safety import hil_safe_inference
25
- from .compression import model_output_decompress, compress_bits
26
- from .distributed import wrap_fsdp
27
- from .training import train_loop
28
- from .telemetry import detect_metric_drift
29
- from .quantization import prepare_qat_fx, convert_qat_fx
30
- from torch.distributed.fsdp import FullyShardedDataParallel
31
- from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
32
-
33
-
34
- app = Flask(__name__)
35
- app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 # 1MB upload limit
36
-
37
- MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR")
38
-
39
-
40
- @app.errorhandler(Exception)
41
- def handle_exception(err):
42
- """Return JSON error responses with stack traces."""
43
- return (
44
- jsonify({"error": str(err), "trace": traceback.format_exc()}),
45
- getattr(err, "code", 500),
46
- )
47
-
48
- class MetricDriftWarning(UserWarning):
49
- """Raised when telemetry metrics drift beyond the configured threshold."""
50
-
51
- def _switch_torch(use_gpu: bool) -> None:
52
- """Install the appropriate PyTorch wheel and restart the process."""
53
- have_cuda = torch.version.cuda is not None
54
- if use_gpu == have_cuda:
55
- return
56
- wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu"
57
- url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu"
58
- subprocess.run([
59
- sys.executable,
60
- "-m",
61
- "pip",
62
- "install",
63
- "--extra-index-url",
64
- url,
65
- wheel,
66
- ], check=True)
67
- os.execv(sys.executable, [sys.executable] + sys.argv)
68
-
69
- def mcp_post(path: str, data=None):
70
- if not MCP_SERVER_ADDR:
71
- return None
72
- url = MCP_SERVER_ADDR.rstrip("/") + path
73
- resp = requests.post(url, json=data)
74
- resp.raise_for_status()
75
- if resp.headers.get("Content-Type", "").startswith("image/"):
76
- return resp.content
77
- return resp.json()
78
-
79
- def mcp_get(path: str):
80
- if not MCP_SERVER_ADDR:
81
- return None
82
- url = MCP_SERVER_ADDR.rstrip("/") + path
83
- resp = requests.get(url)
84
- resp.raise_for_status()
85
- if resp.headers.get("Content-Type", "").startswith("image/"):
86
- return resp.content
87
- return resp.json()
88
-
89
- class ModelManager:
90
- """Manage model state and training utilities for the dashboard."""
91
-
92
- def __init__(
93
- self,
94
- snapshot_dir: str | None = None,
95
- telemetry_log: str | None = None,
96
- *,
97
- drift_window: int = 10,
98
- drift_threshold: float = 0.2,
99
- ) -> None:
100
- self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots")
101
- self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG")
102
- if self.telemetry_log is None:
103
- self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json")
104
- os.makedirs(self.snapshot_dir, exist_ok=True)
105
- self.weights_path = os.path.join(self.snapshot_dir, "model.pt")
106
-
107
- self.model: BitTransformerLM | None = None
108
- self.optimizer: torch.optim.Optimizer | None = None
109
- self.scheduler: torch.optim.lr_scheduler._LRScheduler | None = None
110
- self.total_steps = 100
111
- self.metrics: Dict[str, List[float]] = {
112
- "negentropy_logits": [],
113
- "lz_complexity_logits": [],
114
- "symbiosis_score": [],
115
- }
116
- self.drift_window = drift_window
117
- self.drift_threshold = drift_threshold
118
- self.lambda_K = 1.0
119
- self.lambda_C = 1.0
120
- self.lambda_S = 1.0
121
- self.c_floor = 0.3
122
- self.s_floor = 0.5
123
- self.causal = True
124
- self.diffusion = False
125
- self.decompress_output = False
126
- self.use_compression = False
127
- self.use_gpu = False
128
- self.qat = False
129
-
130
- # Load any existing state
131
- if os.path.exists(self.telemetry_log):
132
- try:
133
- with open(self.telemetry_log) as f:
134
- saved = json.load(f)
135
- for key in self.metrics:
136
- self.metrics[key] = saved.get(key, [])
137
- except Exception:
138
- pass
139
- if os.path.exists(self.weights_path):
140
- try:
141
- self.model = torch.load(self.weights_path, map_location="cpu")
142
- self.optimizer, self.scheduler = configure_optimizer(
143
- self.model, lr=1e-3, total_steps=self.total_steps
144
- )
145
- self._apply_device()
146
- except Exception:
147
- self.model = None
148
-
149
- config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json")
150
- if self.model is None and os.path.exists(config_path):
151
- try:
152
- with open(config_path) as f:
153
- params = json.load(f)
154
- self.init_model(params)
155
- except Exception:
156
- pass
157
-
158
- def init_model(self, params: Dict) -> None:
159
- int_fields = {
160
- "d_model",
161
- "nhead",
162
- "num_layers",
163
- "dim_feedforward",
164
- "max_seq_len",
165
- "chunk_size",
166
- "overlap",
167
- }
168
- float_fields = {"act_threshold"}
169
- bool_fields = {"reversible", "use_checkpoint"}
170
- clean: Dict[str, Any] = {}
171
- for k, v in params.items():
172
- if v is None:
173
- clean[k] = None
174
- elif k in int_fields:
175
- clean[k] = int(v)
176
- elif k in float_fields:
177
- clean[k] = float(v)
178
- elif k in bool_fields:
179
- clean[k] = bool(v)
180
- else:
181
- clean[k] = v
182
- self.model = BitTransformerLM(
183
- **clean,
184
- lambda_K=self.lambda_K,
185
- lambda_C=self.lambda_C,
186
- lambda_S=self.lambda_S,
187
- )
188
- self.optimizer, self.scheduler = configure_optimizer(
189
- self.model, lr=1e-3, total_steps=self.total_steps
190
- )
191
- self._apply_device()
192
- for key in self.metrics:
193
- self.metrics[key].clear()
194
-
195
- def set_lambdas(self, k: float, c: float, s: float) -> None:
196
- """Update λ weights and propagate to the model."""
197
- self.lambda_K = k
198
- self.lambda_C = c
199
- self.lambda_S = s
200
- if self.model is not None:
201
- self.model.set_lambdas(k, c, s)
202
-
203
- def set_floors(self, c_floor: float, s_floor: float) -> None:
204
- """Update safety floors for complexity (C) and symbiosis (S)."""
205
- self.c_floor = c_floor
206
- self.s_floor = s_floor
207
-
208
- def set_diffusion(self, flag: bool) -> None:
209
- """Toggle Diffusion LM mode which disables causal masking and chunking."""
210
- self.diffusion = flag
211
- self.causal = not flag
212
- if self.model is not None and flag:
213
- self.model.chunk_size = None
214
-
215
- def set_decompress_output(self, flag: bool) -> None:
216
- """Enable or disable decompression of model outputs."""
217
- self.decompress_output = flag
218
-
219
- def set_compression(self, flag: bool) -> None:
220
- """Toggle automatic compression of inputs."""
221
- self.use_compression = flag
222
-
223
- def set_qat(self, flag: bool) -> None:
224
- """Enable or disable 4-bit quantization-aware training."""
225
- self.qat = flag
226
- if self.model is None:
227
- return
228
- if flag:
229
- self.model = prepare_qat_fx(self.model)
230
- else:
231
- self.model = convert_qat_fx(self.model)
232
-
233
- def set_gpu(self, flag: bool) -> None:
234
- """Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed."""
235
- _switch_torch(flag)
236
- self.use_gpu = flag and torch.cuda.is_available()
237
- self._apply_device()
238
-
239
- def _apply_device(self) -> None:
240
- """Move the model to the selected device and wrap with FSDP if needed."""
241
- if self.model is None:
242
- return
243
- if self.use_gpu:
244
- device = torch.device("cuda")
245
- if isinstance(self.model, FullyShardedDataParallel):
246
- base = self.model.module
247
- else:
248
- base = self.model
249
- base = base.to(device)
250
- self.model = wrap_fsdp(base, device_id=device)
251
- else:
252
- device = torch.device("cpu")
253
- if isinstance(self.model, FullyShardedDataParallel):
254
- self.model = self.model.module
255
- self.model = self.model.to(device)
256
-
257
- def train_step(self, bits: torch.Tensor) -> float:
258
- assert (
259
- self.model is not None
260
- and self.optimizer is not None
261
- and self.scheduler is not None
262
- )
263
- self.model.train()
264
- device = next(self.model.parameters()).device
265
- bits = bits.to(device)
266
- ratio = 1.0
267
- if self.use_compression:
268
- comps = [compress_bits(row.to(torch.uint8)) for row in bits]
269
- comp_len = sum(c.numel() for c in comps)
270
- ratio = min(comp_len / bits.numel(), 1.0)
271
- logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
272
- else:
273
- logits, telemetry = self.model(bits, causal=self.causal)
274
- pred = logits[:, :-1, :].reshape(-1, 2)
275
- target = bits[:, 1:].reshape(-1)
276
- loss = F.cross_entropy(pred, target)
277
- loss.backward()
278
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
279
- self.optimizer.step()
280
- self.scheduler.step()
281
- self.optimizer.zero_grad()
282
- self._log_metrics(telemetry)
283
- self._save_state()
284
- return loss.item(), ratio
285
-
286
- def train_epochs(
287
- self,
288
- bits: torch.Tensor,
289
- *,
290
- epochs: int = 1,
291
- compress_prob: float = 0.5,
292
- direct_prob: float = 0.0,
293
- batch_size: int = 8,
294
- num_workers: int = 0,
295
- accum_steps: int = 1,
296
- amp: bool = False,
297
- compile_model: bool = False,
298
- ) -> List[Dict[str, float]]:
299
- """Run ``train_loop`` on a batch tensor and persist the state."""
300
- assert self.model is not None
301
- device = next(self.model.parameters()).device
302
- bits = bits.to(device)
303
- import math
304
- steps_per_epoch = max(1, math.ceil(len(bits) / batch_size))
305
- self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps)
306
- self.optimizer, self.scheduler = configure_optimizer(
307
- self.model, lr=1e-3, total_steps=self.total_steps
308
- )
309
- metrics = train_loop(
310
- self.model,
311
- bits,
312
- epochs=epochs,
313
- compress_prob=compress_prob if self.use_compression else 0.0,
314
- direct_prob=direct_prob,
315
- batch_size=batch_size,
316
- num_workers=num_workers,
317
- accum_steps=accum_steps,
318
- amp=amp,
319
- compile_model=compile_model,
320
- forward_kwargs={"causal": self.causal},
321
- optimizer=self.optimizer,
322
- scheduler=self.scheduler,
323
- )
324
- self._save_state()
325
- return metrics
326
-
327
- def scale_up(self, width_mult: float = 1.0) -> None:
328
- assert self.model is not None
329
- params = dict(
330
- d_model=int(self.model.d_model * width_mult),
331
- nhead=self.model.layers[0].self_attn.num_heads,
332
- num_layers=self.model.num_layers * 2,
333
- dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult),
334
- max_seq_len=self.model.pos_enc.pe.size(0),
335
- )
336
- self.model = expand_model(self.model, {
337
- **params,
338
- "lambda_K": self.lambda_K,
339
- "lambda_C": self.lambda_C,
340
- "lambda_S": self.lambda_S,
341
- })
342
- self.optimizer, self.scheduler = configure_optimizer(
343
- self.model, lr=1e-3, total_steps=self.total_steps
344
- )
345
- self._save_state()
346
-
347
- def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None:
348
- self.model, _ = collapse_submodel(
349
- cluster_bits,
350
- target_params,
351
- width_scale=width_scale,
352
- forward_kwargs={"causal": self.causal},
353
- )
354
- self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S)
355
- self.optimizer, self.scheduler = configure_optimizer(
356
- self.model, lr=1e-3, total_steps=self.total_steps
357
- )
358
- self._apply_device()
359
- for key in self.metrics:
360
- self.metrics[key].clear()
361
-
362
- def infer(self, bits: torch.Tensor) -> Dict:
363
- assert self.model is not None
364
- self.model.eval()
365
- device = next(self.model.parameters()).device
366
- bits = bits.to(device)
367
- ratio = 1.0
368
- with torch.no_grad():
369
- if self.use_compression:
370
- comps = [compress_bits(row.to(torch.uint8)) for row in bits]
371
- comp_len = sum(c.numel() for c in comps)
372
- ratio = min(comp_len / bits.numel(), 1.0)
373
- logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
374
- else:
375
- logits, telemetry = self.model(bits, causal=self.causal)
376
- self._log_metrics(telemetry)
377
- pred_bits = logits.argmax(-1)
378
- if self.decompress_output:
379
- try:
380
- pred_bits = model_output_decompress(pred_bits)
381
- except Exception as e:
382
- return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."}
383
- def _to_python(obj):
384
- if isinstance(obj, torch.Tensor):
385
- return obj.tolist()
386
- if isinstance(obj, list):
387
- return [_to_python(o) for o in obj]
388
- if isinstance(obj, dict):
389
- return {kk: _to_python(vv) for kk, vv in obj.items()}
390
- return obj
391
- tele = {k: _to_python(v) for k, v in telemetry.items()}
392
- return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio}
393
-
394
- def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict:
395
- """Run sliding-window inference on a long sequence."""
396
- assert self.model is not None
397
- device = next(self.model.parameters()).device
398
- bits = bits.to(device)
399
- preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap)
400
- for tele in logs:
401
- self._log_metrics(tele)
402
- return {"predicted": preds.tolist(), "windows": len(logs)}
403
-
404
- def _log_metrics(self, telemetry: Dict) -> None:
405
- for key in self.metrics:
406
- val = telemetry[key].mean().item()
407
- self.metrics[key].append(val)
408
- drift = detect_metric_drift(
409
- self.metrics, window=self.drift_window, threshold=self.drift_threshold
410
- )
411
- bad = [k for k, v in drift.items() if v]
412
- if bad:
413
- warnings.warn(
414
- f"Metric drift detected: {', '.join(bad)}",
415
- MetricDriftWarning,
416
- )
417
-
418
- def infer_text(self, text: str) -> Dict[str, Any]:
419
- """Run text through the model using the safety gate."""
420
- assert self.model is not None
421
- device = next(self.model.parameters()).device
422
- bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device)
423
- out_bits, telemetry = hil_safe_inference(
424
- self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor
425
- )
426
- self._log_metrics(telemetry)
427
- return {
428
- "output": bits_to_text(out_bits.squeeze(0).tolist()),
429
- "telemetry": telemetry,
430
- }
431
-
432
- def get_status(self) -> Dict[str, Any]:
433
- info: Dict[str, Any] = {
434
- "use_gpu": self.use_gpu,
435
- "diffusion": self.diffusion,
436
- "compression": self.use_compression,
437
- "lambda_K": self.lambda_K,
438
- "lambda_C": self.lambda_C,
439
- "lambda_S": self.lambda_S,
440
- "c_floor": self.c_floor,
441
- "s_floor": self.s_floor,
442
- "qat": self.qat,
443
- }
444
- if self.model is not None:
445
- info.update(
446
- {
447
- "d_model": self.model.d_model,
448
- "num_layers": self.model.num_layers,
449
- "d_ff": self.model.layers[0].linear1.out_features,
450
- "nhead": self.model.layers[0].self_attn.num_heads,
451
- "max_seq_len": self.model.pos_enc.pe.size(0),
452
- }
453
- )
454
- else:
455
- info.update(
456
- {
457
- "d_model": None,
458
- "num_layers": 0,
459
- "d_ff": None,
460
- "nhead": None,
461
- "max_seq_len": None,
462
- }
463
- )
464
- return info
465
-
466
- def get_model_config(self) -> Dict[str, Any]:
467
- """Return current model hyperparameters and safety settings."""
468
- cfg: Dict[str, Any] = {
469
- "lambda_K": self.lambda_K,
470
- "lambda_C": self.lambda_C,
471
- "lambda_S": self.lambda_S,
472
- "c_floor": self.c_floor,
473
- "s_floor": self.s_floor,
474
- }
475
- if self.model is not None:
476
- cfg.update(
477
- {
478
- "d_model": self.model.d_model,
479
- "nhead": self.model.layers[0].self_attn.num_heads,
480
- "num_layers": self.model.num_layers,
481
- "dim_feedforward": self.model.layers[0].linear1.out_features,
482
- "max_seq_len": self.model.pos_enc.pe.size(0),
483
- "chunk_size": self.model.chunk_size,
484
- "reversible": self.model.reversible,
485
- "use_checkpoint": self.model.use_checkpoint,
486
- }
487
- )
488
- else:
489
- cfg.update(
490
- {
491
- "d_model": None,
492
- "nhead": None,
493
- "num_layers": 0,
494
- "dim_feedforward": None,
495
- "max_seq_len": None,
496
- "chunk_size": None,
497
- "reversible": None,
498
- "use_checkpoint": None,
499
- }
500
- )
501
- return cfg
502
-
503
- def get_metrics(self) -> Dict[str, Any]:
504
- """Return logged telemetry metrics with summary statistics."""
505
- from statistics import mean, stdev
506
-
507
- data = {
508
- "negentropy": self.metrics["negentropy_logits"],
509
- "lz_complexity": self.metrics["lz_complexity_logits"],
510
- "symbiosis": self.metrics["symbiosis_score"],
511
- }
512
- summary: Dict[str, Dict[str, float | None]] = {}
513
- for key, values in data.items():
514
- if values:
515
- m = mean(values)
516
- s = stdev(values) if len(values) > 1 else 0.0
517
- summary[key] = {"mean": m, "std": s}
518
- else:
519
- summary[key] = {"mean": None, "std": None}
520
- data["summary"] = summary
521
- return data
522
-
523
-
524
- def _save_state(self) -> None:
525
- if self.model is None:
526
- return
527
- torch.save(self.model, self.weights_path)
528
- with open(self.telemetry_log, "w") as f:
529
- json.dump(self.metrics, f)
530
-
531
-
532
- manager: ModelManager | None = None
533
-
534
-
535
- @app.route("/")
536
- def index():
537
- return render_template(
538
- "dashboard.html",
539
- metrics=manager.metrics,
540
- lambdas={
541
- "lambda_K": manager.lambda_K,
542
- "lambda_C": manager.lambda_C,
543
- "lambda_S": manager.lambda_S,
544
- },
545
- diffusion=manager.diffusion,
546
- compression=manager.use_compression,
547
- defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty},
548
- c_floor=manager.c_floor,
549
- s_floor=manager.s_floor,
550
- qat=manager.qat,
551
- )
552
-
553
-
554
- @app.route("/status", methods=["GET"])
555
- def status():
556
- if MCP_SERVER_ADDR:
557
- return jsonify(mcp_get("/status"))
558
- return jsonify(manager.get_status())
559
-
560
-
561
- @app.route("/model_config", methods=["GET"])
562
- def model_config():
563
- if MCP_SERVER_ADDR:
564
- return jsonify(mcp_get("/model_config"))
565
- return jsonify(manager.get_model_config())
566
-
567
-
568
- @app.route("/metrics", methods=["GET"])
569
- def metrics():
570
- if MCP_SERVER_ADDR:
571
- return jsonify(mcp_get("/metrics"))
572
- return jsonify(manager.get_metrics())
573
-
574
-
575
- @app.route("/save_checkpoint", methods=["POST"])
576
- def save_checkpoint_route():
577
- repo_id = request.json.get("repo_id")
578
- token = request.json.get("token") or os.getenv("HF_TOKEN")
579
- if MCP_SERVER_ADDR:
580
- return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token}))
581
- if manager.model is None:
582
- return jsonify({"error": "model not initialized"}), 400
583
- if token:
584
- hf_login(token=token)
585
- save_checkpoint(manager.model, repo_id=repo_id)
586
- return jsonify({"status": "saved"})
587
-
588
-
589
- @app.route("/download_checkpoint", methods=["POST"])
590
- def download_checkpoint_route():
591
- repo_id = request.json.get("repo_id")
592
- token = request.json.get("token") or os.getenv("HF_TOKEN")
593
- if MCP_SERVER_ADDR:
594
- return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token}))
595
- if token:
596
- hf_login(token=token)
597
- dest = manager.weights_path + ".gz"
598
- ok = download_checkpoint(dest, repo_id=repo_id)
599
- if not ok:
600
- return jsonify({"status": "failed"}), 500
601
- if manager.model is None:
602
- return jsonify({"status": "downloaded", "loaded": False})
603
- with gzip.open(dest, "rb") as f:
604
- state = torch.load(f, map_location="cpu")
605
- manager.model.load_state_dict(state)
606
- manager.optimizer, manager.scheduler = configure_optimizer(
607
- manager.model, lr=1e-3, total_steps=manager.total_steps
608
- )
609
- manager._apply_device()
610
- manager._save_state()
611
- return jsonify({"status": "downloaded", "loaded": True})
612
-
613
-
614
- @app.route("/text_to_bits", methods=["POST"])
615
- def text_to_bits_route():
616
- text = request.json.get("text", "")
617
- if len(text) > 100_000:
618
- return jsonify({"error": "text too large"}), 413
619
- return jsonify({"bits": text_to_bits(text)})
620
-
621
-
622
- @app.route("/dataset", methods=["GET"])
623
- def dataset_route():
624
- name = request.args.get("name", "")
625
- split = request.args.get("split", "train")
626
- size = int(request.args.get("size", 1))
627
- seq_len = int(request.args.get("seq_len", 64))
628
- if size * seq_len > 1_000_000:
629
- return jsonify({"error": "dataset too large"}), 413
630
- if name == "wikitext2":
631
- try:
632
- from datasets import load_dataset
633
-
634
- ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
635
- lines = [t for t in ds["text"] if t.strip()][:size]
636
- except Exception:
637
- bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
638
- return jsonify({"bits": bits.tolist()})
639
- bits_list = []
640
- for text in lines:
641
- b = text_to_bits(text)[:seq_len]
642
- if len(b) < seq_len:
643
- b.extend([0] * (seq_len - len(b)))
644
- bits_list.append(b)
645
- if len(bits_list) < size:
646
- pad = size - len(bits_list)
647
- bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
648
- return jsonify({"bits": bits_list})
649
- return jsonify({"error": "unknown dataset"}), 400
650
-
651
-
652
- @app.route("/init", methods=["POST"])
653
- def init_model():
654
- data = request.json or {}
655
- int_fields = {
656
- "d_model",
657
- "nhead",
658
- "num_layers",
659
- "dim_feedforward",
660
- "max_seq_len",
661
- "chunk_size",
662
- "overlap",
663
- }
664
- float_fields = {"act_threshold"}
665
- bool_fields = {"reversible", "use_checkpoint"}
666
- params = {}
667
- for k, v in data.items():
668
- if v is None:
669
- params[k] = None
670
- elif k in int_fields:
671
- params[k] = int(v)
672
- elif k in float_fields:
673
- params[k] = float(v)
674
- elif k in bool_fields:
675
- params[k] = bool(v)
676
- else:
677
- params[k] = v
678
- if MCP_SERVER_ADDR:
679
- data = mcp_post("/init", params)
680
- return jsonify(data)
681
- manager.init_model(params)
682
- return jsonify({"status": "initialized", "params": params})
683
-
684
-
685
- @app.route("/train", methods=["POST"])
686
- def train_model():
687
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
688
- if MCP_SERVER_ADDR:
689
- data = mcp_post("/train", {"bits": request.json["bits"]})
690
- return jsonify(data)
691
- loss, ratio = manager.train_step(bits)
692
- return jsonify({"loss": loss, "ratio": ratio})
693
-
694
-
695
- @app.route("/train_epochs", methods=["POST"])
696
- def train_epochs_route():
697
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
698
- epochs = int(request.json.get("epochs", 1))
699
- compress_prob = float(request.json.get("compress_prob", 0.5))
700
- direct_prob = float(request.json.get("direct_prob", 0.0))
701
- if MCP_SERVER_ADDR:
702
- data = mcp_post(
703
- "/train_epochs",
704
- {
705
- "bits": request.json["bits"],
706
- "epochs": epochs,
707
- "compress_prob": compress_prob,
708
- "direct_prob": direct_prob,
709
- },
710
- )
711
- return jsonify(data)
712
- metrics = manager.train_epochs(
713
- bits,
714
- epochs=epochs,
715
- compress_prob=compress_prob,
716
- direct_prob=direct_prob,
717
- )
718
- return jsonify({"metrics": metrics})
719
-
720
-
721
- @app.route("/scale_up", methods=["POST"])
722
- def scale_up():
723
- width_mult = float(request.json.get("width_mult", 1.0))
724
- if MCP_SERVER_ADDR:
725
- data = mcp_post("/scale_up", {"width_mult": width_mult})
726
- return jsonify(data)
727
- manager.scale_up(width_mult)
728
- return jsonify({
729
- "status": "scaled",
730
- "layers": manager.model.num_layers,
731
- "d_model": manager.model.d_model,
732
- })
733
-
734
-
735
- @app.route("/collapse", methods=["POST"])
736
- def collapse_model():
737
- cluster_bits = request.json["clusters"]
738
- params = {k: int(v) for k, v in request.json["params"].items()}
739
- width_scale = float(request.json.get("width_scale", 1.0))
740
- if MCP_SERVER_ADDR:
741
- data = mcp_post(
742
- "/collapse",
743
- {"clusters": cluster_bits, "params": params, "width_scale": width_scale},
744
- )
745
- return jsonify(data)
746
- manager.collapse(cluster_bits, params, width_scale)
747
- return jsonify({"status": "collapsed"})
748
-
749
-
750
- @app.route("/lambdas", methods=["GET", "POST"])
751
- def update_lambdas():
752
- if request.method == "POST":
753
- data = request.json
754
- if MCP_SERVER_ADDR:
755
- res = mcp_post("/lambdas", data)
756
- return jsonify(res)
757
- manager.set_lambdas(
758
- float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])
759
- )
760
- return jsonify({"status": "updated"})
761
- else:
762
- if MCP_SERVER_ADDR:
763
- return jsonify(mcp_get("/lambdas"))
764
- return jsonify(
765
- {
766
- "lambda_K": manager.lambda_K,
767
- "lambda_C": manager.lambda_C,
768
- "lambda_S": manager.lambda_S,
769
- }
770
- )
771
-
772
-
773
- @app.route("/config/telemetry", methods=["GET", "POST"])
774
- def telemetry_config():
775
- """Get or update telemetry λ weights and safety floors."""
776
- if request.method == "POST":
777
- data = request.json
778
- if MCP_SERVER_ADDR:
779
- res = mcp_post("/config/telemetry", data)
780
- return jsonify(res)
781
- manager.set_lambdas(
782
- float(data.get("lambda_K", manager.lambda_K)),
783
- float(data.get("lambda_C", manager.lambda_C)),
784
- float(data.get("lambda_S", manager.lambda_S)),
785
- )
786
- manager.set_floors(
787
- float(data.get("c_floor", manager.c_floor)),
788
- float(data.get("s_floor", manager.s_floor)),
789
- )
790
- return jsonify({"status": "updated"})
791
- else:
792
- if MCP_SERVER_ADDR:
793
- return jsonify(mcp_get("/config/telemetry"))
794
- return jsonify(
795
- {
796
- "lambda_K": manager.lambda_K,
797
- "lambda_C": manager.lambda_C,
798
- "lambda_S": manager.lambda_S,
799
- "c_floor": manager.c_floor,
800
- "s_floor": manager.s_floor,
801
- }
802
- )
803
-
804
-
805
- @app.route("/diffusion", methods=["GET", "POST"])
806
- def update_diffusion():
807
- if request.method == "POST":
808
- if MCP_SERVER_ADDR:
809
- return jsonify(mcp_post("/diffusion", request.json))
810
- manager.set_diffusion(bool(request.json.get("diffusion", False)))
811
- return jsonify({"status": "updated"})
812
- else:
813
- if MCP_SERVER_ADDR:
814
- return jsonify(mcp_get("/diffusion"))
815
- return jsonify({"diffusion": manager.diffusion})
816
-
817
-
818
- @app.route("/gpu", methods=["GET", "POST"])
819
- def update_gpu():
820
- if request.method == "POST":
821
- if MCP_SERVER_ADDR:
822
- return jsonify(mcp_post("/gpu", request.json))
823
- manager.set_gpu(bool(request.json.get("use_gpu", False)))
824
- return jsonify({"status": "updated"})
825
- else:
826
- if MCP_SERVER_ADDR:
827
- return jsonify(mcp_get("/gpu"))
828
- return jsonify({"use_gpu": manager.use_gpu})
829
-
830
-
831
- @app.route("/compression", methods=["GET", "POST"])
832
- def update_compression():
833
- if request.method == "POST":
834
- if MCP_SERVER_ADDR:
835
- return jsonify(mcp_post("/compression", request.json))
836
- manager.set_compression(bool(request.json.get("compression", False)))
837
- return jsonify({"status": "updated"})
838
- else:
839
- if MCP_SERVER_ADDR:
840
- return jsonify(mcp_get("/compression"))
841
- return jsonify({"compression": manager.use_compression})
842
-
843
-
844
- @app.route("/qat", methods=["GET", "POST"])
845
- def update_qat():
846
- if request.method == "POST":
847
- if MCP_SERVER_ADDR:
848
- return jsonify(mcp_post("/qat", request.json))
849
- manager.set_qat(bool(request.json.get("qat", False)))
850
- return jsonify({"status": "updated"})
851
- else:
852
- if MCP_SERVER_ADDR:
853
- return jsonify(mcp_get("/qat"))
854
- return jsonify({"qat": manager.qat})
855
-
856
-
857
- @app.route("/infer", methods=["POST"])
858
- def inference():
859
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
860
- if MCP_SERVER_ADDR:
861
- data = mcp_post("/infer", {"bits": request.json["bits"]})
862
- return jsonify(data)
863
- result = manager.infer(bits)
864
- return jsonify(result)
865
-
866
-
867
- @app.route("/infer_long", methods=["POST"])
868
- def inference_long():
869
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
870
- ctx = int(request.json.get("ctx_bits", 4096))
871
- overlap = int(request.json.get("overlap", 256))
872
- if MCP_SERVER_ADDR:
873
- data = mcp_post(
874
- "/infer_long",
875
- {"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap},
876
- )
877
- return jsonify(data)
878
- result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
879
- return jsonify(result)
880
-
881
-
882
- @app.route("/infer_text", methods=["POST"])
883
- def inference_text():
884
- text = request.json.get("text", "")
885
- if MCP_SERVER_ADDR:
886
- data = mcp_post("/infer_text", {"text": text})
887
- return jsonify(data)
888
- result = manager.infer_text(text)
889
- return jsonify(result)
890
-
891
- @app.route("/plot.png")
892
- def plot_png():
893
- if MCP_SERVER_ADDR:
894
- resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png")
895
- resp.raise_for_status()
896
- return send_file(io.BytesIO(resp.content), mimetype="image/png")
897
- fig, _ = plot_telemetry(manager.metrics)
898
- buf = io.BytesIO()
899
- fig.savefig(buf, format="png")
900
- plt.close(fig)
901
- buf.seek(0)
902
- return send_file(buf, mimetype="image/png")
903
-
904
-
905
- def run_dashboard(host: str | None = None, port: int | None = None,
906
- snapshot_dir: str | None = None, telemetry_log: str | None = None) -> None:
907
- """Launch the Flask dashboard server."""
908
- env_host = os.getenv("HOST", "0.0.0.0")
909
- env_port = int(os.getenv("PORT", "5000"))
910
- host = host or env_host
911
- port = port or env_port
912
- global manager
913
- if manager is None:
914
- manager = ModelManager(snapshot_dir, telemetry_log)
915
- app.run(host=host, port=port, debug=True)
916
-
917
-
918
- if __name__ == "__main__":
919
- import argparse
920
-
921
- parser = argparse.ArgumentParser(description="Run dashboard server")
922
- parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0"))
923
- parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000")))
924
- parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots"))
925
- parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG"))
926
- args = parser.parse_args()
927
- run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log)