thejagstudio commited on
Commit
04b7245
·
verified ·
1 Parent(s): 527cb39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -847
app.py CHANGED
@@ -1,480 +1,3 @@
1
- # import os
2
- # import sys
3
- # import json
4
- # import time
5
- # import importlib.util
6
- # from pathlib import Path
7
- # from flask import Flask, request, jsonify, Response, stream_with_context
8
- # from flask_cors import CORS
9
- # import torch
10
- # from transformers import AutoTokenizer
11
-
12
- # app = Flask(__name__, static_folder='static', static_url_path='/static')
13
- # CORS(app)
14
-
15
- # # Global state
16
- # model = None
17
- # tokenizer = None
18
- # config = None
19
- # device = None
20
- # DiffusionLLM = None
21
- # chat_function = None
22
-
23
-
24
- # def find_file(filename, search_dirs=None):
25
- # """Find a file in current directory or parent directories."""
26
- # if search_dirs is None:
27
- # search_dirs = [
28
- # os.path.dirname(__file__), # Current directory
29
- # os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
- # os.getcwd(), # Working directory
31
- # ]
32
-
33
- # for directory in search_dirs:
34
- # filepath = os.path.join(directory, filename)
35
- # if os.path.exists(filepath):
36
- # print(f"Found {filename} at: {filepath}")
37
- # return filepath
38
-
39
- # return None
40
-
41
-
42
- # def try_import_module(filepath, module_name):
43
- # """Dynamically import a Python file as a module."""
44
- # if not filepath or not os.path.exists(filepath):
45
- # return None
46
-
47
- # try:
48
- # # Add the directory to sys.path
49
- # module_dir = os.path.dirname(filepath)
50
- # if module_dir not in sys.path:
51
- # sys.path.insert(0, module_dir)
52
-
53
- # spec = importlib.util.spec_from_file_location(module_name, filepath)
54
- # if spec is None:
55
- # print(f"Could not create spec for {filepath}")
56
- # return None
57
-
58
- # module = importlib.util.module_from_spec(spec)
59
- # sys.modules[module_name] = module
60
- # spec.loader.exec_module(module)
61
-
62
- # print(f"Successfully imported {module_name} from {filepath}")
63
- # return module
64
- # except Exception as e:
65
- # print(f"Error importing {filepath}: {e}")
66
- # import traceback
67
- # traceback.print_exc()
68
- # return None
69
-
70
-
71
- # def load_model_internal():
72
- # """Load the model and tokenizer."""
73
- # global model, tokenizer, config, device, DiffusionLLM, chat_function
74
-
75
- # if model is not None:
76
- # return True
77
-
78
- # try:
79
- # print("=" * 60)
80
- # print("Starting model loading process...")
81
- # print("=" * 60)
82
-
83
- # # Find and import infer-base.py
84
- # base_path = find_file("infer-base.py")
85
- # if base_path is None:
86
- # raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
-
88
- # print(f"\nImporting infer-base.py from: {base_path}")
89
- # base_mod = try_import_module(base_path, "infer_base")
90
-
91
- # if base_mod is None:
92
- # raise RuntimeError("Failed to import infer-base.py")
93
-
94
- # # Check for DiffusionLLM class
95
- # if not hasattr(base_mod, 'DiffusionLLM'):
96
- # print("Available attributes in infer_base:", dir(base_mod))
97
- # raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
-
99
- # DiffusionLLM = base_mod.DiffusionLLM
100
- # print("✓ Successfully loaded DiffusionLLM class")
101
-
102
- # # Find and import infer-chat.py
103
- # chat_path = find_file("infer-chat.py")
104
- # if chat_path is None:
105
- # raise RuntimeError("Could not find infer-chat.py")
106
-
107
- # print(f"\nImporting infer-chat.py from: {chat_path}")
108
- # chat_mod = try_import_module(chat_path, "infer_chat")
109
-
110
- # if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
- # raise RuntimeError("Failed to import chat function from infer-chat.py")
112
-
113
- # chat_function = chat_mod.chat
114
- # print("✓ Successfully loaded chat function")
115
-
116
- # # Setup pickling workaround for torch.load
117
- # try:
118
- # if hasattr(base_mod, 'ModelConfig'):
119
- # sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
- # sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
- # print("✓ Configured pickle support for model loading")
122
- # except Exception as e:
123
- # print(f"Warning: Could not setup pickle workaround: {e}")
124
-
125
- # # Set device
126
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
- # print(f"\n✓ Using device: {device}")
128
-
129
- # # Load tokenizer
130
- # print("\nLoading tokenizer...")
131
- # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
132
- # if tokenizer.pad_token is None:
133
- # tokenizer.pad_token = tokenizer.eos_token
134
- # print("✓ Tokenizer loaded")
135
-
136
- # # Find model checkpoint
137
- # checkpoint_dirs = [
138
- # "checkpoints",
139
- # "../checkpoints",
140
- # "./checkpoints",
141
- # os.path.join(os.path.dirname(__file__), "checkpoints"),
142
- # os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
- # ]
144
-
145
- # model_path = None
146
- # for checkpoint_dir in checkpoint_dirs:
147
- # best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
- # fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
-
150
- # if os.path.exists(best_path):
151
- # model_path = best_path
152
- # break
153
- # elif os.path.exists(fp32_path):
154
- # model_path = fp32_path
155
- # break
156
-
157
- # if model_path is None:
158
- # raise RuntimeError(
159
- # "Could not find model checkpoint. Looking for:\n"
160
- # " - checkpoints/best_model.pt\n"
161
- # " - checkpoints/model_fp32.pt\n"
162
- # f"Searched directories: {checkpoint_dirs}"
163
- # )
164
-
165
- # print(f"\n✓ Found model checkpoint: {model_path}")
166
- # print("Loading model weights (this may take a minute)...")
167
-
168
- # # Load model
169
- # checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
- # config = checkpoint['config']
171
-
172
- # print("Creating model...")
173
- # model = DiffusionLLM(config)
174
-
175
- # print("Loading state dict...")
176
- # state_dict = checkpoint['model_state']
177
- # state_dict = {k: v.float() for k, v in state_dict.items()}
178
- # model.load_state_dict(state_dict)
179
-
180
- # model = model.to(device)
181
- # model.eval()
182
-
183
- # num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
- # print(f"\n{'=' * 60}")
185
- # print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
186
- # print(f"{'=' * 60}")
187
- # print(f"Parameters: {num_params:.1f}M")
188
- # if 'step' in checkpoint:
189
- # print(f"Training steps: {checkpoint['step']}")
190
- # if 'best_val_loss' in checkpoint:
191
- # print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
- # print(f"{'=' * 60}\n")
193
-
194
- # return True
195
-
196
- # except Exception as e:
197
- # print("\n" + "=" * 60)
198
- # print("ERROR LOADING MODEL")
199
- # print("=" * 60)
200
- # print(f"Error: {e}")
201
- # import traceback
202
- # traceback.print_exc()
203
- # print("=" * 60 + "\n")
204
- # return False
205
-
206
-
207
- # def create_streaming_visualizer():
208
- # """Create a visualizer that yields SSE events instead of printing to terminal."""
209
- # def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
- # # Normalize inputs to lists
211
- # if not isinstance(mask_blocks, list):
212
- # mask_blocks = [mask_blocks]
213
- # is_masked_list = [is_masked_list]
214
-
215
- # # Decode context
216
- # try:
217
- # context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
- # except Exception:
219
- # context_text = str(context_ids[0].tolist())
220
-
221
- # # Build blocks visualization
222
- # all_blocks = []
223
- # for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
- # block_tokens = mask_block[0].tolist()
225
- # block_data = []
226
-
227
- # for i, token_id in enumerate(block_tokens):
228
- # if is_masked[0, i]:
229
- # block_data.append({
230
- # 'type': 'masked',
231
- # 'text': '███'
232
- # })
233
- # else:
234
- # try:
235
- # token_text = tok.decode([token_id], skip_special_tokens=False)
236
- # except Exception:
237
- # token_text = str(int(token_id))
238
- # block_data.append({
239
- # 'type': 'revealed',
240
- # 'text': token_text
241
- # })
242
-
243
- # all_blocks.append({
244
- # 'block_index': block_idx,
245
- # 'tokens': block_data
246
- # })
247
-
248
- # # Return data structure that will be sent as SSE
249
- # return {
250
- # 'context': context_text,
251
- # 'blocks': all_blocks,
252
- # 'num_blocks': len(mask_blocks)
253
- # }
254
-
255
- # return visualizer
256
-
257
-
258
- # @app.route('/')
259
- # def index():
260
- # """Serve the main HTML page."""
261
- # return app.send_static_file('index.html')
262
-
263
-
264
- # @app.route('/api/load', methods=['POST'])
265
- # def load_model_endpoint():
266
- # """Load the model."""
267
- # data = request.json or {}
268
- # check_only = data.get('check_only', False)
269
-
270
- # global model
271
-
272
- # if check_only:
273
- # return jsonify({
274
- # 'loaded': model is not None,
275
- # 'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
- # })
277
-
278
- # if model is not None:
279
- # return jsonify({
280
- # 'loaded': True,
281
- # 'message': 'Model already loaded'
282
- # })
283
-
284
- # success = load_model_internal()
285
-
286
- # if success:
287
- # return jsonify({
288
- # 'loaded': True,
289
- # 'message': 'Model loaded successfully'
290
- # })
291
- # else:
292
- # return jsonify({
293
- # 'loaded': False,
294
- # 'message': 'Failed to load model. Check server logs for details.'
295
- # }), 500
296
-
297
-
298
- # @app.route('/api/generate', methods=['POST'])
299
- # def generate():
300
- # """Generate response without streaming."""
301
- # global model, tokenizer, config, device, chat_function
302
-
303
- # if model is None:
304
- # return jsonify({'error': 'Model not loaded'}), 400
305
-
306
- # if chat_function is None:
307
- # return jsonify({'error': 'Chat function not available'}), 400
308
-
309
- # data = request.json
310
- # instruction = data.get('instruction', '')
311
- # steps = data.get('steps', 64)
312
- # block_size = data.get('block_size', 128)
313
- # max_new_tokens = data.get('max_new_tokens', 128)
314
- # parallel_blocks = data.get('parallel_blocks', 1)
315
-
316
- # if not instruction:
317
- # return jsonify({'error': 'No instruction provided'}), 400
318
-
319
- # try:
320
- # # Generate response
321
- # raw_output, response = chat_function(
322
- # model,
323
- # tokenizer,
324
- # instruction,
325
- # steps=steps,
326
- # block_size=block_size,
327
- # max_new_tokens=max_new_tokens,
328
- # temperature=0.8,
329
- # top_k=50,
330
- # top_p=0.9,
331
- # repetition_penalty=1.2,
332
- # no_repeat_ngram_size=3,
333
- # verbose=False,
334
- # visualize_fn=None,
335
- # parallel_blocks=parallel_blocks,
336
- # )
337
-
338
- # return jsonify({
339
- # 'response': response,
340
- # 'raw_output': raw_output
341
- # })
342
- # except Exception as e:
343
- # import traceback
344
- # traceback.print_exc()
345
- # return jsonify({'error': str(e)}), 500
346
-
347
-
348
- # @app.route('/api/generate-stream', methods=['POST'])
349
- # def generate_stream():
350
- # """Generate response with streaming visualization."""
351
- # global model, tokenizer, config, device, chat_function
352
-
353
- # if model is None:
354
- # return jsonify({'error': 'Model not loaded'}), 400
355
-
356
- # if chat_function is None:
357
- # return jsonify({'error': 'Chat function not available'}), 400
358
-
359
- # data = request.json
360
- # instruction = data.get('instruction', '')
361
- # steps = data.get('steps', 64)
362
- # block_size = data.get('block_size', 128)
363
- # max_new_tokens = data.get('max_new_tokens', 128)
364
- # parallel_blocks = data.get('parallel_blocks', 1)
365
-
366
- # if not instruction:
367
- # return jsonify({'error': 'No instruction provided'}), 400
368
-
369
- # def generate_events():
370
- # try:
371
- # # Import threading to allow yielding from callback
372
- # import queue
373
- # event_queue = queue.Queue()
374
- # generation_complete = {'done': False, 'result': None}
375
-
376
- # def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
- # """This gets called during generation - we need to send events immediately"""
378
- # visualizer = create_streaming_visualizer()
379
- # data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
- # # Put the update in the queue so it can be yielded immediately
381
- # event_queue.put({'type': 'update', 'data': data})
382
-
383
- # # Start generation in a separate thread so we can yield events as they come
384
- # import threading
385
-
386
- # def run_generation():
387
- # try:
388
- # raw_output, response = chat_function(
389
- # model,
390
- # tokenizer,
391
- # instruction,
392
- # steps=steps,
393
- # block_size=block_size,
394
- # max_new_tokens=max_new_tokens,
395
- # temperature=0.8,
396
- # top_k=50,
397
- # top_p=0.9,
398
- # repetition_penalty=1.2,
399
- # no_repeat_ngram_size=3,
400
- # verbose=False,
401
- # visualize_fn=streaming_visualizer,
402
- # parallel_blocks=parallel_blocks,
403
- # )
404
- # generation_complete['result'] = (raw_output, response)
405
- # except Exception as e:
406
- # generation_complete['result'] = ('error', str(e))
407
- # finally:
408
- # generation_complete['done'] = True
409
- # event_queue.put(None) # Signal completion
410
-
411
- # # Start generation thread
412
- # gen_thread = threading.Thread(target=run_generation)
413
- # gen_thread.daemon = True
414
- # gen_thread.start()
415
-
416
- # # Yield start event
417
- # yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
-
419
- # # Yield events as they come from the queue
420
- # while not generation_complete['done'] or not event_queue.empty():
421
- # try:
422
- # event = event_queue.get(timeout=0.1)
423
- # if event is None: # Completion signal
424
- # break
425
- # yield f"data: {json.dumps(event)}\n\n"
426
- # except queue.Empty:
427
- # continue
428
-
429
- # # Wait for thread to finish
430
- # gen_thread.join(timeout=1.0)
431
-
432
- # # Send final response
433
- # if generation_complete['result']:
434
- # raw_output, response = generation_complete['result']
435
- # if raw_output == 'error':
436
- # yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
- # else:
438
- # yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
-
440
- # except Exception as e:
441
- # import traceback
442
- # traceback.print_exc()
443
- # yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
-
445
- # return Response(
446
- # stream_with_context(generate_events()),
447
- # mimetype='text/event-stream',
448
- # headers={
449
- # 'Cache-Control': 'no-cache',
450
- # 'X-Accel-Buffering': 'no'
451
- # }
452
- # )
453
-
454
-
455
- # @app.route('/api/test-stream', methods=['GET'])
456
- # def test_stream():
457
- # """Test streaming endpoint."""
458
- # def generate():
459
- # for i in range(10):
460
- # yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
- # time.sleep(0.5)
462
- # yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
-
464
- # return Response(
465
- # stream_with_context(generate()),
466
- # mimetype='text/event-stream',
467
- # headers={
468
- # 'Cache-Control': 'no-cache',
469
- # 'X-Accel-Buffering': 'no'
470
- # }
471
- # )
472
-
473
-
474
- # if __name__ == '__main__':
475
- # app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
476
-
477
-
478
  import os
479
  import sys
480
  import json
@@ -485,45 +8,9 @@ from flask import Flask, request, jsonify, Response, stream_with_context
485
  from flask_cors import CORS
486
  import torch
487
  from transformers import AutoTokenizer
488
- import threading
489
- import queue
490
- import warnings
491
-
492
- # ============ CRITICAL: CONFIGURE THREADS BEFORE TORCH OPERATIONS ============
493
- # Must be set IMMEDIATELY at module import time
494
- def setup_cpu_threads():
495
- """Configure CPU threads BEFORE any PyTorch parallel work starts."""
496
- cpu_count = os.cpu_count() or 1
497
- physical_cores = cpu_count // 2 if cpu_count > 1 else 1
498
-
499
- # Set environment variables FIRST
500
- os.environ["OMP_NUM_THREADS"] = str(physical_cores)
501
- os.environ["MKL_NUM_THREADS"] = str(physical_cores)
502
- os.environ["NUMEXPR_NUM_THREADS"] = str(physical_cores)
503
-
504
- # Set PyTorch threads BEFORE any operations
505
- try:
506
- torch.set_num_threads(physical_cores)
507
- torch.set_num_interop_threads(physical_cores)
508
- except RuntimeError as e:
509
- warnings.warn(f"Could not set threads: {e} (already initialized)")
510
-
511
- print(f"✓ CPU threads configured: {physical_cores} physical cores")
512
- return physical_cores
513
-
514
- # Call immediately
515
- PHYSICAL_CORES = setup_cpu_threads()
516
- # ============================================================================
517
-
518
- # Configuration flags
519
- USE_ONNX_RUNTIME = False
520
- USE_IPEX = False
521
- USE_TORCH_COMPILE = True
522
- QUANTIZE_MODEL = True
523
- WARMUP_ITERATIONS = 3
524
 
525
  app = Flask(__name__, static_folder='static', static_url_path='/static')
526
- CORS(app, resources={r"/api/*": {"origins": "*"}}) # More permissive for testing
527
 
528
  # Global state
529
  model = None
@@ -532,508 +19,457 @@ config = None
532
  device = None
533
  DiffusionLLM = None
534
  chat_function = None
535
- ModelConfig = None # Will be imported from infer-base.py
536
 
537
  def find_file(filename, search_dirs=None):
538
  """Find a file in current directory or parent directories."""
539
  if search_dirs is None:
540
  search_dirs = [
541
- os.path.dirname(__file__),
542
- os.path.dirname(os.path.dirname(__file__)),
543
- os.getcwd(),
544
  ]
 
545
  for directory in search_dirs:
546
  filepath = os.path.join(directory, filename)
547
  if os.path.exists(filepath):
548
  print(f"Found {filename} at: {filepath}")
549
  return filepath
 
550
  return None
551
 
 
552
  def try_import_module(filepath, module_name):
553
  """Dynamically import a Python file as a module."""
554
  if not filepath or not os.path.exists(filepath):
555
  return None
 
556
  try:
 
557
  module_dir = os.path.dirname(filepath)
558
  if module_dir not in sys.path:
559
  sys.path.insert(0, module_dir)
560
-
561
  spec = importlib.util.spec_from_file_location(module_name, filepath)
562
  if spec is None:
563
  print(f"Could not create spec for {filepath}")
564
  return None
565
-
566
  module = importlib.util.module_from_spec(spec)
567
  sys.modules[module_name] = module
568
  spec.loader.exec_module(module)
569
-
570
- print(f"Successfully imported {module_name}")
571
  return module
572
  except Exception as e:
573
- print(f"Error importing {filepath}: {e}")
574
- if __debug__:
575
- import traceback
576
- traceback.print_exc()
577
  return None
578
 
579
- def quantize_model(model):
580
- """Apply quantization for faster inference."""
581
- if not QUANTIZE_MODEL:
582
- return model
583
-
584
- print("\nApplying quantization...")
585
- try:
586
- # Dynamic quantization - no calibration needed, works on any model
587
- model = torch.quantization.quantize_dynamic(
588
- model,
589
- {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding},
590
- dtype=torch.qint8
591
- )
592
- print("✓ INT8 dynamic quantization applied")
593
- return model
594
- except Exception as e:
595
- print(f"⚠ Quantization failed: {e}")
596
- return model
597
-
598
- def compile_model(model):
599
- """Compile model for maximum speed."""
600
- print("\nCompiling model...")
601
-
602
- # ONNX Runtime (BEST performance)
603
- if USE_ONNX_RUNTIME:
604
- try:
605
- import onnxruntime as ort
606
- onnx_path = "model_optimized.onnx"
607
-
608
- # Export if not exists
609
- if not os.path.exists(onnx_path):
610
- print("Exporting model to ONNX format...")
611
- dummy_input = torch.randint(0, 100, (1, 64))
612
- torch.onnx.export(
613
- model, dummy_input, onnx_path,
614
- input_names=['input_ids'],
615
- output_names=['logits'],
616
- dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
617
- opset_version=16
618
- )
619
-
620
- # Create optimized session
621
- sess_options = ort.SessionOptions()
622
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
623
- sess_options.intra_op_num_threads = PHYSICAL_CORES
624
-
625
- return ort.InferenceSession(onnx_path, sess_options)
626
- except Exception as e:
627
- print(f"⚠ ONNX Runtime failed: {e}")
628
-
629
- # Intel IPEX
630
- if USE_IPEX:
631
- try:
632
- import intel_extension_for_pytorch as ipex
633
- model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
634
- print("✓ Intel IPEX optimization applied")
635
- return model
636
- except Exception as e:
637
- print(f"⚠ IPEX failed: {e}")
638
-
639
- # torch.compile
640
- if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
641
- try:
642
- model = torch.compile(model, mode="max-autotune")
643
- print("✓ torch.compile applied")
644
- except Exception as e:
645
- print(f"⚠ torch.compile failed: {e}")
646
-
647
- return model
648
-
649
- def warmup_model(model, tokenizer, chat_func):
650
- """Warmup the model."""
651
- if WARMUP_ITERATIONS == 0:
652
- return
653
-
654
- print("\nWarming up model...")
655
- start = time.time()
656
-
657
- try:
658
- with torch.inference_mode():
659
- for i in range(WARMUP_ITERATIONS):
660
- chat_func(
661
- model, tokenizer, "Hello",
662
- steps=4, block_size=16, max_new_tokens=8,
663
- temperature=0.7, top_k=50, top_p=0.9,
664
- repetition_penalty=1.2, no_repeat_ngram_size=3,
665
- verbose=False, visualize_fn=None, parallel_blocks=PHYSICAL_CORES
666
- )
667
- print(f" Warmup {i+1}/{WARMUP_ITERATIONS}...")
668
- except Exception as e:
669
- print(f"⚠ Warmup failed: {e}")
670
-
671
- print(f"✓ Warmup complete ({time.time() - start:.2f}s)")
672
 
673
  def load_model_internal():
674
- """Load model with ultra-fast optimizations."""
675
- global model, tokenizer, config, device, DiffusionLLM, chat_function, ModelConfig
676
-
677
  if model is not None:
678
  return True
679
-
680
  try:
681
- print("\n" + "=" * 70)
682
- print("ULTRA-FAST CPU MODEL LOADING")
683
- print("=" * 70)
684
-
685
- # FIRST: Import modules to get ModelConfig
686
- print("\n1. Loading modules...")
687
  base_path = find_file("infer-base.py")
688
  if base_path is None:
689
- raise RuntimeError("Could not find infer-base.py")
690
-
 
691
  base_mod = try_import_module(base_path, "infer_base")
 
692
  if base_mod is None:
693
  raise RuntimeError("Failed to import infer-base.py")
694
-
695
- # CRITICAL: Register ModelConfig for pickle
696
- if hasattr(base_mod, 'ModelConfig'):
697
- ModelConfig = base_mod.ModelConfig
698
- sys.modules['__main__'].ModelConfig = ModelConfig
699
- print("✓ ModelConfig registered for pickle")
700
-
701
  if not hasattr(base_mod, 'DiffusionLLM'):
702
- raise RuntimeError("DiffusionLLM class not found")
703
-
 
704
  DiffusionLLM = base_mod.DiffusionLLM
705
- print("✓ DiffusionLLM loaded")
706
-
707
- # Import chat function
708
  chat_path = find_file("infer-chat.py")
709
  if chat_path is None:
710
  raise RuntimeError("Could not find infer-chat.py")
711
-
 
712
  chat_mod = try_import_module(chat_path, "infer_chat")
 
713
  if chat_mod is None or not hasattr(chat_mod, 'chat'):
714
- raise RuntimeError("Chat function not found")
715
-
716
  chat_function = chat_mod.chat
717
- print("✓ Chat function loaded")
718
-
719
- # Device
720
- device = torch.device("cpu")
721
-
 
 
 
 
 
 
 
 
 
 
722
  # Load tokenizer
723
- print("\n2. Loading tokenizer...")
724
- tokenizer = AutoTokenizer.from_pretrained(
725
- "Qwen/Qwen2.5-0.5B",
726
- use_fast=True,
727
- trust_remote_code=True
728
- )
729
  if tokenizer.pad_token is None:
730
  tokenizer.pad_token = tokenizer.eos_token
731
- print("✓ Fast tokenizer ready")
732
-
733
- # Find model
734
- checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
 
 
 
 
 
 
 
735
  model_path = None
736
  for checkpoint_dir in checkpoint_dirs:
737
  best_path = os.path.join(checkpoint_dir, "best_model.pt")
738
  fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
 
739
  if os.path.exists(best_path):
740
  model_path = best_path
741
  break
742
  elif os.path.exists(fp32_path):
743
  model_path = fp32_path
744
-
 
745
  if model_path is None:
746
- raise RuntimeError("Model checkpoint not found")
747
-
748
- print(f"\n3. Loading checkpoint: {model_path}")
749
-
750
- # CRITICAL: Load checkpoint AFTER ModelConfig is registered
751
- checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
 
 
 
 
 
 
752
  config = checkpoint['config']
753
-
754
- # Create model
755
- print("4. Building model...")
756
  model = DiffusionLLM(config)
757
-
758
- # Load weights
759
  state_dict = checkpoint['model_state']
760
- if not USE_IPEX:
761
- state_dict = {k: v.float() for k, v in state_dict.items()}
762
-
763
  model.load_state_dict(state_dict)
764
- model.eval()
765
  model = model.to(device)
766
-
767
- # Apply optimizations
768
- model = quantize_model(model)
769
- model = compile_model(model)
770
-
771
- # Warmup
772
- warmup_model(model, tokenizer, chat_function)
773
-
774
- # Print summary
775
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
776
- framework = "ONNX Runtime" if USE_ONNX_RUNTIME else "IPEX" if USE_IPEX else "PyTorch"
777
- precision = "INT8" if QUANTIZE_MODEL and not USE_IPEX else "BF16" if USE_IPEX else "FP32"
778
-
779
- print("\n" + "=" * 70)
780
- print(f"✓✓✓ MODEL LOADED & ULTRA-OPTIMIZED ({framework} + {precision}) ✓✓✓")
781
- print("=" * 70)
782
  print(f"Parameters: {num_params:.1f}M")
783
- print(f"CPU Threads: {PHYSICAL_CORES}")
784
  if 'step' in checkpoint:
785
  print(f"Training steps: {checkpoint['step']}")
786
  if 'best_val_loss' in checkpoint:
787
- print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")
788
- print("=" * 70 + "\n")
789
-
790
  return True
791
-
792
  except Exception as e:
793
- print(f"\n ERROR LOADING MODEL: {e}")
794
- if __debug__:
795
- import traceback
796
- traceback.print_exc()
797
- print("=" * 70 + "\n")
 
 
798
  return False
799
 
 
800
  def create_streaming_visualizer():
801
- """Create optimized visualizer."""
802
  def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
 
803
  if not isinstance(mask_blocks, list):
804
  mask_blocks = [mask_blocks]
805
  is_masked_list = [is_masked_list]
806
-
 
807
  try:
808
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
809
  except Exception:
810
  context_text = str(context_ids[0].tolist())
811
-
 
812
  all_blocks = []
813
  for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
814
  block_tokens = mask_block[0].tolist()
815
  block_data = []
816
-
817
- # Efficient batch decoding
818
- token_ids_to_decode = []
819
- positions = []
820
- for i, token_id in enumerate(block_tokens):
821
- if not is_masked[0, i]:
822
- token_ids_to_decode.append(token_id)
823
- positions.append(i)
824
-
825
- try:
826
- decoded_tokens = tok.batch_decode(
827
- [[tid] for tid in token_ids_to_decode],
828
- skip_special_tokens=False
829
- )
830
- except Exception:
831
- decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
832
-
833
- # Reconstruct
834
- decoded_idx = 0
835
  for i, token_id in enumerate(block_tokens):
836
  if is_masked[0, i]:
837
- block_data.append({'type': 'masked', 'text': '███'})
 
 
 
838
  else:
 
 
 
 
839
  block_data.append({
840
  'type': 'revealed',
841
- 'text': decoded_tokens[decoded_idx]
842
  })
843
- decoded_idx += 1
844
-
845
- all_blocks.append({'block_index': block_idx, 'tokens': block_data})
846
-
 
 
 
847
  return {
848
  'context': context_text,
849
  'blocks': all_blocks,
850
  'num_blocks': len(mask_blocks)
851
  }
852
-
853
  return visualizer
854
 
 
855
  @app.route('/')
856
  def index():
857
  """Serve the main HTML page."""
858
  return app.send_static_file('index.html')
859
 
 
860
  @app.route('/api/load', methods=['POST'])
861
  def load_model_endpoint():
862
  """Load the model."""
863
- global model
864
-
865
  data = request.json or {}
866
  check_only = data.get('check_only', False)
867
-
 
 
868
  if check_only:
869
  return jsonify({
870
  'loaded': model is not None,
871
  'message': 'Model is loaded' if model is not None else 'Model not loaded'
872
  })
873
-
874
  if model is not None:
875
  return jsonify({
876
  'loaded': True,
877
  'message': 'Model already loaded'
878
  })
879
-
880
  success = load_model_internal()
 
881
  if success:
882
  return jsonify({
883
  'loaded': True,
884
- 'message': 'Model loaded with ultra-fast CPU optimizations'
885
  })
886
  else:
887
  return jsonify({
888
  'loaded': False,
889
- 'message': 'Failed to load model. Check server logs.'
890
  }), 500
891
 
 
892
  @app.route('/api/generate', methods=['POST'])
893
  def generate():
894
- """Generate response - ultra-fast path."""
895
- global model, tokenizer, chat_function
896
-
897
  if model is None:
898
  return jsonify({'error': 'Model not loaded'}), 400
899
-
900
  if chat_function is None:
901
  return jsonify({'error': 'Chat function not available'}), 400
902
-
903
  data = request.json
904
  instruction = data.get('instruction', '')
905
- steps = data.get('steps', 16) # Minimal for speed
906
- block_size = data.get('block_size', 32)
907
- max_new_tokens = data.get('max_new_tokens', 64)
908
- parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
909
-
910
  if not instruction:
911
  return jsonify({'error': 'No instruction provided'}), 400
912
-
913
  try:
914
- with torch.inference_mode():
915
- raw_output, response = chat_function(
916
- model, tokenizer, instruction,
917
- steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
918
- temperature=0.7, top_k=50, top_p=0.9,
919
- repetition_penalty=1.2, no_repeat_ngram_size=3,
920
- verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks,
921
- )
922
-
923
- return jsonify({'response': response, 'raw_output': raw_output})
924
-
 
 
 
 
 
 
 
 
 
 
 
925
  except Exception as e:
926
- if __debug__:
927
- import traceback
928
- traceback.print_exc()
929
  return jsonify({'error': str(e)}), 500
930
 
 
931
  @app.route('/api/generate-stream', methods=['POST'])
932
  def generate_stream():
933
- """Generate with streaming - optimized."""
934
- global model, tokenizer, chat_function
935
-
936
  if model is None:
937
  return jsonify({'error': 'Model not loaded'}), 400
938
-
 
 
 
939
  data = request.json
940
  instruction = data.get('instruction', '')
941
- steps = data.get('steps', 16)
942
- block_size = data.get('block_size', 32)
943
- max_new_tokens = data.get('max_new_tokens', 64)
944
- parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
945
-
946
  if not instruction:
947
  return jsonify({'error': 'No instruction provided'}), 400
948
-
949
  def generate_events():
950
- event_queue = queue.Queue(maxsize=50) # Limited queue
951
- generation_complete = {'done': False, 'result': None}
952
-
953
- def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
954
- try:
 
 
 
955
  visualizer = create_streaming_visualizer()
956
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
957
- event_queue.put({'type': 'update', 'data': data}, block=False, timeout=0.1)
958
- except queue.Full:
959
- pass # Drop frames if too slow
960
-
961
- def run_generation():
962
- try:
963
- with torch.inference_mode():
 
964
  raw_output, response = chat_function(
965
- model, tokenizer, instruction,
966
- steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
967
- temperature=0.7, top_k=50, top_p=0.9,
968
- repetition_penalty=1.2, no_repeat_ngram_size=3,
969
- verbose=False, visualize_fn=streaming_visualizer,
 
 
 
 
 
 
 
 
970
  parallel_blocks=parallel_blocks,
971
  )
972
- generation_complete['result'] = (raw_output, response)
973
- except Exception as e:
974
- generation_complete['result'] = ('error', str(e))
975
- finally:
976
- generation_complete['done'] = True
977
- event_queue.put(None)
978
-
979
- # Start generation thread
980
- gen_thread = threading.Thread(target=run_generation, daemon=True)
981
- gen_thread.start()
982
-
983
- yield f"data: {json.dumps({'type': 'start', 'ts': time.time()})}\n\n"
984
-
985
- # Stream events
986
- while not generation_complete['done'] or not event_queue.empty():
987
- try:
988
- event = event_queue.get(timeout=0.1)
989
- if event is None:
990
- break
991
- yield f"data: {json.dumps(event)}\n\n"
992
- except queue.Empty:
993
- continue
994
-
995
- gen_thread.join(timeout=2.0)
996
-
997
- # Send final result
998
- if generation_complete['result']:
999
- raw_output, response = generation_complete['result']
1000
- yield f"data: {json.dumps({'type': 'complete' if raw_output != 'error' else 'error',
1001
- 'response': response, 'error': response if raw_output == 'error' else None})}\n\n"
1002
-
 
 
 
 
 
 
 
 
 
 
1003
  return Response(
1004
  stream_with_context(generate_events()),
1005
  mimetype='text/event-stream',
1006
  headers={
1007
  'Cache-Control': 'no-cache',
1008
- 'X-Accel-Buffering': 'no',
1009
- 'Connection': 'keep-alive'
1010
  }
1011
  )
1012
 
1013
- @app.route('/api/status', methods=['GET'])
1014
- def status():
1015
- """Get detailed status."""
1016
- return jsonify({
1017
- 'model_loaded': model is not None,
1018
- 'cpu_cores': os.cpu_count(),
1019
- 'physical_cores': PHYSICAL_CORES,
1020
- 'torch_threads': torch.get_num_threads(),
1021
- 'interop_threads': torch.get_num_interop_threads(),
1022
- 'optimizations': {
1023
- 'onnx_runtime': USE_ONNX_RUNTIME,
1024
- 'ipex': USE_IPEX,
1025
- 'torch_compile': USE_TORCH_COMPILE,
1026
- 'quantization': QUANTIZE_MODEL,
1027
- 'warmup_iterations': WARMUP_ITERATIONS
 
1028
  }
1029
- })
 
1030
 
1031
  if __name__ == '__main__':
1032
- print("\n" + "=" * 70)
1033
- print("ULTRA-FAST CPU INFERENCE SERVER v2.0")
1034
- print("=" * 70)
1035
- print(f"CPU Configuration: {PHYSICAL_CORES} physical cores")
1036
- print(f"Optimizations: ONNX={USE_ONNX_RUNTIME} | IPEX={USE_IPEX} | Compile={USE_TORCH_COMPILE}")
1037
- print("=" * 70 + "\n")
1038
-
1039
- app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import json
 
8
  from flask_cors import CORS
9
  import torch
10
  from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  app = Flask(__name__, static_folder='static', static_url_path='/static')
13
+ CORS(app)
14
 
15
  # Global state
16
  model = None
 
19
  device = None
20
  DiffusionLLM = None
21
  chat_function = None
22
+
23
 
24
  def find_file(filename, search_dirs=None):
25
  """Find a file in current directory or parent directories."""
26
  if search_dirs is None:
27
  search_dirs = [
28
+ os.path.dirname(__file__), # Current directory
29
+ os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
+ os.getcwd(), # Working directory
31
  ]
32
+
33
  for directory in search_dirs:
34
  filepath = os.path.join(directory, filename)
35
  if os.path.exists(filepath):
36
  print(f"Found {filename} at: {filepath}")
37
  return filepath
38
+
39
  return None
40
 
41
+
42
  def try_import_module(filepath, module_name):
43
  """Dynamically import a Python file as a module."""
44
  if not filepath or not os.path.exists(filepath):
45
  return None
46
+
47
  try:
48
+ # Add the directory to sys.path
49
  module_dir = os.path.dirname(filepath)
50
  if module_dir not in sys.path:
51
  sys.path.insert(0, module_dir)
52
+
53
  spec = importlib.util.spec_from_file_location(module_name, filepath)
54
  if spec is None:
55
  print(f"Could not create spec for {filepath}")
56
  return None
57
+
58
  module = importlib.util.module_from_spec(spec)
59
  sys.modules[module_name] = module
60
  spec.loader.exec_module(module)
61
+
62
+ print(f"Successfully imported {module_name} from {filepath}")
63
  return module
64
  except Exception as e:
65
+ print(f"Error importing {filepath}: {e}")
66
+ import traceback
67
+ traceback.print_exc()
 
68
  return None
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def load_model_internal():
72
+ """Load the model and tokenizer."""
73
+ global model, tokenizer, config, device, DiffusionLLM, chat_function
74
+
75
  if model is not None:
76
  return True
77
+
78
  try:
79
+ print("=" * 60)
80
+ print("Starting model loading process...")
81
+ print("=" * 60)
82
+
83
+ # Find and import infer-base.py
 
84
  base_path = find_file("infer-base.py")
85
  if base_path is None:
86
+ raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
+
88
+ print(f"\nImporting infer-base.py from: {base_path}")
89
  base_mod = try_import_module(base_path, "infer_base")
90
+
91
  if base_mod is None:
92
  raise RuntimeError("Failed to import infer-base.py")
93
+
94
+ # Check for DiffusionLLM class
 
 
 
 
 
95
  if not hasattr(base_mod, 'DiffusionLLM'):
96
+ print("Available attributes in infer_base:", dir(base_mod))
97
+ raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
+
99
  DiffusionLLM = base_mod.DiffusionLLM
100
+ print("✓ Successfully loaded DiffusionLLM class")
101
+
102
+ # Find and import infer-chat.py
103
  chat_path = find_file("infer-chat.py")
104
  if chat_path is None:
105
  raise RuntimeError("Could not find infer-chat.py")
106
+
107
+ print(f"\nImporting infer-chat.py from: {chat_path}")
108
  chat_mod = try_import_module(chat_path, "infer_chat")
109
+
110
  if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
+ raise RuntimeError("Failed to import chat function from infer-chat.py")
112
+
113
  chat_function = chat_mod.chat
114
+ print("✓ Successfully loaded chat function")
115
+
116
+ # Setup pickling workaround for torch.load
117
+ try:
118
+ if hasattr(base_mod, 'ModelConfig'):
119
+ sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
+ sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
+ print("✓ Configured pickle support for model loading")
122
+ except Exception as e:
123
+ print(f"Warning: Could not setup pickle workaround: {e}")
124
+
125
+ # Set device
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ print(f"\n✓ Using device: {device}")
128
+
129
  # Load tokenizer
130
+ print("\nLoading tokenizer...")
131
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
 
 
 
 
132
  if tokenizer.pad_token is None:
133
  tokenizer.pad_token = tokenizer.eos_token
134
+ print("✓ Tokenizer loaded")
135
+
136
+ # Find model checkpoint
137
+ checkpoint_dirs = [
138
+ "checkpoints",
139
+ "../checkpoints",
140
+ "./checkpoints",
141
+ os.path.join(os.path.dirname(__file__), "checkpoints"),
142
+ os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
+ ]
144
+
145
  model_path = None
146
  for checkpoint_dir in checkpoint_dirs:
147
  best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
  fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
+
150
  if os.path.exists(best_path):
151
  model_path = best_path
152
  break
153
  elif os.path.exists(fp32_path):
154
  model_path = fp32_path
155
+ break
156
+
157
  if model_path is None:
158
+ raise RuntimeError(
159
+ "Could not find model checkpoint. Looking for:\n"
160
+ " - checkpoints/best_model.pt\n"
161
+ " - checkpoints/model_fp32.pt\n"
162
+ f"Searched directories: {checkpoint_dirs}"
163
+ )
164
+
165
+ print(f"\n✓ Found model checkpoint: {model_path}")
166
+ print("Loading model weights (this may take a minute)...")
167
+
168
+ # Load model
169
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
  config = checkpoint['config']
171
+
172
+ print("Creating model...")
 
173
  model = DiffusionLLM(config)
174
+
175
+ print("Loading state dict...")
176
  state_dict = checkpoint['model_state']
177
+ state_dict = {k: v.float() for k, v in state_dict.items()}
 
 
178
  model.load_state_dict(state_dict)
179
+
180
  model = model.to(device)
181
+ model.eval()
182
+
 
 
 
 
 
 
 
183
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
+ print(f"\n{'=' * 60}")
185
+ print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
186
+ print(f"{'=' * 60}")
 
 
 
187
  print(f"Parameters: {num_params:.1f}M")
 
188
  if 'step' in checkpoint:
189
  print(f"Training steps: {checkpoint['step']}")
190
  if 'best_val_loss' in checkpoint:
191
+ print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
+ print(f"{'=' * 60}\n")
193
+
194
  return True
195
+
196
  except Exception as e:
197
+ print("\n" + "=" * 60)
198
+ print("ERROR LOADING MODEL")
199
+ print("=" * 60)
200
+ print(f"Error: {e}")
201
+ import traceback
202
+ traceback.print_exc()
203
+ print("=" * 60 + "\n")
204
  return False
205
 
206
+
207
  def create_streaming_visualizer():
208
+ """Create a visualizer that yields SSE events instead of printing to terminal."""
209
  def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
+ # Normalize inputs to lists
211
  if not isinstance(mask_blocks, list):
212
  mask_blocks = [mask_blocks]
213
  is_masked_list = [is_masked_list]
214
+
215
+ # Decode context
216
  try:
217
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
  except Exception:
219
  context_text = str(context_ids[0].tolist())
220
+
221
+ # Build blocks visualization
222
  all_blocks = []
223
  for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
  block_tokens = mask_block[0].tolist()
225
  block_data = []
226
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  for i, token_id in enumerate(block_tokens):
228
  if is_masked[0, i]:
229
+ block_data.append({
230
+ 'type': 'masked',
231
+ 'text': '███'
232
+ })
233
  else:
234
+ try:
235
+ token_text = tok.decode([token_id], skip_special_tokens=False)
236
+ except Exception:
237
+ token_text = str(int(token_id))
238
  block_data.append({
239
  'type': 'revealed',
240
+ 'text': token_text
241
  })
242
+
243
+ all_blocks.append({
244
+ 'block_index': block_idx,
245
+ 'tokens': block_data
246
+ })
247
+
248
+ # Return data structure that will be sent as SSE
249
  return {
250
  'context': context_text,
251
  'blocks': all_blocks,
252
  'num_blocks': len(mask_blocks)
253
  }
254
+
255
  return visualizer
256
 
257
+
258
  @app.route('/')
259
  def index():
260
  """Serve the main HTML page."""
261
  return app.send_static_file('index.html')
262
 
263
+
264
  @app.route('/api/load', methods=['POST'])
265
  def load_model_endpoint():
266
  """Load the model."""
 
 
267
  data = request.json or {}
268
  check_only = data.get('check_only', False)
269
+
270
+ global model
271
+
272
  if check_only:
273
  return jsonify({
274
  'loaded': model is not None,
275
  'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
  })
277
+
278
  if model is not None:
279
  return jsonify({
280
  'loaded': True,
281
  'message': 'Model already loaded'
282
  })
283
+
284
  success = load_model_internal()
285
+
286
  if success:
287
  return jsonify({
288
  'loaded': True,
289
+ 'message': 'Model loaded successfully'
290
  })
291
  else:
292
  return jsonify({
293
  'loaded': False,
294
+ 'message': 'Failed to load model. Check server logs for details.'
295
  }), 500
296
 
297
+
298
  @app.route('/api/generate', methods=['POST'])
299
  def generate():
300
+ """Generate response without streaming."""
301
+ global model, tokenizer, config, device, chat_function
302
+
303
  if model is None:
304
  return jsonify({'error': 'Model not loaded'}), 400
305
+
306
  if chat_function is None:
307
  return jsonify({'error': 'Chat function not available'}), 400
308
+
309
  data = request.json
310
  instruction = data.get('instruction', '')
311
+ steps = data.get('steps', 64)
312
+ block_size = data.get('block_size', 128)
313
+ max_new_tokens = data.get('max_new_tokens', 128)
314
+ parallel_blocks = data.get('parallel_blocks', 1)
315
+
316
  if not instruction:
317
  return jsonify({'error': 'No instruction provided'}), 400
318
+
319
  try:
320
+ # Generate response
321
+ raw_output, response = chat_function(
322
+ model,
323
+ tokenizer,
324
+ instruction,
325
+ steps=steps,
326
+ block_size=block_size,
327
+ max_new_tokens=max_new_tokens,
328
+ temperature=0.8,
329
+ top_k=50,
330
+ top_p=0.9,
331
+ repetition_penalty=1.2,
332
+ no_repeat_ngram_size=3,
333
+ verbose=False,
334
+ visualize_fn=None,
335
+ parallel_blocks=parallel_blocks,
336
+ )
337
+
338
+ return jsonify({
339
+ 'response': response,
340
+ 'raw_output': raw_output
341
+ })
342
  except Exception as e:
343
+ import traceback
344
+ traceback.print_exc()
 
345
  return jsonify({'error': str(e)}), 500
346
 
347
+
348
  @app.route('/api/generate-stream', methods=['POST'])
349
  def generate_stream():
350
+ """Generate response with streaming visualization."""
351
+ global model, tokenizer, config, device, chat_function
352
+
353
  if model is None:
354
  return jsonify({'error': 'Model not loaded'}), 400
355
+
356
+ if chat_function is None:
357
+ return jsonify({'error': 'Chat function not available'}), 400
358
+
359
  data = request.json
360
  instruction = data.get('instruction', '')
361
+ steps = data.get('steps', 64)
362
+ block_size = data.get('block_size', 128)
363
+ max_new_tokens = data.get('max_new_tokens', 128)
364
+ parallel_blocks = data.get('parallel_blocks', 1)
365
+
366
  if not instruction:
367
  return jsonify({'error': 'No instruction provided'}), 400
368
+
369
  def generate_events():
370
+ try:
371
+ # Import threading to allow yielding from callback
372
+ import queue
373
+ event_queue = queue.Queue()
374
+ generation_complete = {'done': False, 'result': None}
375
+
376
+ def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
+ """This gets called during generation - we need to send events immediately"""
378
  visualizer = create_streaming_visualizer()
379
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
+ # Put the update in the queue so it can be yielded immediately
381
+ event_queue.put({'type': 'update', 'data': data})
382
+
383
+ # Start generation in a separate thread so we can yield events as they come
384
+ import threading
385
+
386
+ def run_generation():
387
+ try:
388
  raw_output, response = chat_function(
389
+ model,
390
+ tokenizer,
391
+ instruction,
392
+ steps=steps,
393
+ block_size=block_size,
394
+ max_new_tokens=max_new_tokens,
395
+ temperature=0.8,
396
+ top_k=50,
397
+ top_p=0.9,
398
+ repetition_penalty=1.2,
399
+ no_repeat_ngram_size=3,
400
+ verbose=False,
401
+ visualize_fn=streaming_visualizer,
402
  parallel_blocks=parallel_blocks,
403
  )
404
+ generation_complete['result'] = (raw_output, response)
405
+ except Exception as e:
406
+ generation_complete['result'] = ('error', str(e))
407
+ finally:
408
+ generation_complete['done'] = True
409
+ event_queue.put(None) # Signal completion
410
+
411
+ # Start generation thread
412
+ gen_thread = threading.Thread(target=run_generation)
413
+ gen_thread.daemon = True
414
+ gen_thread.start()
415
+
416
+ # Yield start event
417
+ yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
+
419
+ # Yield events as they come from the queue
420
+ while not generation_complete['done'] or not event_queue.empty():
421
+ try:
422
+ event = event_queue.get(timeout=0.1)
423
+ if event is None: # Completion signal
424
+ break
425
+ yield f"data: {json.dumps(event)}\n\n"
426
+ except queue.Empty:
427
+ continue
428
+
429
+ # Wait for thread to finish
430
+ gen_thread.join(timeout=1.0)
431
+
432
+ # Send final response
433
+ if generation_complete['result']:
434
+ raw_output, response = generation_complete['result']
435
+ if raw_output == 'error':
436
+ yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
+ else:
438
+ yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
+
440
+ except Exception as e:
441
+ import traceback
442
+ traceback.print_exc()
443
+ yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
+
445
  return Response(
446
  stream_with_context(generate_events()),
447
  mimetype='text/event-stream',
448
  headers={
449
  'Cache-Control': 'no-cache',
450
+ 'X-Accel-Buffering': 'no'
 
451
  }
452
  )
453
 
454
+
455
+ @app.route('/api/test-stream', methods=['GET'])
456
+ def test_stream():
457
+ """Test streaming endpoint."""
458
+ def generate():
459
+ for i in range(10):
460
+ yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
+ time.sleep(0.5)
462
+ yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
+
464
+ return Response(
465
+ stream_with_context(generate()),
466
+ mimetype='text/event-stream',
467
+ headers={
468
+ 'Cache-Control': 'no-cache',
469
+ 'X-Accel-Buffering': 'no'
470
  }
471
+ )
472
+
473
 
474
  if __name__ == '__main__':
475
+ app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)