thejagstudio commited on
Commit
0265bc8
·
verified ·
1 Parent(s): 486838c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +475 -475
app.py CHANGED
@@ -1,475 +1,475 @@
1
- import os
2
- import sys
3
- import json
4
- import time
5
- import importlib.util
6
- from pathlib import Path
7
- from flask import Flask, request, jsonify, Response, stream_with_context
8
- from flask_cors import CORS
9
- import torch
10
- from transformers import AutoTokenizer
11
-
12
- app = Flask(__name__, static_folder='static', static_url_path='/static')
13
- CORS(app)
14
-
15
- # Global state
16
- model = None
17
- tokenizer = None
18
- config = None
19
- device = None
20
- DiffusionLLM = None
21
- chat_function = None
22
-
23
-
24
- def find_file(filename, search_dirs=None):
25
- """Find a file in current directory or parent directories."""
26
- if search_dirs is None:
27
- search_dirs = [
28
- os.path.dirname(__file__), # Current directory
29
- os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
- os.getcwd(), # Working directory
31
- ]
32
-
33
- for directory in search_dirs:
34
- filepath = os.path.join(directory, filename)
35
- if os.path.exists(filepath):
36
- print(f"Found {filename} at: {filepath}")
37
- return filepath
38
-
39
- return None
40
-
41
-
42
- def try_import_module(filepath, module_name):
43
- """Dynamically import a Python file as a module."""
44
- if not filepath or not os.path.exists(filepath):
45
- return None
46
-
47
- try:
48
- # Add the directory to sys.path
49
- module_dir = os.path.dirname(filepath)
50
- if module_dir not in sys.path:
51
- sys.path.insert(0, module_dir)
52
-
53
- spec = importlib.util.spec_from_file_location(module_name, filepath)
54
- if spec is None:
55
- print(f"Could not create spec for {filepath}")
56
- return None
57
-
58
- module = importlib.util.module_from_spec(spec)
59
- sys.modules[module_name] = module
60
- spec.loader.exec_module(module)
61
-
62
- print(f"Successfully imported {module_name} from {filepath}")
63
- return module
64
- except Exception as e:
65
- print(f"Error importing {filepath}: {e}")
66
- import traceback
67
- traceback.print_exc()
68
- return None
69
-
70
-
71
- def load_model_internal():
72
- """Load the model and tokenizer."""
73
- global model, tokenizer, config, device, DiffusionLLM, chat_function
74
-
75
- if model is not None:
76
- return True
77
-
78
- try:
79
- print("=" * 60)
80
- print("Starting model loading process...")
81
- print("=" * 60)
82
-
83
- # Find and import infer-base.py
84
- base_path = find_file("infer-base.py")
85
- if base_path is None:
86
- raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
-
88
- print(f"\nImporting infer-base.py from: {base_path}")
89
- base_mod = try_import_module(base_path, "infer_base")
90
-
91
- if base_mod is None:
92
- raise RuntimeError("Failed to import infer-base.py")
93
-
94
- # Check for DiffusionLLM class
95
- if not hasattr(base_mod, 'DiffusionLLM'):
96
- print("Available attributes in infer_base:", dir(base_mod))
97
- raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
-
99
- DiffusionLLM = base_mod.DiffusionLLM
100
- print("✓ Successfully loaded DiffusionLLM class")
101
-
102
- # Find and import infer-chat.py
103
- chat_path = find_file("infer-chat.py")
104
- if chat_path is None:
105
- raise RuntimeError("Could not find infer-chat.py")
106
-
107
- print(f"\nImporting infer-chat.py from: {chat_path}")
108
- chat_mod = try_import_module(chat_path, "infer_chat")
109
-
110
- if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
- raise RuntimeError("Failed to import chat function from infer-chat.py")
112
-
113
- chat_function = chat_mod.chat
114
- print("✓ Successfully loaded chat function")
115
-
116
- # Setup pickling workaround for torch.load
117
- try:
118
- if hasattr(base_mod, 'ModelConfig'):
119
- sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
- sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
- print("✓ Configured pickle support for model loading")
122
- except Exception as e:
123
- print(f"Warning: Could not setup pickle workaround: {e}")
124
-
125
- # Set device
126
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
- print(f"\n✓ Using device: {device}")
128
-
129
- # Load tokenizer
130
- print("\nLoading tokenizer...")
131
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
132
- if tokenizer.pad_token is None:
133
- tokenizer.pad_token = tokenizer.eos_token
134
- print("✓ Tokenizer loaded")
135
-
136
- # Find model checkpoint
137
- checkpoint_dirs = [
138
- "checkpoints",
139
- "../checkpoints",
140
- "./checkpoints",
141
- os.path.join(os.path.dirname(__file__), "checkpoints"),
142
- os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
- ]
144
-
145
- model_path = None
146
- for checkpoint_dir in checkpoint_dirs:
147
- best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
- fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
-
150
- if os.path.exists(best_path):
151
- model_path = best_path
152
- break
153
- elif os.path.exists(fp32_path):
154
- model_path = fp32_path
155
- break
156
-
157
- if model_path is None:
158
- raise RuntimeError(
159
- "Could not find model checkpoint. Looking for:\n"
160
- " - checkpoints/best_model.pt\n"
161
- " - checkpoints/model_fp32.pt\n"
162
- f"Searched directories: {checkpoint_dirs}"
163
- )
164
-
165
- print(f"\n✓ Found model checkpoint: {model_path}")
166
- print("Loading model weights (this may take a minute)...")
167
-
168
- # Load model
169
- checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
- config = checkpoint['config']
171
-
172
- print("Creating model...")
173
- model = DiffusionLLM(config)
174
-
175
- print("Loading state dict...")
176
- state_dict = checkpoint['model_state']
177
- state_dict = {k: v.float() for k, v in state_dict.items()}
178
- model.load_state_dict(state_dict)
179
-
180
- model = model.to(device)
181
- model.eval()
182
-
183
- num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
- print(f"\n{'=' * 60}")
185
- print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
186
- print(f"{'=' * 60}")
187
- print(f"Parameters: {num_params:.1f}M")
188
- if 'step' in checkpoint:
189
- print(f"Training steps: {checkpoint['step']}")
190
- if 'best_val_loss' in checkpoint:
191
- print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
- print(f"{'=' * 60}\n")
193
-
194
- return True
195
-
196
- except Exception as e:
197
- print("\n" + "=" * 60)
198
- print("ERROR LOADING MODEL")
199
- print("=" * 60)
200
- print(f"Error: {e}")
201
- import traceback
202
- traceback.print_exc()
203
- print("=" * 60 + "\n")
204
- return False
205
-
206
-
207
- def create_streaming_visualizer():
208
- """Create a visualizer that yields SSE events instead of printing to terminal."""
209
- def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
- # Normalize inputs to lists
211
- if not isinstance(mask_blocks, list):
212
- mask_blocks = [mask_blocks]
213
- is_masked_list = [is_masked_list]
214
-
215
- # Decode context
216
- try:
217
- context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
- except Exception:
219
- context_text = str(context_ids[0].tolist())
220
-
221
- # Build blocks visualization
222
- all_blocks = []
223
- for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
- block_tokens = mask_block[0].tolist()
225
- block_data = []
226
-
227
- for i, token_id in enumerate(block_tokens):
228
- if is_masked[0, i]:
229
- block_data.append({
230
- 'type': 'masked',
231
- 'text': '███'
232
- })
233
- else:
234
- try:
235
- token_text = tok.decode([token_id], skip_special_tokens=False)
236
- except Exception:
237
- token_text = str(int(token_id))
238
- block_data.append({
239
- 'type': 'revealed',
240
- 'text': token_text
241
- })
242
-
243
- all_blocks.append({
244
- 'block_index': block_idx,
245
- 'tokens': block_data
246
- })
247
-
248
- # Return data structure that will be sent as SSE
249
- return {
250
- 'context': context_text,
251
- 'blocks': all_blocks,
252
- 'num_blocks': len(mask_blocks)
253
- }
254
-
255
- return visualizer
256
-
257
-
258
- @app.route('/')
259
- def index():
260
- """Serve the main HTML page."""
261
- return app.send_static_file('index.html')
262
-
263
-
264
- @app.route('/api/load', methods=['POST'])
265
- def load_model_endpoint():
266
- """Load the model."""
267
- data = request.json or {}
268
- check_only = data.get('check_only', False)
269
-
270
- global model
271
-
272
- if check_only:
273
- return jsonify({
274
- 'loaded': model is not None,
275
- 'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
- })
277
-
278
- if model is not None:
279
- return jsonify({
280
- 'loaded': True,
281
- 'message': 'Model already loaded'
282
- })
283
-
284
- success = load_model_internal()
285
-
286
- if success:
287
- return jsonify({
288
- 'loaded': True,
289
- 'message': 'Model loaded successfully'
290
- })
291
- else:
292
- return jsonify({
293
- 'loaded': False,
294
- 'message': 'Failed to load model. Check server logs for details.'
295
- }), 500
296
-
297
-
298
- @app.route('/api/generate', methods=['POST'])
299
- def generate():
300
- """Generate response without streaming."""
301
- global model, tokenizer, config, device, chat_function
302
-
303
- if model is None:
304
- return jsonify({'error': 'Model not loaded'}), 400
305
-
306
- if chat_function is None:
307
- return jsonify({'error': 'Chat function not available'}), 400
308
-
309
- data = request.json
310
- instruction = data.get('instruction', '')
311
- steps = data.get('steps', 64)
312
- block_size = data.get('block_size', 128)
313
- max_new_tokens = data.get('max_new_tokens', 128)
314
- parallel_blocks = data.get('parallel_blocks', 1)
315
-
316
- if not instruction:
317
- return jsonify({'error': 'No instruction provided'}), 400
318
-
319
- try:
320
- # Generate response
321
- raw_output, response = chat_function(
322
- model,
323
- tokenizer,
324
- instruction,
325
- steps=steps,
326
- block_size=block_size,
327
- max_new_tokens=max_new_tokens,
328
- temperature=0.8,
329
- top_k=50,
330
- top_p=0.9,
331
- repetition_penalty=1.2,
332
- no_repeat_ngram_size=3,
333
- verbose=False,
334
- visualize_fn=None,
335
- parallel_blocks=parallel_blocks,
336
- )
337
-
338
- return jsonify({
339
- 'response': response,
340
- 'raw_output': raw_output
341
- })
342
- except Exception as e:
343
- import traceback
344
- traceback.print_exc()
345
- return jsonify({'error': str(e)}), 500
346
-
347
-
348
- @app.route('/api/generate-stream', methods=['POST'])
349
- def generate_stream():
350
- """Generate response with streaming visualization."""
351
- global model, tokenizer, config, device, chat_function
352
-
353
- if model is None:
354
- return jsonify({'error': 'Model not loaded'}), 400
355
-
356
- if chat_function is None:
357
- return jsonify({'error': 'Chat function not available'}), 400
358
-
359
- data = request.json
360
- instruction = data.get('instruction', '')
361
- steps = data.get('steps', 64)
362
- block_size = data.get('block_size', 128)
363
- max_new_tokens = data.get('max_new_tokens', 128)
364
- parallel_blocks = data.get('parallel_blocks', 1)
365
-
366
- if not instruction:
367
- return jsonify({'error': 'No instruction provided'}), 400
368
-
369
- def generate_events():
370
- try:
371
- # Import threading to allow yielding from callback
372
- import queue
373
- event_queue = queue.Queue()
374
- generation_complete = {'done': False, 'result': None}
375
-
376
- def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
- """This gets called during generation - we need to send events immediately"""
378
- visualizer = create_streaming_visualizer()
379
- data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
- # Put the update in the queue so it can be yielded immediately
381
- event_queue.put({'type': 'update', 'data': data})
382
-
383
- # Start generation in a separate thread so we can yield events as they come
384
- import threading
385
-
386
- def run_generation():
387
- try:
388
- raw_output, response = chat_function(
389
- model,
390
- tokenizer,
391
- instruction,
392
- steps=steps,
393
- block_size=block_size,
394
- max_new_tokens=max_new_tokens,
395
- temperature=0.8,
396
- top_k=50,
397
- top_p=0.9,
398
- repetition_penalty=1.2,
399
- no_repeat_ngram_size=3,
400
- verbose=False,
401
- visualize_fn=streaming_visualizer,
402
- parallel_blocks=parallel_blocks,
403
- )
404
- generation_complete['result'] = (raw_output, response)
405
- except Exception as e:
406
- generation_complete['result'] = ('error', str(e))
407
- finally:
408
- generation_complete['done'] = True
409
- event_queue.put(None) # Signal completion
410
-
411
- # Start generation thread
412
- gen_thread = threading.Thread(target=run_generation)
413
- gen_thread.daemon = True
414
- gen_thread.start()
415
-
416
- # Yield start event
417
- yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
-
419
- # Yield events as they come from the queue
420
- while not generation_complete['done'] or not event_queue.empty():
421
- try:
422
- event = event_queue.get(timeout=0.1)
423
- if event is None: # Completion signal
424
- break
425
- yield f"data: {json.dumps(event)}\n\n"
426
- except queue.Empty:
427
- continue
428
-
429
- # Wait for thread to finish
430
- gen_thread.join(timeout=1.0)
431
-
432
- # Send final response
433
- if generation_complete['result']:
434
- raw_output, response = generation_complete['result']
435
- if raw_output == 'error':
436
- yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
- else:
438
- yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
-
440
- except Exception as e:
441
- import traceback
442
- traceback.print_exc()
443
- yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
-
445
- return Response(
446
- stream_with_context(generate_events()),
447
- mimetype='text/event-stream',
448
- headers={
449
- 'Cache-Control': 'no-cache',
450
- 'X-Accel-Buffering': 'no'
451
- }
452
- )
453
-
454
-
455
- @app.route('/api/test-stream', methods=['GET'])
456
- def test_stream():
457
- """Test streaming endpoint."""
458
- def generate():
459
- for i in range(10):
460
- yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
- time.sleep(0.5)
462
- yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
-
464
- return Response(
465
- stream_with_context(generate()),
466
- mimetype='text/event-stream',
467
- headers={
468
- 'Cache-Control': 'no-cache',
469
- 'X-Accel-Buffering': 'no'
470
- }
471
- )
472
-
473
-
474
- if __name__ == '__main__':
475
- app.run(debug=True, host='0.0.0.0', port=5000, threaded=True)
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import importlib.util
6
+ from pathlib import Path
7
+ from flask import Flask, request, jsonify, Response, stream_with_context
8
+ from flask_cors import CORS
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ app = Flask(__name__, static_folder='static', static_url_path='/static')
13
+ CORS(app)
14
+
15
+ # Global state
16
+ model = None
17
+ tokenizer = None
18
+ config = None
19
+ device = None
20
+ DiffusionLLM = None
21
+ chat_function = None
22
+
23
+
24
+ def find_file(filename, search_dirs=None):
25
+ """Find a file in current directory or parent directories."""
26
+ if search_dirs is None:
27
+ search_dirs = [
28
+ os.path.dirname(__file__), # Current directory
29
+ os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
+ os.getcwd(), # Working directory
31
+ ]
32
+
33
+ for directory in search_dirs:
34
+ filepath = os.path.join(directory, filename)
35
+ if os.path.exists(filepath):
36
+ print(f"Found {filename} at: {filepath}")
37
+ return filepath
38
+
39
+ return None
40
+
41
+
42
+ def try_import_module(filepath, module_name):
43
+ """Dynamically import a Python file as a module."""
44
+ if not filepath or not os.path.exists(filepath):
45
+ return None
46
+
47
+ try:
48
+ # Add the directory to sys.path
49
+ module_dir = os.path.dirname(filepath)
50
+ if module_dir not in sys.path:
51
+ sys.path.insert(0, module_dir)
52
+
53
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
54
+ if spec is None:
55
+ print(f"Could not create spec for {filepath}")
56
+ return None
57
+
58
+ module = importlib.util.module_from_spec(spec)
59
+ sys.modules[module_name] = module
60
+ spec.loader.exec_module(module)
61
+
62
+ print(f"Successfully imported {module_name} from {filepath}")
63
+ return module
64
+ except Exception as e:
65
+ print(f"Error importing {filepath}: {e}")
66
+ import traceback
67
+ traceback.print_exc()
68
+ return None
69
+
70
+
71
+ def load_model_internal():
72
+ """Load the model and tokenizer."""
73
+ global model, tokenizer, config, device, DiffusionLLM, chat_function
74
+
75
+ if model is not None:
76
+ return True
77
+
78
+ try:
79
+ print("=" * 60)
80
+ print("Starting model loading process...")
81
+ print("=" * 60)
82
+
83
+ # Find and import infer-base.py
84
+ base_path = find_file("infer-base.py")
85
+ if base_path is None:
86
+ raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
+
88
+ print(f"\nImporting infer-base.py from: {base_path}")
89
+ base_mod = try_import_module(base_path, "infer_base")
90
+
91
+ if base_mod is None:
92
+ raise RuntimeError("Failed to import infer-base.py")
93
+
94
+ # Check for DiffusionLLM class
95
+ if not hasattr(base_mod, 'DiffusionLLM'):
96
+ print("Available attributes in infer_base:", dir(base_mod))
97
+ raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
+
99
+ DiffusionLLM = base_mod.DiffusionLLM
100
+ print("✓ Successfully loaded DiffusionLLM class")
101
+
102
+ # Find and import infer-chat.py
103
+ chat_path = find_file("infer-chat.py")
104
+ if chat_path is None:
105
+ raise RuntimeError("Could not find infer-chat.py")
106
+
107
+ print(f"\nImporting infer-chat.py from: {chat_path}")
108
+ chat_mod = try_import_module(chat_path, "infer_chat")
109
+
110
+ if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
+ raise RuntimeError("Failed to import chat function from infer-chat.py")
112
+
113
+ chat_function = chat_mod.chat
114
+ print("✓ Successfully loaded chat function")
115
+
116
+ # Setup pickling workaround for torch.load
117
+ try:
118
+ if hasattr(base_mod, 'ModelConfig'):
119
+ sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
+ sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
+ print("✓ Configured pickle support for model loading")
122
+ except Exception as e:
123
+ print(f"Warning: Could not setup pickle workaround: {e}")
124
+
125
+ # Set device
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ print(f"\n✓ Using device: {device}")
128
+
129
+ # Load tokenizer
130
+ print("\nLoading tokenizer...")
131
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+ print("✓ Tokenizer loaded")
135
+
136
+ # Find model checkpoint
137
+ checkpoint_dirs = [
138
+ "checkpoints",
139
+ "../checkpoints",
140
+ "./checkpoints",
141
+ os.path.join(os.path.dirname(__file__), "checkpoints"),
142
+ os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
+ ]
144
+
145
+ model_path = None
146
+ for checkpoint_dir in checkpoint_dirs:
147
+ best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
+ fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
+
150
+ if os.path.exists(best_path):
151
+ model_path = best_path
152
+ break
153
+ elif os.path.exists(fp32_path):
154
+ model_path = fp32_path
155
+ break
156
+
157
+ if model_path is None:
158
+ raise RuntimeError(
159
+ "Could not find model checkpoint. Looking for:\n"
160
+ " - checkpoints/best_model.pt\n"
161
+ " - checkpoints/model_fp32.pt\n"
162
+ f"Searched directories: {checkpoint_dirs}"
163
+ )
164
+
165
+ print(f"\n✓ Found model checkpoint: {model_path}")
166
+ print("Loading model weights (this may take a minute)...")
167
+
168
+ # Load model
169
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
+ config = checkpoint['config']
171
+
172
+ print("Creating model...")
173
+ model = DiffusionLLM(config)
174
+
175
+ print("Loading state dict...")
176
+ state_dict = checkpoint['model_state']
177
+ state_dict = {k: v.float() for k, v in state_dict.items()}
178
+ model.load_state_dict(state_dict)
179
+
180
+ model = model.to(device)
181
+ model.eval()
182
+
183
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
+ print(f"\n{'=' * 60}")
185
+ print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
186
+ print(f"{'=' * 60}")
187
+ print(f"Parameters: {num_params:.1f}M")
188
+ if 'step' in checkpoint:
189
+ print(f"Training steps: {checkpoint['step']}")
190
+ if 'best_val_loss' in checkpoint:
191
+ print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
+ print(f"{'=' * 60}\n")
193
+
194
+ return True
195
+
196
+ except Exception as e:
197
+ print("\n" + "=" * 60)
198
+ print("ERROR LOADING MODEL")
199
+ print("=" * 60)
200
+ print(f"Error: {e}")
201
+ import traceback
202
+ traceback.print_exc()
203
+ print("=" * 60 + "\n")
204
+ return False
205
+
206
+
207
+ def create_streaming_visualizer():
208
+ """Create a visualizer that yields SSE events instead of printing to terminal."""
209
+ def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
+ # Normalize inputs to lists
211
+ if not isinstance(mask_blocks, list):
212
+ mask_blocks = [mask_blocks]
213
+ is_masked_list = [is_masked_list]
214
+
215
+ # Decode context
216
+ try:
217
+ context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
+ except Exception:
219
+ context_text = str(context_ids[0].tolist())
220
+
221
+ # Build blocks visualization
222
+ all_blocks = []
223
+ for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
+ block_tokens = mask_block[0].tolist()
225
+ block_data = []
226
+
227
+ for i, token_id in enumerate(block_tokens):
228
+ if is_masked[0, i]:
229
+ block_data.append({
230
+ 'type': 'masked',
231
+ 'text': '███'
232
+ })
233
+ else:
234
+ try:
235
+ token_text = tok.decode([token_id], skip_special_tokens=False)
236
+ except Exception:
237
+ token_text = str(int(token_id))
238
+ block_data.append({
239
+ 'type': 'revealed',
240
+ 'text': token_text
241
+ })
242
+
243
+ all_blocks.append({
244
+ 'block_index': block_idx,
245
+ 'tokens': block_data
246
+ })
247
+
248
+ # Return data structure that will be sent as SSE
249
+ return {
250
+ 'context': context_text,
251
+ 'blocks': all_blocks,
252
+ 'num_blocks': len(mask_blocks)
253
+ }
254
+
255
+ return visualizer
256
+
257
+
258
+ @app.route('/')
259
+ def index():
260
+ """Serve the main HTML page."""
261
+ return app.send_static_file('index.html')
262
+
263
+
264
+ @app.route('/api/load', methods=['POST'])
265
+ def load_model_endpoint():
266
+ """Load the model."""
267
+ data = request.json or {}
268
+ check_only = data.get('check_only', False)
269
+
270
+ global model
271
+
272
+ if check_only:
273
+ return jsonify({
274
+ 'loaded': model is not None,
275
+ 'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
+ })
277
+
278
+ if model is not None:
279
+ return jsonify({
280
+ 'loaded': True,
281
+ 'message': 'Model already loaded'
282
+ })
283
+
284
+ success = load_model_internal()
285
+
286
+ if success:
287
+ return jsonify({
288
+ 'loaded': True,
289
+ 'message': 'Model loaded successfully'
290
+ })
291
+ else:
292
+ return jsonify({
293
+ 'loaded': False,
294
+ 'message': 'Failed to load model. Check server logs for details.'
295
+ }), 500
296
+
297
+
298
+ @app.route('/api/generate', methods=['POST'])
299
+ def generate():
300
+ """Generate response without streaming."""
301
+ global model, tokenizer, config, device, chat_function
302
+
303
+ if model is None:
304
+ return jsonify({'error': 'Model not loaded'}), 400
305
+
306
+ if chat_function is None:
307
+ return jsonify({'error': 'Chat function not available'}), 400
308
+
309
+ data = request.json
310
+ instruction = data.get('instruction', '')
311
+ steps = data.get('steps', 64)
312
+ block_size = data.get('block_size', 128)
313
+ max_new_tokens = data.get('max_new_tokens', 128)
314
+ parallel_blocks = data.get('parallel_blocks', 1)
315
+
316
+ if not instruction:
317
+ return jsonify({'error': 'No instruction provided'}), 400
318
+
319
+ try:
320
+ # Generate response
321
+ raw_output, response = chat_function(
322
+ model,
323
+ tokenizer,
324
+ instruction,
325
+ steps=steps,
326
+ block_size=block_size,
327
+ max_new_tokens=max_new_tokens,
328
+ temperature=0.8,
329
+ top_k=50,
330
+ top_p=0.9,
331
+ repetition_penalty=1.2,
332
+ no_repeat_ngram_size=3,
333
+ verbose=False,
334
+ visualize_fn=None,
335
+ parallel_blocks=parallel_blocks,
336
+ )
337
+
338
+ return jsonify({
339
+ 'response': response,
340
+ 'raw_output': raw_output
341
+ })
342
+ except Exception as e:
343
+ import traceback
344
+ traceback.print_exc()
345
+ return jsonify({'error': str(e)}), 500
346
+
347
+
348
+ @app.route('/api/generate-stream', methods=['POST'])
349
+ def generate_stream():
350
+ """Generate response with streaming visualization."""
351
+ global model, tokenizer, config, device, chat_function
352
+
353
+ if model is None:
354
+ return jsonify({'error': 'Model not loaded'}), 400
355
+
356
+ if chat_function is None:
357
+ return jsonify({'error': 'Chat function not available'}), 400
358
+
359
+ data = request.json
360
+ instruction = data.get('instruction', '')
361
+ steps = data.get('steps', 64)
362
+ block_size = data.get('block_size', 128)
363
+ max_new_tokens = data.get('max_new_tokens', 128)
364
+ parallel_blocks = data.get('parallel_blocks', 1)
365
+
366
+ if not instruction:
367
+ return jsonify({'error': 'No instruction provided'}), 400
368
+
369
+ def generate_events():
370
+ try:
371
+ # Import threading to allow yielding from callback
372
+ import queue
373
+ event_queue = queue.Queue()
374
+ generation_complete = {'done': False, 'result': None}
375
+
376
+ def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
+ """This gets called during generation - we need to send events immediately"""
378
+ visualizer = create_streaming_visualizer()
379
+ data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
+ # Put the update in the queue so it can be yielded immediately
381
+ event_queue.put({'type': 'update', 'data': data})
382
+
383
+ # Start generation in a separate thread so we can yield events as they come
384
+ import threading
385
+
386
+ def run_generation():
387
+ try:
388
+ raw_output, response = chat_function(
389
+ model,
390
+ tokenizer,
391
+ instruction,
392
+ steps=steps,
393
+ block_size=block_size,
394
+ max_new_tokens=max_new_tokens,
395
+ temperature=0.8,
396
+ top_k=50,
397
+ top_p=0.9,
398
+ repetition_penalty=1.2,
399
+ no_repeat_ngram_size=3,
400
+ verbose=False,
401
+ visualize_fn=streaming_visualizer,
402
+ parallel_blocks=parallel_blocks,
403
+ )
404
+ generation_complete['result'] = (raw_output, response)
405
+ except Exception as e:
406
+ generation_complete['result'] = ('error', str(e))
407
+ finally:
408
+ generation_complete['done'] = True
409
+ event_queue.put(None) # Signal completion
410
+
411
+ # Start generation thread
412
+ gen_thread = threading.Thread(target=run_generation)
413
+ gen_thread.daemon = True
414
+ gen_thread.start()
415
+
416
+ # Yield start event
417
+ yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
+
419
+ # Yield events as they come from the queue
420
+ while not generation_complete['done'] or not event_queue.empty():
421
+ try:
422
+ event = event_queue.get(timeout=0.1)
423
+ if event is None: # Completion signal
424
+ break
425
+ yield f"data: {json.dumps(event)}\n\n"
426
+ except queue.Empty:
427
+ continue
428
+
429
+ # Wait for thread to finish
430
+ gen_thread.join(timeout=1.0)
431
+
432
+ # Send final response
433
+ if generation_complete['result']:
434
+ raw_output, response = generation_complete['result']
435
+ if raw_output == 'error':
436
+ yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
+ else:
438
+ yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
+
440
+ except Exception as e:
441
+ import traceback
442
+ traceback.print_exc()
443
+ yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
+
445
+ return Response(
446
+ stream_with_context(generate_events()),
447
+ mimetype='text/event-stream',
448
+ headers={
449
+ 'Cache-Control': 'no-cache',
450
+ 'X-Accel-Buffering': 'no'
451
+ }
452
+ )
453
+
454
+
455
+ @app.route('/api/test-stream', methods=['GET'])
456
+ def test_stream():
457
+ """Test streaming endpoint."""
458
+ def generate():
459
+ for i in range(10):
460
+ yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
+ time.sleep(0.5)
462
+ yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
+
464
+ return Response(
465
+ stream_with_context(generate()),
466
+ mimetype='text/event-stream',
467
+ headers={
468
+ 'Cache-Control': 'no-cache',
469
+ 'X-Accel-Buffering': 'no'
470
+ }
471
+ )
472
+
473
+
474
+ if __name__ == '__main__':
475
+ app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)