WCNegentropy commited on
Commit
d3e2188
·
verified ·
1 Parent(s): 7d0df52

Remove nested directory: BitTransformerLM/mcp_server.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/mcp_server.py +0 -322
BitTransformerLM/mcp_server.py DELETED
@@ -1,322 +0,0 @@
1
- import io
2
- import os
3
- import gzip
4
- import uuid
5
- import traceback
6
- from concurrent.futures import ThreadPoolExecutor
7
- from flask import Flask, request, jsonify, send_file
8
- import matplotlib.pyplot as plt
9
- import torch
10
-
11
- from bit_transformer.dashboard_app import ModelManager
12
- from bit_transformer.dashboard import plot_telemetry
13
- from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
14
- from bit_transformer.optimization import configure_optimizer
15
- from bit_transformer.bit_io import text_to_bits
16
-
17
- app = Flask(__name__)
18
- manager = ModelManager()
19
-
20
- # background job management
21
- executor = ThreadPoolExecutor(max_workers=4)
22
- jobs: dict[str, dict] = {}
23
-
24
-
25
- def _submit_job(fn, *args, **kwargs) -> str:
26
- """Schedule a function for background execution and return a job id."""
27
- job_id = str(uuid.uuid4())
28
- jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []}
29
-
30
- def wrapper():
31
- jobs[job_id]["status"] = "running"
32
- try:
33
- jobs[job_id]["result"] = fn(*args, **kwargs)
34
- jobs[job_id]["status"] = "completed"
35
- except Exception as err: # pragma: no cover - captured for client
36
- jobs[job_id]["status"] = "error"
37
- jobs[job_id]["error"] = str(err)
38
- jobs[job_id]["trace"] = traceback.format_exc()
39
-
40
- executor.submit(wrapper)
41
- return job_id
42
-
43
-
44
- @app.errorhandler(Exception)
45
- def handle_exception(err):
46
- """Return JSON error responses with stack traces."""
47
- return (
48
- jsonify({"error": str(err), "trace": traceback.format_exc()}),
49
- getattr(err, "code", 500),
50
- )
51
-
52
-
53
- @app.route("/init", methods=["POST"])
54
- def init_model():
55
- data = request.json or {}
56
- int_fields = {
57
- "d_model",
58
- "nhead",
59
- "num_layers",
60
- "dim_feedforward",
61
- "max_seq_len",
62
- "chunk_size",
63
- "overlap",
64
- }
65
- float_fields = {"act_threshold"}
66
- bool_fields = {"reversible", "use_checkpoint"}
67
- params = {}
68
- for k, v in data.items():
69
- if v is None:
70
- params[k] = None
71
- elif k in int_fields:
72
- params[k] = int(v)
73
- elif k in float_fields:
74
- params[k] = float(v)
75
- elif k in bool_fields:
76
- params[k] = bool(v)
77
- else:
78
- params[k] = v
79
- manager.init_model(params)
80
- return jsonify({"status": "initialized", "params": params})
81
-
82
- @app.route("/train", methods=["POST"])
83
- def train_model():
84
- bits = request.json["bits"]
85
-
86
- def task():
87
- tensor = torch.tensor(bits, dtype=torch.long)
88
- loss, ratio = manager.train_step(tensor)
89
- return {"loss": loss, "ratio": ratio}
90
-
91
- job_id = _submit_job(task)
92
- return jsonify({"job_id": job_id})
93
-
94
-
95
- @app.route("/train_epochs", methods=["POST"])
96
- def train_epochs_route():
97
- data = request.json
98
- bits = data["bits"]
99
- epochs = int(data.get("epochs", 1))
100
- compress_prob = float(data.get("compress_prob", 0.5))
101
- direct_prob = float(data.get("direct_prob", 0.0))
102
-
103
- def task():
104
- tensor = torch.tensor(bits, dtype=torch.long)
105
- metrics = manager.train_epochs(
106
- tensor,
107
- epochs=epochs,
108
- compress_prob=compress_prob,
109
- direct_prob=direct_prob,
110
- )
111
- return {"metrics": metrics}
112
-
113
- job_id = _submit_job(task)
114
- return jsonify({"job_id": job_id})
115
-
116
- @app.route("/scale_up", methods=["POST"])
117
- def scale_up():
118
- width_mult = float(request.json.get("width_mult", 1.0))
119
-
120
- def task():
121
- manager.scale_up(width_mult)
122
- return {
123
- "status": "scaled",
124
- "layers": manager.model.num_layers,
125
- "d_model": manager.model.d_model,
126
- }
127
-
128
- job_id = _submit_job(task)
129
- return jsonify({"job_id": job_id})
130
-
131
- @app.route("/collapse", methods=["POST"])
132
- def collapse_model():
133
- cluster_bits = request.json["clusters"]
134
- params = {k: int(v) for k, v in request.json["params"].items()}
135
- width_scale = float(request.json.get("width_scale", 1.0))
136
-
137
- def task():
138
- manager.collapse(cluster_bits, params, width_scale)
139
- return {"status": "collapsed"}
140
-
141
- job_id = _submit_job(task)
142
- return jsonify({"job_id": job_id})
143
-
144
-
145
- @app.route("/job/<job_id>", methods=["GET"])
146
- def get_job(job_id: str):
147
- job = jobs.get(job_id)
148
- if job is None:
149
- return jsonify({"error": "not found"}), 404
150
- return jsonify(job)
151
-
152
-
153
- @app.route("/jobs", methods=["GET"])
154
- def list_jobs():
155
- return jsonify(jobs)
156
-
157
- @app.route("/lambdas", methods=["GET", "POST"])
158
- def update_lambdas():
159
- if request.method == "POST":
160
- data = request.json
161
- manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]))
162
- return jsonify({"status": "updated"})
163
- else:
164
- return jsonify({
165
- "lambda_K": manager.lambda_K,
166
- "lambda_C": manager.lambda_C,
167
- "lambda_S": manager.lambda_S,
168
- })
169
-
170
- @app.route("/diffusion", methods=["GET", "POST"])
171
- def update_diffusion():
172
- if request.method == "POST":
173
- manager.set_diffusion(bool(request.json.get("diffusion", False)))
174
- return jsonify({"status": "updated"})
175
- return jsonify({"diffusion": manager.diffusion})
176
-
177
-
178
- @app.route("/qat", methods=["GET", "POST"])
179
- def update_qat():
180
- if request.method == "POST":
181
- manager.set_qat(bool(request.json.get("qat", False)))
182
- return jsonify({"status": "updated"})
183
- return jsonify({"qat": manager.qat})
184
-
185
-
186
- @app.route("/gpu", methods=["GET", "POST"])
187
- def update_gpu():
188
- if request.method == "POST":
189
- manager.set_gpu(bool(request.json.get("use_gpu", False)))
190
- return jsonify({"status": "updated"})
191
- return jsonify({"use_gpu": manager.use_gpu})
192
-
193
- @app.route("/infer", methods=["POST"])
194
- def inference():
195
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
196
- result = manager.infer(bits)
197
- return jsonify(result)
198
-
199
-
200
- @app.route("/infer_long", methods=["POST"])
201
- def inference_long():
202
- bits = torch.tensor(request.json["bits"], dtype=torch.long)
203
- ctx = int(request.json.get("ctx_bits", 4096))
204
- overlap = int(request.json.get("overlap", 256))
205
- result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap)
206
- return jsonify(result)
207
-
208
- @app.route("/infer_text", methods=["POST"])
209
- def inference_text():
210
- text = request.json.get("text", "")
211
- result = manager.infer_text(text)
212
- return jsonify(result)
213
-
214
- @app.route("/status", methods=["GET"])
215
- def status():
216
- return jsonify(manager.get_status())
217
-
218
-
219
- @app.route("/model_config", methods=["GET"])
220
- def model_config():
221
- return jsonify(manager.get_model_config())
222
-
223
-
224
- @app.route("/metrics", methods=["GET"])
225
- def metrics():
226
- return jsonify(manager.get_metrics())
227
-
228
-
229
- @app.route("/save_checkpoint", methods=["POST"])
230
- def save_checkpoint_route():
231
- repo_id = request.json.get("repo_id")
232
- token = request.json.get("token") or os.getenv("HF_TOKEN")
233
- if manager.model is None:
234
- return jsonify({"error": "model not initialized"}), 400
235
- if token:
236
- hf_login(token=token)
237
- save_checkpoint(manager.model, repo_id=repo_id)
238
- return jsonify({"status": "saved"})
239
-
240
-
241
- @app.route("/download_checkpoint", methods=["POST"])
242
- def download_checkpoint_route():
243
- repo_id = request.json.get("repo_id")
244
- token = request.json.get("token") or os.getenv("HF_TOKEN")
245
- if token:
246
- hf_login(token=token)
247
- dest = manager.weights_path + ".gz"
248
- ok = download_checkpoint(dest, repo_id=repo_id)
249
- if not ok:
250
- return jsonify({"status": "failed"}), 500
251
- if manager.model is None:
252
- return jsonify({"status": "downloaded", "loaded": False})
253
- with gzip.open(dest, "rb") as f:
254
- state = torch.load(f, map_location="cpu")
255
- manager.model.load_state_dict(state)
256
- manager.optimizer, manager.scheduler = configure_optimizer(
257
- manager.model, lr=1e-3, total_steps=manager.total_steps
258
- )
259
- manager._apply_device()
260
- manager._save_state()
261
- return jsonify({"status": "downloaded", "loaded": True})
262
-
263
- @app.route("/plot.png")
264
- def plot_png():
265
- fig, _ = plot_telemetry(manager.metrics)
266
- buf = io.BytesIO()
267
- fig.savefig(buf, format="png")
268
- plt.close(fig)
269
- buf.seek(0)
270
- return send_file(buf, mimetype="image/png")
271
-
272
-
273
- @app.route("/text_to_bits", methods=["POST"])
274
- def text_to_bits_route():
275
- text = request.json.get("text", "")
276
- if len(text) > 100_000:
277
- return jsonify({"error": "text too large"}), 413
278
- return jsonify({"bits": text_to_bits(text)})
279
-
280
-
281
- @app.route("/dataset", methods=["GET"])
282
- def dataset_route():
283
- name = request.args.get("name", "")
284
- split = request.args.get("split", "train")
285
- size = int(request.args.get("size", 1))
286
- seq_len = int(request.args.get("seq_len", 64))
287
- if size * seq_len > 1_000_000:
288
- return jsonify({"error": "dataset too large"}), 413
289
- if name == "wikitext2":
290
- try:
291
- from datasets import load_dataset
292
-
293
- ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
294
- lines = [t for t in ds["text"] if t.strip()][:size]
295
- except Exception:
296
- bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long)
297
- return jsonify({"bits": bits.tolist()})
298
- bits_list = []
299
- for text in lines:
300
- b = text_to_bits(text)[:seq_len]
301
- if len(b) < seq_len:
302
- b.extend([0] * (seq_len - len(b)))
303
- bits_list.append(b)
304
- if len(bits_list) < size:
305
- pad = size - len(bits_list)
306
- bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist())
307
- return jsonify({"bits": bits_list})
308
- return jsonify({"error": "unknown dataset"}), 400
309
-
310
-
311
- @app.route("/health")
312
- def health_check():
313
- return jsonify({"status": "ok"})
314
-
315
-
316
- def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None:
317
- app.run(host=host, port=port, debug=True)
318
-
319
-
320
- if __name__ == "__main__":
321
- import torch
322
- run_mcp_server()