Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- Dockerfile +7 -0
- app.py +475 -0
- checkpoints/model_fp32.pt +3 -0
- design.json +185 -0
- infer-base.py +778 -0
- infer-chat.py +656 -0
- requirements.txt +5 -0
- static/ai.mp4 +0 -0
- static/index.html +156 -0
- static/main.js +346 -0
Dockerfile
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3
|
| 2 |
+
WORKDIR /usr/src/app
|
| 3 |
+
COPY requirements.txt ./
|
| 4 |
+
RUN pip install -r requirements.txt
|
| 5 |
+
COPY . .
|
| 6 |
+
EXPOSE 7860
|
| 7 |
+
CMD ["python","./app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import importlib.util
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from flask import Flask, request, jsonify, Response, stream_with_context
|
| 8 |
+
from flask_cors import CORS
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
app = Flask(__name__, static_folder='static', static_url_path='/static')
|
| 13 |
+
CORS(app)
|
| 14 |
+
|
| 15 |
+
# Global state
|
| 16 |
+
model = None
|
| 17 |
+
tokenizer = None
|
| 18 |
+
config = None
|
| 19 |
+
device = None
|
| 20 |
+
DiffusionLLM = None
|
| 21 |
+
chat_function = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def find_file(filename, search_dirs=None):
|
| 25 |
+
"""Find a file in current directory or parent directories."""
|
| 26 |
+
if search_dirs is None:
|
| 27 |
+
search_dirs = [
|
| 28 |
+
os.path.dirname(__file__), # Current directory
|
| 29 |
+
os.path.dirname(os.path.dirname(__file__)), # Parent directory
|
| 30 |
+
os.getcwd(), # Working directory
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
for directory in search_dirs:
|
| 34 |
+
filepath = os.path.join(directory, filename)
|
| 35 |
+
if os.path.exists(filepath):
|
| 36 |
+
print(f"Found {filename} at: {filepath}")
|
| 37 |
+
return filepath
|
| 38 |
+
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def try_import_module(filepath, module_name):
|
| 43 |
+
"""Dynamically import a Python file as a module."""
|
| 44 |
+
if not filepath or not os.path.exists(filepath):
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# Add the directory to sys.path
|
| 49 |
+
module_dir = os.path.dirname(filepath)
|
| 50 |
+
if module_dir not in sys.path:
|
| 51 |
+
sys.path.insert(0, module_dir)
|
| 52 |
+
|
| 53 |
+
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
| 54 |
+
if spec is None:
|
| 55 |
+
print(f"Could not create spec for {filepath}")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
module = importlib.util.module_from_spec(spec)
|
| 59 |
+
sys.modules[module_name] = module
|
| 60 |
+
spec.loader.exec_module(module)
|
| 61 |
+
|
| 62 |
+
print(f"Successfully imported {module_name} from {filepath}")
|
| 63 |
+
return module
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Error importing {filepath}: {e}")
|
| 66 |
+
import traceback
|
| 67 |
+
traceback.print_exc()
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_model_internal():
|
| 72 |
+
"""Load the model and tokenizer."""
|
| 73 |
+
global model, tokenizer, config, device, DiffusionLLM, chat_function
|
| 74 |
+
|
| 75 |
+
if model is not None:
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
print("=" * 60)
|
| 80 |
+
print("Starting model loading process...")
|
| 81 |
+
print("=" * 60)
|
| 82 |
+
|
| 83 |
+
# Find and import infer-base.py
|
| 84 |
+
base_path = find_file("infer-base.py")
|
| 85 |
+
if base_path is None:
|
| 86 |
+
raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
|
| 87 |
+
|
| 88 |
+
print(f"\nImporting infer-base.py from: {base_path}")
|
| 89 |
+
base_mod = try_import_module(base_path, "infer_base")
|
| 90 |
+
|
| 91 |
+
if base_mod is None:
|
| 92 |
+
raise RuntimeError("Failed to import infer-base.py")
|
| 93 |
+
|
| 94 |
+
# Check for DiffusionLLM class
|
| 95 |
+
if not hasattr(base_mod, 'DiffusionLLM'):
|
| 96 |
+
print("Available attributes in infer_base:", dir(base_mod))
|
| 97 |
+
raise RuntimeError("DiffusionLLM class not found in infer-base.py")
|
| 98 |
+
|
| 99 |
+
DiffusionLLM = base_mod.DiffusionLLM
|
| 100 |
+
print("✓ Successfully loaded DiffusionLLM class")
|
| 101 |
+
|
| 102 |
+
# Find and import infer-chat.py
|
| 103 |
+
chat_path = find_file("infer-chat.py")
|
| 104 |
+
if chat_path is None:
|
| 105 |
+
raise RuntimeError("Could not find infer-chat.py")
|
| 106 |
+
|
| 107 |
+
print(f"\nImporting infer-chat.py from: {chat_path}")
|
| 108 |
+
chat_mod = try_import_module(chat_path, "infer_chat")
|
| 109 |
+
|
| 110 |
+
if chat_mod is None or not hasattr(chat_mod, 'chat'):
|
| 111 |
+
raise RuntimeError("Failed to import chat function from infer-chat.py")
|
| 112 |
+
|
| 113 |
+
chat_function = chat_mod.chat
|
| 114 |
+
print("✓ Successfully loaded chat function")
|
| 115 |
+
|
| 116 |
+
# Setup pickling workaround for torch.load
|
| 117 |
+
try:
|
| 118 |
+
if hasattr(base_mod, 'ModelConfig'):
|
| 119 |
+
sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
|
| 120 |
+
sys.modules['__main__'].DiffusionLLM = DiffusionLLM
|
| 121 |
+
print("✓ Configured pickle support for model loading")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Warning: Could not setup pickle workaround: {e}")
|
| 124 |
+
|
| 125 |
+
# Set device
|
| 126 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 127 |
+
print(f"\n✓ Using device: {device}")
|
| 128 |
+
|
| 129 |
+
# Load tokenizer
|
| 130 |
+
print("\nLoading tokenizer...")
|
| 131 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 132 |
+
if tokenizer.pad_token is None:
|
| 133 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 134 |
+
print("✓ Tokenizer loaded")
|
| 135 |
+
|
| 136 |
+
# Find model checkpoint
|
| 137 |
+
checkpoint_dirs = [
|
| 138 |
+
"checkpoints",
|
| 139 |
+
"../checkpoints",
|
| 140 |
+
"./checkpoints",
|
| 141 |
+
os.path.join(os.path.dirname(__file__), "checkpoints"),
|
| 142 |
+
os.path.join(os.path.dirname(__file__), "../checkpoints"),
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
model_path = None
|
| 146 |
+
for checkpoint_dir in checkpoint_dirs:
|
| 147 |
+
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
| 148 |
+
fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
|
| 149 |
+
|
| 150 |
+
if os.path.exists(best_path):
|
| 151 |
+
model_path = best_path
|
| 152 |
+
break
|
| 153 |
+
elif os.path.exists(fp32_path):
|
| 154 |
+
model_path = fp32_path
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
if model_path is None:
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
"Could not find model checkpoint. Looking for:\n"
|
| 160 |
+
" - checkpoints/best_model.pt\n"
|
| 161 |
+
" - checkpoints/model_fp32.pt\n"
|
| 162 |
+
f"Searched directories: {checkpoint_dirs}"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
print(f"\n✓ Found model checkpoint: {model_path}")
|
| 166 |
+
print("Loading model weights (this may take a minute)...")
|
| 167 |
+
|
| 168 |
+
# Load model
|
| 169 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
| 170 |
+
config = checkpoint['config']
|
| 171 |
+
|
| 172 |
+
print("Creating model...")
|
| 173 |
+
model = DiffusionLLM(config)
|
| 174 |
+
|
| 175 |
+
print("Loading state dict...")
|
| 176 |
+
state_dict = checkpoint['model_state']
|
| 177 |
+
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 178 |
+
model.load_state_dict(state_dict)
|
| 179 |
+
|
| 180 |
+
model = model.to(device)
|
| 181 |
+
model.eval()
|
| 182 |
+
|
| 183 |
+
num_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 184 |
+
print(f"\n{'=' * 60}")
|
| 185 |
+
print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓")
|
| 186 |
+
print(f"{'=' * 60}")
|
| 187 |
+
print(f"Parameters: {num_params:.1f}M")
|
| 188 |
+
if 'step' in checkpoint:
|
| 189 |
+
print(f"Training steps: {checkpoint['step']}")
|
| 190 |
+
if 'best_val_loss' in checkpoint:
|
| 191 |
+
print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 192 |
+
print(f"{'=' * 60}\n")
|
| 193 |
+
|
| 194 |
+
return True
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print("\n" + "=" * 60)
|
| 198 |
+
print("ERROR LOADING MODEL")
|
| 199 |
+
print("=" * 60)
|
| 200 |
+
print(f"Error: {e}")
|
| 201 |
+
import traceback
|
| 202 |
+
traceback.print_exc()
|
| 203 |
+
print("=" * 60 + "\n")
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def create_streaming_visualizer():
|
| 208 |
+
"""Create a visualizer that yields SSE events instead of printing to terminal."""
|
| 209 |
+
def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
|
| 210 |
+
# Normalize inputs to lists
|
| 211 |
+
if not isinstance(mask_blocks, list):
|
| 212 |
+
mask_blocks = [mask_blocks]
|
| 213 |
+
is_masked_list = [is_masked_list]
|
| 214 |
+
|
| 215 |
+
# Decode context
|
| 216 |
+
try:
|
| 217 |
+
context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
|
| 218 |
+
except Exception:
|
| 219 |
+
context_text = str(context_ids[0].tolist())
|
| 220 |
+
|
| 221 |
+
# Build blocks visualization
|
| 222 |
+
all_blocks = []
|
| 223 |
+
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
|
| 224 |
+
block_tokens = mask_block[0].tolist()
|
| 225 |
+
block_data = []
|
| 226 |
+
|
| 227 |
+
for i, token_id in enumerate(block_tokens):
|
| 228 |
+
if is_masked[0, i]:
|
| 229 |
+
block_data.append({
|
| 230 |
+
'type': 'masked',
|
| 231 |
+
'text': '███'
|
| 232 |
+
})
|
| 233 |
+
else:
|
| 234 |
+
try:
|
| 235 |
+
token_text = tok.decode([token_id], skip_special_tokens=False)
|
| 236 |
+
except Exception:
|
| 237 |
+
token_text = str(int(token_id))
|
| 238 |
+
block_data.append({
|
| 239 |
+
'type': 'revealed',
|
| 240 |
+
'text': token_text
|
| 241 |
+
})
|
| 242 |
+
|
| 243 |
+
all_blocks.append({
|
| 244 |
+
'block_index': block_idx,
|
| 245 |
+
'tokens': block_data
|
| 246 |
+
})
|
| 247 |
+
|
| 248 |
+
# Return data structure that will be sent as SSE
|
| 249 |
+
return {
|
| 250 |
+
'context': context_text,
|
| 251 |
+
'blocks': all_blocks,
|
| 252 |
+
'num_blocks': len(mask_blocks)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
return visualizer
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@app.route('/')
|
| 259 |
+
def index():
|
| 260 |
+
"""Serve the main HTML page."""
|
| 261 |
+
return app.send_static_file('index.html')
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@app.route('/api/load', methods=['POST'])
|
| 265 |
+
def load_model_endpoint():
|
| 266 |
+
"""Load the model."""
|
| 267 |
+
data = request.json or {}
|
| 268 |
+
check_only = data.get('check_only', False)
|
| 269 |
+
|
| 270 |
+
global model
|
| 271 |
+
|
| 272 |
+
if check_only:
|
| 273 |
+
return jsonify({
|
| 274 |
+
'loaded': model is not None,
|
| 275 |
+
'message': 'Model is loaded' if model is not None else 'Model not loaded'
|
| 276 |
+
})
|
| 277 |
+
|
| 278 |
+
if model is not None:
|
| 279 |
+
return jsonify({
|
| 280 |
+
'loaded': True,
|
| 281 |
+
'message': 'Model already loaded'
|
| 282 |
+
})
|
| 283 |
+
|
| 284 |
+
success = load_model_internal()
|
| 285 |
+
|
| 286 |
+
if success:
|
| 287 |
+
return jsonify({
|
| 288 |
+
'loaded': True,
|
| 289 |
+
'message': 'Model loaded successfully'
|
| 290 |
+
})
|
| 291 |
+
else:
|
| 292 |
+
return jsonify({
|
| 293 |
+
'loaded': False,
|
| 294 |
+
'message': 'Failed to load model. Check server logs for details.'
|
| 295 |
+
}), 500
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@app.route('/api/generate', methods=['POST'])
|
| 299 |
+
def generate():
|
| 300 |
+
"""Generate response without streaming."""
|
| 301 |
+
global model, tokenizer, config, device, chat_function
|
| 302 |
+
|
| 303 |
+
if model is None:
|
| 304 |
+
return jsonify({'error': 'Model not loaded'}), 400
|
| 305 |
+
|
| 306 |
+
if chat_function is None:
|
| 307 |
+
return jsonify({'error': 'Chat function not available'}), 400
|
| 308 |
+
|
| 309 |
+
data = request.json
|
| 310 |
+
instruction = data.get('instruction', '')
|
| 311 |
+
steps = data.get('steps', 64)
|
| 312 |
+
block_size = data.get('block_size', 128)
|
| 313 |
+
max_new_tokens = data.get('max_new_tokens', 128)
|
| 314 |
+
parallel_blocks = data.get('parallel_blocks', 1)
|
| 315 |
+
|
| 316 |
+
if not instruction:
|
| 317 |
+
return jsonify({'error': 'No instruction provided'}), 400
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
# Generate response
|
| 321 |
+
raw_output, response = chat_function(
|
| 322 |
+
model,
|
| 323 |
+
tokenizer,
|
| 324 |
+
instruction,
|
| 325 |
+
steps=steps,
|
| 326 |
+
block_size=block_size,
|
| 327 |
+
max_new_tokens=max_new_tokens,
|
| 328 |
+
temperature=0.8,
|
| 329 |
+
top_k=50,
|
| 330 |
+
top_p=0.9,
|
| 331 |
+
repetition_penalty=1.2,
|
| 332 |
+
no_repeat_ngram_size=3,
|
| 333 |
+
verbose=False,
|
| 334 |
+
visualize_fn=None,
|
| 335 |
+
parallel_blocks=parallel_blocks,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return jsonify({
|
| 339 |
+
'response': response,
|
| 340 |
+
'raw_output': raw_output
|
| 341 |
+
})
|
| 342 |
+
except Exception as e:
|
| 343 |
+
import traceback
|
| 344 |
+
traceback.print_exc()
|
| 345 |
+
return jsonify({'error': str(e)}), 500
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@app.route('/api/generate-stream', methods=['POST'])
|
| 349 |
+
def generate_stream():
|
| 350 |
+
"""Generate response with streaming visualization."""
|
| 351 |
+
global model, tokenizer, config, device, chat_function
|
| 352 |
+
|
| 353 |
+
if model is None:
|
| 354 |
+
return jsonify({'error': 'Model not loaded'}), 400
|
| 355 |
+
|
| 356 |
+
if chat_function is None:
|
| 357 |
+
return jsonify({'error': 'Chat function not available'}), 400
|
| 358 |
+
|
| 359 |
+
data = request.json
|
| 360 |
+
instruction = data.get('instruction', '')
|
| 361 |
+
steps = data.get('steps', 64)
|
| 362 |
+
block_size = data.get('block_size', 128)
|
| 363 |
+
max_new_tokens = data.get('max_new_tokens', 128)
|
| 364 |
+
parallel_blocks = data.get('parallel_blocks', 1)
|
| 365 |
+
|
| 366 |
+
if not instruction:
|
| 367 |
+
return jsonify({'error': 'No instruction provided'}), 400
|
| 368 |
+
|
| 369 |
+
def generate_events():
|
| 370 |
+
try:
|
| 371 |
+
# Import threading to allow yielding from callback
|
| 372 |
+
import queue
|
| 373 |
+
event_queue = queue.Queue()
|
| 374 |
+
generation_complete = {'done': False, 'result': None}
|
| 375 |
+
|
| 376 |
+
def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
|
| 377 |
+
"""This gets called during generation - we need to send events immediately"""
|
| 378 |
+
visualizer = create_streaming_visualizer()
|
| 379 |
+
data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
|
| 380 |
+
# Put the update in the queue so it can be yielded immediately
|
| 381 |
+
event_queue.put({'type': 'update', 'data': data})
|
| 382 |
+
|
| 383 |
+
# Start generation in a separate thread so we can yield events as they come
|
| 384 |
+
import threading
|
| 385 |
+
|
| 386 |
+
def run_generation():
|
| 387 |
+
try:
|
| 388 |
+
raw_output, response = chat_function(
|
| 389 |
+
model,
|
| 390 |
+
tokenizer,
|
| 391 |
+
instruction,
|
| 392 |
+
steps=steps,
|
| 393 |
+
block_size=block_size,
|
| 394 |
+
max_new_tokens=max_new_tokens,
|
| 395 |
+
temperature=0.8,
|
| 396 |
+
top_k=50,
|
| 397 |
+
top_p=0.9,
|
| 398 |
+
repetition_penalty=1.2,
|
| 399 |
+
no_repeat_ngram_size=3,
|
| 400 |
+
verbose=False,
|
| 401 |
+
visualize_fn=streaming_visualizer,
|
| 402 |
+
parallel_blocks=parallel_blocks,
|
| 403 |
+
)
|
| 404 |
+
generation_complete['result'] = (raw_output, response)
|
| 405 |
+
except Exception as e:
|
| 406 |
+
generation_complete['result'] = ('error', str(e))
|
| 407 |
+
finally:
|
| 408 |
+
generation_complete['done'] = True
|
| 409 |
+
event_queue.put(None) # Signal completion
|
| 410 |
+
|
| 411 |
+
# Start generation thread
|
| 412 |
+
gen_thread = threading.Thread(target=run_generation)
|
| 413 |
+
gen_thread.daemon = True
|
| 414 |
+
gen_thread.start()
|
| 415 |
+
|
| 416 |
+
# Yield start event
|
| 417 |
+
yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
|
| 418 |
+
|
| 419 |
+
# Yield events as they come from the queue
|
| 420 |
+
while not generation_complete['done'] or not event_queue.empty():
|
| 421 |
+
try:
|
| 422 |
+
event = event_queue.get(timeout=0.1)
|
| 423 |
+
if event is None: # Completion signal
|
| 424 |
+
break
|
| 425 |
+
yield f"data: {json.dumps(event)}\n\n"
|
| 426 |
+
except queue.Empty:
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
# Wait for thread to finish
|
| 430 |
+
gen_thread.join(timeout=1.0)
|
| 431 |
+
|
| 432 |
+
# Send final response
|
| 433 |
+
if generation_complete['result']:
|
| 434 |
+
raw_output, response = generation_complete['result']
|
| 435 |
+
if raw_output == 'error':
|
| 436 |
+
yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
|
| 437 |
+
else:
|
| 438 |
+
yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
|
| 439 |
+
|
| 440 |
+
except Exception as e:
|
| 441 |
+
import traceback
|
| 442 |
+
traceback.print_exc()
|
| 443 |
+
yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
|
| 444 |
+
|
| 445 |
+
return Response(
|
| 446 |
+
stream_with_context(generate_events()),
|
| 447 |
+
mimetype='text/event-stream',
|
| 448 |
+
headers={
|
| 449 |
+
'Cache-Control': 'no-cache',
|
| 450 |
+
'X-Accel-Buffering': 'no'
|
| 451 |
+
}
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@app.route('/api/test-stream', methods=['GET'])
|
| 456 |
+
def test_stream():
|
| 457 |
+
"""Test streaming endpoint."""
|
| 458 |
+
def generate():
|
| 459 |
+
for i in range(10):
|
| 460 |
+
yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
|
| 461 |
+
time.sleep(0.5)
|
| 462 |
+
yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
|
| 463 |
+
|
| 464 |
+
return Response(
|
| 465 |
+
stream_with_context(generate()),
|
| 466 |
+
mimetype='text/event-stream',
|
| 467 |
+
headers={
|
| 468 |
+
'Cache-Control': 'no-cache',
|
| 469 |
+
'X-Accel-Buffering': 'no'
|
| 470 |
+
}
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
if __name__ == '__main__':
|
| 475 |
+
app.run(debug=True, host='0.0.0.0', port=5000, threaded=True)
|
checkpoints/model_fp32.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26b941c479671cff7d0d93fc1d30711ce717de1abedee1e30c0871a4874db79d
|
| 3 |
+
size 491091299
|
design.json
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"design_system": {
|
| 3 |
+
"name": "Cortex Luminance System",
|
| 4 |
+
"description": "A physics-based design system combining soft aesthetic minimalism with strict luminance layering. It relies on lighting simulation (top highlights, bottom shadows) rather than diverse hues to create depth hierarchy.",
|
| 5 |
+
"version": "1.0.0",
|
| 6 |
+
"mode": "light",
|
| 7 |
+
"philosophy": {
|
| 8 |
+
"core_principle": "Depth through Luminance",
|
| 9 |
+
"lighting_source": "Top-down (90 degrees)",
|
| 10 |
+
"surface_material": "Matte white & Soft Glass",
|
| 11 |
+
"accent_strategy": "Functional Purple (oklch 0.65 0.22 290)",
|
| 12 |
+
"layering_logic": "Higher elevation = Higher lightness (or pure white) + Stronger Shadow. Lower elevation = Lower lightness + Inset Shadow."
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
+
"tokens": {
|
| 16 |
+
"colors": {
|
| 17 |
+
"primitives": {
|
| 18 |
+
"base_hue": "270 (Purple/Violet)",
|
| 19 |
+
"neutral_hue": "265 (Cool Gray)"
|
| 20 |
+
},
|
| 21 |
+
"layers": {
|
| 22 |
+
"bg_root": {
|
| 23 |
+
"value": "linear-gradient(135deg, oklch(0.95 0.02 270) 0%, oklch(0.92 0.03 290) 100%)",
|
| 24 |
+
"description": "Level 0: The ambient canvas. Corresponds to the blurry cloud/gradient background."
|
| 25 |
+
},
|
| 26 |
+
"bg_layer_1": {
|
| 27 |
+
"value": "oklch(0.99 0.005 265)",
|
| 28 |
+
"description": "Level 1: The main application window/sidebar surface. Almost white."
|
| 29 |
+
},
|
| 30 |
+
"bg_layer_2": {
|
| 31 |
+
"value": "oklch(1.0 0 0)",
|
| 32 |
+
"description": "Level 2: Cards, Floating Inputs, Modals. Pure White."
|
| 33 |
+
},
|
| 34 |
+
"bg_sunken": {
|
| 35 |
+
"value": "oklch(0.96 0.01 265)",
|
| 36 |
+
"description": "For inset elements (search bars, progress tracks). Slightly darker than layer 1 to simulate depth."
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"text": {
|
| 40 |
+
"primary": "oklch(0.20 0.02 265)",
|
| 41 |
+
"secondary": "oklch(0.55 0.03 265)",
|
| 42 |
+
"accent": "oklch(0.65 0.22 290)"
|
| 43 |
+
},
|
| 44 |
+
"borders": {
|
| 45 |
+
"subtle": "rgba(0, 0, 0, 0.06)",
|
| 46 |
+
"highlight": "rgba(255, 255, 255, 0.8)"
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"typography": {
|
| 50 |
+
"font_family": "Inter, SF Pro Display, system-ui, sans-serif",
|
| 51 |
+
"weights": {
|
| 52 |
+
"regular": 400,
|
| 53 |
+
"medium": 500,
|
| 54 |
+
"semibold": 600
|
| 55 |
+
},
|
| 56 |
+
"scale": {
|
| 57 |
+
"h1": {
|
| 58 |
+
"size": "32px",
|
| 59 |
+
"weight": 600,
|
| 60 |
+
"letter_spacing": "-0.02em"
|
| 61 |
+
},
|
| 62 |
+
"h2": {
|
| 63 |
+
"size": "24px",
|
| 64 |
+
"weight": 500,
|
| 65 |
+
"letter_spacing": "-0.01em"
|
| 66 |
+
},
|
| 67 |
+
"body_lg": {
|
| 68 |
+
"size": "16px",
|
| 69 |
+
"weight": 400
|
| 70 |
+
},
|
| 71 |
+
"body_sm": {
|
| 72 |
+
"size": "14px",
|
| 73 |
+
"weight": 400
|
| 74 |
+
},
|
| 75 |
+
"caption": {
|
| 76 |
+
"size": "12px",
|
| 77 |
+
"weight": 500,
|
| 78 |
+
"uppercase": false
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
},
|
| 82 |
+
"spacing": {
|
| 83 |
+
"xs": "4px",
|
| 84 |
+
"sm": "8px",
|
| 85 |
+
"md": "16px",
|
| 86 |
+
"lg": "24px",
|
| 87 |
+
"xl": "32px",
|
| 88 |
+
"container_padding": "20px"
|
| 89 |
+
},
|
| 90 |
+
"radii": {
|
| 91 |
+
"sm": "8px",
|
| 92 |
+
"md": "12px",
|
| 93 |
+
"lg": "16px",
|
| 94 |
+
"full": "9999px (Pill)"
|
| 95 |
+
},
|
| 96 |
+
"shadows": {
|
| 97 |
+
"note": "Shadows must imply a top-down light source. Always pair drop-shadows with top-edge inset highlights.",
|
| 98 |
+
"elevation_low": {
|
| 99 |
+
"css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 1px 2px 0 rgba(0, 0, 0, 0.05)",
|
| 100 |
+
"use_case": "Interactive buttons, list items."
|
| 101 |
+
},
|
| 102 |
+
"elevation_medium": {
|
| 103 |
+
"css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03)",
|
| 104 |
+
"use_case": "Standard Cards (Saved Prompts, Suggestions)."
|
| 105 |
+
},
|
| 106 |
+
"elevation_high": {
|
| 107 |
+
"css_value": "box-shadow: inset 0 1px 0 0 rgba(255, 255, 255, 1), 0 10px 15px -3px rgba(0, 0, 0, 0.08), 0 4px 6px -2px rgba(0, 0, 0, 0.04)",
|
| 108 |
+
"use_case": "Floating Input Area, Modals."
|
| 109 |
+
},
|
| 110 |
+
"inset_sunken": {
|
| 111 |
+
"css_value": "box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06), inset 0 -1px 0 0 rgba(255, 255, 255, 0.5)",
|
| 112 |
+
"use_case": "Search bars, tracks, unselected states."
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
},
|
| 116 |
+
"components": {
|
| 117 |
+
"layout_structure": {
|
| 118 |
+
"sidebar": {
|
| 119 |
+
"width": "260px",
|
| 120 |
+
"background": "bg_layer_1",
|
| 121 |
+
"border_right": "1px solid borders.subtle",
|
| 122 |
+
"padding": "md",
|
| 123 |
+
"style": "Flat surface, low contrast."
|
| 124 |
+
},
|
| 125 |
+
"main_area": {
|
| 126 |
+
"background": "bg_layer_2 (with large rounded corners) OR transparent over bg_root",
|
| 127 |
+
"layout": "Flex-col, centered content, maximum width 900px."
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"buttons": {
|
| 131 |
+
"primary": {
|
| 132 |
+
"bg": "black (or dark purple)",
|
| 133 |
+
"text": "white",
|
| 134 |
+
"radius": "md",
|
| 135 |
+
"shadow": "elevation_low",
|
| 136 |
+
"lighting": "Subtle top gradient (lighter top) to show curvature."
|
| 137 |
+
},
|
| 138 |
+
"ghost": {
|
| 139 |
+
"bg": "transparent",
|
| 140 |
+
"hover_bg": "rgba(0,0,0,0.04)",
|
| 141 |
+
"text": "text.secondary"
|
| 142 |
+
},
|
| 143 |
+
"new_chat": {
|
| 144 |
+
"style": "Pill shape / Full radius",
|
| 145 |
+
"bg": "#1A1A1A",
|
| 146 |
+
"text": "white",
|
| 147 |
+
"icon": "plus"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"cards": {
|
| 151 |
+
"prompt_card": {
|
| 152 |
+
"bg": "bg_layer_2",
|
| 153 |
+
"radius": "lg",
|
| 154 |
+
"shadow": "elevation_medium",
|
| 155 |
+
"border": "1px solid borders.subtle",
|
| 156 |
+
"hover": "Transform Y -2px, increase shadow to elevation_high."
|
| 157 |
+
}
|
| 158 |
+
},
|
| 159 |
+
"inputs": {
|
| 160 |
+
"search_bar": {
|
| 161 |
+
"style": "Sunken / Inset",
|
| 162 |
+
"bg": "bg_sunken",
|
| 163 |
+
"shadow": "inset_sunken",
|
| 164 |
+
"radius": "md",
|
| 165 |
+
"icon_color": "text.secondary"
|
| 166 |
+
},
|
| 167 |
+
"main_prompt_area": {
|
| 168 |
+
"style": "Elevated Container",
|
| 169 |
+
"bg": "white",
|
| 170 |
+
"shadow": "elevation_high",
|
| 171 |
+
"radius": "lg",
|
| 172 |
+
"border": "1px solid rgba(0,0,0,0.04)"
|
| 173 |
+
}
|
| 174 |
+
},
|
| 175 |
+
"navigation_items": {
|
| 176 |
+
"base_style": "text.secondary, font-medium, md padding",
|
| 177 |
+
"active_state": {
|
| 178 |
+
"bg": "bg_layer_2",
|
| 179 |
+
"text": "text.primary",
|
| 180 |
+
"shadow": "elevation_low",
|
| 181 |
+
"radius": "sm"
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
}
|
infer-base.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ============== Model Architecture ==============
|
| 11 |
+
|
| 12 |
+
class RMSNorm(nn.Module):
|
| 13 |
+
"""Root Mean Square Layer Normalization."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.eps = eps
|
| 18 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 22 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 23 |
+
return self.weight * x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RotaryEmbedding(nn.Module):
|
| 27 |
+
"""Rotary Position Embeddings (RoPE) with NTK extrapolation."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.scaling_factor = scaling_factor
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.base = base
|
| 34 |
+
self.max_position_embeddings = max_position_embeddings
|
| 35 |
+
self.inv_freq = None
|
| 36 |
+
self._cache = {}
|
| 37 |
+
|
| 38 |
+
def _update_freqs(self, device):
|
| 39 |
+
base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
|
| 40 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 41 |
+
self.inv_freq = inv_freq
|
| 42 |
+
|
| 43 |
+
def forward(self, x, seq_len=None):
|
| 44 |
+
if seq_len is None:
|
| 45 |
+
seq_len = x.shape[-2]
|
| 46 |
+
|
| 47 |
+
if self.inv_freq is None or self.inv_freq.device != x.device:
|
| 48 |
+
self._update_freqs(x.device)
|
| 49 |
+
|
| 50 |
+
cache_key = (seq_len, x.device, x.dtype)
|
| 51 |
+
if cache_key in self._cache:
|
| 52 |
+
return self._cache[cache_key]
|
| 53 |
+
|
| 54 |
+
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
| 55 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 56 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 57 |
+
|
| 58 |
+
cos = emb.cos()[None, None, :, :]
|
| 59 |
+
sin = emb.sin()[None, None, :, :]
|
| 60 |
+
|
| 61 |
+
self._cache[cache_key] = (cos, sin)
|
| 62 |
+
if len(self._cache) > 10:
|
| 63 |
+
self._cache.pop(next(iter(self._cache)))
|
| 64 |
+
|
| 65 |
+
return cos, sin
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 69 |
+
"""Apply rotary embeddings to Q and K."""
|
| 70 |
+
def rotate_half(x):
|
| 71 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 72 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 73 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 74 |
+
|
| 75 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 76 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 77 |
+
return q_embed, k_embed
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class DiffusionAttention(nn.Module):
|
| 81 |
+
"""Multi-head attention with GQA and Flash Attention support."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, config):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.hidden_size = config.hidden_size
|
| 86 |
+
self.num_heads = config.num_attention_heads
|
| 87 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 88 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 89 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 90 |
+
self.use_flash_attn = config.use_flash_attn
|
| 91 |
+
|
| 92 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 93 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 94 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 95 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
|
| 98 |
+
bsz, q_len, _ = hidden_states.size()
|
| 99 |
+
|
| 100 |
+
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 101 |
+
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 102 |
+
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
cos, sin = freqs_cis
|
| 105 |
+
cos = cos[:, :, :q_len, :]
|
| 106 |
+
sin = sin[:, :, :q_len, :]
|
| 107 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 108 |
+
|
| 109 |
+
if past_kv is not None:
|
| 110 |
+
cache_k, cache_v = past_kv
|
| 111 |
+
k = torch.cat([cache_k, k], dim=2)
|
| 112 |
+
v = torch.cat([cache_v, v], dim=2)
|
| 113 |
+
|
| 114 |
+
current_kv = (k, v)
|
| 115 |
+
|
| 116 |
+
k = k.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 117 |
+
v = v.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 118 |
+
|
| 119 |
+
attn_mask = None
|
| 120 |
+
if attention_mask is not None:
|
| 121 |
+
attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
|
| 122 |
+
attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
|
| 123 |
+
|
| 124 |
+
output = F.scaled_dot_product_attention(
|
| 125 |
+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
|
| 129 |
+
return self.o_proj(output), current_kv
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class MLP(nn.Module):
|
| 133 |
+
"""Gated MLP with SiLU activation."""
|
| 134 |
+
|
| 135 |
+
def __init__(self, config):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 138 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 139 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 140 |
+
self.act_fn = nn.SiLU()
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class BlockDiffusionBlock(nn.Module):
|
| 147 |
+
"""Transformer block with pre-norm, attention, and MLP."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, config):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.self_attn = DiffusionAttention(config)
|
| 152 |
+
self.mlp = MLP(config)
|
| 153 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 154 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 155 |
+
self.use_activation_checkpointing = config.use_activation_checkpointing
|
| 156 |
+
|
| 157 |
+
def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
|
| 158 |
+
return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
|
| 159 |
+
|
| 160 |
+
def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
|
| 161 |
+
residual = hidden_states
|
| 162 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 163 |
+
attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
|
| 164 |
+
hidden_states = residual + attn_out
|
| 165 |
+
|
| 166 |
+
residual = hidden_states
|
| 167 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 168 |
+
hidden_states = residual + self.mlp(hidden_states)
|
| 169 |
+
return hidden_states, new_kv
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@dataclass
|
| 173 |
+
class ModelConfig:
|
| 174 |
+
"""Model architecture configuration."""
|
| 175 |
+
vocab_size: int = 151936
|
| 176 |
+
hidden_size: int = 1024
|
| 177 |
+
intermediate_size: int = 2816
|
| 178 |
+
num_hidden_layers: int = 16
|
| 179 |
+
num_attention_heads: int = 16
|
| 180 |
+
num_key_value_heads: int = 4
|
| 181 |
+
max_position_embeddings: int = 16384
|
| 182 |
+
rms_norm_eps: float = 1e-6
|
| 183 |
+
rope_theta: float = 100000.0
|
| 184 |
+
pad_token_id: int = 0
|
| 185 |
+
mask_token_id: int = 1
|
| 186 |
+
use_flash_attn: bool = True
|
| 187 |
+
use_activation_checkpointing: bool = False
|
| 188 |
+
attention_dropout: float = 0.0
|
| 189 |
+
hidden_dropout: float = 0.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class DiffusionLLM(nn.Module):
|
| 193 |
+
"""Complete diffusion language model."""
|
| 194 |
+
|
| 195 |
+
def __init__(self, config: ModelConfig):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.config = config
|
| 198 |
+
|
| 199 |
+
pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
|
| 200 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
|
| 201 |
+
|
| 202 |
+
self.layers = nn.ModuleList([BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
|
| 203 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 204 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 205 |
+
self.rotary_emb = RotaryEmbedding(
|
| 206 |
+
config.hidden_size // config.num_attention_heads,
|
| 207 |
+
config.max_position_embeddings
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 211 |
+
|
| 212 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None):
|
| 213 |
+
bsz, seqlen = input_ids.shape
|
| 214 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 215 |
+
freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
|
| 216 |
+
|
| 217 |
+
if past_key_values is None:
|
| 218 |
+
past_key_values = [None] * len(self.layers)
|
| 219 |
+
|
| 220 |
+
new_kvs = []
|
| 221 |
+
for i, layer in enumerate(self.layers):
|
| 222 |
+
hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
|
| 223 |
+
new_kvs.append(kv)
|
| 224 |
+
|
| 225 |
+
hidden_states = self.norm(hidden_states)
|
| 226 |
+
logits = self.lm_head(hidden_states)
|
| 227 |
+
return logits, new_kvs
|
| 228 |
+
|
| 229 |
+
def get_num_params(self, trainable_only=True):
|
| 230 |
+
if trainable_only:
|
| 231 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 232 |
+
else:
|
| 233 |
+
return sum(p.numel() for p in self.parameters())
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ============== Inference Functions ==============
|
| 237 |
+
|
| 238 |
+
def load_model(model_path: str, device: str = 'cuda'):
|
| 239 |
+
"""Load a saved model (fp16 or fp32) for inference."""
|
| 240 |
+
print(f"Loading model from {model_path}...")
|
| 241 |
+
|
| 242 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
| 243 |
+
config = checkpoint['config']
|
| 244 |
+
|
| 245 |
+
model = DiffusionLLM(config)
|
| 246 |
+
|
| 247 |
+
state_dict = checkpoint['model_state']
|
| 248 |
+
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 249 |
+
model.load_state_dict(state_dict)
|
| 250 |
+
|
| 251 |
+
model = model.to(device)
|
| 252 |
+
model.eval()
|
| 253 |
+
|
| 254 |
+
num_params = model.get_num_params() / 1e6
|
| 255 |
+
file_size = os.path.getsize(model_path) / 1e6
|
| 256 |
+
print(f"✓ Model loaded: {num_params:.1f}M params from {file_size:.1f} MB file")
|
| 257 |
+
|
| 258 |
+
return model, config
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def visualize_diffusion_state(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
|
| 262 |
+
"""Visualize the current state of diffusion generation with multiple blocks.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
mask_blocks: Either a single block tensor (1, block_size) or list of block tensors
|
| 266 |
+
is_masked_list: Either a single mask tensor (1, block_size) or list of mask tensors
|
| 267 |
+
block_colors: List of ANSI color codes for each block. If None, uses defaults.
|
| 268 |
+
"""
|
| 269 |
+
import sys
|
| 270 |
+
import os
|
| 271 |
+
|
| 272 |
+
# Default colors for different blocks (green, cyan, yellow, magenta)
|
| 273 |
+
DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
|
| 274 |
+
MASK_COLOR = '\033[90m' # Gray for masked tokens
|
| 275 |
+
RESET = '\033[0m'
|
| 276 |
+
|
| 277 |
+
# Normalize inputs to lists
|
| 278 |
+
if not isinstance(mask_blocks, list):
|
| 279 |
+
mask_blocks = [mask_blocks]
|
| 280 |
+
is_masked_list = [is_masked_list]
|
| 281 |
+
|
| 282 |
+
if block_colors is None:
|
| 283 |
+
block_colors = DEFAULT_COLORS
|
| 284 |
+
|
| 285 |
+
# Decode context (prompt + previously generated blocks) and replace newlines
|
| 286 |
+
context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
|
| 287 |
+
|
| 288 |
+
# Build visualization for all blocks
|
| 289 |
+
all_blocks_text = []
|
| 290 |
+
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
|
| 291 |
+
color = block_colors[block_idx % len(block_colors)]
|
| 292 |
+
block_tokens = mask_block[0].tolist()
|
| 293 |
+
block_color_tokens = []
|
| 294 |
+
|
| 295 |
+
for i, token_id in enumerate(block_tokens):
|
| 296 |
+
if is_masked[0, i]:
|
| 297 |
+
# Use block-specific color for masked tokens to distinguish blocks
|
| 298 |
+
block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
|
| 299 |
+
else:
|
| 300 |
+
# Decode individual token; use block color for revealed tokens
|
| 301 |
+
token_text = tokenizer.decode([token_id], skip_special_tokens=False)
|
| 302 |
+
block_color_tokens.append(f'{color}{token_text}{RESET}')
|
| 303 |
+
|
| 304 |
+
all_blocks_text.append(''.join(block_color_tokens))
|
| 305 |
+
|
| 306 |
+
# Join all blocks with a subtle separator
|
| 307 |
+
blocks_combined = ''.join(all_blocks_text)
|
| 308 |
+
|
| 309 |
+
# Clear entire terminal
|
| 310 |
+
if clear:
|
| 311 |
+
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
|
| 312 |
+
try:
|
| 313 |
+
os.system(clear_cmd)
|
| 314 |
+
except Exception:
|
| 315 |
+
sys.stdout.write('\r\033[K')
|
| 316 |
+
|
| 317 |
+
# Print legend for parallel blocks
|
| 318 |
+
if len(mask_blocks) > 1:
|
| 319 |
+
legend_parts = []
|
| 320 |
+
for i in range(len(mask_blocks)):
|
| 321 |
+
color = block_colors[i % len(block_colors)]
|
| 322 |
+
legend_parts.append(f'{color}Block {i+1}{RESET}')
|
| 323 |
+
print(f"Generating: {' | '.join(legend_parts)}\n")
|
| 324 |
+
|
| 325 |
+
# Print the full context with colored blocks
|
| 326 |
+
print(f"{context_text}{blocks_combined}", flush=True)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def demo_visualize_truncation():
|
| 330 |
+
"""Demo for visualize_diffusion_state without a full model.
|
| 331 |
+
Simulates streaming output and verifies there is no line duplication when content exceeds terminal width.
|
| 332 |
+
"""
|
| 333 |
+
class MockTokenizer:
|
| 334 |
+
def __init__(self):
|
| 335 |
+
# Map token id to token text (simple ASCII characters and spaces)
|
| 336 |
+
self.vocab = {i: chr(65 + (i % 26)) for i in range(256)}
|
| 337 |
+
self.vocab[32] = ' '
|
| 338 |
+
self.eos_token = '\n'
|
| 339 |
+
self.pad_token = ' '
|
| 340 |
+
|
| 341 |
+
def decode(self, ids, skip_special_tokens=True):
|
| 342 |
+
# ids can be tensor or list
|
| 343 |
+
if isinstance(ids, torch.Tensor):
|
| 344 |
+
ids = ids.tolist()
|
| 345 |
+
if isinstance(ids, (list, tuple)):
|
| 346 |
+
return ''.join(self.vocab.get(int(i) % 256, '?') for i in ids)
|
| 347 |
+
return str(ids)
|
| 348 |
+
|
| 349 |
+
tok = MockTokenizer()
|
| 350 |
+
# Create a long context and a block that's also long
|
| 351 |
+
# Make context exceed terminal width
|
| 352 |
+
term_width = 80
|
| 353 |
+
long_context_ids = torch.tensor([[i % 26 + 65 for i in range(120)]], dtype=torch.long)
|
| 354 |
+
block_size = 32
|
| 355 |
+
mask_block = torch.full((1, block_size), 32, dtype=torch.long) # spaces
|
| 356 |
+
is_masked = torch.ones(1, block_size, dtype=torch.bool)
|
| 357 |
+
for i in range(0, block_size, 3):
|
| 358 |
+
is_masked[0, i] = False
|
| 359 |
+
mask_block[0, i] = 65 + (i % 26)
|
| 360 |
+
|
| 361 |
+
print('\nRunning demo: long prompt + block to test truncation\n')
|
| 362 |
+
for i in range(8):
|
| 363 |
+
visualize_diffusion_state(tok, long_context_ids, [mask_block], [is_masked], ModelConfig(), clear=(i > 0))
|
| 364 |
+
# rotate some tokens to simulate diffusion
|
| 365 |
+
mask_block = torch.roll(mask_block, shifts=1, dims=1)
|
| 366 |
+
time_delay = 0.08
|
| 367 |
+
try:
|
| 368 |
+
import time
|
| 369 |
+
time.sleep(time_delay)
|
| 370 |
+
except Exception:
|
| 371 |
+
pass
|
| 372 |
+
print('\n\nDemo completed.')
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@torch.no_grad()
|
| 376 |
+
def generate_block_diffusion(
|
| 377 |
+
model,
|
| 378 |
+
tokenizer,
|
| 379 |
+
prompt: str,
|
| 380 |
+
steps: int = 16,
|
| 381 |
+
block_size: int = 64,
|
| 382 |
+
max_new_tokens: int = 256,
|
| 383 |
+
device: str = 'cuda',
|
| 384 |
+
temperature: float = 1.0,
|
| 385 |
+
top_k: int = 50,
|
| 386 |
+
top_p: float = 0.9,
|
| 387 |
+
repetition_penalty: float = 1.2,
|
| 388 |
+
no_repeat_ngram_size: int = 3,
|
| 389 |
+
visualize: bool = False,
|
| 390 |
+
parallel_blocks: int = 1, # Number of blocks to generate in parallel
|
| 391 |
+
):
|
| 392 |
+
"""Generate text using block diffusion with proper sampling and repetition control.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
visualize: If True, stream output in real-time showing the diffusion effect.
|
| 396 |
+
parallel_blocks: Number of blocks to generate in parallel (1-4 recommended).
|
| 397 |
+
"""
|
| 398 |
+
import time
|
| 399 |
+
model.eval()
|
| 400 |
+
|
| 401 |
+
prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 402 |
+
|
| 403 |
+
config = model.module.config if hasattr(model, 'module') else model.config
|
| 404 |
+
if hasattr(model, '_orig_mod'):
|
| 405 |
+
config = model._orig_mod.config
|
| 406 |
+
|
| 407 |
+
num_blocks = max_new_tokens // block_size
|
| 408 |
+
parallel_blocks = min(parallel_blocks, num_blocks) # Can't parallelize more than total blocks
|
| 409 |
+
|
| 410 |
+
if not visualize:
|
| 411 |
+
if parallel_blocks > 1:
|
| 412 |
+
print(f"Generating {num_blocks} blocks of {block_size} tokens each ({parallel_blocks} blocks in parallel)...")
|
| 413 |
+
else:
|
| 414 |
+
print(f"Generating {num_blocks} blocks of {block_size} tokens each...")
|
| 415 |
+
else:
|
| 416 |
+
print(f"\n\033[94mStarting diffusion generation...\033[0m\n")
|
| 417 |
+
print(prompt, end='', flush=True)
|
| 418 |
+
|
| 419 |
+
context_ids = prompt_ids
|
| 420 |
+
all_generated_tokens = set(prompt_ids[0].tolist())
|
| 421 |
+
|
| 422 |
+
# Process blocks in batches of parallel_blocks
|
| 423 |
+
blocks_generated = 0
|
| 424 |
+
while blocks_generated < num_blocks:
|
| 425 |
+
# Determine how many blocks to generate this iteration
|
| 426 |
+
current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
|
| 427 |
+
|
| 428 |
+
if current_parallel > 1:
|
| 429 |
+
# Parallel block generation
|
| 430 |
+
generated_blocks = _generate_parallel_blocks(
|
| 431 |
+
model, tokenizer, context_ids, config, device,
|
| 432 |
+
current_parallel, block_size, steps, temperature,
|
| 433 |
+
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
|
| 434 |
+
all_generated_tokens, visualize
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Concatenate all generated blocks to context
|
| 438 |
+
for block in generated_blocks:
|
| 439 |
+
context_ids = torch.cat([context_ids, block], dim=1)
|
| 440 |
+
all_generated_tokens.update(block[0].tolist())
|
| 441 |
+
|
| 442 |
+
if not visualize:
|
| 443 |
+
print(f" Blocks {blocks_generated + 1}-{blocks_generated + current_parallel}/{num_blocks} complete")
|
| 444 |
+
blocks_generated += current_parallel
|
| 445 |
+
else:
|
| 446 |
+
# Single block generation (original logic)
|
| 447 |
+
mask_block, block_token_history = _generate_single_block(
|
| 448 |
+
model, tokenizer, context_ids, config, device,
|
| 449 |
+
block_size, steps, temperature, top_k, top_p,
|
| 450 |
+
repetition_penalty, no_repeat_ngram_size,
|
| 451 |
+
all_generated_tokens, visualize
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
context_ids = torch.cat([context_ids, mask_block], dim=1)
|
| 455 |
+
all_generated_tokens.update(mask_block[0].tolist())
|
| 456 |
+
|
| 457 |
+
if not visualize:
|
| 458 |
+
print(f" Block {blocks_generated + 1}/{num_blocks} complete")
|
| 459 |
+
blocks_generated += 1
|
| 460 |
+
|
| 461 |
+
if visualize:
|
| 462 |
+
# Final newline after visualization
|
| 463 |
+
print("\n")
|
| 464 |
+
|
| 465 |
+
generated_ids = context_ids[0].tolist()
|
| 466 |
+
return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _generate_single_block(
|
| 470 |
+
model, tokenizer, context_ids, config, device,
|
| 471 |
+
block_size, steps, temperature, top_k, top_p,
|
| 472 |
+
repetition_penalty, no_repeat_ngram_size,
|
| 473 |
+
all_generated_tokens, visualize
|
| 474 |
+
):
|
| 475 |
+
"""Generate a single block using diffusion."""
|
| 476 |
+
mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
|
| 477 |
+
is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
|
| 478 |
+
block_token_history = []
|
| 479 |
+
|
| 480 |
+
for step_idx in range(steps):
|
| 481 |
+
full_input = torch.cat([context_ids, mask_block], dim=1)
|
| 482 |
+
attention_mask = torch.ones_like(full_input, dtype=torch.float32)
|
| 483 |
+
|
| 484 |
+
logits, _ = model(full_input, attention_mask=attention_mask)
|
| 485 |
+
block_logits = logits[:, -block_size:, :]
|
| 486 |
+
|
| 487 |
+
block_logits = _apply_sampling_controls(
|
| 488 |
+
block_logits, context_ids, mask_block, is_masked,
|
| 489 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 490 |
+
no_repeat_ngram_size, block_token_history
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
probs = F.softmax(block_logits, dim=-1)
|
| 494 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 495 |
+
probs = probs.clamp(min=1e-10)
|
| 496 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 497 |
+
|
| 498 |
+
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
|
| 499 |
+
sampled_tokens = sampled_tokens.view(1, block_size)
|
| 500 |
+
|
| 501 |
+
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 502 |
+
|
| 503 |
+
tokens_to_unmask = max(1, block_size // steps)
|
| 504 |
+
if step_idx == steps - 1:
|
| 505 |
+
tokens_to_unmask = is_masked.sum().item()
|
| 506 |
+
|
| 507 |
+
if tokens_to_unmask > 0 and is_masked.sum() > 0:
|
| 508 |
+
masked_confidence = confidence.clone()
|
| 509 |
+
masked_confidence[~is_masked] = -1.0
|
| 510 |
+
|
| 511 |
+
num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
|
| 512 |
+
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
|
| 513 |
+
|
| 514 |
+
for idx in top_indices:
|
| 515 |
+
mask_block[0, idx] = sampled_tokens[0, idx]
|
| 516 |
+
is_masked[0, idx] = False
|
| 517 |
+
block_token_history.append(sampled_tokens[0, idx].item())
|
| 518 |
+
all_generated_tokens.add(sampled_tokens[0, idx].item())
|
| 519 |
+
|
| 520 |
+
if visualize:
|
| 521 |
+
visualize_diffusion_state(tokenizer, context_ids, [mask_block], [is_masked], config, clear=(step_idx > 0))
|
| 522 |
+
|
| 523 |
+
return mask_block, block_token_history
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _generate_parallel_blocks(
|
| 527 |
+
model, tokenizer, context_ids, config, device,
|
| 528 |
+
num_parallel, block_size, steps, temperature,
|
| 529 |
+
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
|
| 530 |
+
all_generated_tokens, visualize
|
| 531 |
+
):
|
| 532 |
+
"""Generate multiple blocks in parallel using batched computation.
|
| 533 |
+
|
| 534 |
+
Each block sees all previous blocks in the sequence, maintaining proper order:
|
| 535 |
+
- Block 0: context + [block0]
|
| 536 |
+
- Block 1: context + [block0] + [block1]
|
| 537 |
+
- Block 2: context + [block0] + [block1] + [block2]
|
| 538 |
+
- etc.
|
| 539 |
+
|
| 540 |
+
This ensures sequential coherence while still benefiting from batched computation.
|
| 541 |
+
"""
|
| 542 |
+
batch_size = num_parallel
|
| 543 |
+
context_len = context_ids.shape[1]
|
| 544 |
+
|
| 545 |
+
# Initialize mask blocks for all parallel blocks
|
| 546 |
+
# Shape: (num_parallel, block_size)
|
| 547 |
+
mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
|
| 548 |
+
is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
|
| 549 |
+
block_token_histories = [[] for _ in range(batch_size)]
|
| 550 |
+
|
| 551 |
+
for step_idx in range(steps):
|
| 552 |
+
# Build inputs with proper sequential structure
|
| 553 |
+
# Each batch item has context + all blocks up to and including its own position
|
| 554 |
+
# Block i sees: context + block_0 + block_1 + ... + block_i
|
| 555 |
+
|
| 556 |
+
# Create padded inputs - each batch item has different length
|
| 557 |
+
# We'll pad to the longest sequence (which is the last block)
|
| 558 |
+
max_seq_len = context_len + (num_parallel * block_size)
|
| 559 |
+
|
| 560 |
+
# Build full input for each batch item
|
| 561 |
+
full_inputs = []
|
| 562 |
+
attention_masks = []
|
| 563 |
+
|
| 564 |
+
for b in range(batch_size):
|
| 565 |
+
# This block sees: context + all previous blocks + its own block
|
| 566 |
+
seq_parts = [context_ids[0]] # Start with context
|
| 567 |
+
|
| 568 |
+
# Add all blocks from 0 to b (inclusive)
|
| 569 |
+
for prev_b in range(b + 1):
|
| 570 |
+
seq_parts.append(mask_blocks[prev_b])
|
| 571 |
+
|
| 572 |
+
# Concatenate to form this batch item's input
|
| 573 |
+
batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
|
| 574 |
+
current_len = batch_input.shape[0]
|
| 575 |
+
|
| 576 |
+
# Pad to max_seq_len
|
| 577 |
+
padding_needed = max_seq_len - current_len
|
| 578 |
+
if padding_needed > 0:
|
| 579 |
+
padding = torch.full((padding_needed,), config.pad_token_id, device=device)
|
| 580 |
+
batch_input = torch.cat([batch_input, padding], dim=0)
|
| 581 |
+
|
| 582 |
+
full_inputs.append(batch_input)
|
| 583 |
+
|
| 584 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 585 |
+
attn_mask = torch.zeros(max_seq_len, device=device)
|
| 586 |
+
attn_mask[:current_len] = 1.0
|
| 587 |
+
attention_masks.append(attn_mask)
|
| 588 |
+
|
| 589 |
+
# Stack into batched tensors
|
| 590 |
+
full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
|
| 591 |
+
attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
|
| 592 |
+
|
| 593 |
+
# Single forward pass for all blocks
|
| 594 |
+
logits, _ = model(full_input, attention_mask=attention_mask)
|
| 595 |
+
|
| 596 |
+
# Extract logits for each block's position
|
| 597 |
+
# Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
|
| 598 |
+
block_logits_list = []
|
| 599 |
+
for b in range(batch_size):
|
| 600 |
+
start_pos = context_len + (b * block_size)
|
| 601 |
+
end_pos = start_pos + block_size
|
| 602 |
+
block_logits_list.append(logits[b, start_pos:end_pos, :])
|
| 603 |
+
|
| 604 |
+
block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
|
| 605 |
+
|
| 606 |
+
# Apply sampling controls per batch item
|
| 607 |
+
for b in range(batch_size):
|
| 608 |
+
# Build context that includes previous blocks for repetition penalty
|
| 609 |
+
extended_context = context_ids
|
| 610 |
+
if b > 0:
|
| 611 |
+
prev_blocks = torch.cat([mask_blocks[pb:pb+1] for pb in range(b)], dim=1)
|
| 612 |
+
extended_context = torch.cat([context_ids, prev_blocks], dim=1)
|
| 613 |
+
|
| 614 |
+
block_logits[b:b+1] = _apply_sampling_controls(
|
| 615 |
+
block_logits[b:b+1],
|
| 616 |
+
extended_context,
|
| 617 |
+
mask_blocks[b:b+1],
|
| 618 |
+
is_masked[b:b+1],
|
| 619 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 620 |
+
no_repeat_ngram_size, block_token_histories[b]
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
probs = F.softmax(block_logits, dim=-1)
|
| 624 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 625 |
+
probs = probs.clamp(min=1e-10)
|
| 626 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 627 |
+
|
| 628 |
+
# Sample for all batches
|
| 629 |
+
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
|
| 630 |
+
sampled_tokens = sampled_tokens.view(batch_size, block_size)
|
| 631 |
+
|
| 632 |
+
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 633 |
+
|
| 634 |
+
tokens_to_unmask = max(1, block_size // steps)
|
| 635 |
+
if step_idx == steps - 1:
|
| 636 |
+
tokens_to_unmask = block_size # Unmask all remaining
|
| 637 |
+
|
| 638 |
+
# Unmask for each batch item
|
| 639 |
+
for b in range(batch_size):
|
| 640 |
+
if is_masked[b].sum() > 0:
|
| 641 |
+
masked_confidence = confidence[b].clone()
|
| 642 |
+
masked_confidence[~is_masked[b]] = -1.0
|
| 643 |
+
|
| 644 |
+
num_to_unmask = min(tokens_to_unmask, is_masked[b].sum().item())
|
| 645 |
+
if num_to_unmask > 0:
|
| 646 |
+
_, top_indices = torch.topk(masked_confidence, num_to_unmask)
|
| 647 |
+
|
| 648 |
+
for idx in top_indices:
|
| 649 |
+
mask_blocks[b, idx] = sampled_tokens[b, idx]
|
| 650 |
+
is_masked[b, idx] = False
|
| 651 |
+
block_token_histories[b].append(sampled_tokens[b, idx].item())
|
| 652 |
+
|
| 653 |
+
if visualize:
|
| 654 |
+
# Visualize all blocks with different colors
|
| 655 |
+
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
|
| 656 |
+
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
|
| 657 |
+
visualize_diffusion_state(
|
| 658 |
+
tokenizer, context_ids, block_list, is_masked_list,
|
| 659 |
+
config, clear=(step_idx > 0)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
# Return list of generated blocks
|
| 663 |
+
return [mask_blocks[b:b+1] for b in range(batch_size)]
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def _apply_sampling_controls(
|
| 667 |
+
block_logits, context_ids, mask_block, is_masked,
|
| 668 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 669 |
+
no_repeat_ngram_size, block_token_history
|
| 670 |
+
):
|
| 671 |
+
"""Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
|
| 672 |
+
if repetition_penalty != 1.0:
|
| 673 |
+
seen_tokens = set(context_ids[0].tolist())
|
| 674 |
+
for i in range(mask_block.shape[1]):
|
| 675 |
+
if not is_masked[0, i]:
|
| 676 |
+
seen_tokens.add(mask_block[0, i].item())
|
| 677 |
+
|
| 678 |
+
for token_id in seen_tokens:
|
| 679 |
+
if token_id < block_logits.shape[-1]:
|
| 680 |
+
if block_logits[0, :, token_id].mean() > 0:
|
| 681 |
+
block_logits[:, :, token_id] /= repetition_penalty
|
| 682 |
+
else:
|
| 683 |
+
block_logits[:, :, token_id] *= repetition_penalty
|
| 684 |
+
|
| 685 |
+
block_logits = block_logits / temperature
|
| 686 |
+
|
| 687 |
+
if top_k > 0:
|
| 688 |
+
top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
|
| 689 |
+
block_logits = torch.full_like(block_logits, float('-inf'))
|
| 690 |
+
block_logits.scatter_(-1, top_k_indices, top_k_logits)
|
| 691 |
+
|
| 692 |
+
if top_p < 1.0:
|
| 693 |
+
sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
|
| 694 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 695 |
+
|
| 696 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 697 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 698 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 699 |
+
|
| 700 |
+
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
| 701 |
+
block_logits[indices_to_remove] = float('-inf')
|
| 702 |
+
|
| 703 |
+
if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
|
| 704 |
+
recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size-1):])
|
| 705 |
+
full_history = context_ids[0].tolist() + block_token_history
|
| 706 |
+
for i in range(len(full_history) - no_repeat_ngram_size + 1):
|
| 707 |
+
if tuple(full_history[i:i+no_repeat_ngram_size-1]) == recent_ngram:
|
| 708 |
+
blocked_token = full_history[i + no_repeat_ngram_size - 1]
|
| 709 |
+
if blocked_token < block_logits.shape[-1]:
|
| 710 |
+
block_logits[:, :, blocked_token] = float('-inf')
|
| 711 |
+
|
| 712 |
+
# Safety check: if all logits are -inf, reset to uniform distribution
|
| 713 |
+
all_inf_mask = torch.isinf(block_logits).all(dim=-1)
|
| 714 |
+
if all_inf_mask.any():
|
| 715 |
+
block_logits[all_inf_mask] = 0.0
|
| 716 |
+
|
| 717 |
+
return block_logits
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
# ============== Main Entry Point ==============
|
| 721 |
+
|
| 722 |
+
def main():
|
| 723 |
+
"""Main inference function."""
|
| 724 |
+
# Configuration
|
| 725 |
+
model_path = "../extra-final-boss/checkpoints/model_fp32.pt"
|
| 726 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 727 |
+
|
| 728 |
+
print(f"Using device: {device}")
|
| 729 |
+
|
| 730 |
+
# Allow a quick demo mode to test visualization without loading the model
|
| 731 |
+
import sys
|
| 732 |
+
if len(sys.argv) > 1 and sys.argv[1] == 'demo':
|
| 733 |
+
demo_visualize_truncation()
|
| 734 |
+
return
|
| 735 |
+
|
| 736 |
+
# Load tokenizer
|
| 737 |
+
print("Loading tokenizer...")
|
| 738 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 739 |
+
if tokenizer.pad_token is None:
|
| 740 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 741 |
+
|
| 742 |
+
# Load model
|
| 743 |
+
model, config = load_model(model_path, device)
|
| 744 |
+
|
| 745 |
+
# Generate text
|
| 746 |
+
print("\n" + "=" * 50)
|
| 747 |
+
print("Text Generation")
|
| 748 |
+
print("=" * 50)
|
| 749 |
+
|
| 750 |
+
prompt = "Barrack Obama was born in "
|
| 751 |
+
print(f"Prompt: {prompt}\n")
|
| 752 |
+
|
| 753 |
+
# Set visualize=True to see real-time diffusion effect
|
| 754 |
+
visualize = True
|
| 755 |
+
parallel_blocks = 4 # Generate 2-4 blocks in parallel for speedup
|
| 756 |
+
|
| 757 |
+
generated = generate_block_diffusion(
|
| 758 |
+
model,
|
| 759 |
+
tokenizer,
|
| 760 |
+
prompt=prompt,
|
| 761 |
+
steps=64,
|
| 762 |
+
block_size=64,
|
| 763 |
+
max_new_tokens=512,
|
| 764 |
+
device=device,
|
| 765 |
+
temperature=1,
|
| 766 |
+
top_k=40,
|
| 767 |
+
top_p=0.9,
|
| 768 |
+
repetition_penalty=1.3,
|
| 769 |
+
no_repeat_ngram_size=3,
|
| 770 |
+
visualize=visualize,
|
| 771 |
+
parallel_blocks=parallel_blocks,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
print(f"\nGenerated text:\n{generated}")
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
if __name__ == "__main__":
|
| 778 |
+
main()
|
infer-chat.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib.util
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
# Tracks how many lines the last visualization printed so we can overwrite it
|
| 12 |
+
_visualize_last_lines = 0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def try_import_infer_base(base_path: str):
|
| 16 |
+
"""Dynamically import `infer-base.py` as a module and return it, or None on failure."""
|
| 17 |
+
if not os.path.exists(base_path):
|
| 18 |
+
return None
|
| 19 |
+
try:
|
| 20 |
+
spec = importlib.util.spec_from_file_location("infer_base", base_path)
|
| 21 |
+
module = importlib.util.module_from_spec(spec)
|
| 22 |
+
spec.loader.exec_module(module)
|
| 23 |
+
return module
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Warning: failed to import {base_path}: {e}")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_finetuned_model(model_path: str, device: str = 'cuda'):
|
| 30 |
+
"""Load a saved fine-tuned model for inference."""
|
| 31 |
+
print(f"Loading model from {model_path}...")
|
| 32 |
+
|
| 33 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
| 34 |
+
config = checkpoint['config']
|
| 35 |
+
|
| 36 |
+
# Create model
|
| 37 |
+
model = DiffusionLLM(config)
|
| 38 |
+
|
| 39 |
+
# Load weights
|
| 40 |
+
state_dict = checkpoint['model_state']
|
| 41 |
+
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 42 |
+
model.load_state_dict(state_dict)
|
| 43 |
+
|
| 44 |
+
model = model.to(device)
|
| 45 |
+
model.eval()
|
| 46 |
+
|
| 47 |
+
num_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 48 |
+
print(f"✓ Loaded model: {num_params:.1f}M parameters")
|
| 49 |
+
|
| 50 |
+
# Print training info if available
|
| 51 |
+
if 'step' in checkpoint:
|
| 52 |
+
print(f" Trained for {checkpoint['step']} steps")
|
| 53 |
+
if 'best_val_loss' in checkpoint:
|
| 54 |
+
print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 55 |
+
|
| 56 |
+
return model, config
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def generate_block_diffusion(
|
| 61 |
+
model,
|
| 62 |
+
tokenizer,
|
| 63 |
+
prompt: str,
|
| 64 |
+
steps: int = 32,
|
| 65 |
+
block_size: int = 32,
|
| 66 |
+
max_new_tokens: int = 128,
|
| 67 |
+
device: str = 'cuda',
|
| 68 |
+
temperature: float = 0.8,
|
| 69 |
+
top_k: int = 50,
|
| 70 |
+
top_p: float = 0.9,
|
| 71 |
+
repetition_penalty: float = 1.2,
|
| 72 |
+
no_repeat_ngram_size: int = 3,
|
| 73 |
+
verbose: bool = True,
|
| 74 |
+
visualize_fn=None,
|
| 75 |
+
parallel_blocks: int = 1,
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Generate text using block diffusion with sampling controls.
|
| 79 |
+
|
| 80 |
+
If `visualize_fn` is provided it will be called as:
|
| 81 |
+
visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=True)
|
| 82 |
+
|
| 83 |
+
Returns the decoded generated string (including prompt).
|
| 84 |
+
"""
|
| 85 |
+
model.eval()
|
| 86 |
+
|
| 87 |
+
# Encode prompt
|
| 88 |
+
prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 89 |
+
|
| 90 |
+
# Get model config
|
| 91 |
+
config = model.module.config if hasattr(model, 'module') else getattr(model, 'config', None)
|
| 92 |
+
if hasattr(model, '_orig_mod'):
|
| 93 |
+
config = model._orig_mod.config
|
| 94 |
+
|
| 95 |
+
if config is None:
|
| 96 |
+
raise RuntimeError("Could not determine model config")
|
| 97 |
+
|
| 98 |
+
num_blocks = max_new_tokens // block_size
|
| 99 |
+
parallel_blocks = min(parallel_blocks, num_blocks)
|
| 100 |
+
|
| 101 |
+
if verbose:
|
| 102 |
+
print(f"Generating {num_blocks} blocks of {block_size} tokens ({max_new_tokens} max_new_tokens)\n")
|
| 103 |
+
|
| 104 |
+
context_ids = prompt_ids
|
| 105 |
+
all_generated_tokens = set(prompt_ids[0].tolist())
|
| 106 |
+
|
| 107 |
+
blocks_generated = 0
|
| 108 |
+
while blocks_generated < num_blocks:
|
| 109 |
+
current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
|
| 110 |
+
|
| 111 |
+
if current_parallel > 1:
|
| 112 |
+
new_blocks = _generate_parallel_blocks(
|
| 113 |
+
model, tokenizer, context_ids, config, device,
|
| 114 |
+
current_parallel, block_size, steps, temperature,
|
| 115 |
+
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
|
| 116 |
+
all_generated_tokens, visualize_fn
|
| 117 |
+
)
|
| 118 |
+
for block in new_blocks:
|
| 119 |
+
context_ids = torch.cat([context_ids, block], dim=1)
|
| 120 |
+
blocks_generated += 1
|
| 121 |
+
else:
|
| 122 |
+
mask_block, block_token_history = _generate_single_block(
|
| 123 |
+
model, tokenizer, context_ids, config, device,
|
| 124 |
+
block_size, steps, temperature, top_k, top_p,
|
| 125 |
+
repetition_penalty, no_repeat_ngram_size,
|
| 126 |
+
all_generated_tokens, visualize_fn
|
| 127 |
+
)
|
| 128 |
+
context_ids = torch.cat([context_ids, mask_block], dim=1)
|
| 129 |
+
blocks_generated += 1
|
| 130 |
+
|
| 131 |
+
generated_ids = context_ids[0].tolist()
|
| 132 |
+
return tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _apply_sampling_controls(
|
| 136 |
+
block_logits, context_ids, mask_block, is_masked,
|
| 137 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 138 |
+
no_repeat_ngram_size, block_token_history
|
| 139 |
+
):
|
| 140 |
+
"""Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
|
| 141 |
+
if repetition_penalty != 1.0:
|
| 142 |
+
seen_tokens = set(context_ids[0].tolist())
|
| 143 |
+
for i in range(mask_block.shape[1]):
|
| 144 |
+
if not is_masked[0, i]:
|
| 145 |
+
seen_tokens.add(mask_block[0, i].item())
|
| 146 |
+
|
| 147 |
+
for token_id in seen_tokens:
|
| 148 |
+
if token_id < block_logits.shape[-1]:
|
| 149 |
+
avg = block_logits[0, :, token_id].mean()
|
| 150 |
+
if avg > 0:
|
| 151 |
+
block_logits[:, :, token_id] /= repetition_penalty
|
| 152 |
+
else:
|
| 153 |
+
block_logits[:, :, token_id] *= repetition_penalty
|
| 154 |
+
|
| 155 |
+
block_logits = block_logits / temperature
|
| 156 |
+
|
| 157 |
+
if top_k > 0:
|
| 158 |
+
k = min(top_k, block_logits.size(-1))
|
| 159 |
+
top_k_logits, top_k_indices = torch.topk(block_logits, k, dim=-1)
|
| 160 |
+
filtered = torch.full_like(block_logits, float('-inf'))
|
| 161 |
+
filtered.scatter_(-1, top_k_indices, top_k_logits)
|
| 162 |
+
block_logits = filtered
|
| 163 |
+
|
| 164 |
+
if top_p < 1.0:
|
| 165 |
+
sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
|
| 166 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 167 |
+
|
| 168 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 169 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 170 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 171 |
+
|
| 172 |
+
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
| 173 |
+
block_logits[indices_to_remove] = float('-inf')
|
| 174 |
+
|
| 175 |
+
if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
|
| 176 |
+
recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size - 1):])
|
| 177 |
+
full_history = context_ids[0].tolist() + block_token_history
|
| 178 |
+
for i in range(len(full_history) - no_repeat_ngram_size + 1):
|
| 179 |
+
if tuple(full_history[i:i + no_repeat_ngram_size - 1]) == recent_ngram:
|
| 180 |
+
blocked_token = full_history[i + no_repeat_ngram_size - 1]
|
| 181 |
+
if blocked_token < block_logits.shape[-1]:
|
| 182 |
+
block_logits[:, :, blocked_token] = float('-inf')
|
| 183 |
+
|
| 184 |
+
# Safety: reset if all logits are -inf
|
| 185 |
+
all_inf_mask = torch.isinf(block_logits).all(dim=-1)
|
| 186 |
+
if all_inf_mask.any():
|
| 187 |
+
block_logits[all_inf_mask] = 0.0
|
| 188 |
+
|
| 189 |
+
return block_logits
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _generate_single_block(
|
| 193 |
+
model, tokenizer, context_ids, config, device,
|
| 194 |
+
block_size, steps, temperature, top_k, top_p,
|
| 195 |
+
repetition_penalty, no_repeat_ngram_size,
|
| 196 |
+
all_generated_tokens, visualize_fn=None
|
| 197 |
+
):
|
| 198 |
+
"""Generate a single block using diffusion."""
|
| 199 |
+
mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
|
| 200 |
+
is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
|
| 201 |
+
block_token_history = []
|
| 202 |
+
|
| 203 |
+
for step_idx in range(steps):
|
| 204 |
+
full_input = torch.cat([context_ids, mask_block], dim=1)
|
| 205 |
+
attention_mask = torch.ones_like(full_input, dtype=torch.float32)
|
| 206 |
+
|
| 207 |
+
logits, _ = model(full_input, attention_mask=attention_mask)
|
| 208 |
+
block_logits = logits[:, -block_size:, :]
|
| 209 |
+
|
| 210 |
+
block_logits = _apply_sampling_controls(
|
| 211 |
+
block_logits, context_ids, mask_block, is_masked,
|
| 212 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 213 |
+
no_repeat_ngram_size, block_token_history
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
probs = F.softmax(block_logits, dim=-1)
|
| 217 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 218 |
+
probs = probs.clamp(min=1e-10)
|
| 219 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 220 |
+
|
| 221 |
+
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
|
| 222 |
+
sampled_tokens = sampled_tokens.view(1, block_size)
|
| 223 |
+
|
| 224 |
+
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 225 |
+
|
| 226 |
+
tokens_to_unmask = max(1, block_size // steps)
|
| 227 |
+
if step_idx == steps - 1:
|
| 228 |
+
tokens_to_unmask = int(is_masked.sum().item())
|
| 229 |
+
|
| 230 |
+
if tokens_to_unmask > 0 and is_masked.sum() > 0:
|
| 231 |
+
masked_confidence = confidence.clone()
|
| 232 |
+
masked_confidence[~is_masked] = -1.0
|
| 233 |
+
|
| 234 |
+
num_to_unmask = min(int(tokens_to_unmask), int(is_masked.sum().item()))
|
| 235 |
+
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
|
| 236 |
+
|
| 237 |
+
for idx in top_indices:
|
| 238 |
+
idx = int(idx.item())
|
| 239 |
+
mask_block[0, idx] = sampled_tokens[0, idx]
|
| 240 |
+
is_masked[0, idx] = False
|
| 241 |
+
block_token_history.append(sampled_tokens[0, idx].item())
|
| 242 |
+
all_generated_tokens.add(sampled_tokens[0, idx].item())
|
| 243 |
+
|
| 244 |
+
if callable(visualize_fn):
|
| 245 |
+
try:
|
| 246 |
+
visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
|
| 247 |
+
except Exception:
|
| 248 |
+
pass
|
| 249 |
+
elif visualize_fn:
|
| 250 |
+
visualize_diffusion_state_local(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
|
| 251 |
+
|
| 252 |
+
return mask_block, block_token_history
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _generate_parallel_blocks(
|
| 256 |
+
model, tokenizer, context_ids, config, device,
|
| 257 |
+
num_parallel, block_size, steps, temperature,
|
| 258 |
+
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
|
| 259 |
+
all_generated_tokens, visualize_fn=None
|
| 260 |
+
):
|
| 261 |
+
"""Generate multiple blocks in parallel using batched computation.
|
| 262 |
+
|
| 263 |
+
Each block sees all previous blocks in the sequence, maintaining proper order:
|
| 264 |
+
- Block 0: context + [block0]
|
| 265 |
+
- Block 1: context + [block0] + [block1]
|
| 266 |
+
- Block 2: context + [block0] + [block1] + [block2]
|
| 267 |
+
- etc.
|
| 268 |
+
|
| 269 |
+
This ensures sequential coherence while still benefiting from batched computation.
|
| 270 |
+
"""
|
| 271 |
+
batch_size = num_parallel
|
| 272 |
+
context_len = context_ids.shape[1]
|
| 273 |
+
|
| 274 |
+
# Initialize mask blocks for all parallel blocks
|
| 275 |
+
# Shape: (num_parallel, block_size)
|
| 276 |
+
mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
|
| 277 |
+
is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
|
| 278 |
+
block_token_histories = [[] for _ in range(batch_size)]
|
| 279 |
+
|
| 280 |
+
for step_idx in range(steps):
|
| 281 |
+
# Build inputs with proper sequential structure
|
| 282 |
+
# Each batch item has context + all previous blocks + its own block
|
| 283 |
+
# Block i sees: context + block_0 + block_1 + ... + block_i
|
| 284 |
+
|
| 285 |
+
# Create padded inputs - each batch item has different length
|
| 286 |
+
# We'll pad to the longest sequence (which is the last block)
|
| 287 |
+
max_seq_len = context_len + (num_parallel * block_size)
|
| 288 |
+
|
| 289 |
+
# Build full input for each batch item
|
| 290 |
+
full_inputs = []
|
| 291 |
+
attention_masks = []
|
| 292 |
+
|
| 293 |
+
for b in range(batch_size):
|
| 294 |
+
# This block sees: context + all previous blocks + its own block
|
| 295 |
+
seq_parts = [context_ids[0]] # Start with context
|
| 296 |
+
|
| 297 |
+
# Add all blocks from 0 to b (inclusive)
|
| 298 |
+
for prev_b in range(b + 1):
|
| 299 |
+
seq_parts.append(mask_blocks[prev_b])
|
| 300 |
+
|
| 301 |
+
# Concatenate to form this batch item's input
|
| 302 |
+
batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
|
| 303 |
+
current_len = batch_input.shape[0]
|
| 304 |
+
|
| 305 |
+
# Pad to max_seq_len
|
| 306 |
+
padding_needed = max_seq_len - current_len
|
| 307 |
+
if padding_needed > 0:
|
| 308 |
+
pad_token = config.pad_token_id if config.pad_token_id is not None else 0
|
| 309 |
+
padding = torch.full((padding_needed,), pad_token, device=device)
|
| 310 |
+
batch_input = torch.cat([batch_input, padding], dim=0)
|
| 311 |
+
|
| 312 |
+
full_inputs.append(batch_input)
|
| 313 |
+
|
| 314 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 315 |
+
attn_mask = torch.zeros(max_seq_len, device=device)
|
| 316 |
+
attn_mask[:current_len] = 1.0
|
| 317 |
+
attention_masks.append(attn_mask)
|
| 318 |
+
|
| 319 |
+
# Stack into batched tensors
|
| 320 |
+
full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
|
| 321 |
+
attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
|
| 322 |
+
|
| 323 |
+
# Single forward pass for all blocks
|
| 324 |
+
logits, _ = model(full_input, attention_mask=attention_mask)
|
| 325 |
+
|
| 326 |
+
# Extract logits for each block's position
|
| 327 |
+
# Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
|
| 328 |
+
block_logits_list = []
|
| 329 |
+
for b in range(batch_size):
|
| 330 |
+
start_pos = context_len + (b * block_size)
|
| 331 |
+
end_pos = start_pos + block_size
|
| 332 |
+
block_logits_list.append(logits[b, start_pos:end_pos, :])
|
| 333 |
+
|
| 334 |
+
block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
|
| 335 |
+
|
| 336 |
+
# Apply sampling controls per batch item
|
| 337 |
+
for b in range(batch_size):
|
| 338 |
+
# Build context that includes previous blocks for repetition penalty
|
| 339 |
+
extended_context = context_ids
|
| 340 |
+
if b > 0:
|
| 341 |
+
prev_blocks = mask_blocks[:b]
|
| 342 |
+
extended_context = torch.cat([context_ids] + [prev_blocks.view(1, -1)], dim=1)
|
| 343 |
+
|
| 344 |
+
block_logits[b:b+1] = _apply_sampling_controls(
|
| 345 |
+
block_logits[b:b+1],
|
| 346 |
+
extended_context,
|
| 347 |
+
mask_blocks[b:b+1],
|
| 348 |
+
is_masked[b:b+1],
|
| 349 |
+
repetition_penalty, temperature, top_k, top_p,
|
| 350 |
+
no_repeat_ngram_size, block_token_histories[b]
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
probs = F.softmax(block_logits, dim=-1)
|
| 354 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 355 |
+
probs = probs.clamp(min=1e-10)
|
| 356 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 357 |
+
|
| 358 |
+
# Sample for all batches
|
| 359 |
+
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
|
| 360 |
+
sampled_tokens = sampled_tokens.view(batch_size, block_size)
|
| 361 |
+
|
| 362 |
+
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 363 |
+
|
| 364 |
+
tokens_to_unmask = max(1, block_size // steps)
|
| 365 |
+
if step_idx == steps - 1:
|
| 366 |
+
tokens_to_unmask = block_size # Unmask all remaining
|
| 367 |
+
|
| 368 |
+
# Unmask for each batch item
|
| 369 |
+
for b in range(batch_size):
|
| 370 |
+
if is_masked[b].sum() > 0:
|
| 371 |
+
masked_confidence = confidence[b]
|
| 372 |
+
masked_confidence = masked_confidence.clone()
|
| 373 |
+
masked_confidence[~is_masked[b]] = -1.0
|
| 374 |
+
|
| 375 |
+
num_to_unmask = min(int(tokens_to_unmask), int(is_masked[b].sum().item()))
|
| 376 |
+
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
|
| 377 |
+
|
| 378 |
+
for idx in top_indices:
|
| 379 |
+
idx = int(idx.item())
|
| 380 |
+
mask_blocks[b, idx] = sampled_tokens[b, idx]
|
| 381 |
+
is_masked[b, idx] = False
|
| 382 |
+
block_token_histories[b].append(sampled_tokens[b, idx].item())
|
| 383 |
+
all_generated_tokens.add(sampled_tokens[b, idx].item())
|
| 384 |
+
|
| 385 |
+
if callable(visualize_fn):
|
| 386 |
+
try:
|
| 387 |
+
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
|
| 388 |
+
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
|
| 389 |
+
visualize_fn(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
|
| 390 |
+
except Exception:
|
| 391 |
+
pass
|
| 392 |
+
elif visualize_fn:
|
| 393 |
+
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
|
| 394 |
+
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
|
| 395 |
+
visualize_diffusion_state_local(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
|
| 396 |
+
|
| 397 |
+
# Return list of generated blocks
|
| 398 |
+
return [mask_blocks[b:b+1] for b in range(batch_size)]
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def chat(model, tokenizer, instruction: str, parallel_blocks: int = 1, **kwargs):
|
| 402 |
+
"""Simple chat interface."""
|
| 403 |
+
device = next(model.parameters()).device
|
| 404 |
+
|
| 405 |
+
prompt = format_instruct_prompt(instruction)
|
| 406 |
+
|
| 407 |
+
generated = generate_block_diffusion(
|
| 408 |
+
model,
|
| 409 |
+
tokenizer,
|
| 410 |
+
prompt=prompt,
|
| 411 |
+
device=device,
|
| 412 |
+
parallel_blocks=parallel_blocks,
|
| 413 |
+
**kwargs
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Extract all assistant responses using ChatML tags
|
| 417 |
+
start_tag = "<|im_start|>assistant"
|
| 418 |
+
end_tag = "<|im_end|>"
|
| 419 |
+
resp_parts = []
|
| 420 |
+
pos = 0
|
| 421 |
+
while True:
|
| 422 |
+
start_idx = generated.find(start_tag, pos)
|
| 423 |
+
if start_idx == -1:
|
| 424 |
+
break
|
| 425 |
+
start_idx += len(start_tag)
|
| 426 |
+
end_idx = generated.find(end_tag, start_idx)
|
| 427 |
+
if end_idx == -1:
|
| 428 |
+
resp_parts.append(generated[start_idx:].strip())
|
| 429 |
+
break
|
| 430 |
+
resp_parts.append(generated[start_idx:end_idx].strip())
|
| 431 |
+
pos = end_idx + len(end_tag)
|
| 432 |
+
|
| 433 |
+
if resp_parts:
|
| 434 |
+
resp = "\n\n".join(p for p in resp_parts if p)
|
| 435 |
+
else:
|
| 436 |
+
# Fallback if no assistant tags found
|
| 437 |
+
resp = generated.replace("<|im_start|>assistant", "").replace("<|im_end|>", "").strip()
|
| 438 |
+
|
| 439 |
+
return generated, resp
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def format_instruct_prompt(instruction: str) -> str:
|
| 443 |
+
"""Format instruction using a simple ChatML-like template."""
|
| 444 |
+
return (
|
| 445 |
+
"<|im_start|>system\n"
|
| 446 |
+
"Answer this question truthfully<|im_end|>\n"
|
| 447 |
+
"<|im_start|>user\n"
|
| 448 |
+
f"{instruction}\n"
|
| 449 |
+
"<|im_end|>\n"
|
| 450 |
+
"<|im_start|>assistant\n"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def visualize_diffusion_state_local(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
|
| 455 |
+
"""Local visualization copied from infer-base.py to ensure consistent terminal output."""
|
| 456 |
+
import sys
|
| 457 |
+
import os
|
| 458 |
+
|
| 459 |
+
# Default colors for different blocks (green, cyan, yellow, magenta)
|
| 460 |
+
DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
|
| 461 |
+
MASK_COLOR = '\033[90m' # Gray for masked tokens
|
| 462 |
+
RESET = '\033[0m'
|
| 463 |
+
|
| 464 |
+
# Normalize inputs to lists
|
| 465 |
+
if not isinstance(mask_blocks, list):
|
| 466 |
+
mask_blocks = [mask_blocks]
|
| 467 |
+
is_masked_list = [is_masked_list]
|
| 468 |
+
|
| 469 |
+
if block_colors is None:
|
| 470 |
+
block_colors = DEFAULT_COLORS
|
| 471 |
+
|
| 472 |
+
# Decode context (prompt + previously generated blocks) and replace newlines
|
| 473 |
+
try:
|
| 474 |
+
context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
|
| 475 |
+
except Exception:
|
| 476 |
+
# Fallback to str
|
| 477 |
+
context_text = str(context_ids[0].tolist())
|
| 478 |
+
|
| 479 |
+
# Build visualization for all blocks
|
| 480 |
+
all_blocks_text = []
|
| 481 |
+
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
|
| 482 |
+
color = block_colors[block_idx % len(block_colors)]
|
| 483 |
+
block_tokens = mask_block[0].tolist()
|
| 484 |
+
block_color_tokens = []
|
| 485 |
+
|
| 486 |
+
for i, token_id in enumerate(block_tokens):
|
| 487 |
+
if is_masked[0, i]:
|
| 488 |
+
# Use block-specific color for masked tokens to distinguish blocks
|
| 489 |
+
block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
|
| 490 |
+
else:
|
| 491 |
+
# Decode individual token; use block color for revealed tokens
|
| 492 |
+
try:
|
| 493 |
+
token_text = tokenizer.decode([token_id], skip_special_tokens=False)
|
| 494 |
+
except Exception:
|
| 495 |
+
token_text = str(int(token_id))
|
| 496 |
+
block_color_tokens.append(f'{color}{token_text}{RESET}')
|
| 497 |
+
|
| 498 |
+
all_blocks_text.append(''.join(block_color_tokens))
|
| 499 |
+
|
| 500 |
+
# Join all blocks with a subtle separator
|
| 501 |
+
blocks_combined = ''.join(all_blocks_text)
|
| 502 |
+
|
| 503 |
+
# Overwrite previous visualization area (if any) by moving cursor up and clearing lines.
|
| 504 |
+
# This prevents accumulation of repeated frames in terminals like VSCode integrated terminal.
|
| 505 |
+
global _visualize_last_lines
|
| 506 |
+
if clear and _visualize_last_lines > 0:
|
| 507 |
+
try:
|
| 508 |
+
# Move cursor up to the start of the previous block
|
| 509 |
+
sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
|
| 510 |
+
# Clear each line that was previously printed
|
| 511 |
+
for _ in range(_visualize_last_lines):
|
| 512 |
+
sys.stdout.write('\x1b[2K') # Erase entire line
|
| 513 |
+
sys.stdout.write('\x1b[1B') # Move cursor down one line
|
| 514 |
+
# Move cursor back to the top of cleared region
|
| 515 |
+
sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
|
| 516 |
+
sys.stdout.flush()
|
| 517 |
+
except Exception:
|
| 518 |
+
# Fallback to whole-screen clear
|
| 519 |
+
try:
|
| 520 |
+
sys.stdout.write('\x1b[2J\x1b[H')
|
| 521 |
+
sys.stdout.flush()
|
| 522 |
+
except Exception:
|
| 523 |
+
try:
|
| 524 |
+
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
|
| 525 |
+
os.system(clear_cmd)
|
| 526 |
+
except Exception:
|
| 527 |
+
sys.stdout.write('\r\033[K')
|
| 528 |
+
sys.stdout.flush()
|
| 529 |
+
elif clear:
|
| 530 |
+
# No previous region to overwrite; do a simple ANSI clear to start fresh
|
| 531 |
+
try:
|
| 532 |
+
sys.stdout.write('\x1b[2J\x1b[H')
|
| 533 |
+
sys.stdout.flush()
|
| 534 |
+
except Exception:
|
| 535 |
+
try:
|
| 536 |
+
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
|
| 537 |
+
os.system(clear_cmd)
|
| 538 |
+
except Exception:
|
| 539 |
+
sys.stdout.write('\r\033[K')
|
| 540 |
+
sys.stdout.flush()
|
| 541 |
+
|
| 542 |
+
# Print legend for parallel blocks
|
| 543 |
+
if len(mask_blocks) > 1:
|
| 544 |
+
legend_parts = []
|
| 545 |
+
for i in range(len(mask_blocks)):
|
| 546 |
+
color = block_colors[i % len(block_colors)]
|
| 547 |
+
legend_parts.append(f'{color}Block {i+1}{RESET}')
|
| 548 |
+
print(f"Generating: {' | '.join(legend_parts)}\n")
|
| 549 |
+
|
| 550 |
+
# Print the full context with colored blocks
|
| 551 |
+
# Ensure trailing newline so subsequent clears have predictable behavior
|
| 552 |
+
out_text = f"{context_text}{blocks_combined}\n"
|
| 553 |
+
try:
|
| 554 |
+
sys.stdout.write(out_text)
|
| 555 |
+
sys.stdout.flush()
|
| 556 |
+
except Exception:
|
| 557 |
+
print(out_text, flush=True)
|
| 558 |
+
|
| 559 |
+
# Update last-lines counter so next frame can overwrite this one
|
| 560 |
+
try:
|
| 561 |
+
_visualize_last_lines = out_text.count('\n') + (1 if len(mask_blocks) > 1 else 0) + 1
|
| 562 |
+
except Exception:
|
| 563 |
+
_visualize_last_lines = out_text.count('\n')
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def main():
|
| 567 |
+
base_path = os.path.join(os.path.dirname(__file__), "infer-base.py")
|
| 568 |
+
base_mod = try_import_infer_base(base_path)
|
| 569 |
+
|
| 570 |
+
if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'):
|
| 571 |
+
raise RuntimeError("DiffusionLLM not found in infer-base.py")
|
| 572 |
+
|
| 573 |
+
DiffusionLLM = base_mod.DiffusionLLM
|
| 574 |
+
|
| 575 |
+
# Workaround for torch.load pickling
|
| 576 |
+
try:
|
| 577 |
+
main_mod = sys.modules.get('__main__')
|
| 578 |
+
if main_mod is not None:
|
| 579 |
+
if hasattr(base_mod, 'ModelConfig'):
|
| 580 |
+
setattr(main_mod, 'ModelConfig', base_mod.ModelConfig)
|
| 581 |
+
setattr(main_mod, 'DiffusionLLM', DiffusionLLM)
|
| 582 |
+
except Exception:
|
| 583 |
+
pass
|
| 584 |
+
|
| 585 |
+
parser = argparse.ArgumentParser()
|
| 586 |
+
parser.add_argument("--model", type=str, default="./checkpoints/model_fp32.pt", help="Path to model checkpoint")
|
| 587 |
+
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-0.5B", help="Tokenizer model id or path")
|
| 588 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 589 |
+
parser.add_argument("--visualize", action="store_true", default=False, help="Enable visualization during generation")
|
| 590 |
+
parser.add_argument("--steps", type=int, default=64)
|
| 591 |
+
parser.add_argument("--block_size", type=int, default=128)
|
| 592 |
+
parser.add_argument("--max_new_tokens", type=int, default=128)
|
| 593 |
+
parser.add_argument("--parallel_blocks", type=int, default=1, help="Number of blocks to generate in parallel")
|
| 594 |
+
args = parser.parse_args()
|
| 595 |
+
|
| 596 |
+
device = torch.device(args.device)
|
| 597 |
+
print(f"Using device: {device}")
|
| 598 |
+
|
| 599 |
+
# Load tokenizer
|
| 600 |
+
print("Loading tokenizer...")
|
| 601 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
| 602 |
+
if tokenizer.pad_token is None:
|
| 603 |
+
# set pad token if not present
|
| 604 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 605 |
+
|
| 606 |
+
# Load model
|
| 607 |
+
best_model_path = "checkpoints/best_model.pt"
|
| 608 |
+
if os.path.exists(best_model_path):
|
| 609 |
+
print("Loading best model...")
|
| 610 |
+
model, config = load_finetuned_model(best_model_path, device)
|
| 611 |
+
else:
|
| 612 |
+
model, config = load_finetuned_model(args.model, device)
|
| 613 |
+
|
| 614 |
+
# Use the local visualization implementation for consistency
|
| 615 |
+
visualize_fn = None
|
| 616 |
+
if args.visualize:
|
| 617 |
+
visualize_fn = visualize_diffusion_state_local
|
| 618 |
+
|
| 619 |
+
print("Ready. Type a message and press Enter (empty line to quit).\n")
|
| 620 |
+
|
| 621 |
+
while True:
|
| 622 |
+
try:
|
| 623 |
+
user_input = input("User: ").strip()
|
| 624 |
+
except (EOFError, KeyboardInterrupt):
|
| 625 |
+
print("\nExiting.")
|
| 626 |
+
break
|
| 627 |
+
if user_input == "":
|
| 628 |
+
print("Goodbye.")
|
| 629 |
+
break
|
| 630 |
+
|
| 631 |
+
raw_output, response = chat(
|
| 632 |
+
model,
|
| 633 |
+
tokenizer,
|
| 634 |
+
user_input,
|
| 635 |
+
steps=args.steps,
|
| 636 |
+
block_size=args.block_size,
|
| 637 |
+
max_new_tokens=args.max_new_tokens,
|
| 638 |
+
temperature=0.8,
|
| 639 |
+
top_k=50,
|
| 640 |
+
top_p=0.9,
|
| 641 |
+
repetition_penalty=1.2,
|
| 642 |
+
no_repeat_ngram_size=3,
|
| 643 |
+
verbose=False,
|
| 644 |
+
visualize_fn=visualize_fn,
|
| 645 |
+
parallel_blocks=args.parallel_blocks,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
print("\nRaw Output:\n")
|
| 649 |
+
print(raw_output)
|
| 650 |
+
print("\nAssistant:\n")
|
| 651 |
+
print(response)
|
| 652 |
+
print("\n" + ("=" * 60) + "\n")
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
if __name__ == "__main__":
|
| 656 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask>=2.0
|
| 2 |
+
transformers>=4.0.0
|
| 3 |
+
torch
|
| 4 |
+
sentencepiece
|
| 5 |
+
flask_cors
|
static/ai.mp4
ADDED
|
Binary file (77.4 kB). View file
|
|
|
static/index.html
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 6 |
+
<title>Diffusion LLM – Chat</title>
|
| 7 |
+
<!-- Tailwind CDN -->
|
| 8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 9 |
+
<!-- Inter Font -->
|
| 10 |
+
<link rel="preconnect" href="https://fonts.gstatic.com">
|
| 11 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 12 |
+
<style>
|
| 13 |
+
html,body{font-family:Inter,ui-sans-serif,system-ui,-apple-system,'Segoe UI',Roboto,'Helvetica Neue',Arial}
|
| 14 |
+
/* custome slider */
|
| 15 |
+
input[type=range] {
|
| 16 |
+
-webkit-appearance: none;
|
| 17 |
+
width: 100%;
|
| 18 |
+
height: 6px;
|
| 19 |
+
border-radius: 5px;
|
| 20 |
+
background: #e0e0e0;
|
| 21 |
+
outline: none;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
input[type=range]::-webkit-slider-thumb {
|
| 25 |
+
-webkit-appearance: none;
|
| 26 |
+
appearance: none;
|
| 27 |
+
width: 16px;
|
| 28 |
+
height: 16px;
|
| 29 |
+
border-radius: 50%;
|
| 30 |
+
background: #6b21a8;
|
| 31 |
+
cursor: pointer;
|
| 32 |
+
box-shadow: 0 0 2px rgba(0,0,0,0.2);
|
| 33 |
+
transition: background 0.3s ease;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
input[type=range]::-webkit-slider-thumb:hover {
|
| 37 |
+
background: #7c2dbe;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
input[type=range]::-moz-range-thumb {
|
| 41 |
+
width: 16px;
|
| 42 |
+
height: 16px;
|
| 43 |
+
border-radius: 50%;
|
| 44 |
+
background: #6b21a8;
|
| 45 |
+
cursor: pointer;
|
| 46 |
+
box-shadow: 0 0 2px rgba(0,0,0,0.2);
|
| 47 |
+
transition: background 0.3s ease;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
input[type=range]::-moz-range-thumb:hover {
|
| 51 |
+
background: #7c2dbe;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
</style>
|
| 57 |
+
</head>
|
| 58 |
+
<body>
|
| 59 |
+
<div class="h-screen w-screen flex items-start gap-6 p-8 bg-gradient-to-br from-purple-50 to-purple-100">
|
| 60 |
+
|
| 61 |
+
<!-- Sidebar -->
|
| 62 |
+
<aside id="sidebar" class="w-64 h-full bg-white/90 flex flex-col items-center justify-between backdrop-blur-sm rounded-xl p-5 shadow-sm border border-gray-100">
|
| 63 |
+
<div>
|
| 64 |
+
<div class="flex items-center gap-3 mb-4">
|
| 65 |
+
<div class="w-9 h-9 rounded-md bg-gradient-to-br from-purple-200 to-purple-300"></div>
|
| 66 |
+
<div>
|
| 67 |
+
<div class="text-sm font-semibold text-slate-900">Cortex</div>
|
| 68 |
+
<div class="text-xs text-slate-500">Diffusion LLM</div>
|
| 69 |
+
</div>
|
| 70 |
+
</div>
|
| 71 |
+
|
| 72 |
+
<button id="new-chat" class="w-full inline-flex items-center justify-center gap-2 bg-black text-white py-2 rounded-full text-sm font-medium shadow-sm mb-4">+ New chat</button>
|
| 73 |
+
|
| 74 |
+
<nav class="w-full min-w-48 flex-1 flex flex-col gap-2 text-sm" id="chat-list" aria-label="Saved chats">
|
| 75 |
+
<!-- Chat items are dynamically injected here by JavaScript -->
|
| 76 |
+
</nav>
|
| 77 |
+
</div>
|
| 78 |
+
|
| 79 |
+
<div class="mt-6 text-xs text-slate-500 ">Signed in as <strong class="text-slate-700">you@example.com</strong></div>
|
| 80 |
+
</aside>
|
| 81 |
+
|
| 82 |
+
<!-- Main content -->
|
| 83 |
+
<main class="flex-1 flex items-center justify-center w-full h-full">
|
| 84 |
+
<div class="w-full bg-white rounded-2xl p-7 shadow-lg border border-gray-100 flex flex-col h-full">
|
| 85 |
+
|
| 86 |
+
<header class="flex items-center justify-between mb-3 border-b border-gray-200 pb-3">
|
| 87 |
+
<div class="flex items-center gap-3">
|
| 88 |
+
<button id="btn-toggle-sidebar" aria-label="Toggle sidebar" class="inline-flex items-center justify-center p-2 rounded-md bg-white shadow sm:hidden">☰</button>
|
| 89 |
+
<h1 id="app-title" class="text-lg font-semibold">Diffusion LLM Chat</h1>
|
| 90 |
+
</div>
|
| 91 |
+
|
| 92 |
+
<div class="flex items-center gap-3">
|
| 93 |
+
<button id="btn-load" class="bg-black text-white px-3 py-2 rounded-md text-sm font-medium">Load Model</button>
|
| 94 |
+
<span id="load-status" class="text-sm text-slate-500">Not loaded</span>
|
| 95 |
+
</div>
|
| 96 |
+
</header>
|
| 97 |
+
|
| 98 |
+
<section class="flex-1 flex flex-col overflow-hidden">
|
| 99 |
+
<div id="welcome" class="text-center py-6">
|
| 100 |
+
<div class="mx-auto w-24 h-24">
|
| 101 |
+
<video src="/static/ai.mp4" alt="Assistant Avatar" autoplay loop muted class="w-full h-full scale-[2] object-cover mix-blend-multiply" style="filter: hue-rotate(45deg)" />
|
| 102 |
+
</div>
|
| 103 |
+
<p class="mt-4 text-purple-600 font-medium">Hello, Jagrat Patel</p>
|
| 104 |
+
<h2 class="mt-2 text-2xl font-semibold text-slate-900">How can I assist you today?</h2>
|
| 105 |
+
|
| 106 |
+
<div class="mt-6 flex items-center justify-center gap-4 flex-wrap">
|
| 107 |
+
<button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Deeper Research <span class="block text-xs text-slate-500 mt-1">Ask for long-form, research-backed answers.</span></button>
|
| 108 |
+
<button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Saved prompts <span class="block text-xs text-slate-500 mt-1">Quickly reuse your favorite prompts.</span></button>
|
| 109 |
+
<button class="bg-white px-5 py-3 rounded-lg shadow-sm border text-sm hover:scale-105 hover:bg-purple-50 hover:border-purple-300 transition-all">Check Facts <span class="block text-xs text-slate-500 mt-1">Compare GDPR vs CCPA differences.</span></button>
|
| 110 |
+
</div>
|
| 111 |
+
</div>
|
| 112 |
+
|
| 113 |
+
<div id="chat" class="hidden flex-1 overflow-auto px-2 py-3" role="log" aria-live="polite">
|
| 114 |
+
<!-- messages injected here -->
|
| 115 |
+
</div>
|
| 116 |
+
</section>
|
| 117 |
+
|
| 118 |
+
<form id="prompt-form" class="mt-4 bg-white p-4 rounded-xl shadow-inner border border-gray-100" aria-label="Chat prompt">
|
| 119 |
+
<div class="mb-4 flex flex-row gap-4 flex-wrap items-center justify-between">
|
| 120 |
+
<div class="flex items-center gap-4 w-[24%]">
|
| 121 |
+
<label for="steps" class="text-sm font-medium text-slate-700">Steps:</label>
|
| 122 |
+
<input type="range" id="steps" min="1" max="100" value="64" class="flex-1">
|
| 123 |
+
<span id="steps-value" class="text-sm text-slate-500 w-8">64</span>
|
| 124 |
+
</div>
|
| 125 |
+
<div class="flex items-center gap-4 w-[24%]">
|
| 126 |
+
<label for="block_size" class="text-sm font-medium text-slate-700">Block Size:</label>
|
| 127 |
+
<input type="range" id="block_size" min="8" max="256" value="128" step="8" class="flex-1">
|
| 128 |
+
<span id="block_size-value" class="text-sm text-slate-500 w-8">128</span>
|
| 129 |
+
</div>
|
| 130 |
+
<div class="flex items-center gap-4 w-[24%]">
|
| 131 |
+
<label for="max_new_tokens" class="text-sm font-medium text-slate-700">Max New Tokens:</label>
|
| 132 |
+
<input type="range" id="max_new_tokens" min="32" max="1024" value="128" step="32" class="flex-1">
|
| 133 |
+
<span id="max_new_tokens-value" class="text-sm text-slate-500 w-8">128</span>
|
| 134 |
+
</div>
|
| 135 |
+
<div class="flex items-center gap-4 w-[24%]">
|
| 136 |
+
<label for="parallel_blocks" class="text-sm font-medium text-slate-700">Parallel Blocks:</label>
|
| 137 |
+
<input type="range" id="parallel_blocks" min="1" max="4" value="1" step="1" class="flex-1">
|
| 138 |
+
<span id="parallel_blocks-value" class="text-sm text-slate-500 w-8">1</span>
|
| 139 |
+
</div>
|
| 140 |
+
</div>
|
| 141 |
+
<div class="flex gap-3">
|
| 142 |
+
<textarea id="prompt" class="flex-1 resize-y rounded-lg border border-gray-200 p-3 text-sm focus:outline-none focus:ring-[1px] focus:ring-purple-500 focus:border-purple-500" placeholder="Ask me anything..." rows="2" aria-label="Message input"></textarea>
|
| 143 |
+
<div class="flex flex-col justify-between">
|
| 144 |
+
<button type="submit" id="btn-send" class="bg-black text-white px-4 py-2 rounded-md">Send</button>
|
| 145 |
+
</div>
|
| 146 |
+
</div>
|
| 147 |
+
</form>
|
| 148 |
+
|
| 149 |
+
<div class="mt-4 text-center text-xs text-slate-500">Model served by Flask. See README for run instructions.</div>
|
| 150 |
+
</div>
|
| 151 |
+
</main>
|
| 152 |
+
</div>
|
| 153 |
+
|
| 154 |
+
<script src="/static/main.js"></script>
|
| 155 |
+
</body>
|
| 156 |
+
</html>
|
static/main.js
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Global state
|
| 2 |
+
let isModelLoaded = false;
|
| 3 |
+
|
| 4 |
+
// DOM Elements
|
| 5 |
+
const els = {
|
| 6 |
+
chat: document.getElementById("chat"),
|
| 7 |
+
promptForm: document.getElementById("prompt-form"),
|
| 8 |
+
promptInput: document.getElementById("prompt"),
|
| 9 |
+
loadBtn: document.getElementById("btn-load"),
|
| 10 |
+
testStreamBtn: document.getElementById("btn-test-stream"),
|
| 11 |
+
status: document.getElementById("load-status"),
|
| 12 |
+
sidebar: document.getElementById("sidebar"),
|
| 13 |
+
sidebarToggle: document.getElementById("btn-toggle-sidebar"),
|
| 14 |
+
chatList: document.getElementById("chat-list"),
|
| 15 |
+
newChatBtn: document.getElementById("new-chat"),
|
| 16 |
+
sendBtn: document.getElementById("btn-send"),
|
| 17 |
+
steps: document.getElementById("steps"),
|
| 18 |
+
block_size: document.getElementById("block_size"),
|
| 19 |
+
max_new_tokens: document.getElementById("max_new_tokens"),
|
| 20 |
+
parallel_blocks: document.getElementById("parallel_blocks"),
|
| 21 |
+
stepsValue: document.getElementById("steps-value"),
|
| 22 |
+
block_sizeValue: document.getElementById("block_size-value"),
|
| 23 |
+
max_new_tokensValue: document.getElementById("max_new_tokens-value"),
|
| 24 |
+
parallel_blocksValue: document.getElementById("parallel_blocks-value"),
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
// Update slider values
|
| 28 |
+
els.steps.addEventListener("input", () => {
|
| 29 |
+
els.stepsValue.textContent = els.steps.value;
|
| 30 |
+
});
|
| 31 |
+
els.block_size.addEventListener("input", () => {
|
| 32 |
+
els.block_sizeValue.textContent = els.block_size.value;
|
| 33 |
+
});
|
| 34 |
+
els.max_new_tokens.addEventListener("input", () => {
|
| 35 |
+
els.max_new_tokensValue.textContent = els.max_new_tokens.value;
|
| 36 |
+
});
|
| 37 |
+
els.parallel_blocks.addEventListener("input", () => {
|
| 38 |
+
els.parallel_blocksValue.textContent = els.parallel_blocks.value;
|
| 39 |
+
});
|
| 40 |
+
|
| 41 |
+
// --- Logic ---
|
| 42 |
+
|
| 43 |
+
async function checkLoadStatus() {
|
| 44 |
+
try {
|
| 45 |
+
const res = await fetch("/api/load", {
|
| 46 |
+
method: "POST",
|
| 47 |
+
headers: { "Content-Type": "application/json" },
|
| 48 |
+
body: JSON.stringify({ check_only: true }),
|
| 49 |
+
});
|
| 50 |
+
|
| 51 |
+
if (res.ok) {
|
| 52 |
+
const data = await res.json();
|
| 53 |
+
if (data.loaded) {
|
| 54 |
+
isModelLoaded = true;
|
| 55 |
+
els.status.textContent = "Ready";
|
| 56 |
+
els.status.className = "text-sm text-green-600 font-medium";
|
| 57 |
+
els.loadBtn.style.display = 'none';
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
} catch (e) {
|
| 61 |
+
console.log("Model check failed:", e);
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
els.loadBtn.addEventListener("click", async () => {
|
| 66 |
+
els.loadBtn.disabled = true;
|
| 67 |
+
els.status.textContent = "Loading Model (this may take time)...";
|
| 68 |
+
els.status.className = "text-sm text-yellow-600 font-medium";
|
| 69 |
+
|
| 70 |
+
try {
|
| 71 |
+
const res = await fetch("/api/load", {
|
| 72 |
+
method: "POST",
|
| 73 |
+
headers: { "Content-Type": "application/json" },
|
| 74 |
+
body: JSON.stringify({ check_only: false }),
|
| 75 |
+
});
|
| 76 |
+
const data = await res.json();
|
| 77 |
+
|
| 78 |
+
if (res.ok) {
|
| 79 |
+
isModelLoaded = true;
|
| 80 |
+
els.status.textContent = "Model Loaded";
|
| 81 |
+
els.status.className = "text-sm text-green-600 font-medium";
|
| 82 |
+
els.loadBtn.style.display = 'none';
|
| 83 |
+
} else {
|
| 84 |
+
throw new Error(data.message || "Load failed");
|
| 85 |
+
}
|
| 86 |
+
} catch (e) {
|
| 87 |
+
els.status.textContent = "Error Loading";
|
| 88 |
+
els.status.className = "text-sm text-red-500";
|
| 89 |
+
alert("Error: " + e.message);
|
| 90 |
+
} finally {
|
| 91 |
+
els.loadBtn.disabled = false;
|
| 92 |
+
}
|
| 93 |
+
});
|
| 94 |
+
|
| 95 |
+
els.promptForm.addEventListener("submit", async (e) => {
|
| 96 |
+
e.preventDefault();
|
| 97 |
+
|
| 98 |
+
const text = els.promptInput.value.trim();
|
| 99 |
+
if (!text) return;
|
| 100 |
+
|
| 101 |
+
// UI Updates
|
| 102 |
+
addMessage("user", text);
|
| 103 |
+
els.promptInput.value = "";
|
| 104 |
+
|
| 105 |
+
// Create Assistant Bubble
|
| 106 |
+
const assistantBubble = addMessage("assistant", "");
|
| 107 |
+
const contentPre = assistantBubble.querySelector(".content");
|
| 108 |
+
const textContent = contentPre.querySelector(".text-content");
|
| 109 |
+
|
| 110 |
+
const visualizationDiv = document.createElement("div");
|
| 111 |
+
visualizationDiv.className = "visualization mb-2 font-mono text-xs";
|
| 112 |
+
|
| 113 |
+
// Loading spinner (SVG)
|
| 114 |
+
const spinner = document.createElement("div");
|
| 115 |
+
spinner.className = "flex items-center gap-2 text-slate-400";
|
| 116 |
+
spinner.innerHTML = `
|
| 117 |
+
<svg class="animate-spin h-4 w-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
| 118 |
+
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
| 119 |
+
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v4a4 4 0 00-4 4H4z"></path>
|
| 120 |
+
</svg>
|
| 121 |
+
<span class="text-xs">Generating...</span>
|
| 122 |
+
`;
|
| 123 |
+
visualizationDiv.appendChild(spinner);
|
| 124 |
+
|
| 125 |
+
contentPre.insertBefore(visualizationDiv, textContent);
|
| 126 |
+
|
| 127 |
+
// Disable send button
|
| 128 |
+
els.sendBtn.disabled = true;
|
| 129 |
+
els.sendBtn.textContent = "Generating...";
|
| 130 |
+
els.promptInput.disabled = true;
|
| 131 |
+
|
| 132 |
+
// Generate Request with Streaming
|
| 133 |
+
try {
|
| 134 |
+
const res = await fetch("/api/generate-stream", {
|
| 135 |
+
method: "POST",
|
| 136 |
+
headers: { "Content-Type": "application/json" },
|
| 137 |
+
body: JSON.stringify({
|
| 138 |
+
instruction: text,
|
| 139 |
+
steps: parseInt(els.steps.value),
|
| 140 |
+
block_size: parseInt(els.block_size.value),
|
| 141 |
+
max_new_tokens: parseInt(els.max_new_tokens.value),
|
| 142 |
+
parallel_blocks: parseInt(els.parallel_blocks.value),
|
| 143 |
+
}),
|
| 144 |
+
});
|
| 145 |
+
|
| 146 |
+
if (!res.ok) {
|
| 147 |
+
throw new Error(`Server Error ${res.status}`);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
const reader = res.body.getReader();
|
| 151 |
+
const decoder = new TextDecoder();
|
| 152 |
+
let buffer = "";
|
| 153 |
+
|
| 154 |
+
while (true) {
|
| 155 |
+
const { done, value } = await reader.read();
|
| 156 |
+
|
| 157 |
+
if (done) break;
|
| 158 |
+
|
| 159 |
+
buffer += decoder.decode(value, { stream: true });
|
| 160 |
+
const lines = buffer.split("\n");
|
| 161 |
+
buffer = lines.pop(); // Keep incomplete line in buffer
|
| 162 |
+
|
| 163 |
+
for (const line of lines) {
|
| 164 |
+
if (line.startsWith("data: ")) {
|
| 165 |
+
const jsonStr = line.slice(6);
|
| 166 |
+
if (jsonStr.trim()) {
|
| 167 |
+
try {
|
| 168 |
+
const data = JSON.parse(jsonStr);
|
| 169 |
+
handleStreamEvent(data, visualizationDiv, textContent);
|
| 170 |
+
} catch (e) {
|
| 171 |
+
console.error("Failed to parse SSE data:", e);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
} catch (error) {
|
| 178 |
+
if (textContent) textContent.textContent = `Error: ${error.message}`;
|
| 179 |
+
} finally {
|
| 180 |
+
els.sendBtn.disabled = false;
|
| 181 |
+
els.sendBtn.textContent = "Send";
|
| 182 |
+
els.promptInput.disabled = false;
|
| 183 |
+
}
|
| 184 |
+
});
|
| 185 |
+
|
| 186 |
+
function handleStreamEvent(data, visualizationDiv, textContent) {
|
| 187 |
+
if (data.type === "start") {
|
| 188 |
+
textContent.textContent = "";
|
| 189 |
+
} else if (data.type === "update") {
|
| 190 |
+
// Render visualization
|
| 191 |
+
renderVisualization(data.data, visualizationDiv);
|
| 192 |
+
scrollToBottom();
|
| 193 |
+
} else if (data.type === "complete") {
|
| 194 |
+
// Clear visualization and show final response
|
| 195 |
+
visualizationDiv.innerHTML = "";
|
| 196 |
+
textContent.textContent = data.response || "No response";
|
| 197 |
+
scrollToBottom();
|
| 198 |
+
} else if (data.type === "error") {
|
| 199 |
+
textContent.textContent = `Error: ${data.error}`;
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
function renderVisualization(vizData, container) {
|
| 204 |
+
// Clear previous content
|
| 205 |
+
container.innerHTML = "";
|
| 206 |
+
|
| 207 |
+
// Show context
|
| 208 |
+
const contextDiv = document.createElement("div");
|
| 209 |
+
contextDiv.className = "text-slate-600 mb-1";
|
| 210 |
+
contextDiv.textContent = vizData.context;
|
| 211 |
+
container.appendChild(contextDiv);
|
| 212 |
+
|
| 213 |
+
// Show blocks
|
| 214 |
+
const blocksDiv = document.createElement("div");
|
| 215 |
+
blocksDiv.classList.add("flex", "flex-wrap", "gap-0");
|
| 216 |
+
|
| 217 |
+
const blockColors = ["text-green-600", "text-cyan-600", "text-yellow-600", "text-purple-600"];
|
| 218 |
+
|
| 219 |
+
vizData.blocks.forEach((block, blockIdx) => {
|
| 220 |
+
const blockSpan = document.createElement("span");
|
| 221 |
+
blockSpan.className = blockColors[blockIdx % blockColors.length];
|
| 222 |
+
|
| 223 |
+
block.tokens.forEach((token) => {
|
| 224 |
+
if (token.type === "masked") {
|
| 225 |
+
const maskedSpan = document.createElement("span");
|
| 226 |
+
maskedSpan.className = blockColors[blockIdx % blockColors.length];
|
| 227 |
+
maskedSpan.innerText = token.text + " ";
|
| 228 |
+
blockSpan.appendChild(maskedSpan);
|
| 229 |
+
} else {
|
| 230 |
+
const textNode = document.createTextNode(token.text);
|
| 231 |
+
blockSpan.appendChild(textNode);
|
| 232 |
+
}
|
| 233 |
+
});
|
| 234 |
+
|
| 235 |
+
blocksDiv.appendChild(blockSpan);
|
| 236 |
+
});
|
| 237 |
+
|
| 238 |
+
container.appendChild(blocksDiv);
|
| 239 |
+
|
| 240 |
+
// Add legend if multiple blocks
|
| 241 |
+
if (vizData.num_blocks > 1) {
|
| 242 |
+
const legendDiv = document.createElement("div");
|
| 243 |
+
legendDiv.className = "text-xs text-slate-500 mt-1";
|
| 244 |
+
const legends = [];
|
| 245 |
+
for (let i = 0; i < vizData.num_blocks; i++) {
|
| 246 |
+
legends.push(`Block ${i + 1}`);
|
| 247 |
+
}
|
| 248 |
+
legendDiv.textContent = `Generating: ${legends.join(" | ")}`;
|
| 249 |
+
container.appendChild(legendDiv);
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// --- UI Helpers ---
|
| 254 |
+
|
| 255 |
+
function addMessage(role, text) {
|
| 256 |
+
const wrapper = document.createElement("div");
|
| 257 |
+
wrapper.className = "mb-6 max-w-[100%] flex flex-col";
|
| 258 |
+
|
| 259 |
+
const bubble = document.createElement("div");
|
| 260 |
+
const isUser = role === "user";
|
| 261 |
+
|
| 262 |
+
bubble.className = isUser ? "self-end bg-slate-900 text-white p-4 rounded-2xl rounded-tr-sm max-w-[85%]" : "self-start bg-white border border-gray-200 text-slate-800 p-4 rounded-2xl rounded-tl-sm max-w-[65%] whitespace-pre-wrap overflow-x-auto shadow-sm flex flex-wrap";
|
| 263 |
+
|
| 264 |
+
// Main Content container that holds the response text
|
| 265 |
+
const pre = document.createElement("div");
|
| 266 |
+
pre.className = "content whitespace-pre-wrap font-sans text-sm leading-relaxed";
|
| 267 |
+
|
| 268 |
+
// The actual text content
|
| 269 |
+
const textSpan = document.createElement("span");
|
| 270 |
+
textSpan.className = "text-content";
|
| 271 |
+
textSpan.textContent = text;
|
| 272 |
+
|
| 273 |
+
pre.appendChild(textSpan);
|
| 274 |
+
bubble.appendChild(pre);
|
| 275 |
+
wrapper.appendChild(bubble);
|
| 276 |
+
els.chat.appendChild(wrapper);
|
| 277 |
+
scrollToBottom();
|
| 278 |
+
|
| 279 |
+
// Hide welcome screen
|
| 280 |
+
const welcome = document.getElementById("welcome");
|
| 281 |
+
if (welcome) {
|
| 282 |
+
welcome.classList.add("hidden");
|
| 283 |
+
}
|
| 284 |
+
els.chat.classList.remove("hidden");
|
| 285 |
+
|
| 286 |
+
return bubble;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
function scrollToBottom() {
|
| 290 |
+
els.chat.scrollTop = els.chat.scrollHeight;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// Sidebar Toggle
|
| 294 |
+
els.sidebarToggle.addEventListener("click", () => {
|
| 295 |
+
els.sidebar.classList.toggle("-translate-x-full");
|
| 296 |
+
});
|
| 297 |
+
|
| 298 |
+
// New Chat Button
|
| 299 |
+
els.newChatBtn.addEventListener("click", () => {
|
| 300 |
+
// Clear chat
|
| 301 |
+
els.chat.innerHTML = "";
|
| 302 |
+
els.chat.classList.add("hidden");
|
| 303 |
+
|
| 304 |
+
// Show welcome screen
|
| 305 |
+
const welcome = document.getElementById("welcome");
|
| 306 |
+
if (welcome) {
|
| 307 |
+
welcome.classList.remove("hidden");
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// Clear input
|
| 311 |
+
els.promptInput.value = "";
|
| 312 |
+
});
|
| 313 |
+
|
| 314 |
+
// Initialize
|
| 315 |
+
(async () => {
|
| 316 |
+
await checkLoadStatus();
|
| 317 |
+
if (!isModelLoaded) {
|
| 318 |
+
els.loadBtn.disabled = true;
|
| 319 |
+
els.status.textContent = "Loading Model (this may take time)...";
|
| 320 |
+
els.status.className = "text-sm text-yellow-600 font-medium";
|
| 321 |
+
|
| 322 |
+
try {
|
| 323 |
+
const res = await fetch("/api/load", {
|
| 324 |
+
method: "POST",
|
| 325 |
+
headers: { "Content-Type": "application/json" },
|
| 326 |
+
body: JSON.stringify({ check_only: false }),
|
| 327 |
+
});
|
| 328 |
+
const data = await res.json();
|
| 329 |
+
|
| 330 |
+
if (res.ok) {
|
| 331 |
+
isModelLoaded = true;
|
| 332 |
+
els.status.textContent = "Model Loaded";
|
| 333 |
+
els.status.className = "text-sm text-green-600 font-medium";
|
| 334 |
+
els.loadBtn.style.display = 'none';
|
| 335 |
+
} else {
|
| 336 |
+
throw new Error(data.message || "Load failed");
|
| 337 |
+
}
|
| 338 |
+
} catch (e) {
|
| 339 |
+
els.status.textContent = "Error Loading";
|
| 340 |
+
els.status.className = "text-sm text-red-500";
|
| 341 |
+
} finally {
|
| 342 |
+
els.loadBtn.disabled = false;
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
})();
|
| 346 |
+
els.chat.classList.add("hidden");
|