WCNegentropy's picture
Update bit_transformer/dashboard_app.py
f8e4192 verified
import io
import json
import os
import traceback
import inspect
from typing import Any, Dict, List, Optional, Union
from flask import Flask, jsonify, request, render_template, send_file
import subprocess
import sys
import warnings
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import requests
import gzip
from .model import BitTransformerLM, infer_long_sequence
from .optimization import configure_optimizer
from .collapse import collapse_submodel
from .dashboard import plot_telemetry
from .scale import expand_model
from .bit_io import text_to_bits, bits_to_text
from .safety import hil_safe_inference
from .compression import model_output_decompress, compress_bits
from .distributed import wrap_fsdp
from .training import train_loop
from .telemetry import detect_metric_drift
from .quantization import prepare_qat_fx, convert_qat_fx
from torch.distributed.fsdp import FullyShardedDataParallel
from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 # 1MB upload limit
MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR")
@app.errorhandler(Exception)
def handle_exception(err):
"""Return JSON error responses with stack traces."""
return (
jsonify({"error": str(err), "trace": traceback.format_exc()}),
getattr(err, "code", 500),
)
class MetricDriftWarning(UserWarning):
"""Raised when telemetry metrics drift beyond the configured threshold."""
def _switch_torch(use_gpu: bool) -> None:
"""Install the appropriate PyTorch wheel and restart the process."""
have_cuda = torch.version.cuda is not None
if use_gpu == have_cuda:
return
wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu"
url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu"
subprocess.run([
sys.executable,
"-m",
"pip",
"install",
"--index-url",
url,
wheel,
], check=True)
os.execv(sys.executable, [sys.executable] + sys.argv)
def mcp_post(path: str, data=None):
if not MCP_SERVER_ADDR:
return None
url = MCP_SERVER_ADDR.rstrip("/") + path
resp = requests.post(url, json=data)
resp.raise_for_status()
if resp.headers.get("Content-Type", "").startswith("image/"):
return resp.content
return resp.json()
def mcp_get(path: str):
if not MCP_SERVER_ADDR:
return None
url = MCP_SERVER_ADDR.rstrip("/") + path
resp = requests.get(url)
resp.raise_for_status()
if resp.headers.get("Content-Type", "").startswith("image/"):
return resp.content
return resp.json()
class ModelManager:
"""Manage model state and training utilities for the dashboard."""
def __init__(
self,
snapshot_dir: Optional[str] = None,
telemetry_log: Optional[str] = None,
*,
drift_window: int = 10,
drift_threshold: float = 0.2,
) -> None:
self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots")
self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG")
if self.telemetry_log is None:
self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json")
os.makedirs(self.snapshot_dir, exist_ok=True)
self.weights_path = os.path.join(self.snapshot_dir, "model.pt")
self.model: Optional[BitTransformerLM] = None
self.optimizer: Optional[torch.optim.Optimizer] = None
self.scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None
self.total_steps = 100
self.metrics: Dict[str, List[float]] = {
"negentropy_logits": [],
"lz_complexity_logits": [],
"symbiosis_score": [],
}
self.drift_window = drift_window
self.drift_threshold = drift_threshold
self.lambda_K = 1.0
self.lambda_C = 1.0
self.lambda_S = 1.0
self.c_floor = 0.3
self.s_floor = 0.5
self.causal = True
self.diffusion = False
self.decompress_output = False
self.use_compression = False
self.use_gpu = False
self.qat = False
# Load any existing state
if os.path.exists(self.telemetry_log):
try:
with open(self.telemetry_log) as f:
saved = json.load(f)
for key in self.metrics:
self.metrics[key] = saved.get(key, [])
except Exception:
pass
if os.path.exists(self.weights_path):
try:
self.model = torch.load(self.weights_path, map_location="cpu")
self.optimizer, self.scheduler = configure_optimizer(
self.model, lr=1e-3, total_steps=self.total_steps
)
self._apply_device()
except Exception:
self.model = None
config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json")
if self.model is None and os.path.exists(config_path):
try:
with open(config_path) as f:
params = json.load(f)
self.init_model(params)
except Exception:
pass
def init_model(self, params: Dict) -> None:
int_fields = {
"d_model",
"nhead",
"num_layers",
"dim_feedforward",
"max_seq_len",
"chunk_size",
"overlap",
}
float_fields = {"act_threshold"}
bool_fields = {"reversible", "use_checkpoint"}
clean: Dict[str, Any] = {}
for k, v in params.items():
if v is None:
clean[k] = None
elif k in int_fields:
clean[k] = int(v)
elif k in float_fields:
clean[k] = float(v)
elif k in bool_fields:
clean[k] = bool(v)
else:
clean[k] = v
self.model = BitTransformerLM(
**clean,
lambda_K=self.lambda_K,
lambda_C=self.lambda_C,
lambda_S=self.lambda_S,
)
self.optimizer, self.scheduler = configure_optimizer(
self.model, lr=1e-3, total_steps=self.total_steps
)
self._apply_device()
for key in self.metrics:
self.metrics[key].clear()
def set_lambdas(self, k: float, c: float, s: float) -> None:
"""Update λ weights and propagate to the model."""
self.lambda_K = k
self.lambda_C = c
self.lambda_S = s
if self.model is not None:
self.model.set_lambdas(k, c, s)
def set_floors(self, c_floor: float, s_floor: float) -> None:
"""Update safety floors for complexity (C) and symbiosis (S)."""
self.c_floor = c_floor
self.s_floor = s_floor
def set_diffusion(self, flag: bool) -> None:
"""Toggle Diffusion LM mode which disables causal masking and chunking."""
self.diffusion = flag
self.causal = not flag
if self.model is not None and flag:
self.model.chunk_size = None
def set_decompress_output(self, flag: bool) -> None:
"""Enable or disable decompression of model outputs."""
self.decompress_output = flag
def set_compression(self, flag: bool) -> None:
"""Toggle automatic compression of inputs."""
self.use_compression = flag
def set_qat(self, flag: bool) -> None:
"""Enable or disable 4-bit quantization-aware training."""
self.qat = flag
if self.model is None:
return
if flag:
self.model = prepare_qat_fx(self.model)
else:
self.model = convert_qat_fx(self.model)
def set_gpu(self, flag: bool) -> None:
"""Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed."""
_switch_torch(flag)
self.use_gpu = flag and torch.cuda.is_available()
self._apply_device()
def _apply_device(self) -> None:
"""Move the model to the selected device and wrap with FSDP if needed."""
if self.model is None:
return
if self.use_gpu:
device = torch.device("cuda")
if isinstance(self.model, FullyShardedDataParallel):
base = self.model.module
else:
base = self.model
base = base.to(device)
self.model = wrap_fsdp(base, device_id=device)
else:
device = torch.device("cpu")
if isinstance(self.model, FullyShardedDataParallel):
self.model = self.model.module
self.model = self.model.to(device)
def train_step(self, bits: torch.Tensor) -> float:
assert (
self.model is not None
and self.optimizer is not None
and self.scheduler is not None
)
self.model.train()
device = next(self.model.parameters()).device
bits = bits.to(device)
ratio = 1.0
if self.use_compression:
comps = [compress_bits(row.to(torch.uint8)) for row in bits]
comp_len = sum(c.numel() for c in comps)
ratio = min(comp_len / bits.numel(), 1.0)
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
else:
logits, telemetry = self.model(bits, causal=self.causal)
pred = logits[:, :-1, :].reshape(-1, 2)
target = bits[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self._log_metrics(telemetry)
self._save_state()
return loss.item(), ratio
def train_epochs(
self,
bits: torch.Tensor,
*,
epochs: int = 1,
compress_prob: float = 0.5,
direct_prob: float = 0.0,
batch_size: int = 8,
num_workers: int = 0,
accum_steps: int = 1,
amp: bool = False,
compile_model: bool = False,
) -> List[Dict[str, float]]:
"""Run ``train_loop`` on a batch tensor and persist the state."""
assert self.model is not None
device = next(self.model.parameters()).device
bits = bits.to(device)
import math
steps_per_epoch = max(1, math.ceil(len(bits) / batch_size))
self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps)
self.optimizer, self.scheduler = configure_optimizer(
self.model, lr=1e-3, total_steps=self.total_steps
)
metrics = train_loop(
self.model,
bits,
epochs=epochs,
compress_prob=compress_prob if self.use_compression else 0.0,
direct_prob=direct_prob,
batch_size=batch_size,
num_workers=num_workers,
accum_steps=accum_steps,
amp=amp,
compile_model=compile_model,
forward_kwargs={"causal": self.causal},
optimizer=self.optimizer,
scheduler=self.scheduler,
)
self._save_state()
return metrics
def scale_up(self, width_mult: float = 1.0) -> None:
assert self.model is not None
params = dict(
d_model=int(self.model.d_model * width_mult),
nhead=self.model.layers[0].self_attn.num_heads,
num_layers=self.model.num_layers * 2,
dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult),
max_seq_len=self.model.pos_enc.pe.size(0),
)
self.model = expand_model(self.model, {
**params,
"lambda_K": self.lambda_K,
"lambda_C": self.lambda_C,
"lambda_S": self.lambda_S,
})
self.optimizer, self.scheduler = configure_optimizer(
self.model, lr=1e-3, total_steps=self.total_steps
)
self._save_state()
def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None:
self.model, _ = collapse_submodel(
cluster_bits,
target_params,
width_scale=width_scale,
forward_kwargs={"causal": self.causal},
)
self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S)
self.optimizer, self.scheduler = configure_optimizer(
self.model, lr=1e-3, total_steps=self.total_steps
)
self._apply_device()
for key in self.metrics:
self.metrics[key].clear()
def infer(self, bits: torch.Tensor) -> Dict:
assert self.model is not None
self.model.eval()
device = next(self.model.parameters()).device
bits = bits.to(device)
ratio = 1.0
with torch.no_grad():
if self.use_compression:
comps = [compress_bits(row.to(torch.uint8)) for row in bits]
comp_len = sum(c.numel() for c in comps)
ratio = min(comp_len / bits.numel(), 1.0)
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal)
else:
logits, telemetry = self.model(bits, causal=self.causal)
self._log_metrics(telemetry)
pred_bits = logits.argmax(-1)
if self.decompress_output:
try:
pred_bits = model_output_decompress(pred_bits)
except Exception as e:
return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."}
def _to_python(obj):
if isinstance(obj, torch.Tensor):
return obj.tolist()
if isinstance(obj, list):
return [_to_python(o) for o in obj]
if isinstance(obj, dict):
return {kk: _to_python(vv) for kk, vv in obj.items()}
return obj
tele = {k: _to_python(v) for k, v in telemetry.items()}
return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio}
def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict:
"""Run sliding-window inference on a long sequence."""
assert self.model is not None
device = next(self.model.parameters()).device
bits = bits.to(device)
preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap)
for tele in logs:
self._log_metrics(tele)
return {"predicted": preds.tolist(), "windows": len(logs)}
def _log_metrics(self, telemetry: Dict) -> None:
for key in self.metrics:
val = telemetry[key].mean().item()
self.metrics[key].append(val)
drift = detect_metric_drift(
self.metrics, window=self.drift_window, threshold=self.drift_threshold
)
bad = [k for k, v in drift.items() if v]
if bad:
warnings.warn(
f"Metric drift detected: {', '.join(bad)}",
MetricDriftWarning,
)
def infer_text(self, text: str) -> Dict[str, Any]:
"""Run text through the model using the safety gate."""
assert self.model is not None
device = next(self.model.parameters()).device
bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device)
out_bits, telemetry = hil_safe_inference(
self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor
)
self._log_metrics(telemetry)
return {
"output": bits_to_text(out_bits.squeeze(0).tolist()),
"telemetry": telemetry,
}
def get_status(self) -> Dict[str, Any]:
info: Dict[str, Any] = {
"use_gpu": self.use_gpu,
"diffusion": self.diffusion,
"compression": self.use_compression,
"lambda_K": self.lambda_K,
"lambda_C": self.lambda_C,
"lambda_S": self.lambda_S,
"c_floor": self.c_floor,
"s_floor": self.s_floor,
"qat": self.qat,
}
if self.model is not None:
info.update(
{
"d_model": self.model.d_model,
"num_layers": self.model.num_layers,
"d_ff": self.model.layers[0].linear1.out_features,
"nhead": self.model.layers[0].self_attn.num_heads,
"max_seq_len": self.model.pos_enc.pe.size(0),
}
)
else:
info.update(
{
"d_model": None,
"num_layers": 0,
"d_ff": None,
"nhead": None,
"max_seq_len": None,
}
)
return info
def get_model_config(self) -> Dict[str, Any]:
"""Return current model hyperparameters and safety settings."""
cfg: Dict[str, Any] = {
"lambda_K": self.lambda_K,
"lambda_C": self.lambda_C,
"lambda_S": self.lambda_S,
"c_floor": self.c_floor,
"s_floor": self.s_floor,
}
if self.model is not None:
cfg.update(
{
"d_model": self.model.d_model,
"nhead": self.model.layers[0].self_attn.num_heads,
"num_layers": self.model.num_layers,
"dim_feedforward": self.model.layers[0].linear1.out_features,
"max_seq_len": self.model.pos_enc.pe.size(0),
"chunk_size": self.model.chunk_size,
"reversible": self.model.reversible,
"use_checkpoint": self.model.use_checkpoint,
}
)
else:
cfg.update(
{
"d_model": None,
"nhead": None,
"num_layers": 0,
"dim_feedforward": None,
"max_seq_len": None,
"chunk_size": None,
"reversible": None,
"use_checkpoint": None,
}
)
return cfg
def get_metrics(self) -> Dict[str, Any]:
"""Return logged telemetry metrics with summary statistics."""
from statistics import mean, stdev
data = {
"negentropy": self.metrics["negentropy_logits"],
"lz_complexity": self.metrics["lz_complexity_logits"],
"symbiosis": self.metrics["symbiosis_score"],
}
summary: Dict[str, Dict[str, Optional[float]]] = {}
for key, values in data.items():
if values:
m = mean(values)
s = stdev(values) if len(values) > 1 else 0.0
summary[key] = {"mean": m, "std": s}
else:
summary[key] = {"mean": None, "std": None}
data["summary"] = summary
return data
def _save_state(self) -> None:
if self.model is None:
return
torch.save(self.model, self.weights_path)
with open(self.telemetry_log, "w") as f:
json.dump(self.metrics, f)
manager: Optional[ModelManager] = None
@app.route("/")
def index():
return render_template(
"dashboard.html",
metrics=manager.metrics,
lambdas={
"lambda_K": manager.lambda_K,
"lambda_C": manager.lambda_C,
"lambda_S": manager.lambda_S,
},
diffusion=manager.diffusion,
compression=manager.use_compression,
defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty},
c_floor=manager.c_floor,
s_floor=manager.s_floor,
qat=manager.qat,
)
@app.route("/status", methods=["GET"])
def status():
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/status"))
return jsonify(manager.get_status())
@app.route("/model_config", methods=["GET"])
def model_config():
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/model_config"))
return jsonify(manager.get_model_config())
@app.route("/metrics", methods=["GET"])
def metrics():
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/metrics"))
return jsonify(manager.get_metrics())
@app.route("/save_checkpoint", methods=["POST"])
def save_checkpoint_route():
repo_id = request.json.get("repo_id")
token = request.json.get("token") or os.getenv("HF_TOKEN")
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token}))
if manager.model is None:
return jsonify({"error": "model not initialized"}), 400
if token:
hf_login(token=token)
save_checkpoint(manager.model, repo_id=repo_id)
return jsonify({"status": "saved"})
@app.route("/download_checkpoint", methods=["POST"])
def download_checkpoint_route():
repo_id = request.json.get("repo_id")
token = request.json.get("token") or os.getenv("HF_TOKEN")
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token}))
if token:
hf_login(token=token)
dest = manager.weights_path + ".gz"
ok = download_checkpoint(dest, repo_id=repo_id)
if not ok:
return jsonify({"status": "failed"}), 500
if manager.model is None:
return jsonify({"status": "downloaded", "loaded": False})
with gzip.open(dest, "rb") as f:
state = torch.load(f, map_location="cpu")
manager.model.load_state_dict(state)
manager.optimizer, manager.scheduler = configure_optimizer(
manager.model, lr=1e-3, total_steps=manager.total_steps
)
manager._apply_device()
manager._save_state()
return jsonify({"status": "downloaded", "loaded": True})
@app.route("/text_to_bits", methods=["POST"])
def text_to_bits_route():
text = request.json.get("text", "")
if len(text) > 100_000:
return jsonify({"error": "text too large"}), 413
return jsonify({"bits": text_to_bits(text)})
@app.route("/dataset", methods=["GET"])
def dataset_route():
name = request.args.get("name", "")
split = request.args.get("split", "train")
size = int(request.args.get("size", 1))
seq_len = int(request.args.get("seq_len", 64))
if size * seq_len > 1_000_000:
return jsonify({"error": "dataset too large"}), 413
if name == "wikitext2":
try:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
lines = [t for t in ds["text"] if t.strip()][:size]
except Exception:
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
return jsonify({"bits": bits.tolist()})
bits_list = []
for text in lines:
b = text_to_bits(text)[:seq_len]
if len(b) < seq_len:
b.extend([0] * (seq_len - len(b)))
bits_list.append(b)
if len(bits_list) < size:
pad = size - len(bits_list)
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
return jsonify({"bits": bits_list})
return jsonify({"error": "unknown dataset"}), 400
@app.route("/init", methods=["POST"])
def init_model():
data = request.json or {}
int_fields = {
"d_model",
"nhead",
"num_layers",
"dim_feedforward",
"max_seq_len",
"chunk_size",
"overlap",
}
float_fields = {"act_threshold"}
bool_fields = {"reversible", "use_checkpoint"}
params = {}
for k, v in data.items():
if v is None:
params[k] = None
elif k in int_fields:
params[k] = int(v)
elif k in float_fields:
params[k] = float(v)
elif k in bool_fields:
params[k] = bool(v)
else:
params[k] = v
if MCP_SERVER_ADDR:
data = mcp_post("/init", params)
return jsonify(data)
manager.init_model(params)
return jsonify({"status": "initialized", "params": params})
@app.route("/train", methods=["POST"])
def train_model():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
if MCP_SERVER_ADDR:
data = mcp_post("/train", {"bits": request.json["bits"]})
return jsonify(data)
loss, ratio = manager.train_step(bits)
return jsonify({"loss": loss, "ratio": ratio})
@app.route("/train_epochs", methods=["POST"])
def train_epochs_route():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
epochs = int(request.json.get("epochs", 1))
compress_prob = float(request.json.get("compress_prob", 0.5))
direct_prob = float(request.json.get("direct_prob", 0.0))
if MCP_SERVER_ADDR:
data = mcp_post(
"/train_epochs",
{
"bits": request.json["bits"],
"epochs": epochs,
"compress_prob": compress_prob,
"direct_prob": direct_prob,
},
)
return jsonify(data)
metrics = manager.train_epochs(
bits,
epochs=epochs,
compress_prob=compress_prob,
direct_prob=direct_prob,
)
return jsonify({"metrics": metrics})
@app.route("/scale_up", methods=["POST"])
def scale_up():
width_mult = float(request.json.get("width_mult", 1.0))
if MCP_SERVER_ADDR:
data = mcp_post("/scale_up", {"width_mult": width_mult})
return jsonify(data)
manager.scale_up(width_mult)
return jsonify({
"status": "scaled",
"layers": manager.model.num_layers,
"d_model": manager.model.d_model,
})
@app.route("/collapse", methods=["POST"])
def collapse_model():
cluster_bits = request.json["clusters"]
params = {k: int(v) for k, v in request.json["params"].items()}
width_scale = float(request.json.get("width_scale", 1.0))
if MCP_SERVER_ADDR:
data = mcp_post(
"/collapse",
{"clusters": cluster_bits, "params": params, "width_scale": width_scale},
)
return jsonify(data)
manager.collapse(cluster_bits, params, width_scale)
return jsonify({"status": "collapsed"})
@app.route("/lambdas", methods=["GET", "POST"])
def update_lambdas():
if request.method == "POST":
data = request.json
if MCP_SERVER_ADDR:
res = mcp_post("/lambdas", data)
return jsonify(res)
manager.set_lambdas(
float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])
)
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/lambdas"))
return jsonify(
{
"lambda_K": manager.lambda_K,
"lambda_C": manager.lambda_C,
"lambda_S": manager.lambda_S,
}
)
@app.route("/config/telemetry", methods=["GET", "POST"])
def telemetry_config():
"""Get or update telemetry λ weights and safety floors."""
if request.method == "POST":
data = request.json
if MCP_SERVER_ADDR:
res = mcp_post("/config/telemetry", data)
return jsonify(res)
manager.set_lambdas(
float(data.get("lambda_K", manager.lambda_K)),
float(data.get("lambda_C", manager.lambda_C)),
float(data.get("lambda_S", manager.lambda_S)),
)
manager.set_floors(
float(data.get("c_floor", manager.c_floor)),
float(data.get("s_floor", manager.s_floor)),
)
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/config/telemetry"))
return jsonify(
{
"lambda_K": manager.lambda_K,
"lambda_C": manager.lambda_C,
"lambda_S": manager.lambda_S,
"c_floor": manager.c_floor,
"s_floor": manager.s_floor,
}
)
@app.route("/diffusion", methods=["GET", "POST"])
def update_diffusion():
if request.method == "POST":
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/diffusion", request.json))
manager.set_diffusion(bool(request.json.get("diffusion", False)))
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/diffusion"))
return jsonify({"diffusion": manager.diffusion})
@app.route("/gpu", methods=["GET", "POST"])
def update_gpu():
if request.method == "POST":
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/gpu", request.json))
manager.set_gpu(bool(request.json.get("use_gpu", False)))
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/gpu"))
return jsonify({"use_gpu": manager.use_gpu})
@app.route("/compression", methods=["GET", "POST"])
def update_compression():
if request.method == "POST":
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/compression", request.json))
manager.set_compression(bool(request.json.get("compression", False)))
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/compression"))
return jsonify({"compression": manager.use_compression})
@app.route("/qat", methods=["GET", "POST"])
def update_qat():
if request.method == "POST":
if MCP_SERVER_ADDR:
return jsonify(mcp_post("/qat", request.json))
manager.set_qat(bool(request.json.get("qat", False)))
return jsonify({"status": "updated"})
else:
if MCP_SERVER_ADDR:
return jsonify(mcp_get("/qat"))
return jsonify({"qat": manager.qat})
@app.route("/infer", methods=["POST"])
def inference():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
if MCP_SERVER_ADDR:
data = mcp_post("/infer", {"bits": request.json["bits"]})
return jsonify(data)
result = manager.infer(bits)
return jsonify(result)
@app.route("/infer_long", methods=["POST"])
def inference_long():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
ctx = int(request.json.get("ctx_bits", 4096))
overlap = int(request.json.get("overlap", 256))
if MCP_SERVER_ADDR:
data = mcp_post(
"/infer_long",
{"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap},
)
return jsonify(data)
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
return jsonify(result)
@app.route("/infer_text", methods=["POST"])
def inference_text():
text = request.json.get("text", "")
if MCP_SERVER_ADDR:
data = mcp_post("/infer_text", {"text": text})
return jsonify(data)
result = manager.infer_text(text)
return jsonify(result)
@app.route("/plot.png")
def plot_png():
if MCP_SERVER_ADDR:
resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png")
resp.raise_for_status()
return send_file(io.BytesIO(resp.content), mimetype="image/png")
fig, _ = plot_telemetry(manager.metrics)
buf = io.BytesIO()
fig.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
return send_file(buf, mimetype="image/png")
def run_dashboard(host: Optional[str] = None, port: Optional[int] = None,
snapshot_dir: Optional[str] = None, telemetry_log: Optional[str] = None) -> None:
"""Launch the Flask dashboard server."""
env_host = os.getenv("HOST", "0.0.0.0")
env_port = int(os.getenv("PORT", "5000"))
host = host or env_host
port = port or env_port
global manager
if manager is None:
manager = ModelManager(snapshot_dir, telemetry_log)
app.run(host=host, port=port, debug=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run dashboard server")
parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0"))
parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000")))
parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots"))
parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG"))
args = parser.parse_args()
run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log)