BitTransformerLM / mcp_server.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
import io
import os
import gzip
import uuid
import traceback
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, request, jsonify, send_file
import matplotlib.pyplot as plt
import torch
from bit_transformer.dashboard_app import ModelManager
from bit_transformer.dashboard import plot_telemetry
from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
from bit_transformer.optimization import configure_optimizer
from bit_transformer.bit_io import text_to_bits
from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset
app = Flask(__name__)
manager = ModelManager()
# background job management
executor = ThreadPoolExecutor(max_workers=4)
jobs: dict[str, dict] = {}
def _submit_job(fn, *args, **kwargs) -> str:
"""Schedule a function for background execution and return a job id."""
job_id = str(uuid.uuid4())
jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []}
def wrapper():
jobs[job_id]["status"] = "running"
try:
jobs[job_id]["result"] = fn(*args, **kwargs)
jobs[job_id]["status"] = "completed"
except Exception as err: # pragma: no cover - captured for client
jobs[job_id]["status"] = "error"
jobs[job_id]["error"] = str(err)
jobs[job_id]["trace"] = traceback.format_exc()
executor.submit(wrapper)
return job_id
@app.errorhandler(Exception)
def handle_exception(err):
"""Return JSON error responses with stack traces."""
return (
jsonify({"error": str(err), "trace": traceback.format_exc()}),
getattr(err, "code", 500),
)
@app.route("/init", methods=["POST"])
def init_model():
data = request.json or {}
int_fields = {
"d_model",
"nhead",
"num_layers",
"dim_feedforward",
"max_seq_len",
"chunk_size",
"overlap",
}
float_fields = {"act_threshold"}
bool_fields = {"reversible", "use_checkpoint"}
params = {}
for k, v in data.items():
if v is None:
params[k] = None
elif k in int_fields:
params[k] = int(v)
elif k in float_fields:
params[k] = float(v)
elif k in bool_fields:
params[k] = bool(v)
else:
params[k] = v
manager.init_model(params)
return jsonify({"status": "initialized", "params": params})
@app.route("/train", methods=["POST"])
def train_model():
bits = request.json["bits"]
def task():
tensor = torch.tensor(bits, dtype=torch.long)
loss, ratio = manager.train_step(tensor)
return {"loss": loss, "ratio": ratio}
job_id = _submit_job(task)
return jsonify({"job_id": job_id})
@app.route("/train_epochs", methods=["POST"])
def train_epochs_route():
data = request.json
bits = data["bits"]
epochs = int(data.get("epochs", 1))
compress_prob = float(data.get("compress_prob", 0.5))
direct_prob = float(data.get("direct_prob", 0.0))
def task():
tensor = torch.tensor(bits, dtype=torch.long)
metrics = manager.train_epochs(
tensor,
epochs=epochs,
compress_prob=compress_prob,
direct_prob=direct_prob,
)
return {"metrics": metrics}
job_id = _submit_job(task)
return jsonify({"job_id": job_id})
@app.route("/scale_up", methods=["POST"])
def scale_up():
width_mult = float(request.json.get("width_mult", 1.0))
def task():
manager.scale_up(width_mult)
return {
"status": "scaled",
"layers": manager.model.num_layers,
"d_model": manager.model.d_model,
}
job_id = _submit_job(task)
return jsonify({"job_id": job_id})
@app.route("/collapse", methods=["POST"])
def collapse_model():
cluster_bits = request.json["clusters"]
params = {k: int(v) for k, v in request.json["params"].items()}
width_scale = float(request.json.get("width_scale", 1.0))
def task():
manager.collapse(cluster_bits, params, width_scale)
return {"status": "collapsed"}
job_id = _submit_job(task)
return jsonify({"job_id": job_id})
@app.route("/job/<job_id>", methods=["GET"])
def get_job(job_id: str):
job = jobs.get(job_id)
if job is None:
return jsonify({"error": "not found"}), 404
return jsonify(job)
@app.route("/jobs", methods=["GET"])
def list_jobs():
return jsonify(jobs)
@app.route("/lambdas", methods=["GET", "POST"])
def update_lambdas():
if request.method == "POST":
data = request.json
manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]))
return jsonify({"status": "updated"})
else:
return jsonify({
"lambda_K": manager.lambda_K,
"lambda_C": manager.lambda_C,
"lambda_S": manager.lambda_S,
})
@app.route("/diffusion", methods=["GET", "POST"])
def update_diffusion():
if request.method == "POST":
manager.set_diffusion(bool(request.json.get("diffusion", False)))
return jsonify({"status": "updated"})
return jsonify({"diffusion": manager.diffusion})
@app.route("/qat", methods=["GET", "POST"])
def update_qat():
if request.method == "POST":
manager.set_qat(bool(request.json.get("qat", False)))
return jsonify({"status": "updated"})
return jsonify({"qat": manager.qat})
@app.route("/gpu", methods=["GET", "POST"])
def update_gpu():
if request.method == "POST":
manager.set_gpu(bool(request.json.get("use_gpu", False)))
return jsonify({"status": "updated"})
return jsonify({"use_gpu": manager.use_gpu})
@app.route("/infer", methods=["POST"])
def inference():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
result = manager.infer(bits)
return jsonify(result)
@app.route("/infer_long", methods=["POST"])
def inference_long():
bits = torch.tensor(request.json["bits"], dtype=torch.long)
ctx = int(request.json.get("ctx_bits", 4096))
overlap = int(request.json.get("overlap", 256))
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
return jsonify(result)
@app.route("/infer_text", methods=["POST"])
def inference_text():
text = request.json.get("text", "")
result = manager.infer_text(text)
return jsonify(result)
@app.route("/status", methods=["GET"])
def status():
return jsonify(manager.get_status())
@app.route("/model_config", methods=["GET"])
def model_config():
return jsonify(manager.get_model_config())
@app.route("/metrics", methods=["GET"])
def metrics():
return jsonify(manager.get_metrics())
@app.route("/save_checkpoint", methods=["POST"])
def save_checkpoint_route():
repo_id = request.json.get("repo_id")
token = request.json.get("token") or os.getenv("HF_TOKEN")
if manager.model is None:
return jsonify({"error": "model not initialized"}), 400
if token:
hf_login(token=token)
save_checkpoint(manager.model, repo_id=repo_id)
return jsonify({"status": "saved"})
@app.route("/download_checkpoint", methods=["POST"])
def download_checkpoint_route():
repo_id = request.json.get("repo_id")
token = request.json.get("token") or os.getenv("HF_TOKEN")
if token:
hf_login(token=token)
dest = manager.weights_path + ".gz"
ok = download_checkpoint(dest, repo_id=repo_id)
if not ok:
return jsonify({"status": "failed"}), 500
if manager.model is None:
return jsonify({"status": "downloaded", "loaded": False})
with gzip.open(dest, "rb") as f:
state = torch.load(f, map_location="cpu")
manager.model.load_state_dict(state)
manager.optimizer, manager.scheduler = configure_optimizer(
manager.model, lr=1e-3, total_steps=manager.total_steps
)
manager._apply_device()
manager._save_state()
return jsonify({"status": "downloaded", "loaded": True})
@app.route("/plot.png")
def plot_png():
fig, _ = plot_telemetry(manager.metrics)
buf = io.BytesIO()
fig.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
return send_file(buf, mimetype="image/png")
@app.route("/text_to_bits", methods=["POST"])
def text_to_bits_route():
text = request.json.get("text", "")
if len(text) > 100_000:
return jsonify({"error": "text too large"}), 413
return jsonify({"bits": text_to_bits(text)})
@app.route("/dataset", methods=["GET"])
def dataset_route():
name = request.args.get("name", "")
split = request.args.get("split", "train")
size = int(request.args.get("size", 1))
seq_len = int(request.args.get("seq_len", 64))
if size * seq_len > 1_000_000:
return jsonify({"error": "dataset too large"}), 413
if name == "wikitext2":
try:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
lines = [t for t in ds["text"] if t.strip()][:size]
except Exception:
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
return jsonify({"bits": bits.tolist()})
bits_list = []
for text in lines:
b = text_to_bits(text)[:seq_len]
if len(b) < seq_len:
b.extend([0] * (seq_len - len(b)))
bits_list.append(b)
if len(bits_list) < size:
pad = size - len(bits_list)
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
return jsonify({"bits": bits_list})
return jsonify({"error": "unknown dataset"}), 400
# Dataset Management Endpoints
@app.route("/dataset/create", methods=["POST"])
def create_dataset():
"""Create and upload a new BitTransformerLM dataset."""
data = request.json or {}
hf_token = data.get("hf_token") or os.getenv("HF_TOKEN")
repo_id = data.get("repo_id", "BitTransformerLM")
source_texts = data.get("source_texts", None)
if not hf_token:
return jsonify({"error": "HF token required"}), 400
def task():
try:
dataset_url = create_bittransformerlm_dataset(
hf_token=hf_token,
repo_id=repo_id,
source_texts=source_texts
)
return {
"status": "success",
"dataset_url": dataset_url,
"repo_id": repo_id
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
job_id = _submit_job(task)
return jsonify({"job_id": job_id, "message": "Dataset creation started"})
@app.route("/dataset/builder", methods=["POST"])
def create_dataset_builder():
"""Initialize a dataset builder for custom dataset creation."""
data = request.json or {}
hf_token = data.get("hf_token") or os.getenv("HF_TOKEN")
repo_id = data.get("repo_id", "BitTransformerLM")
if not hf_token:
return jsonify({"error": "HF token required"}), 400
try:
builder = BitTransformerDatasetBuilder(hf_token, repo_id)
# Store builder configuration
builder_info = {
"repo_id": repo_id,
"config": builder.config,
"status": "ready"
}
return jsonify({
"status": "builder_created",
"builder_info": builder_info
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/dataset/generate", methods=["POST"])
def generate_dataset_samples():
"""Generate specific types of dataset samples."""
data = request.json or {}
sample_type = data.get("type", "text_to_bits") # text_to_bits, synthetic, safety, compression
count = int(data.get("count", 100))
max_len = int(data.get("max_len", 256))
texts = data.get("texts", None)
if count > 5000:
return jsonify({"error": "count too large, max 5000"}), 400
def task():
try:
# Create temporary builder (no upload)
builder = BitTransformerDatasetBuilder("dummy_token", "temp")
if sample_type == "text_to_bits":
if not texts:
texts = builder._get_default_texts()[:count]
samples = builder.generate_text_to_bits_data(texts[:count], max_len)
elif sample_type == "synthetic":
samples = builder.generate_synthetic_patterns(count, max_len)
elif sample_type == "safety":
samples = builder.generate_safety_benchmarks(count)
elif sample_type == "compression":
# Need base samples first
base_texts = builder._get_default_texts()[:50]
base_samples = builder.generate_text_to_bits_data(base_texts, max_len)
samples = builder.generate_compression_variants(base_samples)[:count]
else:
return {"error": f"Unknown sample type: {sample_type}"}
return {
"status": "success",
"samples": samples[:10], # Return first 10 for preview
"total_generated": len(samples),
"sample_type": sample_type
}
except Exception as e:
return {"error": str(e)}
job_id = _submit_job(task)
return jsonify({"job_id": job_id, "message": f"Generating {sample_type} samples"})
@app.route("/dataset/info", methods=["GET"])
def dataset_info():
"""Get information about available dataset generation options."""
return jsonify({
"sample_types": [
{
"type": "text_to_bits",
"description": "Convert text to parity-protected bit sequences",
"parameters": ["texts", "max_len"]
},
{
"type": "synthetic",
"description": "Generate synthetic bit patterns",
"parameters": ["count", "max_len"],
"patterns": ["alternating", "blocks", "fibonacci", "prime_based", "random_walk"]
},
{
"type": "safety",
"description": "Generate safety benchmark sequences",
"parameters": ["count"],
"categories": ["low_entropy", "medium_entropy", "high_entropy", "edge_cases"]
},
{
"type": "compression",
"description": "Generate compressed variants of base sequences",
"parameters": ["count", "compression_ratios"]
}
],
"default_config": {
"max_sequence_length": 512,
"total_samples": 25000,
"safety_thresholds": {
"min_negentropy": 0.1,
"max_lz_complexity": 0.9,
"min_symbiosis": 0.3
}
}
})
@app.route("/health")
def health_check():
return jsonify({"status": "ok"})
def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None:
app.run(host=host, port=port, debug=True)
if __name__ == "__main__":
import torch
run_mcp_server()