thejagstudio commited on
Commit
486838c
·
verified ·
1 Parent(s): 3f3700b

Upload 10 files

Browse files
Files changed (10) hide show
  1. Dockerfile +7 -0
  2. app.py +475 -0
  3. checkpoints/model_fp32.pt +3 -0
  4. design.json +185 -0
  5. infer-base.py +778 -0
  6. infer-chat.py +656 -0
  7. requirements.txt +5 -0
  8. static/ai.mp4 +0 -0
  9. static/index.html +156 -0
  10. static/main.js +346 -0
Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM python:3
2
+ WORKDIR /usr/src/app
3
+ COPY requirements.txt ./
4
+ RUN pip install -r requirements.txt
5
+ COPY . .
6
+ EXPOSE 7860
7
+ CMD ["python","./app.py"]
app.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import importlib.util
6
+ from pathlib import Path
7
+ from flask import Flask, request, jsonify, Response, stream_with_context
8
+ from flask_cors import CORS
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ app = Flask(__name__, static_folder='static', static_url_path='/static')
13
+ CORS(app)
14
+
15
+ # Global state
16
+ model = None
17
+ tokenizer = None
18
+ config = None
19
+ device = None
20
+ DiffusionLLM = None
21
+ chat_function = None
22
+
23
+
24
+ def find_file(filename, search_dirs=None):
25
+ """Find a file in current directory or parent directories."""
26
+ if search_dirs is None:
27
+ search_dirs = [
28
+ os.path.dirname(__file__), # Current directory
29
+ os.path.dirname(os.path.dirname(__file__)), # Parent directory
30
+ os.getcwd(), # Working directory
31
+ ]
32
+
33
+ for directory in search_dirs:
34
+ filepath = os.path.join(directory, filename)
35
+ if os.path.exists(filepath):
36
+ print(f"Found {filename} at: {filepath}")
37
+ return filepath
38
+
39
+ return None
40
+
41
+
42
+ def try_import_module(filepath, module_name):
43
+ """Dynamically import a Python file as a module."""
44
+ if not filepath or not os.path.exists(filepath):
45
+ return None
46
+
47
+ try:
48
+ # Add the directory to sys.path
49
+ module_dir = os.path.dirname(filepath)
50
+ if module_dir not in sys.path:
51
+ sys.path.insert(0, module_dir)
52
+
53
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
54
+ if spec is None:
55
+ print(f"Could not create spec for {filepath}")
56
+ return None
57
+
58
+ module = importlib.util.module_from_spec(spec)
59
+ sys.modules[module_name] = module
60
+ spec.loader.exec_module(module)
61
+
62
+ print(f"Successfully imported {module_name} from {filepath}")
63
+ return module
64
+ except Exception as e:
65
+ print(f"Error importing {filepath}: {e}")
66
+ import traceback
67
+ traceback.print_exc()
68
+ return None
69
+
70
+
71
+ def load_model_internal():
72
+ """Load the model and tokenizer."""
73
+ global model, tokenizer, config, device, DiffusionLLM, chat_function
74
+
75
+ if model is not None:
76
+ return True
77
+
78
+ try:
79
+ print("=" * 60)
80
+ print("Starting model loading process...")
81
+ print("=" * 60)
82
+
83
+ # Find and import infer-base.py
84
+ base_path = find_file("infer-base.py")
85
+ if base_path is None:
86
+ raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
87
+
88
+ print(f"\nImporting infer-base.py from: {base_path}")
89
+ base_mod = try_import_module(base_path, "infer_base")
90
+
91
+ if base_mod is None:
92
+ raise RuntimeError("Failed to import infer-base.py")
93
+
94
+ # Check for DiffusionLLM class
95
+ if not hasattr(base_mod, 'DiffusionLLM'):
96
+ print("Available attributes in infer_base:", dir(base_mod))
97
+ raise RuntimeError("DiffusionLLM class not found in infer-base.py")
98
+
99
+ DiffusionLLM = base_mod.DiffusionLLM
100
+ print("✓ Successfully loaded DiffusionLLM class")
101
+
102
+ # Find and import infer-chat.py
103
+ chat_path = find_file("infer-chat.py")
104
+ if chat_path is None:
105
+ raise RuntimeError("Could not find infer-chat.py")
106
+
107
+ print(f"\nImporting infer-chat.py from: {chat_path}")
108
+ chat_mod = try_import_module(chat_path, "infer_chat")
109
+
110
+ if chat_mod is None or not hasattr(chat_mod, 'chat'):
111
+ raise RuntimeError("Failed to import chat function from infer-chat.py")
112
+
113
+ chat_function = chat_mod.chat
114
+ print("✓ Successfully loaded chat function")
115
+
116
+ # Setup pickling workaround for torch.load
117
+ try:
118
+ if hasattr(base_mod, 'ModelConfig'):
119
+ sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
120
+ sys.modules['__main__'].DiffusionLLM = DiffusionLLM
121
+ print("✓ Configured pickle support for model loading")
122
+ except Exception as e:
123
+ print(f"Warning: Could not setup pickle workaround: {e}")
124
+
125
+ # Set device
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ print(f"\n✓ Using device: {device}")
128
+
129
+ # Load tokenizer
130
+ print("\nLoading tokenizer...")
131
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+ print("✓ Tokenizer loaded")
135
+
136
+ # Find model checkpoint
137
+ checkpoint_dirs = [
138
+ "checkpoints",
139
+ "../checkpoints",
140
+ "./checkpoints",
141
+ os.path.join(os.path.dirname(__file__), "checkpoints"),
142
+ os.path.join(os.path.dirname(__file__), "../checkpoints"),
143
+ ]
144
+
145
+ model_path = None
146
+ for checkpoint_dir in checkpoint_dirs:
147
+ best_path = os.path.join(checkpoint_dir, "best_model.pt")
148
+ fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
149
+
150
+ if os.path.exists(best_path):
151
+ model_path = best_path
152
+ break
153
+ elif os.path.exists(fp32_path):
154
+ model_path = fp32_path
155
+ break
156
+
157
+ if model_path is None:
158
+ raise RuntimeError(
159
+ "Could not find model checkpoint. Looking for:\n"
160
+ " - checkpoints/best_model.pt\n"
161
+ " - checkpoints/model_fp32.pt\n"
162
+ f"Searched directories: {checkpoint_dirs}"
163
+ )
164
+
165
+ print(f"\n✓ Found model checkpoint: {model_path}")
166
+ print("Loading model weights (this may take a minute)...")
167
+
168
+ # Load model
169
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
170
+ config = checkpoint['config']
171
+
172
+ print("Creating model...")
173
+ model = DiffusionLLM(config)
174
+
175
+ print("Loading state dict...")
176
+ state_dict = checkpoint['model_state']
177
+ state_dict = {k: v.float() for k, v in state_dict.items()}
178
+ model.load_state_dict(state_dict)
179
+
180
+ model = model.to(device)
181
+ model.eval()
182
+
183
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
184
+ print(f"\n{'=' * 60}")
185
+ print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
186
+ print(f"{'=' * 60}")
187
+ print(f"Parameters: {num_params:.1f}M")
188
+ if 'step' in checkpoint:
189
+ print(f"Training steps: {checkpoint['step']}")
190
+ if 'best_val_loss' in checkpoint:
191
+ print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
192
+ print(f"{'=' * 60}\n")
193
+
194
+ return True
195
+
196
+ except Exception as e:
197
+ print("\n" + "=" * 60)
198
+ print("ERROR LOADING MODEL")
199
+ print("=" * 60)
200
+ print(f"Error: {e}")
201
+ import traceback
202
+ traceback.print_exc()
203
+ print("=" * 60 + "\n")
204
+ return False
205
+
206
+
207
+ def create_streaming_visualizer():
208
+ """Create a visualizer that yields SSE events instead of printing to terminal."""
209
+ def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
210
+ # Normalize inputs to lists
211
+ if not isinstance(mask_blocks, list):
212
+ mask_blocks = [mask_blocks]
213
+ is_masked_list = [is_masked_list]
214
+
215
+ # Decode context
216
+ try:
217
+ context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
218
+ except Exception:
219
+ context_text = str(context_ids[0].tolist())
220
+
221
+ # Build blocks visualization
222
+ all_blocks = []
223
+ for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
224
+ block_tokens = mask_block[0].tolist()
225
+ block_data = []
226
+
227
+ for i, token_id in enumerate(block_tokens):
228
+ if is_masked[0, i]:
229
+ block_data.append({
230
+ 'type': 'masked',
231
+ 'text': '███'
232
+ })
233
+ else:
234
+ try:
235
+ token_text = tok.decode([token_id], skip_special_tokens=False)
236
+ except Exception:
237
+ token_text = str(int(token_id))
238
+ block_data.append({
239
+ 'type': 'revealed',
240
+ 'text': token_text
241
+ })
242
+
243
+ all_blocks.append({
244
+ 'block_index': block_idx,
245
+ 'tokens': block_data
246
+ })
247
+
248
+ # Return data structure that will be sent as SSE
249
+ return {
250
+ 'context': context_text,
251
+ 'blocks': all_blocks,
252
+ 'num_blocks': len(mask_blocks)
253
+ }
254
+
255
+ return visualizer
256
+
257
+
258
+ @app.route('/')
259
+ def index():
260
+ """Serve the main HTML page."""
261
+ return app.send_static_file('index.html')
262
+
263
+
264
+ @app.route('/api/load', methods=['POST'])
265
+ def load_model_endpoint():
266
+ """Load the model."""
267
+ data = request.json or {}
268
+ check_only = data.get('check_only', False)
269
+
270
+ global model
271
+
272
+ if check_only:
273
+ return jsonify({
274
+ 'loaded': model is not None,
275
+ 'message': 'Model is loaded' if model is not None else 'Model not loaded'
276
+ })
277
+
278
+ if model is not None:
279
+ return jsonify({
280
+ 'loaded': True,
281
+ 'message': 'Model already loaded'
282
+ })
283
+
284
+ success = load_model_internal()
285
+
286
+ if success:
287
+ return jsonify({
288
+ 'loaded': True,
289
+ 'message': 'Model loaded successfully'
290
+ })
291
+ else:
292
+ return jsonify({
293
+ 'loaded': False,
294
+ 'message': 'Failed to load model. Check server logs for details.'
295
+ }), 500
296
+
297
+
298
+ @app.route('/api/generate', methods=['POST'])
299
+ def generate():
300
+ """Generate response without streaming."""
301
+ global model, tokenizer, config, device, chat_function
302
+
303
+ if model is None:
304
+ return jsonify({'error': 'Model not loaded'}), 400
305
+
306
+ if chat_function is None:
307
+ return jsonify({'error': 'Chat function not available'}), 400
308
+
309
+ data = request.json
310
+ instruction = data.get('instruction', '')
311
+ steps = data.get('steps', 64)
312
+ block_size = data.get('block_size', 128)
313
+ max_new_tokens = data.get('max_new_tokens', 128)
314
+ parallel_blocks = data.get('parallel_blocks', 1)
315
+
316
+ if not instruction:
317
+ return jsonify({'error': 'No instruction provided'}), 400
318
+
319
+ try:
320
+ # Generate response
321
+ raw_output, response = chat_function(
322
+ model,
323
+ tokenizer,
324
+ instruction,
325
+ steps=steps,
326
+ block_size=block_size,
327
+ max_new_tokens=max_new_tokens,
328
+ temperature=0.8,
329
+ top_k=50,
330
+ top_p=0.9,
331
+ repetition_penalty=1.2,
332
+ no_repeat_ngram_size=3,
333
+ verbose=False,
334
+ visualize_fn=None,
335
+ parallel_blocks=parallel_blocks,
336
+ )
337
+
338
+ return jsonify({
339
+ 'response': response,
340
+ 'raw_output': raw_output
341
+ })
342
+ except Exception as e:
343
+ import traceback
344
+ traceback.print_exc()
345
+ return jsonify({'error': str(e)}), 500
346
+
347
+
348
+ @app.route('/api/generate-stream', methods=['POST'])
349
+ def generate_stream():
350
+ """Generate response with streaming visualization."""
351
+ global model, tokenizer, config, device, chat_function
352
+
353
+ if model is None:
354
+ return jsonify({'error': 'Model not loaded'}), 400
355
+
356
+ if chat_function is None:
357
+ return jsonify({'error': 'Chat function not available'}), 400
358
+
359
+ data = request.json
360
+ instruction = data.get('instruction', '')
361
+ steps = data.get('steps', 64)
362
+ block_size = data.get('block_size', 128)
363
+ max_new_tokens = data.get('max_new_tokens', 128)
364
+ parallel_blocks = data.get('parallel_blocks', 1)
365
+
366
+ if not instruction:
367
+ return jsonify({'error': 'No instruction provided'}), 400
368
+
369
+ def generate_events():
370
+ try:
371
+ # Import threading to allow yielding from callback
372
+ import queue
373
+ event_queue = queue.Queue()
374
+ generation_complete = {'done': False, 'result': None}
375
+
376
+ def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
377
+ """This gets called during generation - we need to send events immediately"""
378
+ visualizer = create_streaming_visualizer()
379
+ data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
380
+ # Put the update in the queue so it can be yielded immediately
381
+ event_queue.put({'type': 'update', 'data': data})
382
+
383
+ # Start generation in a separate thread so we can yield events as they come
384
+ import threading
385
+
386
+ def run_generation():
387
+ try:
388
+ raw_output, response = chat_function(
389
+ model,
390
+ tokenizer,
391
+ instruction,
392
+ steps=steps,
393
+ block_size=block_size,
394
+ max_new_tokens=max_new_tokens,
395
+ temperature=0.8,
396
+ top_k=50,
397
+ top_p=0.9,
398
+ repetition_penalty=1.2,
399
+ no_repeat_ngram_size=3,
400
+ verbose=False,
401
+ visualize_fn=streaming_visualizer,
402
+ parallel_blocks=parallel_blocks,
403
+ )
404
+ generation_complete['result'] = (raw_output, response)
405
+ except Exception as e:
406
+ generation_complete['result'] = ('error', str(e))
407
+ finally:
408
+ generation_complete['done'] = True
409
+ event_queue.put(None) # Signal completion
410
+
411
+ # Start generation thread
412
+ gen_thread = threading.Thread(target=run_generation)
413
+ gen_thread.daemon = True
414
+ gen_thread.start()
415
+
416
+ # Yield start event
417
+ yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
418
+
419
+ # Yield events as they come from the queue
420
+ while not generation_complete['done'] or not event_queue.empty():
421
+ try:
422
+ event = event_queue.get(timeout=0.1)
423
+ if event is None: # Completion signal
424
+ break
425
+ yield f"data: {json.dumps(event)}\n\n"
426
+ except queue.Empty:
427
+ continue
428
+
429
+ # Wait for thread to finish
430
+ gen_thread.join(timeout=1.0)
431
+
432
+ # Send final response
433
+ if generation_complete['result']:
434
+ raw_output, response = generation_complete['result']
435
+ if raw_output == 'error':
436
+ yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
437
+ else:
438
+ yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
439
+
440
+ except Exception as e:
441
+ import traceback
442
+ traceback.print_exc()
443
+ yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
444
+
445
+ return Response(
446
+ stream_with_context(generate_events()),
447
+ mimetype='text/event-stream',
448
+ headers={
449
+ 'Cache-Control': 'no-cache',
450
+ 'X-Accel-Buffering': 'no'
451
+ }
452
+ )
453
+
454
+
455
+ @app.route('/api/test-stream', methods=['GET'])
456
+ def test_stream():
457
+ """Test streaming endpoint."""
458
+ def generate():
459
+ for i in range(10):
460
+ yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
461
+ time.sleep(0.5)
462
+ yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
463
+
464
+ return Response(
465
+ stream_with_context(generate()),
466
+ mimetype='text/event-stream',
467
+ headers={
468
+ 'Cache-Control': 'no-cache',
469
+ 'X-Accel-Buffering': 'no'
470
+ }
471
+ )
472
+
473
+
474
+ if __name__ == '__main__':
475
+ app.run(debug=True, host='0.0.0.0', port=5000, threaded=True)
checkpoints/model_fp32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26b941c479671cff7d0d93fc1d30711ce717de1abedee1e30c0871a4874db79d
3
+ size 491091299
design.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "design_system": {
3
+ "name": "Cortex Luminance System",
4
+ "description": "A physics-based design system combining soft aesthetic minimalism with strict luminance layering. It relies on lighting simulation (top highlights, bottom shadows) rather than diverse hues to create depth hierarchy.",
5
+ "version": "1.0.0",
6
+ "mode": "light",
7
+ "philosophy": {
8
+ "core_principle": "Depth through Luminance",
9
+ "lighting_source": "Top-down (90 degrees)",
10
+ "surface_material": "Matte white & Soft Glass",
11
+ "accent_strategy": "Functional Purple (oklch 0.65 0.22 290)",
12
+ "layering_logic": "Higher elevation = Higher lightness (or pure white) + Stronger Shadow. Lower elevation = Lower lightness + Inset Shadow."
13
+ }
14
+ },
15
+ "tokens": {
16
+ "colors": {
17
+ "primitives": {
18
+ "base_hue": "270 (Purple/Violet)",
19
+ "neutral_hue": "265 (Cool Gray)"
20
+ },
21
+ "layers": {
22
+ "bg_root": {
23
+ "value": "linear-gradient(135deg, oklch(0.95 0.02 270) 0%, oklch(0.92 0.03 290) 100%)",
24
+ "description": "Level 0: The ambient canvas. Corresponds to the blurry cloud/gradient background."
25
+ },
26
+ "bg_layer_1": {
27
+ "value": "oklch(0.99 0.005 265)",
28
+ "description": "Level 1: The main application window/sidebar surface. Almost white."
29
+ },
30
+ "bg_layer_2": {
31
+ "value": "oklch(1.0 0 0)",
32
+ "description": "Level 2: Cards, Floating Inputs, Modals. Pure White."
33
+ },
34
+ "bg_sunken": {
35
+ "value": "oklch(0.96 0.01 265)",
36
+ "description": "For inset elements (search bars, progress tracks). Slightly darker than layer 1 to simulate depth."
37
+ }
38
+ },
39
+ "text": {
40
+ "primary": "oklch(0.20 0.02 265)",
41
+ "secondary": "oklch(0.55 0.03 265)",
42
+ "accent": "oklch(0.65 0.22 290)"
43
+ },
44
+ "borders": {
45
+ "subtle": "rgba(0, 0, 0, 0.06)",
46
+ "highlight": "rgba(255, 255, 255, 0.8)"
47
+ }
48
+ },
49
+ "typography": {
50
+ "font_family": "Inter, SF Pro Display, system-ui, sans-serif",
51
+ "weights": {
52
+ "regular": 400,
53
+ "medium": 500,
54
+ "semibold": 600
55
+ },
56
+ "scale": {
57
+ "h1": {
58
+ "size": "32px",
59
+ "weight": 600,
60
+ "letter_spacing": "-0.02em"
61
+ },
62
+ "h2": {
63
+ "size": "24px",
64
+ "weight": 500,
65
+ "letter_spacing": "-0.01em"
66
+ },
67
+ "body_lg": {
68
+ "size": "16px",
69
+ "weight": 400
70
+ },
71
+ "body_sm": {
72
+ "size": "14px",
73
+ "weight": 400
74
+ },
75
+ "caption": {
76
+ "size": "12px",
77
+ "weight": 500,
78
+ "uppercase": false
79
+ }
80
+ }
81
+ },
82
+ "spacing": {
83
+ "xs": "4px",
84
+ "sm": "8px",
85
+ "md": "16px",
86
+ "lg": "24px",
87
+ "xl": "32px",
88
+ "container_padding": "20px"
89
+ },
90
+ "radii": {
91
+ "sm": "8px",
92
+ "md": "12px",
93
+ "lg": "16px",
94
+ "full": "9999px (Pill)"
95
+ },
96
+ "shadows": {
97
+ "note": "Shadows must imply a top-down light source. Always pair drop-shadows with top-edge inset highlights.",
98
+ "elevation_low": {
99
+ "css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 1px 2px 0 rgba(0, 0, 0, 0.05)",
100
+ "use_case": "Interactive buttons, list items."
101
+ },
102
+ "elevation_medium": {
103
+ "css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03)",
104
+ "use_case": "Standard Cards (Saved Prompts, Suggestions)."
105
+ },
106
+ "elevation_high": {
107
+ "css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 10px 15px -3px rgba(0, 0, 0, 0.08), 0 4px 6px -2px rgba(0, 0, 0, 0.04)",
108
+ "use_case": "Floating Input Area, Modals."
109
+ },
110
+ "inset_sunken": {
111
+ "css_value": "box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06), inset 0 -1px 0 0 rgba(255, 255, 255, 0.5)",
112
+ "use_case": "Search bars, tracks, unselected states."
113
+ }
114
+ }
115
+ },
116
+ "components": {
117
+ "layout_structure": {
118
+ "sidebar": {
119
+ "width": "260px",
120
+ "background": "bg_layer_1",
121
+ "border_right": "1px solid borders.subtle",
122
+ "padding": "md",
123
+ "style": "Flat surface, low contrast."
124
+ },
125
+ "main_area": {
126
+ "background": "bg_layer_2 (with large rounded corners) OR transparent over bg_root",
127
+ "layout": "Flex-col, centered content, maximum width 900px."
128
+ }
129
+ },
130
+ "buttons": {
131
+ "primary": {
132
+ "bg": "black (or dark purple)",
133
+ "text": "white",
134
+ "radius": "md",
135
+ "shadow": "elevation_low",
136
+ "lighting": "Subtle top gradient (lighter top) to show curvature."
137
+ },
138
+ "ghost": {
139
+ "bg": "transparent",
140
+ "hover_bg": "rgba(0,0,0,0.04)",
141
+ "text": "text.secondary"
142
+ },
143
+ "new_chat": {
144
+ "style": "Pill shape / Full radius",
145
+ "bg": "#1A1A1A",
146
+ "text": "white",
147
+ "icon": "plus"
148
+ }
149
+ },
150
+ "cards": {
151
+ "prompt_card": {
152
+ "bg": "bg_layer_2",
153
+ "radius": "lg",
154
+ "shadow": "elevation_medium",
155
+ "border": "1px solid borders.subtle",
156
+ "hover": "Transform Y -2px, increase shadow to elevation_high."
157
+ }
158
+ },
159
+ "inputs": {
160
+ "search_bar": {
161
+ "style": "Sunken / Inset",
162
+ "bg": "bg_sunken",
163
+ "shadow": "inset_sunken",
164
+ "radius": "md",
165
+ "icon_color": "text.secondary"
166
+ },
167
+ "main_prompt_area": {
168
+ "style": "Elevated Container",
169
+ "bg": "white",
170
+ "shadow": "elevation_high",
171
+ "radius": "lg",
172
+ "border": "1px solid rgba(0,0,0,0.04)"
173
+ }
174
+ },
175
+ "navigation_items": {
176
+ "base_style": "text.secondary, font-medium, md padding",
177
+ "active_state": {
178
+ "bg": "bg_layer_2",
179
+ "text": "text.primary",
180
+ "shadow": "elevation_low",
181
+ "radius": "sm"
182
+ }
183
+ }
184
+ }
185
+ }
infer-base.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from dataclasses import dataclass
6
+ import os
7
+ import math
8
+
9
+
10
+ # ============== Model Architecture ==============
11
+
12
+ class RMSNorm(nn.Module):
13
+ """Root Mean Square Layer Normalization."""
14
+
15
+ def __init__(self, dim: int, eps: float = 1e-6):
16
+ super().__init__()
17
+ self.eps = eps
18
+ self.weight = nn.Parameter(torch.ones(dim))
19
+
20
+ def forward(self, x):
21
+ var = x.pow(2).mean(-1, keepdim=True)
22
+ x = x * torch.rsqrt(var + self.eps)
23
+ return self.weight * x
24
+
25
+
26
+ class RotaryEmbedding(nn.Module):
27
+ """Rotary Position Embeddings (RoPE) with NTK extrapolation."""
28
+
29
+ def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
30
+ super().__init__()
31
+ self.scaling_factor = scaling_factor
32
+ self.dim = dim
33
+ self.base = base
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.inv_freq = None
36
+ self._cache = {}
37
+
38
+ def _update_freqs(self, device):
39
+ base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
40
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
41
+ self.inv_freq = inv_freq
42
+
43
+ def forward(self, x, seq_len=None):
44
+ if seq_len is None:
45
+ seq_len = x.shape[-2]
46
+
47
+ if self.inv_freq is None or self.inv_freq.device != x.device:
48
+ self._update_freqs(x.device)
49
+
50
+ cache_key = (seq_len, x.device, x.dtype)
51
+ if cache_key in self._cache:
52
+ return self._cache[cache_key]
53
+
54
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
55
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
56
+ emb = torch.cat((freqs, freqs), dim=-1)
57
+
58
+ cos = emb.cos()[None, None, :, :]
59
+ sin = emb.sin()[None, None, :, :]
60
+
61
+ self._cache[cache_key] = (cos, sin)
62
+ if len(self._cache) > 10:
63
+ self._cache.pop(next(iter(self._cache)))
64
+
65
+ return cos, sin
66
+
67
+
68
+ def apply_rotary_pos_emb(q, k, cos, sin):
69
+ """Apply rotary embeddings to Q and K."""
70
+ def rotate_half(x):
71
+ x1 = x[..., : x.shape[-1] // 2]
72
+ x2 = x[..., x.shape[-1] // 2:]
73
+ return torch.cat((-x2, x1), dim=-1)
74
+
75
+ q_embed = (q * cos) + (rotate_half(q) * sin)
76
+ k_embed = (k * cos) + (rotate_half(k) * sin)
77
+ return q_embed, k_embed
78
+
79
+
80
+ class DiffusionAttention(nn.Module):
81
+ """Multi-head attention with GQA and Flash Attention support."""
82
+
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.hidden_size = config.hidden_size
86
+ self.num_heads = config.num_attention_heads
87
+ self.head_dim = self.hidden_size // self.num_heads
88
+ self.num_key_value_heads = config.num_key_value_heads
89
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
90
+ self.use_flash_attn = config.use_flash_attn
91
+
92
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
93
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
94
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
95
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
96
+
97
+ def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
98
+ bsz, q_len, _ = hidden_states.size()
99
+
100
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
102
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
103
+
104
+ cos, sin = freqs_cis
105
+ cos = cos[:, :, :q_len, :]
106
+ sin = sin[:, :, :q_len, :]
107
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
108
+
109
+ if past_kv is not None:
110
+ cache_k, cache_v = past_kv
111
+ k = torch.cat([cache_k, k], dim=2)
112
+ v = torch.cat([cache_v, v], dim=2)
113
+
114
+ current_kv = (k, v)
115
+
116
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
117
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
118
+
119
+ attn_mask = None
120
+ if attention_mask is not None:
121
+ attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
122
+ attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
123
+
124
+ output = F.scaled_dot_product_attention(
125
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
126
+ )
127
+
128
+ output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
129
+ return self.o_proj(output), current_kv
130
+
131
+
132
+ class MLP(nn.Module):
133
+ """Gated MLP with SiLU activation."""
134
+
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
138
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
139
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
140
+ self.act_fn = nn.SiLU()
141
+
142
+ def forward(self, x):
143
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
144
+
145
+
146
+ class BlockDiffusionBlock(nn.Module):
147
+ """Transformer block with pre-norm, attention, and MLP."""
148
+
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.self_attn = DiffusionAttention(config)
152
+ self.mlp = MLP(config)
153
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
154
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
155
+ self.use_activation_checkpointing = config.use_activation_checkpointing
156
+
157
+ def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
158
+ return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
159
+
160
+ def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
161
+ residual = hidden_states
162
+ hidden_states = self.input_layernorm(hidden_states)
163
+ attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
164
+ hidden_states = residual + attn_out
165
+
166
+ residual = hidden_states
167
+ hidden_states = self.post_attention_layernorm(hidden_states)
168
+ hidden_states = residual + self.mlp(hidden_states)
169
+ return hidden_states, new_kv
170
+
171
+
172
+ @dataclass
173
+ class ModelConfig:
174
+ """Model architecture configuration."""
175
+ vocab_size: int = 151936
176
+ hidden_size: int = 1024
177
+ intermediate_size: int = 2816
178
+ num_hidden_layers: int = 16
179
+ num_attention_heads: int = 16
180
+ num_key_value_heads: int = 4
181
+ max_position_embeddings: int = 16384
182
+ rms_norm_eps: float = 1e-6
183
+ rope_theta: float = 100000.0
184
+ pad_token_id: int = 0
185
+ mask_token_id: int = 1
186
+ use_flash_attn: bool = True
187
+ use_activation_checkpointing: bool = False
188
+ attention_dropout: float = 0.0
189
+ hidden_dropout: float = 0.0
190
+
191
+
192
+ class DiffusionLLM(nn.Module):
193
+ """Complete diffusion language model."""
194
+
195
+ def __init__(self, config: ModelConfig):
196
+ super().__init__()
197
+ self.config = config
198
+
199
+ pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
200
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
201
+
202
+ self.layers = nn.ModuleList([BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
203
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
204
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
205
+ self.rotary_emb = RotaryEmbedding(
206
+ config.hidden_size // config.num_attention_heads,
207
+ config.max_position_embeddings
208
+ )
209
+
210
+ self.lm_head.weight = self.embed_tokens.weight
211
+
212
+ def forward(self, input_ids, attention_mask=None, past_key_values=None):
213
+ bsz, seqlen = input_ids.shape
214
+ hidden_states = self.embed_tokens(input_ids)
215
+ freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
216
+
217
+ if past_key_values is None:
218
+ past_key_values = [None] * len(self.layers)
219
+
220
+ new_kvs = []
221
+ for i, layer in enumerate(self.layers):
222
+ hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
223
+ new_kvs.append(kv)
224
+
225
+ hidden_states = self.norm(hidden_states)
226
+ logits = self.lm_head(hidden_states)
227
+ return logits, new_kvs
228
+
229
+ def get_num_params(self, trainable_only=True):
230
+ if trainable_only:
231
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
232
+ else:
233
+ return sum(p.numel() for p in self.parameters())
234
+
235
+
236
+ # ============== Inference Functions ==============
237
+
238
+ def load_model(model_path: str, device: str = 'cuda'):
239
+ """Load a saved model (fp16 or fp32) for inference."""
240
+ print(f"Loading model from {model_path}...")
241
+
242
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
243
+ config = checkpoint['config']
244
+
245
+ model = DiffusionLLM(config)
246
+
247
+ state_dict = checkpoint['model_state']
248
+ state_dict = {k: v.float() for k, v in state_dict.items()}
249
+ model.load_state_dict(state_dict)
250
+
251
+ model = model.to(device)
252
+ model.eval()
253
+
254
+ num_params = model.get_num_params() / 1e6
255
+ file_size = os.path.getsize(model_path) / 1e6
256
+ print(f"✓ Model loaded: {num_params:.1f}M params from {file_size:.1f} MB file")
257
+
258
+ return model, config
259
+
260
+
261
+ def visualize_diffusion_state(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
262
+ """Visualize the current state of diffusion generation with multiple blocks.
263
+
264
+ Args:
265
+ mask_blocks: Either a single block tensor (1, block_size) or list of block tensors
266
+ is_masked_list: Either a single mask tensor (1, block_size) or list of mask tensors
267
+ block_colors: List of ANSI color codes for each block. If None, uses defaults.
268
+ """
269
+ import sys
270
+ import os
271
+
272
+ # Default colors for different blocks (green, cyan, yellow, magenta)
273
+ DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
274
+ MASK_COLOR = '\033[90m' # Gray for masked tokens
275
+ RESET = '\033[0m'
276
+
277
+ # Normalize inputs to lists
278
+ if not isinstance(mask_blocks, list):
279
+ mask_blocks = [mask_blocks]
280
+ is_masked_list = [is_masked_list]
281
+
282
+ if block_colors is None:
283
+ block_colors = DEFAULT_COLORS
284
+
285
+ # Decode context (prompt + previously generated blocks) and replace newlines
286
+ context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
287
+
288
+ # Build visualization for all blocks
289
+ all_blocks_text = []
290
+ for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
291
+ color = block_colors[block_idx % len(block_colors)]
292
+ block_tokens = mask_block[0].tolist()
293
+ block_color_tokens = []
294
+
295
+ for i, token_id in enumerate(block_tokens):
296
+ if is_masked[0, i]:
297
+ # Use block-specific color for masked tokens to distinguish blocks
298
+ block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
299
+ else:
300
+ # Decode individual token; use block color for revealed tokens
301
+ token_text = tokenizer.decode([token_id], skip_special_tokens=False)
302
+ block_color_tokens.append(f'{color}{token_text}{RESET}')
303
+
304
+ all_blocks_text.append(''.join(block_color_tokens))
305
+
306
+ # Join all blocks with a subtle separator
307
+ blocks_combined = ''.join(all_blocks_text)
308
+
309
+ # Clear entire terminal
310
+ if clear:
311
+ clear_cmd = 'cls' if os.name == 'nt' else 'clear'
312
+ try:
313
+ os.system(clear_cmd)
314
+ except Exception:
315
+ sys.stdout.write('\r\033[K')
316
+
317
+ # Print legend for parallel blocks
318
+ if len(mask_blocks) > 1:
319
+ legend_parts = []
320
+ for i in range(len(mask_blocks)):
321
+ color = block_colors[i % len(block_colors)]
322
+ legend_parts.append(f'{color}Block {i+1}{RESET}')
323
+ print(f"Generating: {' | '.join(legend_parts)}\n")
324
+
325
+ # Print the full context with colored blocks
326
+ print(f"{context_text}{blocks_combined}", flush=True)
327
+
328
+
329
+ def demo_visualize_truncation():
330
+ """Demo for visualize_diffusion_state without a full model.
331
+ Simulates streaming output and verifies there is no line duplication when content exceeds terminal width.
332
+ """
333
+ class MockTokenizer:
334
+ def __init__(self):
335
+ # Map token id to token text (simple ASCII characters and spaces)
336
+ self.vocab = {i: chr(65 + (i % 26)) for i in range(256)}
337
+ self.vocab[32] = ' '
338
+ self.eos_token = '\n'
339
+ self.pad_token = ' '
340
+
341
+ def decode(self, ids, skip_special_tokens=True):
342
+ # ids can be tensor or list
343
+ if isinstance(ids, torch.Tensor):
344
+ ids = ids.tolist()
345
+ if isinstance(ids, (list, tuple)):
346
+ return ''.join(self.vocab.get(int(i) % 256, '?') for i in ids)
347
+ return str(ids)
348
+
349
+ tok = MockTokenizer()
350
+ # Create a long context and a block that's also long
351
+ # Make context exceed terminal width
352
+ term_width = 80
353
+ long_context_ids = torch.tensor([[i % 26 + 65 for i in range(120)]], dtype=torch.long)
354
+ block_size = 32
355
+ mask_block = torch.full((1, block_size), 32, dtype=torch.long) # spaces
356
+ is_masked = torch.ones(1, block_size, dtype=torch.bool)
357
+ for i in range(0, block_size, 3):
358
+ is_masked[0, i] = False
359
+ mask_block[0, i] = 65 + (i % 26)
360
+
361
+ print('\nRunning demo: long prompt + block to test truncation\n')
362
+ for i in range(8):
363
+ visualize_diffusion_state(tok, long_context_ids, [mask_block], [is_masked], ModelConfig(), clear=(i > 0))
364
+ # rotate some tokens to simulate diffusion
365
+ mask_block = torch.roll(mask_block, shifts=1, dims=1)
366
+ time_delay = 0.08
367
+ try:
368
+ import time
369
+ time.sleep(time_delay)
370
+ except Exception:
371
+ pass
372
+ print('\n\nDemo completed.')
373
+
374
+
375
+ @torch.no_grad()
376
+ def generate_block_diffusion(
377
+ model,
378
+ tokenizer,
379
+ prompt: str,
380
+ steps: int = 16,
381
+ block_size: int = 64,
382
+ max_new_tokens: int = 256,
383
+ device: str = 'cuda',
384
+ temperature: float = 1.0,
385
+ top_k: int = 50,
386
+ top_p: float = 0.9,
387
+ repetition_penalty: float = 1.2,
388
+ no_repeat_ngram_size: int = 3,
389
+ visualize: bool = False,
390
+ parallel_blocks: int = 1, # Number of blocks to generate in parallel
391
+ ):
392
+ """Generate text using block diffusion with proper sampling and repetition control.
393
+
394
+ Args:
395
+ visualize: If True, stream output in real-time showing the diffusion effect.
396
+ parallel_blocks: Number of blocks to generate in parallel (1-4 recommended).
397
+ """
398
+ import time
399
+ model.eval()
400
+
401
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
402
+
403
+ config = model.module.config if hasattr(model, 'module') else model.config
404
+ if hasattr(model, '_orig_mod'):
405
+ config = model._orig_mod.config
406
+
407
+ num_blocks = max_new_tokens // block_size
408
+ parallel_blocks = min(parallel_blocks, num_blocks) # Can't parallelize more than total blocks
409
+
410
+ if not visualize:
411
+ if parallel_blocks > 1:
412
+ print(f"Generating {num_blocks} blocks of {block_size} tokens each ({parallel_blocks} blocks in parallel)...")
413
+ else:
414
+ print(f"Generating {num_blocks} blocks of {block_size} tokens each...")
415
+ else:
416
+ print(f"\n\033[94mStarting diffusion generation...\033[0m\n")
417
+ print(prompt, end='', flush=True)
418
+
419
+ context_ids = prompt_ids
420
+ all_generated_tokens = set(prompt_ids[0].tolist())
421
+
422
+ # Process blocks in batches of parallel_blocks
423
+ blocks_generated = 0
424
+ while blocks_generated < num_blocks:
425
+ # Determine how many blocks to generate this iteration
426
+ current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
427
+
428
+ if current_parallel > 1:
429
+ # Parallel block generation
430
+ generated_blocks = _generate_parallel_blocks(
431
+ model, tokenizer, context_ids, config, device,
432
+ current_parallel, block_size, steps, temperature,
433
+ top_k, top_p, repetition_penalty, no_repeat_ngram_size,
434
+ all_generated_tokens, visualize
435
+ )
436
+
437
+ # Concatenate all generated blocks to context
438
+ for block in generated_blocks:
439
+ context_ids = torch.cat([context_ids, block], dim=1)
440
+ all_generated_tokens.update(block[0].tolist())
441
+
442
+ if not visualize:
443
+ print(f" Blocks {blocks_generated + 1}-{blocks_generated + current_parallel}/{num_blocks} complete")
444
+ blocks_generated += current_parallel
445
+ else:
446
+ # Single block generation (original logic)
447
+ mask_block, block_token_history = _generate_single_block(
448
+ model, tokenizer, context_ids, config, device,
449
+ block_size, steps, temperature, top_k, top_p,
450
+ repetition_penalty, no_repeat_ngram_size,
451
+ all_generated_tokens, visualize
452
+ )
453
+
454
+ context_ids = torch.cat([context_ids, mask_block], dim=1)
455
+ all_generated_tokens.update(mask_block[0].tolist())
456
+
457
+ if not visualize:
458
+ print(f" Block {blocks_generated + 1}/{num_blocks} complete")
459
+ blocks_generated += 1
460
+
461
+ if visualize:
462
+ # Final newline after visualization
463
+ print("\n")
464
+
465
+ generated_ids = context_ids[0].tolist()
466
+ return tokenizer.decode(generated_ids, skip_special_tokens=True)
467
+
468
+
469
+ def _generate_single_block(
470
+ model, tokenizer, context_ids, config, device,
471
+ block_size, steps, temperature, top_k, top_p,
472
+ repetition_penalty, no_repeat_ngram_size,
473
+ all_generated_tokens, visualize
474
+ ):
475
+ """Generate a single block using diffusion."""
476
+ mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
477
+ is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
478
+ block_token_history = []
479
+
480
+ for step_idx in range(steps):
481
+ full_input = torch.cat([context_ids, mask_block], dim=1)
482
+ attention_mask = torch.ones_like(full_input, dtype=torch.float32)
483
+
484
+ logits, _ = model(full_input, attention_mask=attention_mask)
485
+ block_logits = logits[:, -block_size:, :]
486
+
487
+ block_logits = _apply_sampling_controls(
488
+ block_logits, context_ids, mask_block, is_masked,
489
+ repetition_penalty, temperature, top_k, top_p,
490
+ no_repeat_ngram_size, block_token_history
491
+ )
492
+
493
+ probs = F.softmax(block_logits, dim=-1)
494
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
495
+ probs = probs.clamp(min=1e-10)
496
+ probs = probs / probs.sum(dim=-1, keepdim=True)
497
+
498
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
499
+ sampled_tokens = sampled_tokens.view(1, block_size)
500
+
501
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
502
+
503
+ tokens_to_unmask = max(1, block_size // steps)
504
+ if step_idx == steps - 1:
505
+ tokens_to_unmask = is_masked.sum().item()
506
+
507
+ if tokens_to_unmask > 0 and is_masked.sum() > 0:
508
+ masked_confidence = confidence.clone()
509
+ masked_confidence[~is_masked] = -1.0
510
+
511
+ num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
512
+ _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
513
+
514
+ for idx in top_indices:
515
+ mask_block[0, idx] = sampled_tokens[0, idx]
516
+ is_masked[0, idx] = False
517
+ block_token_history.append(sampled_tokens[0, idx].item())
518
+ all_generated_tokens.add(sampled_tokens[0, idx].item())
519
+
520
+ if visualize:
521
+ visualize_diffusion_state(tokenizer, context_ids, [mask_block], [is_masked], config, clear=(step_idx > 0))
522
+
523
+ return mask_block, block_token_history
524
+
525
+
526
+ def _generate_parallel_blocks(
527
+ model, tokenizer, context_ids, config, device,
528
+ num_parallel, block_size, steps, temperature,
529
+ top_k, top_p, repetition_penalty, no_repeat_ngram_size,
530
+ all_generated_tokens, visualize
531
+ ):
532
+ """Generate multiple blocks in parallel using batched computation.
533
+
534
+ Each block sees all previous blocks in the sequence, maintaining proper order:
535
+ - Block 0: context + [block0]
536
+ - Block 1: context + [block0] + [block1]
537
+ - Block 2: context + [block0] + [block1] + [block2]
538
+ - etc.
539
+
540
+ This ensures sequential coherence while still benefiting from batched computation.
541
+ """
542
+ batch_size = num_parallel
543
+ context_len = context_ids.shape[1]
544
+
545
+ # Initialize mask blocks for all parallel blocks
546
+ # Shape: (num_parallel, block_size)
547
+ mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
548
+ is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
549
+ block_token_histories = [[] for _ in range(batch_size)]
550
+
551
+ for step_idx in range(steps):
552
+ # Build inputs with proper sequential structure
553
+ # Each batch item has context + all blocks up to and including its own position
554
+ # Block i sees: context + block_0 + block_1 + ... + block_i
555
+
556
+ # Create padded inputs - each batch item has different length
557
+ # We'll pad to the longest sequence (which is the last block)
558
+ max_seq_len = context_len + (num_parallel * block_size)
559
+
560
+ # Build full input for each batch item
561
+ full_inputs = []
562
+ attention_masks = []
563
+
564
+ for b in range(batch_size):
565
+ # This block sees: context + all previous blocks + its own block
566
+ seq_parts = [context_ids[0]] # Start with context
567
+
568
+ # Add all blocks from 0 to b (inclusive)
569
+ for prev_b in range(b + 1):
570
+ seq_parts.append(mask_blocks[prev_b])
571
+
572
+ # Concatenate to form this batch item's input
573
+ batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
574
+ current_len = batch_input.shape[0]
575
+
576
+ # Pad to max_seq_len
577
+ padding_needed = max_seq_len - current_len
578
+ if padding_needed > 0:
579
+ padding = torch.full((padding_needed,), config.pad_token_id, device=device)
580
+ batch_input = torch.cat([batch_input, padding], dim=0)
581
+
582
+ full_inputs.append(batch_input)
583
+
584
+ # Create attention mask (1 for real tokens, 0 for padding)
585
+ attn_mask = torch.zeros(max_seq_len, device=device)
586
+ attn_mask[:current_len] = 1.0
587
+ attention_masks.append(attn_mask)
588
+
589
+ # Stack into batched tensors
590
+ full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
591
+ attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
592
+
593
+ # Single forward pass for all blocks
594
+ logits, _ = model(full_input, attention_mask=attention_mask)
595
+
596
+ # Extract logits for each block's position
597
+ # Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
598
+ block_logits_list = []
599
+ for b in range(batch_size):
600
+ start_pos = context_len + (b * block_size)
601
+ end_pos = start_pos + block_size
602
+ block_logits_list.append(logits[b, start_pos:end_pos, :])
603
+
604
+ block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
605
+
606
+ # Apply sampling controls per batch item
607
+ for b in range(batch_size):
608
+ # Build context that includes previous blocks for repetition penalty
609
+ extended_context = context_ids
610
+ if b > 0:
611
+ prev_blocks = torch.cat([mask_blocks[pb:pb+1] for pb in range(b)], dim=1)
612
+ extended_context = torch.cat([context_ids, prev_blocks], dim=1)
613
+
614
+ block_logits[b:b+1] = _apply_sampling_controls(
615
+ block_logits[b:b+1],
616
+ extended_context,
617
+ mask_blocks[b:b+1],
618
+ is_masked[b:b+1],
619
+ repetition_penalty, temperature, top_k, top_p,
620
+ no_repeat_ngram_size, block_token_histories[b]
621
+ )
622
+
623
+ probs = F.softmax(block_logits, dim=-1)
624
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
625
+ probs = probs.clamp(min=1e-10)
626
+ probs = probs / probs.sum(dim=-1, keepdim=True)
627
+
628
+ # Sample for all batches
629
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
630
+ sampled_tokens = sampled_tokens.view(batch_size, block_size)
631
+
632
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
633
+
634
+ tokens_to_unmask = max(1, block_size // steps)
635
+ if step_idx == steps - 1:
636
+ tokens_to_unmask = block_size # Unmask all remaining
637
+
638
+ # Unmask for each batch item
639
+ for b in range(batch_size):
640
+ if is_masked[b].sum() > 0:
641
+ masked_confidence = confidence[b].clone()
642
+ masked_confidence[~is_masked[b]] = -1.0
643
+
644
+ num_to_unmask = min(tokens_to_unmask, is_masked[b].sum().item())
645
+ if num_to_unmask > 0:
646
+ _, top_indices = torch.topk(masked_confidence, num_to_unmask)
647
+
648
+ for idx in top_indices:
649
+ mask_blocks[b, idx] = sampled_tokens[b, idx]
650
+ is_masked[b, idx] = False
651
+ block_token_histories[b].append(sampled_tokens[b, idx].item())
652
+
653
+ if visualize:
654
+ # Visualize all blocks with different colors
655
+ block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
656
+ is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
657
+ visualize_diffusion_state(
658
+ tokenizer, context_ids, block_list, is_masked_list,
659
+ config, clear=(step_idx > 0)
660
+ )
661
+
662
+ # Return list of generated blocks
663
+ return [mask_blocks[b:b+1] for b in range(batch_size)]
664
+
665
+
666
+ def _apply_sampling_controls(
667
+ block_logits, context_ids, mask_block, is_masked,
668
+ repetition_penalty, temperature, top_k, top_p,
669
+ no_repeat_ngram_size, block_token_history
670
+ ):
671
+ """Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
672
+ if repetition_penalty != 1.0:
673
+ seen_tokens = set(context_ids[0].tolist())
674
+ for i in range(mask_block.shape[1]):
675
+ if not is_masked[0, i]:
676
+ seen_tokens.add(mask_block[0, i].item())
677
+
678
+ for token_id in seen_tokens:
679
+ if token_id < block_logits.shape[-1]:
680
+ if block_logits[0, :, token_id].mean() > 0:
681
+ block_logits[:, :, token_id] /= repetition_penalty
682
+ else:
683
+ block_logits[:, :, token_id] *= repetition_penalty
684
+
685
+ block_logits = block_logits / temperature
686
+
687
+ if top_k > 0:
688
+ top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
689
+ block_logits = torch.full_like(block_logits, float('-inf'))
690
+ block_logits.scatter_(-1, top_k_indices, top_k_logits)
691
+
692
+ if top_p < 1.0:
693
+ sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
694
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
695
+
696
+ sorted_indices_to_remove = cumulative_probs > top_p
697
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
698
+ sorted_indices_to_remove[..., 0] = 0
699
+
700
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
701
+ block_logits[indices_to_remove] = float('-inf')
702
+
703
+ if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
704
+ recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size-1):])
705
+ full_history = context_ids[0].tolist() + block_token_history
706
+ for i in range(len(full_history) - no_repeat_ngram_size + 1):
707
+ if tuple(full_history[i:i+no_repeat_ngram_size-1]) == recent_ngram:
708
+ blocked_token = full_history[i + no_repeat_ngram_size - 1]
709
+ if blocked_token < block_logits.shape[-1]:
710
+ block_logits[:, :, blocked_token] = float('-inf')
711
+
712
+ # Safety check: if all logits are -inf, reset to uniform distribution
713
+ all_inf_mask = torch.isinf(block_logits).all(dim=-1)
714
+ if all_inf_mask.any():
715
+ block_logits[all_inf_mask] = 0.0
716
+
717
+ return block_logits
718
+
719
+
720
+ # ============== Main Entry Point ==============
721
+
722
+ def main():
723
+ """Main inference function."""
724
+ # Configuration
725
+ model_path = "../extra-final-boss/checkpoints/model_fp32.pt"
726
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
727
+
728
+ print(f"Using device: {device}")
729
+
730
+ # Allow a quick demo mode to test visualization without loading the model
731
+ import sys
732
+ if len(sys.argv) > 1 and sys.argv[1] == 'demo':
733
+ demo_visualize_truncation()
734
+ return
735
+
736
+ # Load tokenizer
737
+ print("Loading tokenizer...")
738
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
739
+ if tokenizer.pad_token is None:
740
+ tokenizer.pad_token = tokenizer.eos_token
741
+
742
+ # Load model
743
+ model, config = load_model(model_path, device)
744
+
745
+ # Generate text
746
+ print("\n" + "=" * 50)
747
+ print("Text Generation")
748
+ print("=" * 50)
749
+
750
+ prompt = "Barrack Obama was born in "
751
+ print(f"Prompt: {prompt}\n")
752
+
753
+ # Set visualize=True to see real-time diffusion effect
754
+ visualize = True
755
+ parallel_blocks = 4 # Generate 2-4 blocks in parallel for speedup
756
+
757
+ generated = generate_block_diffusion(
758
+ model,
759
+ tokenizer,
760
+ prompt=prompt,
761
+ steps=64,
762
+ block_size=64,
763
+ max_new_tokens=512,
764
+ device=device,
765
+ temperature=1,
766
+ top_k=40,
767
+ top_p=0.9,
768
+ repetition_penalty=1.3,
769
+ no_repeat_ngram_size=3,
770
+ visualize=visualize,
771
+ parallel_blocks=parallel_blocks,
772
+ )
773
+
774
+ print(f"\nGenerated text:\n{generated}")
775
+
776
+
777
+ if __name__ == "__main__":
778
+ main()
infer-chat.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import argparse
5
+ import importlib.util
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer
10
+
11
+ # Tracks how many lines the last visualization printed so we can overwrite it
12
+ _visualize_last_lines = 0
13
+
14
+
15
+ def try_import_infer_base(base_path: str):
16
+ """Dynamically import `infer-base.py` as a module and return it, or None on failure."""
17
+ if not os.path.exists(base_path):
18
+ return None
19
+ try:
20
+ spec = importlib.util.spec_from_file_location("infer_base", base_path)
21
+ module = importlib.util.module_from_spec(spec)
22
+ spec.loader.exec_module(module)
23
+ return module
24
+ except Exception as e:
25
+ print(f"Warning: failed to import {base_path}: {e}")
26
+ return None
27
+
28
+
29
+ def load_finetuned_model(model_path: str, device: str = 'cuda'):
30
+ """Load a saved fine-tuned model for inference."""
31
+ print(f"Loading model from {model_path}...")
32
+
33
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
34
+ config = checkpoint['config']
35
+
36
+ # Create model
37
+ model = DiffusionLLM(config)
38
+
39
+ # Load weights
40
+ state_dict = checkpoint['model_state']
41
+ state_dict = {k: v.float() for k, v in state_dict.items()}
42
+ model.load_state_dict(state_dict)
43
+
44
+ model = model.to(device)
45
+ model.eval()
46
+
47
+ num_params = sum(p.numel() for p in model.parameters()) / 1e6
48
+ print(f"✓ Loaded model: {num_params:.1f}M parameters")
49
+
50
+ # Print training info if available
51
+ if 'step' in checkpoint:
52
+ print(f" Trained for {checkpoint['step']} steps")
53
+ if 'best_val_loss' in checkpoint:
54
+ print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
55
+
56
+ return model, config
57
+
58
+
59
+ @torch.no_grad()
60
+ def generate_block_diffusion(
61
+ model,
62
+ tokenizer,
63
+ prompt: str,
64
+ steps: int = 32,
65
+ block_size: int = 32,
66
+ max_new_tokens: int = 128,
67
+ device: str = 'cuda',
68
+ temperature: float = 0.8,
69
+ top_k: int = 50,
70
+ top_p: float = 0.9,
71
+ repetition_penalty: float = 1.2,
72
+ no_repeat_ngram_size: int = 3,
73
+ verbose: bool = True,
74
+ visualize_fn=None,
75
+ parallel_blocks: int = 1,
76
+ ):
77
+ """
78
+ Generate text using block diffusion with sampling controls.
79
+
80
+ If `visualize_fn` is provided it will be called as:
81
+ visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=True)
82
+
83
+ Returns the decoded generated string (including prompt).
84
+ """
85
+ model.eval()
86
+
87
+ # Encode prompt
88
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
89
+
90
+ # Get model config
91
+ config = model.module.config if hasattr(model, 'module') else getattr(model, 'config', None)
92
+ if hasattr(model, '_orig_mod'):
93
+ config = model._orig_mod.config
94
+
95
+ if config is None:
96
+ raise RuntimeError("Could not determine model config")
97
+
98
+ num_blocks = max_new_tokens // block_size
99
+ parallel_blocks = min(parallel_blocks, num_blocks)
100
+
101
+ if verbose:
102
+ print(f"Generating {num_blocks} blocks of {block_size} tokens ({max_new_tokens} max_new_tokens)\n")
103
+
104
+ context_ids = prompt_ids
105
+ all_generated_tokens = set(prompt_ids[0].tolist())
106
+
107
+ blocks_generated = 0
108
+ while blocks_generated < num_blocks:
109
+ current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
110
+
111
+ if current_parallel > 1:
112
+ new_blocks = _generate_parallel_blocks(
113
+ model, tokenizer, context_ids, config, device,
114
+ current_parallel, block_size, steps, temperature,
115
+ top_k, top_p, repetition_penalty, no_repeat_ngram_size,
116
+ all_generated_tokens, visualize_fn
117
+ )
118
+ for block in new_blocks:
119
+ context_ids = torch.cat([context_ids, block], dim=1)
120
+ blocks_generated += 1
121
+ else:
122
+ mask_block, block_token_history = _generate_single_block(
123
+ model, tokenizer, context_ids, config, device,
124
+ block_size, steps, temperature, top_k, top_p,
125
+ repetition_penalty, no_repeat_ngram_size,
126
+ all_generated_tokens, visualize_fn
127
+ )
128
+ context_ids = torch.cat([context_ids, mask_block], dim=1)
129
+ blocks_generated += 1
130
+
131
+ generated_ids = context_ids[0].tolist()
132
+ return tokenizer.decode(generated_ids, skip_special_tokens=False)
133
+
134
+
135
+ def _apply_sampling_controls(
136
+ block_logits, context_ids, mask_block, is_masked,
137
+ repetition_penalty, temperature, top_k, top_p,
138
+ no_repeat_ngram_size, block_token_history
139
+ ):
140
+ """Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
141
+ if repetition_penalty != 1.0:
142
+ seen_tokens = set(context_ids[0].tolist())
143
+ for i in range(mask_block.shape[1]):
144
+ if not is_masked[0, i]:
145
+ seen_tokens.add(mask_block[0, i].item())
146
+
147
+ for token_id in seen_tokens:
148
+ if token_id < block_logits.shape[-1]:
149
+ avg = block_logits[0, :, token_id].mean()
150
+ if avg > 0:
151
+ block_logits[:, :, token_id] /= repetition_penalty
152
+ else:
153
+ block_logits[:, :, token_id] *= repetition_penalty
154
+
155
+ block_logits = block_logits / temperature
156
+
157
+ if top_k > 0:
158
+ k = min(top_k, block_logits.size(-1))
159
+ top_k_logits, top_k_indices = torch.topk(block_logits, k, dim=-1)
160
+ filtered = torch.full_like(block_logits, float('-inf'))
161
+ filtered.scatter_(-1, top_k_indices, top_k_logits)
162
+ block_logits = filtered
163
+
164
+ if top_p < 1.0:
165
+ sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
166
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
167
+
168
+ sorted_indices_to_remove = cumulative_probs > top_p
169
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
170
+ sorted_indices_to_remove[..., 0] = 0
171
+
172
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
173
+ block_logits[indices_to_remove] = float('-inf')
174
+
175
+ if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
176
+ recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size - 1):])
177
+ full_history = context_ids[0].tolist() + block_token_history
178
+ for i in range(len(full_history) - no_repeat_ngram_size + 1):
179
+ if tuple(full_history[i:i + no_repeat_ngram_size - 1]) == recent_ngram:
180
+ blocked_token = full_history[i + no_repeat_ngram_size - 1]
181
+ if blocked_token < block_logits.shape[-1]:
182
+ block_logits[:, :, blocked_token] = float('-inf')
183
+
184
+ # Safety: reset if all logits are -inf
185
+ all_inf_mask = torch.isinf(block_logits).all(dim=-1)
186
+ if all_inf_mask.any():
187
+ block_logits[all_inf_mask] = 0.0
188
+
189
+ return block_logits
190
+
191
+
192
+ def _generate_single_block(
193
+ model, tokenizer, context_ids, config, device,
194
+ block_size, steps, temperature, top_k, top_p,
195
+ repetition_penalty, no_repeat_ngram_size,
196
+ all_generated_tokens, visualize_fn=None
197
+ ):
198
+ """Generate a single block using diffusion."""
199
+ mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
200
+ is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
201
+ block_token_history = []
202
+
203
+ for step_idx in range(steps):
204
+ full_input = torch.cat([context_ids, mask_block], dim=1)
205
+ attention_mask = torch.ones_like(full_input, dtype=torch.float32)
206
+
207
+ logits, _ = model(full_input, attention_mask=attention_mask)
208
+ block_logits = logits[:, -block_size:, :]
209
+
210
+ block_logits = _apply_sampling_controls(
211
+ block_logits, context_ids, mask_block, is_masked,
212
+ repetition_penalty, temperature, top_k, top_p,
213
+ no_repeat_ngram_size, block_token_history
214
+ )
215
+
216
+ probs = F.softmax(block_logits, dim=-1)
217
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
218
+ probs = probs.clamp(min=1e-10)
219
+ probs = probs / probs.sum(dim=-1, keepdim=True)
220
+
221
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
222
+ sampled_tokens = sampled_tokens.view(1, block_size)
223
+
224
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
225
+
226
+ tokens_to_unmask = max(1, block_size // steps)
227
+ if step_idx == steps - 1:
228
+ tokens_to_unmask = int(is_masked.sum().item())
229
+
230
+ if tokens_to_unmask > 0 and is_masked.sum() > 0:
231
+ masked_confidence = confidence.clone()
232
+ masked_confidence[~is_masked] = -1.0
233
+
234
+ num_to_unmask = min(int(tokens_to_unmask), int(is_masked.sum().item()))
235
+ _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
236
+
237
+ for idx in top_indices:
238
+ idx = int(idx.item())
239
+ mask_block[0, idx] = sampled_tokens[0, idx]
240
+ is_masked[0, idx] = False
241
+ block_token_history.append(sampled_tokens[0, idx].item())
242
+ all_generated_tokens.add(sampled_tokens[0, idx].item())
243
+
244
+ if callable(visualize_fn):
245
+ try:
246
+ visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
247
+ except Exception:
248
+ pass
249
+ elif visualize_fn:
250
+ visualize_diffusion_state_local(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
251
+
252
+ return mask_block, block_token_history
253
+
254
+
255
+ def _generate_parallel_blocks(
256
+ model, tokenizer, context_ids, config, device,
257
+ num_parallel, block_size, steps, temperature,
258
+ top_k, top_p, repetition_penalty, no_repeat_ngram_size,
259
+ all_generated_tokens, visualize_fn=None
260
+ ):
261
+ """Generate multiple blocks in parallel using batched computation.
262
+
263
+ Each block sees all previous blocks in the sequence, maintaining proper order:
264
+ - Block 0: context + [block0]
265
+ - Block 1: context + [block0] + [block1]
266
+ - Block 2: context + [block0] + [block1] + [block2]
267
+ - etc.
268
+
269
+ This ensures sequential coherence while still benefiting from batched computation.
270
+ """
271
+ batch_size = num_parallel
272
+ context_len = context_ids.shape[1]
273
+
274
+ # Initialize mask blocks for all parallel blocks
275
+ # Shape: (num_parallel, block_size)
276
+ mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
277
+ is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
278
+ block_token_histories = [[] for _ in range(batch_size)]
279
+
280
+ for step_idx in range(steps):
281
+ # Build inputs with proper sequential structure
282
+ # Each batch item has context + all previous blocks + its own block
283
+ # Block i sees: context + block_0 + block_1 + ... + block_i
284
+
285
+ # Create padded inputs - each batch item has different length
286
+ # We'll pad to the longest sequence (which is the last block)
287
+ max_seq_len = context_len + (num_parallel * block_size)
288
+
289
+ # Build full input for each batch item
290
+ full_inputs = []
291
+ attention_masks = []
292
+
293
+ for b in range(batch_size):
294
+ # This block sees: context + all previous blocks + its own block
295
+ seq_parts = [context_ids[0]] # Start with context
296
+
297
+ # Add all blocks from 0 to b (inclusive)
298
+ for prev_b in range(b + 1):
299
+ seq_parts.append(mask_blocks[prev_b])
300
+
301
+ # Concatenate to form this batch item's input
302
+ batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
303
+ current_len = batch_input.shape[0]
304
+
305
+ # Pad to max_seq_len
306
+ padding_needed = max_seq_len - current_len
307
+ if padding_needed > 0:
308
+ pad_token = config.pad_token_id if config.pad_token_id is not None else 0
309
+ padding = torch.full((padding_needed,), pad_token, device=device)
310
+ batch_input = torch.cat([batch_input, padding], dim=0)
311
+
312
+ full_inputs.append(batch_input)
313
+
314
+ # Create attention mask (1 for real tokens, 0 for padding)
315
+ attn_mask = torch.zeros(max_seq_len, device=device)
316
+ attn_mask[:current_len] = 1.0
317
+ attention_masks.append(attn_mask)
318
+
319
+ # Stack into batched tensors
320
+ full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
321
+ attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
322
+
323
+ # Single forward pass for all blocks
324
+ logits, _ = model(full_input, attention_mask=attention_mask)
325
+
326
+ # Extract logits for each block's position
327
+ # Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
328
+ block_logits_list = []
329
+ for b in range(batch_size):
330
+ start_pos = context_len + (b * block_size)
331
+ end_pos = start_pos + block_size
332
+ block_logits_list.append(logits[b, start_pos:end_pos, :])
333
+
334
+ block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
335
+
336
+ # Apply sampling controls per batch item
337
+ for b in range(batch_size):
338
+ # Build context that includes previous blocks for repetition penalty
339
+ extended_context = context_ids
340
+ if b > 0:
341
+ prev_blocks = mask_blocks[:b]
342
+ extended_context = torch.cat([context_ids] + [prev_blocks.view(1, -1)], dim=1)
343
+
344
+ block_logits[b:b+1] = _apply_sampling_controls(
345
+ block_logits[b:b+1],
346
+ extended_context,
347
+ mask_blocks[b:b+1],
348
+ is_masked[b:b+1],
349
+ repetition_penalty, temperature, top_k, top_p,
350
+ no_repeat_ngram_size, block_token_histories[b]
351
+ )
352
+
353
+ probs = F.softmax(block_logits, dim=-1)
354
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
355
+ probs = probs.clamp(min=1e-10)
356
+ probs = probs / probs.sum(dim=-1, keepdim=True)
357
+
358
+ # Sample for all batches
359
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
360
+ sampled_tokens = sampled_tokens.view(batch_size, block_size)
361
+
362
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
363
+
364
+ tokens_to_unmask = max(1, block_size // steps)
365
+ if step_idx == steps - 1:
366
+ tokens_to_unmask = block_size # Unmask all remaining
367
+
368
+ # Unmask for each batch item
369
+ for b in range(batch_size):
370
+ if is_masked[b].sum() > 0:
371
+ masked_confidence = confidence[b]
372
+ masked_confidence = masked_confidence.clone()
373
+ masked_confidence[~is_masked[b]] = -1.0
374
+
375
+ num_to_unmask = min(int(tokens_to_unmask), int(is_masked[b].sum().item()))
376
+ _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
377
+
378
+ for idx in top_indices:
379
+ idx = int(idx.item())
380
+ mask_blocks[b, idx] = sampled_tokens[b, idx]
381
+ is_masked[b, idx] = False
382
+ block_token_histories[b].append(sampled_tokens[b, idx].item())
383
+ all_generated_tokens.add(sampled_tokens[b, idx].item())
384
+
385
+ if callable(visualize_fn):
386
+ try:
387
+ block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
388
+ is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
389
+ visualize_fn(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
390
+ except Exception:
391
+ pass
392
+ elif visualize_fn:
393
+ block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
394
+ is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
395
+ visualize_diffusion_state_local(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
396
+
397
+ # Return list of generated blocks
398
+ return [mask_blocks[b:b+1] for b in range(batch_size)]
399
+
400
+
401
+ def chat(model, tokenizer, instruction: str, parallel_blocks: int = 1, **kwargs):
402
+ """Simple chat interface."""
403
+ device = next(model.parameters()).device
404
+
405
+ prompt = format_instruct_prompt(instruction)
406
+
407
+ generated = generate_block_diffusion(
408
+ model,
409
+ tokenizer,
410
+ prompt=prompt,
411
+ device=device,
412
+ parallel_blocks=parallel_blocks,
413
+ **kwargs
414
+ )
415
+
416
+ # Extract all assistant responses using ChatML tags
417
+ start_tag = "<|im_start|>assistant"
418
+ end_tag = "<|im_end|>"
419
+ resp_parts = []
420
+ pos = 0
421
+ while True:
422
+ start_idx = generated.find(start_tag, pos)
423
+ if start_idx == -1:
424
+ break
425
+ start_idx += len(start_tag)
426
+ end_idx = generated.find(end_tag, start_idx)
427
+ if end_idx == -1:
428
+ resp_parts.append(generated[start_idx:].strip())
429
+ break
430
+ resp_parts.append(generated[start_idx:end_idx].strip())
431
+ pos = end_idx + len(end_tag)
432
+
433
+ if resp_parts:
434
+ resp = "\n\n".join(p for p in resp_parts if p)
435
+ else:
436
+ # Fallback if no assistant tags found
437
+ resp = generated.replace("<|im_start|>assistant", "").replace("<|im_end|>", "").strip()
438
+
439
+ return generated, resp
440
+
441
+
442
+ def format_instruct_prompt(instruction: str) -> str:
443
+ """Format instruction using a simple ChatML-like template."""
444
+ return (
445
+ "<|im_start|>system\n"
446
+ "Answer this question truthfully<|im_end|>\n"
447
+ "<|im_start|>user\n"
448
+ f"{instruction}\n"
449
+ "<|im_end|>\n"
450
+ "<|im_start|>assistant\n"
451
+ )
452
+
453
+
454
+ def visualize_diffusion_state_local(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
455
+ """Local visualization copied from infer-base.py to ensure consistent terminal output."""
456
+ import sys
457
+ import os
458
+
459
+ # Default colors for different blocks (green, cyan, yellow, magenta)
460
+ DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
461
+ MASK_COLOR = '\033[90m' # Gray for masked tokens
462
+ RESET = '\033[0m'
463
+
464
+ # Normalize inputs to lists
465
+ if not isinstance(mask_blocks, list):
466
+ mask_blocks = [mask_blocks]
467
+ is_masked_list = [is_masked_list]
468
+
469
+ if block_colors is None:
470
+ block_colors = DEFAULT_COLORS
471
+
472
+ # Decode context (prompt + previously generated blocks) and replace newlines
473
+ try:
474
+ context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
475
+ except Exception:
476
+ # Fallback to str
477
+ context_text = str(context_ids[0].tolist())
478
+
479
+ # Build visualization for all blocks
480
+ all_blocks_text = []
481
+ for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
482
+ color = block_colors[block_idx % len(block_colors)]
483
+ block_tokens = mask_block[0].tolist()
484
+ block_color_tokens = []
485
+
486
+ for i, token_id in enumerate(block_tokens):
487
+ if is_masked[0, i]:
488
+ # Use block-specific color for masked tokens to distinguish blocks
489
+ block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
490
+ else:
491
+ # Decode individual token; use block color for revealed tokens
492
+ try:
493
+ token_text = tokenizer.decode([token_id], skip_special_tokens=False)
494
+ except Exception:
495
+ token_text = str(int(token_id))
496
+ block_color_tokens.append(f'{color}{token_text}{RESET}')
497
+
498
+ all_blocks_text.append(''.join(block_color_tokens))
499
+
500
+ # Join all blocks with a subtle separator
501
+ blocks_combined = ''.join(all_blocks_text)
502
+
503
+ # Overwrite previous visualization area (if any) by moving cursor up and clearing lines.
504
+ # This prevents accumulation of repeated frames in terminals like VSCode integrated terminal.
505
+ global _visualize_last_lines
506
+ if clear and _visualize_last_lines > 0:
507
+ try:
508
+ # Move cursor up to the start of the previous block
509
+ sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
510
+ # Clear each line that was previously printed
511
+ for _ in range(_visualize_last_lines):
512
+ sys.stdout.write('\x1b[2K') # Erase entire line
513
+ sys.stdout.write('\x1b[1B') # Move cursor down one line
514
+ # Move cursor back to the top of cleared region
515
+ sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
516
+ sys.stdout.flush()
517
+ except Exception:
518
+ # Fallback to whole-screen clear
519
+ try:
520
+ sys.stdout.write('\x1b[2J\x1b[H')
521
+ sys.stdout.flush()
522
+ except Exception:
523
+ try:
524
+ clear_cmd = 'cls' if os.name == 'nt' else 'clear'
525
+ os.system(clear_cmd)
526
+ except Exception:
527
+ sys.stdout.write('\r\033[K')
528
+ sys.stdout.flush()
529
+ elif clear:
530
+ # No previous region to overwrite; do a simple ANSI clear to start fresh
531
+ try:
532
+ sys.stdout.write('\x1b[2J\x1b[H')
533
+ sys.stdout.flush()
534
+ except Exception:
535
+ try:
536
+ clear_cmd = 'cls' if os.name == 'nt' else 'clear'
537
+ os.system(clear_cmd)
538
+ except Exception:
539
+ sys.stdout.write('\r\033[K')
540
+ sys.stdout.flush()
541
+
542
+ # Print legend for parallel blocks
543
+ if len(mask_blocks) > 1:
544
+ legend_parts = []
545
+ for i in range(len(mask_blocks)):
546
+ color = block_colors[i % len(block_colors)]
547
+ legend_parts.append(f'{color}Block {i+1}{RESET}')
548
+ print(f"Generating: {' | '.join(legend_parts)}\n")
549
+
550
+ # Print the full context with colored blocks
551
+ # Ensure trailing newline so subsequent clears have predictable behavior
552
+ out_text = f"{context_text}{blocks_combined}\n"
553
+ try:
554
+ sys.stdout.write(out_text)
555
+ sys.stdout.flush()
556
+ except Exception:
557
+ print(out_text, flush=True)
558
+
559
+ # Update last-lines counter so next frame can overwrite this one
560
+ try:
561
+ _visualize_last_lines = out_text.count('\n') + (1 if len(mask_blocks) > 1 else 0) + 1
562
+ except Exception:
563
+ _visualize_last_lines = out_text.count('\n')
564
+
565
+
566
+ def main():
567
+ base_path = os.path.join(os.path.dirname(__file__), "infer-base.py")
568
+ base_mod = try_import_infer_base(base_path)
569
+
570
+ if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'):
571
+ raise RuntimeError("DiffusionLLM not found in infer-base.py")
572
+
573
+ DiffusionLLM = base_mod.DiffusionLLM
574
+
575
+ # Workaround for torch.load pickling
576
+ try:
577
+ main_mod = sys.modules.get('__main__')
578
+ if main_mod is not None:
579
+ if hasattr(base_mod, 'ModelConfig'):
580
+ setattr(main_mod, 'ModelConfig', base_mod.ModelConfig)
581
+ setattr(main_mod, 'DiffusionLLM', DiffusionLLM)
582
+ except Exception:
583
+ pass
584
+
585
+ parser = argparse.ArgumentParser()
586
+ parser.add_argument("--model", type=str, default="./checkpoints/model_fp32.pt", help="Path to model checkpoint")
587
+ parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-0.5B", help="Tokenizer model id or path")
588
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
589
+ parser.add_argument("--visualize", action="store_true", default=False, help="Enable visualization during generation")
590
+ parser.add_argument("--steps", type=int, default=64)
591
+ parser.add_argument("--block_size", type=int, default=128)
592
+ parser.add_argument("--max_new_tokens", type=int, default=128)
593
+ parser.add_argument("--parallel_blocks", type=int, default=1, help="Number of blocks to generate in parallel")
594
+ args = parser.parse_args()
595
+
596
+ device = torch.device(args.device)
597
+ print(f"Using device: {device}")
598
+
599
+ # Load tokenizer
600
+ print("Loading tokenizer...")
601
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
602
+ if tokenizer.pad_token is None:
603
+ # set pad token if not present
604
+ tokenizer.pad_token = tokenizer.eos_token
605
+
606
+ # Load model
607
+ best_model_path = "checkpoints/best_model.pt"
608
+ if os.path.exists(best_model_path):
609
+ print("Loading best model...")
610
+ model, config = load_finetuned_model(best_model_path, device)
611
+ else:
612
+ model, config = load_finetuned_model(args.model, device)
613
+
614
+ # Use the local visualization implementation for consistency
615
+ visualize_fn = None
616
+ if args.visualize:
617
+ visualize_fn = visualize_diffusion_state_local
618
+
619
+ print("Ready. Type a message and press Enter (empty line to quit).\n")
620
+
621
+ while True:
622
+ try:
623
+ user_input = input("User: ").strip()
624
+ except (EOFError, KeyboardInterrupt):
625
+ print("\nExiting.")
626
+ break
627
+ if user_input == "":
628
+ print("Goodbye.")
629
+ break
630
+
631
+ raw_output, response = chat(
632
+ model,
633
+ tokenizer,
634
+ user_input,
635
+ steps=args.steps,
636
+ block_size=args.block_size,
637
+ max_new_tokens=args.max_new_tokens,
638
+ temperature=0.8,
639
+ top_k=50,
640
+ top_p=0.9,
641
+ repetition_penalty=1.2,
642
+ no_repeat_ngram_size=3,
643
+ verbose=False,
644
+ visualize_fn=visualize_fn,
645
+ parallel_blocks=args.parallel_blocks,
646
+ )
647
+
648
+ print("\nRaw Output:\n")
649
+ print(raw_output)
650
+ print("\nAssistant:\n")
651
+ print(response)
652
+ print("\n" + ("=" * 60) + "\n")
653
+
654
+
655
+ if __name__ == "__main__":
656
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask>=2.0
2
+ transformers>=4.0.0
3
+ torch
4
+ sentencepiece
5
+ flask_cors
static/ai.mp4 ADDED
Binary file (77.4 kB). View file
 
static/index.html ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>Diffusion LLM – Chat</title>
7
+ <!-- Tailwind CDN -->
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <!-- Inter Font -->
10
+ <link rel="preconnect" href="https://fonts.gstatic.com">
11
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
12
+ <style>
13
+ html,body{font-family:Inter,ui-sans-serif,system-ui,-apple-system,'Segoe UI',Roboto,'Helvetica Neue',Arial}
14
+ /* custome slider */
15
+ input[type=range] {
16
+ -webkit-appearance: none;
17
+ width: 100%;
18
+ height: 6px;
19
+ border-radius: 5px;
20
+ background: #e0e0e0;
21
+ outline: none;
22
+ }
23
+
24
+ input[type=range]::-webkit-slider-thumb {
25
+ -webkit-appearance: none;
26
+ appearance: none;
27
+ width: 16px;
28
+ height: 16px;
29
+ border-radius: 50%;
30
+ background: #6b21a8;
31
+ cursor: pointer;
32
+ box-shadow: 0 0 2px rgba(0,0,0,0.2);
33
+ transition: background 0.3s ease;
34
+ }
35
+
36
+ input[type=range]::-webkit-slider-thumb:hover {
37
+ background: #7c2dbe;
38
+ }
39
+
40
+ input[type=range]::-moz-range-thumb {
41
+ width: 16px;
42
+ height: 16px;
43
+ border-radius: 50%;
44
+ background: #6b21a8;
45
+ cursor: pointer;
46
+ box-shadow: 0 0 2px rgba(0,0,0,0.2);
47
+ transition: background 0.3s ease;
48
+ }
49
+
50
+ input[type=range]::-moz-range-thumb:hover {
51
+ background: #7c2dbe;
52
+ }
53
+
54
+
55
+
56
+ </style>
57
+ </head>
58
+ <body>
59
+ <div class="h-screen w-screen flex items-start gap-6 p-8 bg-gradient-to-br from-purple-50 to-purple-100">
60
+
61
+ <!-- Sidebar -->
62
+ <aside id="sidebar" class="w-64 h-full bg-white/90 flex flex-col items-center justify-between backdrop-blur-sm rounded-xl p-5 shadow-sm border border-gray-100">
63
+ <div>
64
+ <div class="flex items-center gap-3 mb-4">
65
+ <div class="w-9 h-9 rounded-md bg-gradient-to-br from-purple-200 to-purple-300"></div>
66
+ <div>
67
+ <div class="text-sm font-semibold text-slate-900">Cortex</div>
68
+ <div class="text-xs text-slate-500">Diffusion LLM</div>
69
+ </div>
70
+ </div>
71
+
72
+ <button id="new-chat" class="w-full inline-flex items-center justify-center gap-2 bg-black text-white py-2 rounded-full text-sm font-medium shadow-sm mb-4">+ New chat</button>
73
+
74
+ <nav class="w-full min-w-48 flex-1 flex flex-col gap-2 text-sm" id="chat-list" aria-label="Saved chats">
75
+ <!-- Chat items are dynamically injected here by JavaScript -->
76
+ </nav>
77
+ </div>
78
+
79
+ <div class="mt-6 text-xs text-slate-500 ">Signed in as <strong class="text-slate-700">you@example.com</strong></div>
80
+ </aside>
81
+
82
+ <!-- Main content -->
83
+ <main class="flex-1 flex items-center justify-center w-full h-full">
84
+ <div class="w-full bg-white rounded-2xl p-7 shadow-lg border border-gray-100 flex flex-col h-full">
85
+
86
+ <header class="flex items-center justify-between mb-3 border-b border-gray-200 pb-3">
87
+ <div class="flex items-center gap-3">
88
+ <button id="btn-toggle-sidebar" aria-label="Toggle sidebar" class="inline-flex items-center justify-center p-2 rounded-md bg-white shadow sm:hidden">☰</button>
89
+ <h1 id="app-title" class="text-lg font-semibold">Diffusion LLM Chat</h1>
90
+ </div>
91
+
92
+ <div class="flex items-center gap-3">
93
+ <button id="btn-load" class="bg-black text-white px-3 py-2 rounded-md text-sm font-medium">Load Model</button>
94
+ <span id="load-status" class="text-sm text-slate-500">Not loaded</span>
95
+ </div>
96
+ </header>
97
+
98
+ <section class="flex-1 flex flex-col overflow-hidden">
99
+ <div id="welcome" class="text-center py-6">
100
+ <div class="mx-auto w-24 h-24">
101
+ <video src="/static/ai.mp4" alt="Assistant Avatar" autoplay loop muted class="w-full h-full scale-[2] object-cover mix-blend-multiply" style="filter: hue-rotate(45deg)" />
102
+ </div>
103
+ <p class="mt-4 text-purple-600 font-medium">Hello, Jagrat Patel</p>
104
+ <h2 class="mt-2 text-2xl font-semibold text-slate-900">How can I assist you today?</h2>
105
+
106
+ <div class="mt-6 flex items-center justify-center gap-4 flex-wrap">
107
+ <button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Deeper Research &nbsp;<span class="block text-xs text-slate-500 mt-1">Ask for long-form, research-backed answers.</span></button>
108
+ <button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Saved prompts &nbsp;<span class="block text-xs text-slate-500 mt-1">Quickly reuse your favorite prompts.</span></button>
109
+ <button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Check Facts &nbsp;<span class="block text-xs text-slate-500 mt-1">Compare GDPR vs CCPA differences.</span></button>
110
+ </div>
111
+ </div>
112
+
113
+ <div id="chat" class="hidden flex-1 overflow-auto px-2 py-3" role="log" aria-live="polite">
114
+ <!-- messages injected here -->
115
+ </div>
116
+ </section>
117
+
118
+ <form id="prompt-form" class="mt-4 bg-white p-4 rounded-xl shadow-inner border border-gray-100" aria-label="Chat prompt">
119
+ <div class="mb-4 flex flex-row gap-4 flex-wrap items-center justify-between">
120
+ <div class="flex items-center gap-4 w-[24%]">
121
+ <label for="steps" class="text-sm font-medium text-slate-700">Steps:</label>
122
+ <input type="range" id="steps" min="1" max="100" value="64" class="flex-1">
123
+ <span id="steps-value" class="text-sm text-slate-500 w-8">64</span>
124
+ </div>
125
+ <div class="flex items-center gap-4 w-[24%]">
126
+ <label for="block_size" class="text-sm font-medium text-slate-700">Block Size:</label>
127
+ <input type="range" id="block_size" min="8" max="256" value="128" step="8" class="flex-1">
128
+ <span id="block_size-value" class="text-sm text-slate-500 w-8">128</span>
129
+ </div>
130
+ <div class="flex items-center gap-4 w-[24%]">
131
+ <label for="max_new_tokens" class="text-sm font-medium text-slate-700">Max New Tokens:</label>
132
+ <input type="range" id="max_new_tokens" min="32" max="1024" value="128" step="32" class="flex-1">
133
+ <span id="max_new_tokens-value" class="text-sm text-slate-500 w-8">128</span>
134
+ </div>
135
+ <div class="flex items-center gap-4 w-[24%]">
136
+ <label for="parallel_blocks" class="text-sm font-medium text-slate-700">Parallel Blocks:</label>
137
+ <input type="range" id="parallel_blocks" min="1" max="4" value="1" step="1" class="flex-1">
138
+ <span id="parallel_blocks-value" class="text-sm text-slate-500 w-8">1</span>
139
+ </div>
140
+ </div>
141
+ <div class="flex gap-3">
142
+ <textarea id="prompt" class="flex-1 resize-y rounded-lg border border-gray-200 p-3 text-sm focus:outline-none focus:ring-[1px] focus:ring-purple-500 focus:border-purple-500" placeholder="Ask me anything..." rows="2" aria-label="Message input"></textarea>
143
+ <div class="flex flex-col justify-between">
144
+ <button type="submit" id="btn-send" class="bg-black text-white px-4 py-2 rounded-md">Send</button>
145
+ </div>
146
+ </div>
147
+ </form>
148
+
149
+ <div class="mt-4 text-center text-xs text-slate-500">Model served by Flask. See README for run instructions.</div>
150
+ </div>
151
+ </main>
152
+ </div>
153
+
154
+ <script src="/static/main.js"></script>
155
+ </body>
156
+ </html>
static/main.js ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Global state
2
+ let isModelLoaded = false;
3
+
4
+ // DOM Elements
5
+ const els = {
6
+ chat: document.getElementById("chat"),
7
+ promptForm: document.getElementById("prompt-form"),
8
+ promptInput: document.getElementById("prompt"),
9
+ loadBtn: document.getElementById("btn-load"),
10
+ testStreamBtn: document.getElementById("btn-test-stream"),
11
+ status: document.getElementById("load-status"),
12
+ sidebar: document.getElementById("sidebar"),
13
+ sidebarToggle: document.getElementById("btn-toggle-sidebar"),
14
+ chatList: document.getElementById("chat-list"),
15
+ newChatBtn: document.getElementById("new-chat"),
16
+ sendBtn: document.getElementById("btn-send"),
17
+ steps: document.getElementById("steps"),
18
+ block_size: document.getElementById("block_size"),
19
+ max_new_tokens: document.getElementById("max_new_tokens"),
20
+ parallel_blocks: document.getElementById("parallel_blocks"),
21
+ stepsValue: document.getElementById("steps-value"),
22
+ block_sizeValue: document.getElementById("block_size-value"),
23
+ max_new_tokensValue: document.getElementById("max_new_tokens-value"),
24
+ parallel_blocksValue: document.getElementById("parallel_blocks-value"),
25
+ };
26
+
27
+ // Update slider values
28
+ els.steps.addEventListener("input", () => {
29
+ els.stepsValue.textContent = els.steps.value;
30
+ });
31
+ els.block_size.addEventListener("input", () => {
32
+ els.block_sizeValue.textContent = els.block_size.value;
33
+ });
34
+ els.max_new_tokens.addEventListener("input", () => {
35
+ els.max_new_tokensValue.textContent = els.max_new_tokens.value;
36
+ });
37
+ els.parallel_blocks.addEventListener("input", () => {
38
+ els.parallel_blocksValue.textContent = els.parallel_blocks.value;
39
+ });
40
+
41
+ // --- Logic ---
42
+
43
+ async function checkLoadStatus() {
44
+ try {
45
+ const res = await fetch("/api/load", {
46
+ method: "POST",
47
+ headers: { "Content-Type": "application/json" },
48
+ body: JSON.stringify({ check_only: true }),
49
+ });
50
+
51
+ if (res.ok) {
52
+ const data = await res.json();
53
+ if (data.loaded) {
54
+ isModelLoaded = true;
55
+ els.status.textContent = "Ready";
56
+ els.status.className = "text-sm text-green-600 font-medium";
57
+ els.loadBtn.style.display = 'none';
58
+ }
59
+ }
60
+ } catch (e) {
61
+ console.log("Model check failed:", e);
62
+ }
63
+ }
64
+
65
+ els.loadBtn.addEventListener("click", async () => {
66
+ els.loadBtn.disabled = true;
67
+ els.status.textContent = "Loading Model (this may take time)...";
68
+ els.status.className = "text-sm text-yellow-600 font-medium";
69
+
70
+ try {
71
+ const res = await fetch("/api/load", {
72
+ method: "POST",
73
+ headers: { "Content-Type": "application/json" },
74
+ body: JSON.stringify({ check_only: false }),
75
+ });
76
+ const data = await res.json();
77
+
78
+ if (res.ok) {
79
+ isModelLoaded = true;
80
+ els.status.textContent = "Model Loaded";
81
+ els.status.className = "text-sm text-green-600 font-medium";
82
+ els.loadBtn.style.display = 'none';
83
+ } else {
84
+ throw new Error(data.message || "Load failed");
85
+ }
86
+ } catch (e) {
87
+ els.status.textContent = "Error Loading";
88
+ els.status.className = "text-sm text-red-500";
89
+ alert("Error: " + e.message);
90
+ } finally {
91
+ els.loadBtn.disabled = false;
92
+ }
93
+ });
94
+
95
+ els.promptForm.addEventListener("submit", async (e) => {
96
+ e.preventDefault();
97
+
98
+ const text = els.promptInput.value.trim();
99
+ if (!text) return;
100
+
101
+ // UI Updates
102
+ addMessage("user", text);
103
+ els.promptInput.value = "";
104
+
105
+ // Create Assistant Bubble
106
+ const assistantBubble = addMessage("assistant", "");
107
+ const contentPre = assistantBubble.querySelector(".content");
108
+ const textContent = contentPre.querySelector(".text-content");
109
+
110
+ const visualizationDiv = document.createElement("div");
111
+ visualizationDiv.className = "visualization mb-2 font-mono text-xs";
112
+
113
+ // Loading spinner (SVG)
114
+ const spinner = document.createElement("div");
115
+ spinner.className = "flex items-center gap-2 text-slate-400";
116
+ spinner.innerHTML = `
117
+ <svg class="animate-spin h-4 w-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
118
+ <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
119
+ <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v4a4 4 0 00-4 4H4z"></path>
120
+ </svg>
121
+ <span class="text-xs">Generating...</span>
122
+ `;
123
+ visualizationDiv.appendChild(spinner);
124
+
125
+ contentPre.insertBefore(visualizationDiv, textContent);
126
+
127
+ // Disable send button
128
+ els.sendBtn.disabled = true;
129
+ els.sendBtn.textContent = "Generating...";
130
+ els.promptInput.disabled = true;
131
+
132
+ // Generate Request with Streaming
133
+ try {
134
+ const res = await fetch("/api/generate-stream", {
135
+ method: "POST",
136
+ headers: { "Content-Type": "application/json" },
137
+ body: JSON.stringify({
138
+ instruction: text,
139
+ steps: parseInt(els.steps.value),
140
+ block_size: parseInt(els.block_size.value),
141
+ max_new_tokens: parseInt(els.max_new_tokens.value),
142
+ parallel_blocks: parseInt(els.parallel_blocks.value),
143
+ }),
144
+ });
145
+
146
+ if (!res.ok) {
147
+ throw new Error(`Server Error ${res.status}`);
148
+ }
149
+
150
+ const reader = res.body.getReader();
151
+ const decoder = new TextDecoder();
152
+ let buffer = "";
153
+
154
+ while (true) {
155
+ const { done, value } = await reader.read();
156
+
157
+ if (done) break;
158
+
159
+ buffer += decoder.decode(value, { stream: true });
160
+ const lines = buffer.split("\n");
161
+ buffer = lines.pop(); // Keep incomplete line in buffer
162
+
163
+ for (const line of lines) {
164
+ if (line.startsWith("data: ")) {
165
+ const jsonStr = line.slice(6);
166
+ if (jsonStr.trim()) {
167
+ try {
168
+ const data = JSON.parse(jsonStr);
169
+ handleStreamEvent(data, visualizationDiv, textContent);
170
+ } catch (e) {
171
+ console.error("Failed to parse SSE data:", e);
172
+ }
173
+ }
174
+ }
175
+ }
176
+ }
177
+ } catch (error) {
178
+ if (textContent) textContent.textContent = `Error: ${error.message}`;
179
+ } finally {
180
+ els.sendBtn.disabled = false;
181
+ els.sendBtn.textContent = "Send";
182
+ els.promptInput.disabled = false;
183
+ }
184
+ });
185
+
186
+ function handleStreamEvent(data, visualizationDiv, textContent) {
187
+ if (data.type === "start") {
188
+ textContent.textContent = "";
189
+ } else if (data.type === "update") {
190
+ // Render visualization
191
+ renderVisualization(data.data, visualizationDiv);
192
+ scrollToBottom();
193
+ } else if (data.type === "complete") {
194
+ // Clear visualization and show final response
195
+ visualizationDiv.innerHTML = "";
196
+ textContent.textContent = data.response || "No response";
197
+ scrollToBottom();
198
+ } else if (data.type === "error") {
199
+ textContent.textContent = `Error: ${data.error}`;
200
+ }
201
+ }
202
+
203
+ function renderVisualization(vizData, container) {
204
+ // Clear previous content
205
+ container.innerHTML = "";
206
+
207
+ // Show context
208
+ const contextDiv = document.createElement("div");
209
+ contextDiv.className = "text-slate-600 mb-1";
210
+ contextDiv.textContent = vizData.context;
211
+ container.appendChild(contextDiv);
212
+
213
+ // Show blocks
214
+ const blocksDiv = document.createElement("div");
215
+ blocksDiv.classList.add("flex", "flex-wrap", "gap-0");
216
+
217
+ const blockColors = ["text-green-600", "text-cyan-600", "text-yellow-600", "text-purple-600"];
218
+
219
+ vizData.blocks.forEach((block, blockIdx) => {
220
+ const blockSpan = document.createElement("span");
221
+ blockSpan.className = blockColors[blockIdx % blockColors.length];
222
+
223
+ block.tokens.forEach((token) => {
224
+ if (token.type === "masked") {
225
+ const maskedSpan = document.createElement("span");
226
+ maskedSpan.className = blockColors[blockIdx % blockColors.length];
227
+ maskedSpan.innerText = token.text + " ";
228
+ blockSpan.appendChild(maskedSpan);
229
+ } else {
230
+ const textNode = document.createTextNode(token.text);
231
+ blockSpan.appendChild(textNode);
232
+ }
233
+ });
234
+
235
+ blocksDiv.appendChild(blockSpan);
236
+ });
237
+
238
+ container.appendChild(blocksDiv);
239
+
240
+ // Add legend if multiple blocks
241
+ if (vizData.num_blocks > 1) {
242
+ const legendDiv = document.createElement("div");
243
+ legendDiv.className = "text-xs text-slate-500 mt-1";
244
+ const legends = [];
245
+ for (let i = 0; i < vizData.num_blocks; i++) {
246
+ legends.push(`Block ${i + 1}`);
247
+ }
248
+ legendDiv.textContent = `Generating: ${legends.join(" | ")}`;
249
+ container.appendChild(legendDiv);
250
+ }
251
+ }
252
+
253
+ // --- UI Helpers ---
254
+
255
+ function addMessage(role, text) {
256
+ const wrapper = document.createElement("div");
257
+ wrapper.className = "mb-6 max-w-[100%] flex flex-col";
258
+
259
+ const bubble = document.createElement("div");
260
+ const isUser = role === "user";
261
+
262
+ bubble.className = isUser ? "self-end bg-slate-900 text-white p-4 rounded-2xl rounded-tr-sm max-w-[85%]" : "self-start bg-white border border-gray-200 text-slate-800 p-4 rounded-2xl rounded-tl-sm max-w-[65%] whitespace-pre-wrap overflow-x-auto shadow-sm flex flex-wrap";
263
+
264
+ // Main Content container that holds the response text
265
+ const pre = document.createElement("div");
266
+ pre.className = "content whitespace-pre-wrap font-sans text-sm leading-relaxed";
267
+
268
+ // The actual text content
269
+ const textSpan = document.createElement("span");
270
+ textSpan.className = "text-content";
271
+ textSpan.textContent = text;
272
+
273
+ pre.appendChild(textSpan);
274
+ bubble.appendChild(pre);
275
+ wrapper.appendChild(bubble);
276
+ els.chat.appendChild(wrapper);
277
+ scrollToBottom();
278
+
279
+ // Hide welcome screen
280
+ const welcome = document.getElementById("welcome");
281
+ if (welcome) {
282
+ welcome.classList.add("hidden");
283
+ }
284
+ els.chat.classList.remove("hidden");
285
+
286
+ return bubble;
287
+ }
288
+
289
+ function scrollToBottom() {
290
+ els.chat.scrollTop = els.chat.scrollHeight;
291
+ }
292
+
293
+ // Sidebar Toggle
294
+ els.sidebarToggle.addEventListener("click", () => {
295
+ els.sidebar.classList.toggle("-translate-x-full");
296
+ });
297
+
298
+ // New Chat Button
299
+ els.newChatBtn.addEventListener("click", () => {
300
+ // Clear chat
301
+ els.chat.innerHTML = "";
302
+ els.chat.classList.add("hidden");
303
+
304
+ // Show welcome screen
305
+ const welcome = document.getElementById("welcome");
306
+ if (welcome) {
307
+ welcome.classList.remove("hidden");
308
+ }
309
+
310
+ // Clear input
311
+ els.promptInput.value = "";
312
+ });
313
+
314
+ // Initialize
315
+ (async () => {
316
+ await checkLoadStatus();
317
+ if (!isModelLoaded) {
318
+ els.loadBtn.disabled = true;
319
+ els.status.textContent = "Loading Model (this may take time)...";
320
+ els.status.className = "text-sm text-yellow-600 font-medium";
321
+
322
+ try {
323
+ const res = await fetch("/api/load", {
324
+ method: "POST",
325
+ headers: { "Content-Type": "application/json" },
326
+ body: JSON.stringify({ check_only: false }),
327
+ });
328
+ const data = await res.json();
329
+
330
+ if (res.ok) {
331
+ isModelLoaded = true;
332
+ els.status.textContent = "Model Loaded";
333
+ els.status.className = "text-sm text-green-600 font-medium";
334
+ els.loadBtn.style.display = 'none';
335
+ } else {
336
+ throw new Error(data.message || "Load failed");
337
+ }
338
+ } catch (e) {
339
+ els.status.textContent = "Error Loading";
340
+ els.status.className = "text-sm text-red-500";
341
+ } finally {
342
+ els.loadBtn.disabled = false;
343
+ }
344
+ }
345
+ })();
346
+ els.chat.classList.add("hidden");