thejagstudio commited on
Commit
ded94a3
Β·
verified Β·
1 Parent(s): 0265bc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +873 -270
app.py CHANGED
@@ -1,3 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import json
@@ -8,6 +485,8 @@ from flask import Flask, request, jsonify, Response, stream_with_context
8
  from flask_cors import CORS
9
  import torch
10
  from transformers import AutoTokenizer
 
 
11
 
12
  app = Flask(__name__, static_folder='static', static_url_path='/static')
13
  CORS(app)
@@ -19,457 +498,581 @@ config = None
19
  device = None
20
  DiffusionLLM = None
21
  chat_function = None
 
22
 
 
 
 
 
 
 
 
23
 
24
  def find_file(filename, search_dirs=None):
25
  """Find a file in current directory or parent directories."""
26
  if search_dirs is None:
27
  search_dirs = [
28
- os.path.dirname(__file__), # Current directory
29
- os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
- os.getcwd(), # Working directory
31
  ]
32
-
33
  for directory in search_dirs:
34
  filepath = os.path.join(directory, filename)
35
  if os.path.exists(filepath):
36
  print(f"Found {filename} at: {filepath}")
37
  return filepath
38
-
39
  return None
40
 
41
-
42
  def try_import_module(filepath, module_name):
43
  """Dynamically import a Python file as a module."""
44
  if not filepath or not os.path.exists(filepath):
45
  return None
46
-
47
  try:
48
- # Add the directory to sys.path
49
  module_dir = os.path.dirname(filepath)
50
  if module_dir not in sys.path:
51
  sys.path.insert(0, module_dir)
52
-
53
  spec = importlib.util.spec_from_file_location(module_name, filepath)
54
  if spec is None:
55
  print(f"Could not create spec for {filepath}")
56
  return None
57
-
58
  module = importlib.util.module_from_spec(spec)
59
  sys.modules[module_name] = module
60
  spec.loader.exec_module(module)
61
-
62
  print(f"Successfully imported {module_name} from {filepath}")
63
  return module
64
  except Exception as e:
65
  print(f"Error importing {filepath}: {e}")
66
- import traceback
67
- traceback.print_exc()
 
68
  return None
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def load_model_internal():
72
- """Load the model and tokenizer."""
73
- global model, tokenizer, config, device, DiffusionLLM, chat_function
74
-
75
  if model is not None:
76
  return True
77
-
78
  try:
79
  print("=" * 60)
80
- print("Starting model loading process...")
81
  print("=" * 60)
82
-
83
- # Find and import infer-base.py
 
 
 
84
  base_path = find_file("infer-base.py")
85
  if base_path is None:
86
- raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
-
88
- print(f"\nImporting infer-base.py from: {base_path}")
89
  base_mod = try_import_module(base_path, "infer_base")
90
-
91
- if base_mod is None:
92
- raise RuntimeError("Failed to import infer-base.py")
93
-
94
- # Check for DiffusionLLM class
95
- if not hasattr(base_mod, 'DiffusionLLM'):
96
- print("Available attributes in infer_base:", dir(base_mod))
97
- raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
-
99
  DiffusionLLM = base_mod.DiffusionLLM
100
- print("βœ“ Successfully loaded DiffusionLLM class")
101
-
102
- # Find and import infer-chat.py
103
  chat_path = find_file("infer-chat.py")
104
  if chat_path is None:
105
  raise RuntimeError("Could not find infer-chat.py")
106
-
107
- print(f"\nImporting infer-chat.py from: {chat_path}")
108
  chat_mod = try_import_module(chat_path, "infer_chat")
109
-
110
  if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
- raise RuntimeError("Failed to import chat function from infer-chat.py")
112
-
113
  chat_function = chat_mod.chat
114
- print("βœ“ Successfully loaded chat function")
115
-
116
- # Setup pickling workaround for torch.load
117
- try:
118
- if hasattr(base_mod, 'ModelConfig'):
119
- sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
- sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
- print("βœ“ Configured pickle support for model loading")
122
- except Exception as e:
123
- print(f"Warning: Could not setup pickle workaround: {e}")
124
-
125
- # Set device
126
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
- print(f"\nβœ“ Using device: {device}")
128
-
129
  # Load tokenizer
130
  print("\nLoading tokenizer...")
131
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
 
 
 
 
132
  if tokenizer.pad_token is None:
133
  tokenizer.pad_token = tokenizer.eos_token
134
- print("βœ“ Tokenizer loaded")
135
-
136
- # Find model checkpoint
137
- checkpoint_dirs = [
138
- "checkpoints",
139
- "../checkpoints",
140
- "./checkpoints",
141
- os.path.join(os.path.dirname(__file__), "checkpoints"),
142
- os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
- ]
144
-
145
  model_path = None
146
  for checkpoint_dir in checkpoint_dirs:
147
  best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
  fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
-
150
  if os.path.exists(best_path):
151
  model_path = best_path
152
  break
153
  elif os.path.exists(fp32_path):
154
  model_path = fp32_path
155
- break
156
-
157
  if model_path is None:
158
- raise RuntimeError(
159
- "Could not find model checkpoint. Looking for:\n"
160
- " - checkpoints/best_model.pt\n"
161
- " - checkpoints/model_fp32.pt\n"
162
- f"Searched directories: {checkpoint_dirs}"
163
- )
164
-
165
- print(f"\nβœ“ Found model checkpoint: {model_path}")
166
- print("Loading model weights (this may take a minute)...")
167
-
168
- # Load model
169
- checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
  config = checkpoint['config']
171
-
172
- print("Creating model...")
173
  model = DiffusionLLM(config)
174
-
175
- print("Loading state dict...")
176
  state_dict = checkpoint['model_state']
177
- state_dict = {k: v.float() for k, v in state_dict.items()}
 
 
 
 
178
  model.load_state_dict(state_dict)
179
-
180
- model = model.to(device)
181
  model.eval()
182
-
 
 
 
 
 
 
 
 
 
183
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
  print(f"\n{'=' * 60}")
185
- print(f"βœ“βœ“βœ“ MODEL LOADED SUCCESSFULLY βœ“βœ“βœ“")
186
  print(f"{'=' * 60}")
 
187
  print(f"Parameters: {num_params:.1f}M")
 
 
188
  if 'step' in checkpoint:
189
  print(f"Training steps: {checkpoint['step']}")
190
- if 'best_val_loss' in checkpoint:
191
- print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
  print(f"{'=' * 60}\n")
193
-
194
  return True
195
-
196
  except Exception as e:
197
- print("\n" + "=" * 60)
198
- print("ERROR LOADING MODEL")
199
- print("=" * 60)
200
- print(f"Error: {e}")
201
- import traceback
202
- traceback.print_exc()
203
- print("=" * 60 + "\n")
204
  return False
205
 
206
-
207
  def create_streaming_visualizer():
208
- """Create a visualizer that yields SSE events instead of printing to terminal."""
209
  def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
- # Normalize inputs to lists
211
  if not isinstance(mask_blocks, list):
212
  mask_blocks = [mask_blocks]
213
  is_masked_list = [is_masked_list]
214
-
215
- # Decode context
216
  try:
 
217
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
  except Exception:
219
  context_text = str(context_ids[0].tolist())
220
-
221
- # Build blocks visualization
222
  all_blocks = []
223
  for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
  block_tokens = mask_block[0].tolist()
225
  block_data = []
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  for i, token_id in enumerate(block_tokens):
228
  if is_masked[0, i]:
229
- block_data.append({
230
- 'type': 'masked',
231
- 'text': 'β–ˆβ–ˆβ–ˆ'
232
- })
233
  else:
234
- try:
235
- token_text = tok.decode([token_id], skip_special_tokens=False)
236
- except Exception:
237
- token_text = str(int(token_id))
238
  block_data.append({
239
  'type': 'revealed',
240
- 'text': token_text
241
  })
242
-
243
- all_blocks.append({
244
- 'block_index': block_idx,
245
- 'tokens': block_data
246
- })
247
-
248
- # Return data structure that will be sent as SSE
249
  return {
250
  'context': context_text,
251
  'blocks': all_blocks,
252
  'num_blocks': len(mask_blocks)
253
  }
254
-
255
  return visualizer
256
 
257
-
258
  @app.route('/')
259
  def index():
260
  """Serve the main HTML page."""
261
  return app.send_static_file('index.html')
262
 
263
-
264
  @app.route('/api/load', methods=['POST'])
265
  def load_model_endpoint():
266
  """Load the model."""
267
  data = request.json or {}
268
  check_only = data.get('check_only', False)
269
-
270
  global model
271
-
272
  if check_only:
273
  return jsonify({
274
  'loaded': model is not None,
275
  'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
  })
277
-
278
  if model is not None:
279
  return jsonify({
280
  'loaded': True,
281
  'message': 'Model already loaded'
282
  })
283
-
284
  success = load_model_internal()
285
-
286
  if success:
287
  return jsonify({
288
  'loaded': True,
289
- 'message': 'Model loaded successfully'
290
  })
291
  else:
292
  return jsonify({
293
  'loaded': False,
294
- 'message': 'Failed to load model. Check server logs for details.'
295
  }), 500
296
 
297
-
298
  @app.route('/api/generate', methods=['POST'])
299
  def generate():
300
- """Generate response without streaming."""
301
- global model, tokenizer, config, device, chat_function
302
-
303
  if model is None:
304
  return jsonify({'error': 'Model not loaded'}), 400
305
-
306
  if chat_function is None:
307
  return jsonify({'error': 'Chat function not available'}), 400
308
-
309
  data = request.json
310
  instruction = data.get('instruction', '')
311
- steps = data.get('steps', 64)
312
- block_size = data.get('block_size', 128)
313
- max_new_tokens = data.get('max_new_tokens', 128)
314
- parallel_blocks = data.get('parallel_blocks', 1)
315
-
316
  if not instruction:
317
  return jsonify({'error': 'No instruction provided'}), 400
318
-
319
  try:
320
- # Generate response
321
- raw_output, response = chat_function(
322
- model,
323
- tokenizer,
324
- instruction,
325
- steps=steps,
326
- block_size=block_size,
327
- max_new_tokens=max_new_tokens,
328
- temperature=0.8,
329
- top_k=50,
330
- top_p=0.9,
331
- repetition_penalty=1.2,
332
- no_repeat_ngram_size=3,
333
- verbose=False,
334
- visualize_fn=None,
335
- parallel_blocks=parallel_blocks,
336
- )
337
-
 
338
  return jsonify({
339
  'response': response,
340
  'raw_output': raw_output
341
  })
 
342
  except Exception as e:
343
- import traceback
344
- traceback.print_exc()
 
345
  return jsonify({'error': str(e)}), 500
346
 
347
-
348
  @app.route('/api/generate-stream', methods=['POST'])
349
  def generate_stream():
350
- """Generate response with streaming visualization."""
351
- global model, tokenizer, config, device, chat_function
352
-
353
  if model is None:
354
  return jsonify({'error': 'Model not loaded'}), 400
355
-
356
  if chat_function is None:
357
  return jsonify({'error': 'Chat function not available'}), 400
358
-
359
  data = request.json
360
  instruction = data.get('instruction', '')
361
- steps = data.get('steps', 64)
362
- block_size = data.get('block_size', 128)
363
- max_new_tokens = data.get('max_new_tokens', 128)
364
- parallel_blocks = data.get('parallel_blocks', 1)
365
-
366
  if not instruction:
367
  return jsonify({'error': 'No instruction provided'}), 400
368
-
369
  def generate_events():
370
- try:
371
- # Import threading to allow yielding from callback
372
- import queue
373
- event_queue = queue.Queue()
374
- generation_complete = {'done': False, 'result': None}
375
-
376
- def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
- """This gets called during generation - we need to send events immediately"""
378
  visualizer = create_streaming_visualizer()
379
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
- # Put the update in the queue so it can be yielded immediately
381
- event_queue.put({'type': 'update', 'data': data})
382
-
383
- # Start generation in a separate thread so we can yield events as they come
384
- import threading
385
-
386
- def run_generation():
387
- try:
388
  raw_output, response = chat_function(
389
- model,
390
- tokenizer,
391
- instruction,
392
- steps=steps,
393
- block_size=block_size,
394
- max_new_tokens=max_new_tokens,
395
- temperature=0.8,
396
- top_k=50,
397
- top_p=0.9,
398
- repetition_penalty=1.2,
399
- no_repeat_ngram_size=3,
400
- verbose=False,
401
- visualize_fn=streaming_visualizer,
402
  parallel_blocks=parallel_blocks,
403
  )
404
- generation_complete['result'] = (raw_output, response)
405
- except Exception as e:
406
- generation_complete['result'] = ('error', str(e))
407
- finally:
408
- generation_complete['done'] = True
409
- event_queue.put(None) # Signal completion
410
-
411
- # Start generation thread
412
- gen_thread = threading.Thread(target=run_generation)
413
- gen_thread.daemon = True
414
- gen_thread.start()
415
-
416
- # Yield start event
417
- yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
-
419
- # Yield events as they come from the queue
420
- while not generation_complete['done'] or not event_queue.empty():
421
- try:
422
- event = event_queue.get(timeout=0.1)
423
- if event is None: # Completion signal
424
- break
425
- yield f"data: {json.dumps(event)}\n\n"
426
- except queue.Empty:
427
- continue
428
-
429
- # Wait for thread to finish
430
- gen_thread.join(timeout=1.0)
431
-
432
- # Send final response
433
- if generation_complete['result']:
434
- raw_output, response = generation_complete['result']
435
- if raw_output == 'error':
436
- yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
- else:
438
- yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
-
440
- except Exception as e:
441
- import traceback
442
- traceback.print_exc()
443
- yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
-
445
  return Response(
446
  stream_with_context(generate_events()),
447
  mimetype='text/event-stream',
448
  headers={
449
  'Cache-Control': 'no-cache',
450
- 'X-Accel-Buffering': 'no'
 
451
  }
452
  )
453
 
454
-
455
- @app.route('/api/test-stream', methods=['GET'])
456
- def test_stream():
457
- """Test streaming endpoint."""
458
- def generate():
459
- for i in range(10):
460
- yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
- time.sleep(0.5)
462
- yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
-
464
- return Response(
465
- stream_with_context(generate()),
466
- mimetype='text/event-stream',
467
- headers={
468
- 'Cache-Control': 'no-cache',
469
- 'X-Accel-Buffering': 'no'
470
  }
471
- )
472
-
473
 
474
  if __name__ == '__main__':
475
- app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import sys
3
+ # import json
4
+ # import time
5
+ # import importlib.util
6
+ # from pathlib import Path
7
+ # from flask import Flask, request, jsonify, Response, stream_with_context
8
+ # from flask_cors import CORS
9
+ # import torch
10
+ # from transformers import AutoTokenizer
11
+
12
+ # app = Flask(__name__, static_folder='static', static_url_path='/static')
13
+ # CORS(app)
14
+
15
+ # # Global state
16
+ # model = None
17
+ # tokenizer = None
18
+ # config = None
19
+ # device = None
20
+ # DiffusionLLM = None
21
+ # chat_function = None
22
+
23
+
24
+ # def find_file(filename, search_dirs=None):
25
+ # """Find a file in current directory or parent directories."""
26
+ # if search_dirs is None:
27
+ # search_dirs = [
28
+ # os.path.dirname(__file__), # Current directory
29
+ # os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
+ # os.getcwd(), # Working directory
31
+ # ]
32
+
33
+ # for directory in search_dirs:
34
+ # filepath = os.path.join(directory, filename)
35
+ # if os.path.exists(filepath):
36
+ # print(f"Found {filename} at: {filepath}")
37
+ # return filepath
38
+
39
+ # return None
40
+
41
+
42
+ # def try_import_module(filepath, module_name):
43
+ # """Dynamically import a Python file as a module."""
44
+ # if not filepath or not os.path.exists(filepath):
45
+ # return None
46
+
47
+ # try:
48
+ # # Add the directory to sys.path
49
+ # module_dir = os.path.dirname(filepath)
50
+ # if module_dir not in sys.path:
51
+ # sys.path.insert(0, module_dir)
52
+
53
+ # spec = importlib.util.spec_from_file_location(module_name, filepath)
54
+ # if spec is None:
55
+ # print(f"Could not create spec for {filepath}")
56
+ # return None
57
+
58
+ # module = importlib.util.module_from_spec(spec)
59
+ # sys.modules[module_name] = module
60
+ # spec.loader.exec_module(module)
61
+
62
+ # print(f"Successfully imported {module_name} from {filepath}")
63
+ # return module
64
+ # except Exception as e:
65
+ # print(f"Error importing {filepath}: {e}")
66
+ # import traceback
67
+ # traceback.print_exc()
68
+ # return None
69
+
70
+
71
+ # def load_model_internal():
72
+ # """Load the model and tokenizer."""
73
+ # global model, tokenizer, config, device, DiffusionLLM, chat_function
74
+
75
+ # if model is not None:
76
+ # return True
77
+
78
+ # try:
79
+ # print("=" * 60)
80
+ # print("Starting model loading process...")
81
+ # print("=" * 60)
82
+
83
+ # # Find and import infer-base.py
84
+ # base_path = find_file("infer-base.py")
85
+ # if base_path is None:
86
+ # raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
+
88
+ # print(f"\nImporting infer-base.py from: {base_path}")
89
+ # base_mod = try_import_module(base_path, "infer_base")
90
+
91
+ # if base_mod is None:
92
+ # raise RuntimeError("Failed to import infer-base.py")
93
+
94
+ # # Check for DiffusionLLM class
95
+ # if not hasattr(base_mod, 'DiffusionLLM'):
96
+ # print("Available attributes in infer_base:", dir(base_mod))
97
+ # raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
+
99
+ # DiffusionLLM = base_mod.DiffusionLLM
100
+ # print("βœ“ Successfully loaded DiffusionLLM class")
101
+
102
+ # # Find and import infer-chat.py
103
+ # chat_path = find_file("infer-chat.py")
104
+ # if chat_path is None:
105
+ # raise RuntimeError("Could not find infer-chat.py")
106
+
107
+ # print(f"\nImporting infer-chat.py from: {chat_path}")
108
+ # chat_mod = try_import_module(chat_path, "infer_chat")
109
+
110
+ # if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
+ # raise RuntimeError("Failed to import chat function from infer-chat.py")
112
+
113
+ # chat_function = chat_mod.chat
114
+ # print("βœ“ Successfully loaded chat function")
115
+
116
+ # # Setup pickling workaround for torch.load
117
+ # try:
118
+ # if hasattr(base_mod, 'ModelConfig'):
119
+ # sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
+ # sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
+ # print("βœ“ Configured pickle support for model loading")
122
+ # except Exception as e:
123
+ # print(f"Warning: Could not setup pickle workaround: {e}")
124
+
125
+ # # Set device
126
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ # print(f"\nβœ“ Using device: {device}")
128
+
129
+ # # Load tokenizer
130
+ # print("\nLoading tokenizer...")
131
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
132
+ # if tokenizer.pad_token is None:
133
+ # tokenizer.pad_token = tokenizer.eos_token
134
+ # print("βœ“ Tokenizer loaded")
135
+
136
+ # # Find model checkpoint
137
+ # checkpoint_dirs = [
138
+ # "checkpoints",
139
+ # "../checkpoints",
140
+ # "./checkpoints",
141
+ # os.path.join(os.path.dirname(__file__), "checkpoints"),
142
+ # os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
+ # ]
144
+
145
+ # model_path = None
146
+ # for checkpoint_dir in checkpoint_dirs:
147
+ # best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
+ # fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
+
150
+ # if os.path.exists(best_path):
151
+ # model_path = best_path
152
+ # break
153
+ # elif os.path.exists(fp32_path):
154
+ # model_path = fp32_path
155
+ # break
156
+
157
+ # if model_path is None:
158
+ # raise RuntimeError(
159
+ # "Could not find model checkpoint. Looking for:\n"
160
+ # " - checkpoints/best_model.pt\n"
161
+ # " - checkpoints/model_fp32.pt\n"
162
+ # f"Searched directories: {checkpoint_dirs}"
163
+ # )
164
+
165
+ # print(f"\nβœ“ Found model checkpoint: {model_path}")
166
+ # print("Loading model weights (this may take a minute)...")
167
+
168
+ # # Load model
169
+ # checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
+ # config = checkpoint['config']
171
+
172
+ # print("Creating model...")
173
+ # model = DiffusionLLM(config)
174
+
175
+ # print("Loading state dict...")
176
+ # state_dict = checkpoint['model_state']
177
+ # state_dict = {k: v.float() for k, v in state_dict.items()}
178
+ # model.load_state_dict(state_dict)
179
+
180
+ # model = model.to(device)
181
+ # model.eval()
182
+
183
+ # num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
+ # print(f"\n{'=' * 60}")
185
+ # print(f"βœ“βœ“βœ“ MODEL LOADED SUCCESSFULLY βœ“βœ“βœ“")
186
+ # print(f"{'=' * 60}")
187
+ # print(f"Parameters: {num_params:.1f}M")
188
+ # if 'step' in checkpoint:
189
+ # print(f"Training steps: {checkpoint['step']}")
190
+ # if 'best_val_loss' in checkpoint:
191
+ # print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
+ # print(f"{'=' * 60}\n")
193
+
194
+ # return True
195
+
196
+ # except Exception as e:
197
+ # print("\n" + "=" * 60)
198
+ # print("ERROR LOADING MODEL")
199
+ # print("=" * 60)
200
+ # print(f"Error: {e}")
201
+ # import traceback
202
+ # traceback.print_exc()
203
+ # print("=" * 60 + "\n")
204
+ # return False
205
+
206
+
207
+ # def create_streaming_visualizer():
208
+ # """Create a visualizer that yields SSE events instead of printing to terminal."""
209
+ # def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
+ # # Normalize inputs to lists
211
+ # if not isinstance(mask_blocks, list):
212
+ # mask_blocks = [mask_blocks]
213
+ # is_masked_list = [is_masked_list]
214
+
215
+ # # Decode context
216
+ # try:
217
+ # context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
+ # except Exception:
219
+ # context_text = str(context_ids[0].tolist())
220
+
221
+ # # Build blocks visualization
222
+ # all_blocks = []
223
+ # for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
+ # block_tokens = mask_block[0].tolist()
225
+ # block_data = []
226
+
227
+ # for i, token_id in enumerate(block_tokens):
228
+ # if is_masked[0, i]:
229
+ # block_data.append({
230
+ # 'type': 'masked',
231
+ # 'text': 'β–ˆβ–ˆβ–ˆ'
232
+ # })
233
+ # else:
234
+ # try:
235
+ # token_text = tok.decode([token_id], skip_special_tokens=False)
236
+ # except Exception:
237
+ # token_text = str(int(token_id))
238
+ # block_data.append({
239
+ # 'type': 'revealed',
240
+ # 'text': token_text
241
+ # })
242
+
243
+ # all_blocks.append({
244
+ # 'block_index': block_idx,
245
+ # 'tokens': block_data
246
+ # })
247
+
248
+ # # Return data structure that will be sent as SSE
249
+ # return {
250
+ # 'context': context_text,
251
+ # 'blocks': all_blocks,
252
+ # 'num_blocks': len(mask_blocks)
253
+ # }
254
+
255
+ # return visualizer
256
+
257
+
258
+ # @app.route('/')
259
+ # def index():
260
+ # """Serve the main HTML page."""
261
+ # return app.send_static_file('index.html')
262
+
263
+
264
+ # @app.route('/api/load', methods=['POST'])
265
+ # def load_model_endpoint():
266
+ # """Load the model."""
267
+ # data = request.json or {}
268
+ # check_only = data.get('check_only', False)
269
+
270
+ # global model
271
+
272
+ # if check_only:
273
+ # return jsonify({
274
+ # 'loaded': model is not None,
275
+ # 'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
+ # })
277
+
278
+ # if model is not None:
279
+ # return jsonify({
280
+ # 'loaded': True,
281
+ # 'message': 'Model already loaded'
282
+ # })
283
+
284
+ # success = load_model_internal()
285
+
286
+ # if success:
287
+ # return jsonify({
288
+ # 'loaded': True,
289
+ # 'message': 'Model loaded successfully'
290
+ # })
291
+ # else:
292
+ # return jsonify({
293
+ # 'loaded': False,
294
+ # 'message': 'Failed to load model. Check server logs for details.'
295
+ # }), 500
296
+
297
+
298
+ # @app.route('/api/generate', methods=['POST'])
299
+ # def generate():
300
+ # """Generate response without streaming."""
301
+ # global model, tokenizer, config, device, chat_function
302
+
303
+ # if model is None:
304
+ # return jsonify({'error': 'Model not loaded'}), 400
305
+
306
+ # if chat_function is None:
307
+ # return jsonify({'error': 'Chat function not available'}), 400
308
+
309
+ # data = request.json
310
+ # instruction = data.get('instruction', '')
311
+ # steps = data.get('steps', 64)
312
+ # block_size = data.get('block_size', 128)
313
+ # max_new_tokens = data.get('max_new_tokens', 128)
314
+ # parallel_blocks = data.get('parallel_blocks', 1)
315
+
316
+ # if not instruction:
317
+ # return jsonify({'error': 'No instruction provided'}), 400
318
+
319
+ # try:
320
+ # # Generate response
321
+ # raw_output, response = chat_function(
322
+ # model,
323
+ # tokenizer,
324
+ # instruction,
325
+ # steps=steps,
326
+ # block_size=block_size,
327
+ # max_new_tokens=max_new_tokens,
328
+ # temperature=0.8,
329
+ # top_k=50,
330
+ # top_p=0.9,
331
+ # repetition_penalty=1.2,
332
+ # no_repeat_ngram_size=3,
333
+ # verbose=False,
334
+ # visualize_fn=None,
335
+ # parallel_blocks=parallel_blocks,
336
+ # )
337
+
338
+ # return jsonify({
339
+ # 'response': response,
340
+ # 'raw_output': raw_output
341
+ # })
342
+ # except Exception as e:
343
+ # import traceback
344
+ # traceback.print_exc()
345
+ # return jsonify({'error': str(e)}), 500
346
+
347
+
348
+ # @app.route('/api/generate-stream', methods=['POST'])
349
+ # def generate_stream():
350
+ # """Generate response with streaming visualization."""
351
+ # global model, tokenizer, config, device, chat_function
352
+
353
+ # if model is None:
354
+ # return jsonify({'error': 'Model not loaded'}), 400
355
+
356
+ # if chat_function is None:
357
+ # return jsonify({'error': 'Chat function not available'}), 400
358
+
359
+ # data = request.json
360
+ # instruction = data.get('instruction', '')
361
+ # steps = data.get('steps', 64)
362
+ # block_size = data.get('block_size', 128)
363
+ # max_new_tokens = data.get('max_new_tokens', 128)
364
+ # parallel_blocks = data.get('parallel_blocks', 1)
365
+
366
+ # if not instruction:
367
+ # return jsonify({'error': 'No instruction provided'}), 400
368
+
369
+ # def generate_events():
370
+ # try:
371
+ # # Import threading to allow yielding from callback
372
+ # import queue
373
+ # event_queue = queue.Queue()
374
+ # generation_complete = {'done': False, 'result': None}
375
+
376
+ # def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
+ # """This gets called during generation - we need to send events immediately"""
378
+ # visualizer = create_streaming_visualizer()
379
+ # data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
+ # # Put the update in the queue so it can be yielded immediately
381
+ # event_queue.put({'type': 'update', 'data': data})
382
+
383
+ # # Start generation in a separate thread so we can yield events as they come
384
+ # import threading
385
+
386
+ # def run_generation():
387
+ # try:
388
+ # raw_output, response = chat_function(
389
+ # model,
390
+ # tokenizer,
391
+ # instruction,
392
+ # steps=steps,
393
+ # block_size=block_size,
394
+ # max_new_tokens=max_new_tokens,
395
+ # temperature=0.8,
396
+ # top_k=50,
397
+ # top_p=0.9,
398
+ # repetition_penalty=1.2,
399
+ # no_repeat_ngram_size=3,
400
+ # verbose=False,
401
+ # visualize_fn=streaming_visualizer,
402
+ # parallel_blocks=parallel_blocks,
403
+ # )
404
+ # generation_complete['result'] = (raw_output, response)
405
+ # except Exception as e:
406
+ # generation_complete['result'] = ('error', str(e))
407
+ # finally:
408
+ # generation_complete['done'] = True
409
+ # event_queue.put(None) # Signal completion
410
+
411
+ # # Start generation thread
412
+ # gen_thread = threading.Thread(target=run_generation)
413
+ # gen_thread.daemon = True
414
+ # gen_thread.start()
415
+
416
+ # # Yield start event
417
+ # yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
+
419
+ # # Yield events as they come from the queue
420
+ # while not generation_complete['done'] or not event_queue.empty():
421
+ # try:
422
+ # event = event_queue.get(timeout=0.1)
423
+ # if event is None: # Completion signal
424
+ # break
425
+ # yield f"data: {json.dumps(event)}\n\n"
426
+ # except queue.Empty:
427
+ # continue
428
+
429
+ # # Wait for thread to finish
430
+ # gen_thread.join(timeout=1.0)
431
+
432
+ # # Send final response
433
+ # if generation_complete['result']:
434
+ # raw_output, response = generation_complete['result']
435
+ # if raw_output == 'error':
436
+ # yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
+ # else:
438
+ # yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
+
440
+ # except Exception as e:
441
+ # import traceback
442
+ # traceback.print_exc()
443
+ # yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
+
445
+ # return Response(
446
+ # stream_with_context(generate_events()),
447
+ # mimetype='text/event-stream',
448
+ # headers={
449
+ # 'Cache-Control': 'no-cache',
450
+ # 'X-Accel-Buffering': 'no'
451
+ # }
452
+ # )
453
+
454
+
455
+ # @app.route('/api/test-stream', methods=['GET'])
456
+ # def test_stream():
457
+ # """Test streaming endpoint."""
458
+ # def generate():
459
+ # for i in range(10):
460
+ # yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
+ # time.sleep(0.5)
462
+ # yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
+
464
+ # return Response(
465
+ # stream_with_context(generate()),
466
+ # mimetype='text/event-stream',
467
+ # headers={
468
+ # 'Cache-Control': 'no-cache',
469
+ # 'X-Accel-Buffering': 'no'
470
+ # }
471
+ # )
472
+
473
+
474
+ # if __name__ == '__main__':
475
+ # app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
476
+
477
+
478
  import os
479
  import sys
480
  import json
 
485
  from flask_cors import CORS
486
  import torch
487
  from transformers import AutoTokenizer
488
+ import threading
489
+ import queue
490
 
491
  app = Flask(__name__, static_folder='static', static_url_path='/static')
492
  CORS(app)
 
498
  device = None
499
  DiffusionLLM = None
500
  chat_function = None
501
+ optimized_pipeline = None # For ONNX/IPEX
502
 
503
+ # ==================== CONFIGURATION ====================
504
+ USE_ONNX_RUNTIME = False # Set True for ONNX (fastest)
505
+ USE_IPEX = False # Set True for Intel CPUs
506
+ USE_TORCH_COMPILE = True # Set True for PyTorch 2.0+ (good default)
507
+ QUANTIZE_MODEL = True # INT8/BF16 quantization
508
+ WARMUP_ITERATIONS = 3 # Warmup for stable performance
509
+ # =======================================================
510
 
511
  def find_file(filename, search_dirs=None):
512
  """Find a file in current directory or parent directories."""
513
  if search_dirs is None:
514
  search_dirs = [
515
+ os.path.dirname(__file__),
516
+ os.path.dirname(os.path.dirname(__file__)),
517
+ os.getcwd(),
518
  ]
 
519
  for directory in search_dirs:
520
  filepath = os.path.join(directory, filename)
521
  if os.path.exists(filepath):
522
  print(f"Found {filename} at: {filepath}")
523
  return filepath
 
524
  return None
525
 
 
526
  def try_import_module(filepath, module_name):
527
  """Dynamically import a Python file as a module."""
528
  if not filepath or not os.path.exists(filepath):
529
  return None
 
530
  try:
 
531
  module_dir = os.path.dirname(filepath)
532
  if module_dir not in sys.path:
533
  sys.path.insert(0, module_dir)
 
534
  spec = importlib.util.spec_from_file_location(module_name, filepath)
535
  if spec is None:
536
  print(f"Could not create spec for {filepath}")
537
  return None
 
538
  module = importlib.util.module_from_spec(spec)
539
  sys.modules[module_name] = module
540
  spec.loader.exec_module(module)
 
541
  print(f"Successfully imported {module_name} from {filepath}")
542
  return module
543
  except Exception as e:
544
  print(f"Error importing {filepath}: {e}")
545
+ if __debug__:
546
+ import traceback
547
+ traceback.print_exc()
548
  return None
549
 
550
+ def configure_cpu_optimization():
551
+ """Configure CPU for maximum performance."""
552
+ print("\n" + "=" * 60)
553
+ print("CPU OPTIMIZATION CONFIGURATION")
554
+ print("=" * 60)
555
+
556
+ # Get CPU info
557
+ cpu_count = os.cpu_count()
558
+ physical_cores = cpu_count // 2 if cpu_count else cpu_count
559
+
560
+ # Optimal thread settings
561
+ threads = physical_cores or cpu_count or 1
562
+ torch.set_num_threads(threads)
563
+ torch.set_num_interop_threads(threads)
564
+
565
+ # Environment variables for MKL/OMP
566
+ os.environ["OMP_NUM_THREADS"] = str(threads)
567
+ os.environ["MKL_NUM_THREADS"] = str(threads)
568
+ os.environ["NUMEXPR_NUM_THREADS"] = str(threads)
569
+
570
+ print(f"βœ“ CPU Cores: {cpu_count} ({physical_cores} physical)")
571
+ print(f"βœ“ Threads: {threads}")
572
+ print(f"βœ“ OMP/MKL threads: {threads}")
573
+
574
+ # Intel-specific optimizations
575
+ if "intel" in torch.__version__.lower() or USE_IPEX:
576
+ torch.backends.quantized.engine = 'fbgemm'
577
+ print("βœ“ Intel FBGEMM backend enabled")
578
+
579
+ print("=" * 60 + "\n")
580
+
581
+ def quantize_model(model):
582
+ """Apply quantization for faster inference."""
583
+ if not QUANTIZE_MODEL:
584
+ return model
585
+
586
+ print("\nApplying quantization...")
587
+ try:
588
+ # Use torch.compile with quantization if available
589
+ if hasattr(torch, 'ao') and hasattr(torch.ao, 'quantization'):
590
+ # Dynamic quantization (fastest, no calibration needed)
591
+ model = torch.quantization.quantize_dynamic(
592
+ model,
593
+ {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d},
594
+ dtype=torch.qint8
595
+ )
596
+ print("βœ“ Applied INT8 dynamic quantization")
597
+ elif USE_IPEX:
598
+ # IPEX BF16 optimization
599
+ model = model.to(torch.bfloat16)
600
+ print("βœ“ Applied BF16 precision (IPEX)")
601
+
602
+ return model
603
+ except Exception as e:
604
+ print(f"⚠ Quantization failed: {e}")
605
+ return model
606
+
607
+ def compile_model(model):
608
+ """Compile model for maximum speed."""
609
+ global USE_TORCH_COMPILE, USE_ONNX_RUNTIME
610
+
611
+ print("\nCompiling model...")
612
+
613
+ # Option 1: ONNX Runtime (BEST performance)
614
+ if USE_ONNX_RUNTIME:
615
+ try:
616
+ import onnxruntime as ort
617
+ # Export to ONNX (one-time cost, but worth it)
618
+ onnx_path = "model_optimized.onnx"
619
+ if not os.path.exists(onnx_path):
620
+ print("Exporting to ONNX format...")
621
+ dummy_input = torch.randint(0, 100, (1, 128)) # Adjust shape as needed
622
+ torch.onnx.export(
623
+ model,
624
+ dummy_input,
625
+ onnx_path,
626
+ input_names=['input_ids'],
627
+ output_names=['logits'],
628
+ dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
629
+ opset_version=16
630
+ )
631
+
632
+ # Create ONNX Runtime session
633
+ sess_options = ort.SessionOptions()
634
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
635
+ sess_options.enable_cpu_mem_arena = True
636
+ sess_options.enable_mem_pattern = True
637
+
638
+ # Use all cores
639
+ sess_options.intra_op_num_threads = torch.get_num_threads()
640
+
641
+ provider = 'CPUExecutionProvider'
642
+ compiled_model = ort.InferenceSession(onnx_path, sess_options, providers=[provider])
643
+
644
+ print("βœ“ ONNX Runtime compilation complete")
645
+ return compiled_model
646
+ except Exception as e:
647
+ print(f"⚠ ONNX Runtime failed: {e}, using torch.compile")
648
+ USE_ONNX_RUNTIME = False
649
+
650
+ # Option 2: Intel IPEX
651
+ if USE_IPEX:
652
+ try:
653
+ import intel_extension_for_pytorch as ipex
654
+ model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
655
+ print("βœ“ Intel IPEX O3 optimization applied")
656
+ return model
657
+ except Exception as e:
658
+ print(f"⚠ IPEX failed: {e}")
659
+ USE_IPEX = False
660
+
661
+ # Option 3: torch.compile (PyTorch 2.0+)
662
+ if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
663
+ try:
664
+ # Use "max-autotune" for best performance
665
+ model = torch.compile(
666
+ model,
667
+ mode="max-autotune",
668
+ fullgraph=True,
669
+ backend="inductor"
670
+ )
671
+ print("βœ“ torch.compile (max-autotune) applied")
672
+ except Exception as e:
673
+ print(f"⚠ torch.compile failed: {e}, using eager mode")
674
+
675
+ return model
676
+
677
+ def warmup_model(model, tokenizer, chat_func):
678
+ """Warmup the model for consistent performance."""
679
+ if WARMUP_ITERATIONS == 0:
680
+ return
681
+
682
+ print("\nWarming up model...")
683
+ start_time = time.time()
684
+
685
+ try:
686
+ with torch.inference_mode():
687
+ for i in range(WARMUP_ITERATIONS):
688
+ _ = chat_func(
689
+ model, tokenizer, "Hello world",
690
+ steps=8, block_size=32, max_new_tokens=16,
691
+ temperature=0.8, top_k=50, top_p=0.9,
692
+ repetition_penalty=1.2, no_repeat_ngram_size=3,
693
+ verbose=False, visualize_fn=None, parallel_blocks=2
694
+ )
695
+ print(f" Warmup {i+1}/{WARMUP_ITERATIONS} complete")
696
+ except Exception as e:
697
+ print(f"⚠ Warmup failed: {e}")
698
+
699
+ print(f"βœ“ Warmup finished in {time.time() - start_time:.2f}s")
700
 
701
  def load_model_internal():
702
+ """Load the model with ultra-fast optimizations."""
703
+ global model, tokenizer, config, device, DiffusionLLM, chat_function, optimized_pipeline
704
+
705
  if model is not None:
706
  return True
707
+
708
  try:
709
  print("=" * 60)
710
+ print("ULTRA-FAST CPU MODEL LOADING")
711
  print("=" * 60)
712
+
713
+ # Configure CPU
714
+ configure_cpu_optimization()
715
+
716
+ # Import modules
717
  base_path = find_file("infer-base.py")
718
  if base_path is None:
719
+ raise RuntimeError("Could not find infer-base.py")
720
+
 
721
  base_mod = try_import_module(base_path, "infer_base")
722
+ if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'):
723
+ raise RuntimeError("DiffusionLLM class not found")
724
+
 
 
 
 
 
 
725
  DiffusionLLM = base_mod.DiffusionLLM
726
+
 
 
727
  chat_path = find_file("infer-chat.py")
728
  if chat_path is None:
729
  raise RuntimeError("Could not find infer-chat.py")
730
+
 
731
  chat_mod = try_import_module(chat_path, "infer_chat")
 
732
  if chat_mod is None or not hasattr(chat_mod, 'chat'):
733
+ raise RuntimeError("Chat function not found")
734
+
735
  chat_function = chat_mod.chat
736
+
737
+ # Device
738
+ device = torch.device("cpu")
739
+
 
 
 
 
 
 
 
 
 
 
 
740
  # Load tokenizer
741
  print("\nLoading tokenizer...")
742
+ tokenizer = AutoTokenizer.from_pretrained(
743
+ "Qwen/Qwen2.5-0.5B",
744
+ use_fast=True,
745
+ trust_remote_code=True
746
+ )
747
  if tokenizer.pad_token is None:
748
  tokenizer.pad_token = tokenizer.eos_token
749
+ print("βœ“ Fast tokenizer loaded")
750
+
751
+ # Find model
752
+ checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
 
 
 
 
 
 
 
753
  model_path = None
754
  for checkpoint_dir in checkpoint_dirs:
755
  best_path = os.path.join(checkpoint_dir, "best_model.pt")
756
  fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
 
757
  if os.path.exists(best_path):
758
  model_path = best_path
759
  break
760
  elif os.path.exists(fp32_path):
761
  model_path = fp32_path
762
+
 
763
  if model_path is None:
764
+ raise RuntimeError("Model checkpoint not found")
765
+
766
+ print(f"\nLoading model from: {model_path}")
767
+
768
+ # Load checkpoint
769
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
 
 
 
 
 
 
770
  config = checkpoint['config']
771
+
772
+ # Create and load model
773
  model = DiffusionLLM(config)
 
 
774
  state_dict = checkpoint['model_state']
775
+
776
+ # Convert to float32 for CPU
777
+ if not USE_IPEX: # Keep FP32 for ONNX, use BF16 for IPEX
778
+ state_dict = {k: v.float() for k, v in state_dict.items()}
779
+
780
  model.load_state_dict(state_dict)
 
 
781
  model.eval()
782
+
783
+ # Apply optimizations
784
+ model = quantize_model(model)
785
+ model = model.to(device)
786
+ model = compile_model(model)
787
+
788
+ # Warmup
789
+ warmup_model(model, tokenizer, chat_function)
790
+
791
+ # Print summary
792
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
793
  print(f"\n{'=' * 60}")
794
+ print(f"βœ“βœ“βœ“ MODEL LOADED & ULTRA-OPTIMIZED FOR CPU βœ“βœ“βœ“")
795
  print(f"{'=' * 60}")
796
+ print(f"Framework: {'ONNX Runtime' if USE_ONNX_RUNTIME else 'IPEX' if USE_IPEX else 'PyTorch'}")
797
  print(f"Parameters: {num_params:.1f}M")
798
+ print(f"CPU Threads: {torch.get_num_threads()}")
799
+ print(f"Quantization: {'INT8' if QUANTIZE_MODEL else 'BF16' if USE_IPEX else 'FP32'}")
800
  if 'step' in checkpoint:
801
  print(f"Training steps: {checkpoint['step']}")
 
 
802
  print(f"{'=' * 60}\n")
803
+
804
  return True
805
+
806
  except Exception as e:
807
+ print(f"\nERROR LOADING MODEL: {e}")
808
+ if __debug__:
809
+ import traceback
810
+ traceback.print_exc()
 
 
 
811
  return False
812
 
 
813
  def create_streaming_visualizer():
814
+ """Create optimized visualizer."""
815
  def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
 
816
  if not isinstance(mask_blocks, list):
817
  mask_blocks = [mask_blocks]
818
  is_masked_list = [is_masked_list]
819
+
 
820
  try:
821
+ # Decode only once for efficiency
822
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
823
  except Exception:
824
  context_text = str(context_ids[0].tolist())
825
+
 
826
  all_blocks = []
827
  for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
828
  block_tokens = mask_block[0].tolist()
829
  block_data = []
830
+
831
+ # Batch decode for speed
832
+ token_ids_to_decode = []
833
+ positions = []
834
+
835
+ for i, token_id in enumerate(block_tokens):
836
+ if not is_masked[0, i]:
837
+ token_ids_to_decode.append(token_id)
838
+ positions.append(i)
839
+
840
+ try:
841
+ decoded_tokens = tok.batch_decode(
842
+ [[tid] for tid in token_ids_to_decode],
843
+ skip_special_tokens=False
844
+ )
845
+ except Exception:
846
+ decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
847
+
848
+ # Reconstruct block
849
+ decoded_idx = 0
850
  for i, token_id in enumerate(block_tokens):
851
  if is_masked[0, i]:
852
+ block_data.append({'type': 'masked', 'text': 'β–ˆβ–ˆβ–ˆ'})
 
 
 
853
  else:
 
 
 
 
854
  block_data.append({
855
  'type': 'revealed',
856
+ 'text': decoded_tokens[decoded_idx]
857
  })
858
+ decoded_idx += 1
859
+
860
+ all_blocks.append({'block_index': block_idx, 'tokens': block_data})
861
+
 
 
 
862
  return {
863
  'context': context_text,
864
  'blocks': all_blocks,
865
  'num_blocks': len(mask_blocks)
866
  }
867
+
868
  return visualizer
869
 
 
870
  @app.route('/')
871
  def index():
872
  """Serve the main HTML page."""
873
  return app.send_static_file('index.html')
874
 
 
875
  @app.route('/api/load', methods=['POST'])
876
  def load_model_endpoint():
877
  """Load the model."""
878
  data = request.json or {}
879
  check_only = data.get('check_only', False)
880
+
881
  global model
882
+
883
  if check_only:
884
  return jsonify({
885
  'loaded': model is not None,
886
  'message': 'Model is loaded' if model is not None else 'Model not loaded'
887
  })
888
+
889
  if model is not None:
890
  return jsonify({
891
  'loaded': True,
892
  'message': 'Model already loaded'
893
  })
894
+
895
  success = load_model_internal()
 
896
  if success:
897
  return jsonify({
898
  'loaded': True,
899
+ 'message': 'Model loaded with ultra-fast CPU optimizations'
900
  })
901
  else:
902
  return jsonify({
903
  'loaded': False,
904
+ 'message': 'Failed to load model. Check server logs.'
905
  }), 500
906
 
 
907
  @app.route('/api/generate', methods=['POST'])
908
  def generate():
909
+ """Generate response - optimized for minimal latency."""
910
+ global model, tokenizer, device, chat_function
911
+
912
  if model is None:
913
  return jsonify({'error': 'Model not loaded'}), 400
914
+
915
  if chat_function is None:
916
  return jsonify({'error': 'Chat function not available'}), 400
917
+
918
  data = request.json
919
  instruction = data.get('instruction', '')
920
+ steps = data.get('steps', 16) # Further reduced for speed
921
+ block_size = data.get('block_size', 32)
922
+ max_new_tokens = data.get('max_new_tokens', 64)
923
+ parallel_blocks = data.get('parallel_blocks', torch.get_num_threads())
924
+
925
  if not instruction:
926
  return jsonify({'error': 'No instruction provided'}), 400
927
+
928
  try:
929
+ # Fast path: no overhead
930
+ with torch.inference_mode():
931
+ raw_output, response = chat_function(
932
+ model,
933
+ tokenizer,
934
+ instruction,
935
+ steps=steps,
936
+ block_size=block_size,
937
+ max_new_tokens=max_new_tokens,
938
+ temperature=0.7, # Slightly lower for faster sampling
939
+ top_k=50,
940
+ top_p=0.9,
941
+ repetition_penalty=1.2,
942
+ no_repeat_ngram_size=3,
943
+ verbose=False,
944
+ visualize_fn=None,
945
+ parallel_blocks=parallel_blocks,
946
+ )
947
+
948
  return jsonify({
949
  'response': response,
950
  'raw_output': raw_output
951
  })
952
+
953
  except Exception as e:
954
+ if __debug__:
955
+ import traceback
956
+ traceback.print_exc()
957
  return jsonify({'error': str(e)}), 500
958
 
 
959
  @app.route('/api/generate-stream', methods=['POST'])
960
  def generate_stream():
961
+ """Generate with streaming - optimized with queue."""
962
+ global model, tokenizer, chat_function
963
+
964
  if model is None:
965
  return jsonify({'error': 'Model not loaded'}), 400
966
+
967
  if chat_function is None:
968
  return jsonify({'error': 'Chat function not available'}), 400
969
+
970
  data = request.json
971
  instruction = data.get('instruction', '')
972
+ steps = data.get('steps', 16)
973
+ block_size = data.get('block_size', 32)
974
+ max_new_tokens = data.get('max_new_tokens', 64)
975
+ parallel_blocks = data.get('parallel_blocks', torch.get_num_threads())
976
+
977
  if not instruction:
978
  return jsonify({'error': 'No instruction provided'}), 400
979
+
980
  def generate_events():
981
+ event_queue = queue.Queue(maxsize=100) # Limit queue size
982
+ generation_complete = {'done': False, 'result': None}
983
+
984
+ def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
985
+ try:
 
 
 
986
  visualizer = create_streaming_visualizer()
987
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
988
+ event_queue.put({'type': 'update', 'data': data}, block=False)
989
+ except queue.Full:
990
+ pass # Drop updates if queue full (prevents memory issues)
991
+
992
+ def run_generation():
993
+ try:
994
+ with torch.inference_mode():
 
995
  raw_output, response = chat_function(
996
+ model, tokenizer, instruction,
997
+ steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
998
+ temperature=0.7, top_k=50, top_p=0.9,
999
+ repetition_penalty=1.2, no_repeat_ngram_size=3,
1000
+ verbose=False, visualize_fn=streaming_visualizer,
 
 
 
 
 
 
 
 
1001
  parallel_blocks=parallel_blocks,
1002
  )
1003
+ generation_complete['result'] = (raw_output, response)
1004
+ except Exception as e:
1005
+ generation_complete['result'] = ('error', str(e))
1006
+ finally:
1007
+ generation_complete['done'] = True
1008
+ event_queue.put(None)
1009
+
1010
+ import threading
1011
+ gen_thread = threading.Thread(target=run_generation)
1012
+ gen_thread.daemon = True
1013
+ gen_thread.start()
1014
+
1015
+ yield f"data: {json.dumps({'type': 'start'})}\n\n"
1016
+
1017
+ while not generation_complete['done'] or not event_queue.empty():
1018
+ try:
1019
+ event = event_queue.get(timeout=0.1)
1020
+ if event is None:
1021
+ break
1022
+ yield f"data: {json.dumps(event)}\n\n"
1023
+ except queue.Empty:
1024
+ continue
1025
+
1026
+ gen_thread.join(timeout=2.0)
1027
+
1028
+ if generation_complete['result']:
1029
+ raw_output, response = generation_complete['result']
1030
+ if raw_output == 'error':
1031
+ yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
1032
+ else:
1033
+ yield f"data: {json.dumps({'type': 'complete', 'response': response})}\n\n"
1034
+
 
 
 
 
 
 
 
 
 
1035
  return Response(
1036
  stream_with_context(generate_events()),
1037
  mimetype='text/event-stream',
1038
  headers={
1039
  'Cache-Control': 'no-cache',
1040
+ 'X-Accel-Buffering': 'no',
1041
+ 'Connection': 'keep-alive'
1042
  }
1043
  )
1044
 
1045
+ @app.route('/api/status', methods=['GET'])
1046
+ def status():
1047
+ """Get system status."""
1048
+ return jsonify({
1049
+ 'model_loaded': model is not None,
1050
+ 'torch_threads': torch.get_num_threads(),
1051
+ 'interop_threads': torch.get_num_interop_threads(),
1052
+ 'cpu_count': os.cpu_count(),
1053
+ 'optimizations': {
1054
+ 'onnx_runtime': USE_ONNX_RUNTIME,
1055
+ 'ipex': USE_IPEX,
1056
+ 'torch_compile': USE_TORCH_COMPILE,
1057
+ 'quantization': QUANTIZE_MODEL
 
 
 
1058
  }
1059
+ })
 
1060
 
1061
  if __name__ == '__main__':
1062
+ print("\n" + "=" * 70)
1063
+ print("ULTRA-FAST CPU INFERENCE SERVER")
1064
+ print("=" * 70)
1065
+ print("Available optimizations:")
1066
+ print(" βœ“ ONNX Runtime (best):", USE_ONNX_RUNTIME)
1067
+ print(" βœ“ Intel IPEX:", USE_IPEX)
1068
+ print(" βœ“ torch.compile:", USE_TORCH_COMPILE)
1069
+ print(" βœ“ Quantization:", QUANTIZE_MODEL)
1070
+ print(" βœ“ Multi-threading")
1071
+ print(" βœ“ Inference mode")
1072
+ print(" βœ“ Fast tokenizer")
1073
+ print(" βœ“ Memory layout optimization")
1074
+ print("\nTo install ONNX Runtime: pip install onnxruntime")
1075
+ print("To install Intel IPEX: pip install intel-extension-for-pytorch")
1076
+ print("=" * 70 + "\n")
1077
+
1078
+ app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)