diffusionGPT / app.py
thejagstudio's picture
Update app.py
527cb39 verified
raw
history blame
36.7 kB
# import os
# import sys
# import json
# import time
# import importlib.util
# from pathlib import Path
# from flask import Flask, request, jsonify, Response, stream_with_context
# from flask_cors import CORS
# import torch
# from transformers import AutoTokenizer
# app = Flask(__name__, static_folder='static', static_url_path='/static')
# CORS(app)
# # Global state
# model = None
# tokenizer = None
# config = None
# device = None
# DiffusionLLM = None
# chat_function = None
# def find_file(filename, search_dirs=None):
# """Find a file in current directory or parent directories."""
# if search_dirs is None:
# search_dirs = [
# os.path.dirname(__file__), # Current directory
# os.path.dirname(os.path.dirname(__file__)), # Parent directory
# os.getcwd(), # Working directory
# ]
# for directory in search_dirs:
# filepath = os.path.join(directory, filename)
# if os.path.exists(filepath):
# print(f"Found {filename} at: {filepath}")
# return filepath
# return None
# def try_import_module(filepath, module_name):
# """Dynamically import a Python file as a module."""
# if not filepath or not os.path.exists(filepath):
# return None
# try:
# # Add the directory to sys.path
# module_dir = os.path.dirname(filepath)
# if module_dir not in sys.path:
# sys.path.insert(0, module_dir)
# spec = importlib.util.spec_from_file_location(module_name, filepath)
# if spec is None:
# print(f"Could not create spec for {filepath}")
# return None
# module = importlib.util.module_from_spec(spec)
# sys.modules[module_name] = module
# spec.loader.exec_module(module)
# print(f"Successfully imported {module_name} from {filepath}")
# return module
# except Exception as e:
# print(f"Error importing {filepath}: {e}")
# import traceback
# traceback.print_exc()
# return None
# def load_model_internal():
# """Load the model and tokenizer."""
# global model, tokenizer, config, device, DiffusionLLM, chat_function
# if model is not None:
# return True
# try:
# print("=" * 60)
# print("Starting model loading process...")
# print("=" * 60)
# # Find and import infer-base.py
# base_path = find_file("infer-base.py")
# if base_path is None:
# raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.")
# print(f"\nImporting infer-base.py from: {base_path}")
# base_mod = try_import_module(base_path, "infer_base")
# if base_mod is None:
# raise RuntimeError("Failed to import infer-base.py")
# # Check for DiffusionLLM class
# if not hasattr(base_mod, 'DiffusionLLM'):
# print("Available attributes in infer_base:", dir(base_mod))
# raise RuntimeError("DiffusionLLM class not found in infer-base.py")
# DiffusionLLM = base_mod.DiffusionLLM
# print("βœ“ Successfully loaded DiffusionLLM class")
# # Find and import infer-chat.py
# chat_path = find_file("infer-chat.py")
# if chat_path is None:
# raise RuntimeError("Could not find infer-chat.py")
# print(f"\nImporting infer-chat.py from: {chat_path}")
# chat_mod = try_import_module(chat_path, "infer_chat")
# if chat_mod is None or not hasattr(chat_mod, 'chat'):
# raise RuntimeError("Failed to import chat function from infer-chat.py")
# chat_function = chat_mod.chat
# print("βœ“ Successfully loaded chat function")
# # Setup pickling workaround for torch.load
# try:
# if hasattr(base_mod, 'ModelConfig'):
# sys.modules['__main__'].ModelConfig = base_mod.ModelConfig
# sys.modules['__main__'].DiffusionLLM = DiffusionLLM
# print("βœ“ Configured pickle support for model loading")
# except Exception as e:
# print(f"Warning: Could not setup pickle workaround: {e}")
# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"\nβœ“ Using device: {device}")
# # Load tokenizer
# print("\nLoading tokenizer...")
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
# print("βœ“ Tokenizer loaded")
# # Find model checkpoint
# checkpoint_dirs = [
# "checkpoints",
# "../checkpoints",
# "./checkpoints",
# os.path.join(os.path.dirname(__file__), "checkpoints"),
# os.path.join(os.path.dirname(__file__), "../checkpoints"),
# ]
# model_path = None
# for checkpoint_dir in checkpoint_dirs:
# best_path = os.path.join(checkpoint_dir, "best_model.pt")
# fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
# if os.path.exists(best_path):
# model_path = best_path
# break
# elif os.path.exists(fp32_path):
# model_path = fp32_path
# break
# if model_path is None:
# raise RuntimeError(
# "Could not find model checkpoint. Looking for:\n"
# " - checkpoints/best_model.pt\n"
# " - checkpoints/model_fp32.pt\n"
# f"Searched directories: {checkpoint_dirs}"
# )
# print(f"\nβœ“ Found model checkpoint: {model_path}")
# print("Loading model weights (this may take a minute)...")
# # Load model
# checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# config = checkpoint['config']
# print("Creating model...")
# model = DiffusionLLM(config)
# print("Loading state dict...")
# state_dict = checkpoint['model_state']
# state_dict = {k: v.float() for k, v in state_dict.items()}
# model.load_state_dict(state_dict)
# model = model.to(device)
# model.eval()
# num_params = sum(p.numel() for p in model.parameters()) / 1e6
# print(f"\n{'=' * 60}")
# print(f"βœ“βœ“βœ“ MODEL LOADED SUCCESSFULLY βœ“βœ“βœ“")
# print(f"{'=' * 60}")
# print(f"Parameters: {num_params:.1f}M")
# if 'step' in checkpoint:
# print(f"Training steps: {checkpoint['step']}")
# if 'best_val_loss' in checkpoint:
# print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
# print(f"{'=' * 60}\n")
# return True
# except Exception as e:
# print("\n" + "=" * 60)
# print("ERROR LOADING MODEL")
# print("=" * 60)
# print(f"Error: {e}")
# import traceback
# traceback.print_exc()
# print("=" * 60 + "\n")
# return False
# def create_streaming_visualizer():
# """Create a visualizer that yields SSE events instead of printing to terminal."""
# def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
# # Normalize inputs to lists
# if not isinstance(mask_blocks, list):
# mask_blocks = [mask_blocks]
# is_masked_list = [is_masked_list]
# # Decode context
# try:
# context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
# except Exception:
# context_text = str(context_ids[0].tolist())
# # Build blocks visualization
# all_blocks = []
# for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
# block_tokens = mask_block[0].tolist()
# block_data = []
# for i, token_id in enumerate(block_tokens):
# if is_masked[0, i]:
# block_data.append({
# 'type': 'masked',
# 'text': 'β–ˆβ–ˆβ–ˆ'
# })
# else:
# try:
# token_text = tok.decode([token_id], skip_special_tokens=False)
# except Exception:
# token_text = str(int(token_id))
# block_data.append({
# 'type': 'revealed',
# 'text': token_text
# })
# all_blocks.append({
# 'block_index': block_idx,
# 'tokens': block_data
# })
# # Return data structure that will be sent as SSE
# return {
# 'context': context_text,
# 'blocks': all_blocks,
# 'num_blocks': len(mask_blocks)
# }
# return visualizer
# @app.route('/')
# def index():
# """Serve the main HTML page."""
# return app.send_static_file('index.html')
# @app.route('/api/load', methods=['POST'])
# def load_model_endpoint():
# """Load the model."""
# data = request.json or {}
# check_only = data.get('check_only', False)
# global model
# if check_only:
# return jsonify({
# 'loaded': model is not None,
# 'message': 'Model is loaded' if model is not None else 'Model not loaded'
# })
# if model is not None:
# return jsonify({
# 'loaded': True,
# 'message': 'Model already loaded'
# })
# success = load_model_internal()
# if success:
# return jsonify({
# 'loaded': True,
# 'message': 'Model loaded successfully'
# })
# else:
# return jsonify({
# 'loaded': False,
# 'message': 'Failed to load model. Check server logs for details.'
# }), 500
# @app.route('/api/generate', methods=['POST'])
# def generate():
# """Generate response without streaming."""
# global model, tokenizer, config, device, chat_function
# if model is None:
# return jsonify({'error': 'Model not loaded'}), 400
# if chat_function is None:
# return jsonify({'error': 'Chat function not available'}), 400
# data = request.json
# instruction = data.get('instruction', '')
# steps = data.get('steps', 64)
# block_size = data.get('block_size', 128)
# max_new_tokens = data.get('max_new_tokens', 128)
# parallel_blocks = data.get('parallel_blocks', 1)
# if not instruction:
# return jsonify({'error': 'No instruction provided'}), 400
# try:
# # Generate response
# raw_output, response = chat_function(
# model,
# tokenizer,
# instruction,
# steps=steps,
# block_size=block_size,
# max_new_tokens=max_new_tokens,
# temperature=0.8,
# top_k=50,
# top_p=0.9,
# repetition_penalty=1.2,
# no_repeat_ngram_size=3,
# verbose=False,
# visualize_fn=None,
# parallel_blocks=parallel_blocks,
# )
# return jsonify({
# 'response': response,
# 'raw_output': raw_output
# })
# except Exception as e:
# import traceback
# traceback.print_exc()
# return jsonify({'error': str(e)}), 500
# @app.route('/api/generate-stream', methods=['POST'])
# def generate_stream():
# """Generate response with streaming visualization."""
# global model, tokenizer, config, device, chat_function
# if model is None:
# return jsonify({'error': 'Model not loaded'}), 400
# if chat_function is None:
# return jsonify({'error': 'Chat function not available'}), 400
# data = request.json
# instruction = data.get('instruction', '')
# steps = data.get('steps', 64)
# block_size = data.get('block_size', 128)
# max_new_tokens = data.get('max_new_tokens', 128)
# parallel_blocks = data.get('parallel_blocks', 1)
# if not instruction:
# return jsonify({'error': 'No instruction provided'}), 400
# def generate_events():
# try:
# # Import threading to allow yielding from callback
# import queue
# event_queue = queue.Queue()
# generation_complete = {'done': False, 'result': None}
# def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
# """This gets called during generation - we need to send events immediately"""
# visualizer = create_streaming_visualizer()
# data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
# # Put the update in the queue so it can be yielded immediately
# event_queue.put({'type': 'update', 'data': data})
# # Start generation in a separate thread so we can yield events as they come
# import threading
# def run_generation():
# try:
# raw_output, response = chat_function(
# model,
# tokenizer,
# instruction,
# steps=steps,
# block_size=block_size,
# max_new_tokens=max_new_tokens,
# temperature=0.8,
# top_k=50,
# top_p=0.9,
# repetition_penalty=1.2,
# no_repeat_ngram_size=3,
# verbose=False,
# visualize_fn=streaming_visualizer,
# parallel_blocks=parallel_blocks,
# )
# generation_complete['result'] = (raw_output, response)
# except Exception as e:
# generation_complete['result'] = ('error', str(e))
# finally:
# generation_complete['done'] = True
# event_queue.put(None) # Signal completion
# # Start generation thread
# gen_thread = threading.Thread(target=run_generation)
# gen_thread.daemon = True
# gen_thread.start()
# # Yield start event
# yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n"
# # Yield events as they come from the queue
# while not generation_complete['done'] or not event_queue.empty():
# try:
# event = event_queue.get(timeout=0.1)
# if event is None: # Completion signal
# break
# yield f"data: {json.dumps(event)}\n\n"
# except queue.Empty:
# continue
# # Wait for thread to finish
# gen_thread.join(timeout=1.0)
# # Send final response
# if generation_complete['result']:
# raw_output, response = generation_complete['result']
# if raw_output == 'error':
# yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
# else:
# yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n"
# except Exception as e:
# import traceback
# traceback.print_exc()
# yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
# return Response(
# stream_with_context(generate_events()),
# mimetype='text/event-stream',
# headers={
# 'Cache-Control': 'no-cache',
# 'X-Accel-Buffering': 'no'
# }
# )
# @app.route('/api/test-stream', methods=['GET'])
# def test_stream():
# """Test streaming endpoint."""
# def generate():
# for i in range(10):
# yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n"
# time.sleep(0.5)
# yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n"
# return Response(
# stream_with_context(generate()),
# mimetype='text/event-stream',
# headers={
# 'Cache-Control': 'no-cache',
# 'X-Accel-Buffering': 'no'
# }
# )
# if __name__ == '__main__':
# app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
import os
import sys
import json
import time
import importlib.util
from pathlib import Path
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import torch
from transformers import AutoTokenizer
import threading
import queue
import warnings
# ============ CRITICAL: CONFIGURE THREADS BEFORE TORCH OPERATIONS ============
# Must be set IMMEDIATELY at module import time
def setup_cpu_threads():
"""Configure CPU threads BEFORE any PyTorch parallel work starts."""
cpu_count = os.cpu_count() or 1
physical_cores = cpu_count // 2 if cpu_count > 1 else 1
# Set environment variables FIRST
os.environ["OMP_NUM_THREADS"] = str(physical_cores)
os.environ["MKL_NUM_THREADS"] = str(physical_cores)
os.environ["NUMEXPR_NUM_THREADS"] = str(physical_cores)
# Set PyTorch threads BEFORE any operations
try:
torch.set_num_threads(physical_cores)
torch.set_num_interop_threads(physical_cores)
except RuntimeError as e:
warnings.warn(f"Could not set threads: {e} (already initialized)")
print(f"βœ“ CPU threads configured: {physical_cores} physical cores")
return physical_cores
# Call immediately
PHYSICAL_CORES = setup_cpu_threads()
# ============================================================================
# Configuration flags
USE_ONNX_RUNTIME = False
USE_IPEX = False
USE_TORCH_COMPILE = True
QUANTIZE_MODEL = True
WARMUP_ITERATIONS = 3
app = Flask(__name__, static_folder='static', static_url_path='/static')
CORS(app, resources={r"/api/*": {"origins": "*"}}) # More permissive for testing
# Global state
model = None
tokenizer = None
config = None
device = None
DiffusionLLM = None
chat_function = None
ModelConfig = None # Will be imported from infer-base.py
def find_file(filename, search_dirs=None):
"""Find a file in current directory or parent directories."""
if search_dirs is None:
search_dirs = [
os.path.dirname(__file__),
os.path.dirname(os.path.dirname(__file__)),
os.getcwd(),
]
for directory in search_dirs:
filepath = os.path.join(directory, filename)
if os.path.exists(filepath):
print(f"Found {filename} at: {filepath}")
return filepath
return None
def try_import_module(filepath, module_name):
"""Dynamically import a Python file as a module."""
if not filepath or not os.path.exists(filepath):
return None
try:
module_dir = os.path.dirname(filepath)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(module_name, filepath)
if spec is None:
print(f"Could not create spec for {filepath}")
return None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
print(f"βœ“ Successfully imported {module_name}")
return module
except Exception as e:
print(f"βœ— Error importing {filepath}: {e}")
if __debug__:
import traceback
traceback.print_exc()
return None
def quantize_model(model):
"""Apply quantization for faster inference."""
if not QUANTIZE_MODEL:
return model
print("\nApplying quantization...")
try:
# Dynamic quantization - no calibration needed, works on any model
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding},
dtype=torch.qint8
)
print("βœ“ INT8 dynamic quantization applied")
return model
except Exception as e:
print(f"⚠ Quantization failed: {e}")
return model
def compile_model(model):
"""Compile model for maximum speed."""
print("\nCompiling model...")
# ONNX Runtime (BEST performance)
if USE_ONNX_RUNTIME:
try:
import onnxruntime as ort
onnx_path = "model_optimized.onnx"
# Export if not exists
if not os.path.exists(onnx_path):
print("Exporting model to ONNX format...")
dummy_input = torch.randint(0, 100, (1, 64))
torch.onnx.export(
model, dummy_input, onnx_path,
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
opset_version=16
)
# Create optimized session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = PHYSICAL_CORES
return ort.InferenceSession(onnx_path, sess_options)
except Exception as e:
print(f"⚠ ONNX Runtime failed: {e}")
# Intel IPEX
if USE_IPEX:
try:
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
print("βœ“ Intel IPEX optimization applied")
return model
except Exception as e:
print(f"⚠ IPEX failed: {e}")
# torch.compile
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
try:
model = torch.compile(model, mode="max-autotune")
print("βœ“ torch.compile applied")
except Exception as e:
print(f"⚠ torch.compile failed: {e}")
return model
def warmup_model(model, tokenizer, chat_func):
"""Warmup the model."""
if WARMUP_ITERATIONS == 0:
return
print("\nWarming up model...")
start = time.time()
try:
with torch.inference_mode():
for i in range(WARMUP_ITERATIONS):
chat_func(
model, tokenizer, "Hello",
steps=4, block_size=16, max_new_tokens=8,
temperature=0.7, top_k=50, top_p=0.9,
repetition_penalty=1.2, no_repeat_ngram_size=3,
verbose=False, visualize_fn=None, parallel_blocks=PHYSICAL_CORES
)
print(f" Warmup {i+1}/{WARMUP_ITERATIONS}...")
except Exception as e:
print(f"⚠ Warmup failed: {e}")
print(f"βœ“ Warmup complete ({time.time() - start:.2f}s)")
def load_model_internal():
"""Load model with ultra-fast optimizations."""
global model, tokenizer, config, device, DiffusionLLM, chat_function, ModelConfig
if model is not None:
return True
try:
print("\n" + "=" * 70)
print("ULTRA-FAST CPU MODEL LOADING")
print("=" * 70)
# FIRST: Import modules to get ModelConfig
print("\n1. Loading modules...")
base_path = find_file("infer-base.py")
if base_path is None:
raise RuntimeError("Could not find infer-base.py")
base_mod = try_import_module(base_path, "infer_base")
if base_mod is None:
raise RuntimeError("Failed to import infer-base.py")
# CRITICAL: Register ModelConfig for pickle
if hasattr(base_mod, 'ModelConfig'):
ModelConfig = base_mod.ModelConfig
sys.modules['__main__'].ModelConfig = ModelConfig
print("βœ“ ModelConfig registered for pickle")
if not hasattr(base_mod, 'DiffusionLLM'):
raise RuntimeError("DiffusionLLM class not found")
DiffusionLLM = base_mod.DiffusionLLM
print("βœ“ DiffusionLLM loaded")
# Import chat function
chat_path = find_file("infer-chat.py")
if chat_path is None:
raise RuntimeError("Could not find infer-chat.py")
chat_mod = try_import_module(chat_path, "infer_chat")
if chat_mod is None or not hasattr(chat_mod, 'chat'):
raise RuntimeError("Chat function not found")
chat_function = chat_mod.chat
print("βœ“ Chat function loaded")
# Device
device = torch.device("cpu")
# Load tokenizer
print("\n2. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-0.5B",
use_fast=True,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ“ Fast tokenizer ready")
# Find model
checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
model_path = None
for checkpoint_dir in checkpoint_dirs:
best_path = os.path.join(checkpoint_dir, "best_model.pt")
fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt")
if os.path.exists(best_path):
model_path = best_path
break
elif os.path.exists(fp32_path):
model_path = fp32_path
if model_path is None:
raise RuntimeError("Model checkpoint not found")
print(f"\n3. Loading checkpoint: {model_path}")
# CRITICAL: Load checkpoint AFTER ModelConfig is registered
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
config = checkpoint['config']
# Create model
print("4. Building model...")
model = DiffusionLLM(config)
# Load weights
state_dict = checkpoint['model_state']
if not USE_IPEX:
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
model = model.to(device)
# Apply optimizations
model = quantize_model(model)
model = compile_model(model)
# Warmup
warmup_model(model, tokenizer, chat_function)
# Print summary
num_params = sum(p.numel() for p in model.parameters()) / 1e6
framework = "ONNX Runtime" if USE_ONNX_RUNTIME else "IPEX" if USE_IPEX else "PyTorch"
precision = "INT8" if QUANTIZE_MODEL and not USE_IPEX else "BF16" if USE_IPEX else "FP32"
print("\n" + "=" * 70)
print(f"βœ“βœ“βœ“ MODEL LOADED & ULTRA-OPTIMIZED ({framework} + {precision}) βœ“βœ“βœ“")
print("=" * 70)
print(f"Parameters: {num_params:.1f}M")
print(f"CPU Threads: {PHYSICAL_CORES}")
if 'step' in checkpoint:
print(f"Training steps: {checkpoint['step']}")
if 'best_val_loss' in checkpoint:
print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")
print("=" * 70 + "\n")
return True
except Exception as e:
print(f"\nβœ— ERROR LOADING MODEL: {e}")
if __debug__:
import traceback
traceback.print_exc()
print("=" * 70 + "\n")
return False
def create_streaming_visualizer():
"""Create optimized visualizer."""
def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
if not isinstance(mask_blocks, list):
mask_blocks = [mask_blocks]
is_masked_list = [is_masked_list]
try:
context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
except Exception:
context_text = str(context_ids[0].tolist())
all_blocks = []
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
block_tokens = mask_block[0].tolist()
block_data = []
# Efficient batch decoding
token_ids_to_decode = []
positions = []
for i, token_id in enumerate(block_tokens):
if not is_masked[0, i]:
token_ids_to_decode.append(token_id)
positions.append(i)
try:
decoded_tokens = tok.batch_decode(
[[tid] for tid in token_ids_to_decode],
skip_special_tokens=False
)
except Exception:
decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
# Reconstruct
decoded_idx = 0
for i, token_id in enumerate(block_tokens):
if is_masked[0, i]:
block_data.append({'type': 'masked', 'text': 'β–ˆβ–ˆβ–ˆ'})
else:
block_data.append({
'type': 'revealed',
'text': decoded_tokens[decoded_idx]
})
decoded_idx += 1
all_blocks.append({'block_index': block_idx, 'tokens': block_data})
return {
'context': context_text,
'blocks': all_blocks,
'num_blocks': len(mask_blocks)
}
return visualizer
@app.route('/')
def index():
"""Serve the main HTML page."""
return app.send_static_file('index.html')
@app.route('/api/load', methods=['POST'])
def load_model_endpoint():
"""Load the model."""
global model
data = request.json or {}
check_only = data.get('check_only', False)
if check_only:
return jsonify({
'loaded': model is not None,
'message': 'Model is loaded' if model is not None else 'Model not loaded'
})
if model is not None:
return jsonify({
'loaded': True,
'message': 'Model already loaded'
})
success = load_model_internal()
if success:
return jsonify({
'loaded': True,
'message': 'Model loaded with ultra-fast CPU optimizations'
})
else:
return jsonify({
'loaded': False,
'message': 'Failed to load model. Check server logs.'
}), 500
@app.route('/api/generate', methods=['POST'])
def generate():
"""Generate response - ultra-fast path."""
global model, tokenizer, chat_function
if model is None:
return jsonify({'error': 'Model not loaded'}), 400
if chat_function is None:
return jsonify({'error': 'Chat function not available'}), 400
data = request.json
instruction = data.get('instruction', '')
steps = data.get('steps', 16) # Minimal for speed
block_size = data.get('block_size', 32)
max_new_tokens = data.get('max_new_tokens', 64)
parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
if not instruction:
return jsonify({'error': 'No instruction provided'}), 400
try:
with torch.inference_mode():
raw_output, response = chat_function(
model, tokenizer, instruction,
steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
temperature=0.7, top_k=50, top_p=0.9,
repetition_penalty=1.2, no_repeat_ngram_size=3,
verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks,
)
return jsonify({'response': response, 'raw_output': raw_output})
except Exception as e:
if __debug__:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/generate-stream', methods=['POST'])
def generate_stream():
"""Generate with streaming - optimized."""
global model, tokenizer, chat_function
if model is None:
return jsonify({'error': 'Model not loaded'}), 400
data = request.json
instruction = data.get('instruction', '')
steps = data.get('steps', 16)
block_size = data.get('block_size', 32)
max_new_tokens = data.get('max_new_tokens', 64)
parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
if not instruction:
return jsonify({'error': 'No instruction provided'}), 400
def generate_events():
event_queue = queue.Queue(maxsize=50) # Limited queue
generation_complete = {'done': False, 'result': None}
def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
try:
visualizer = create_streaming_visualizer()
data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
event_queue.put({'type': 'update', 'data': data}, block=False, timeout=0.1)
except queue.Full:
pass # Drop frames if too slow
def run_generation():
try:
with torch.inference_mode():
raw_output, response = chat_function(
model, tokenizer, instruction,
steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
temperature=0.7, top_k=50, top_p=0.9,
repetition_penalty=1.2, no_repeat_ngram_size=3,
verbose=False, visualize_fn=streaming_visualizer,
parallel_blocks=parallel_blocks,
)
generation_complete['result'] = (raw_output, response)
except Exception as e:
generation_complete['result'] = ('error', str(e))
finally:
generation_complete['done'] = True
event_queue.put(None)
# Start generation thread
gen_thread = threading.Thread(target=run_generation, daemon=True)
gen_thread.start()
yield f"data: {json.dumps({'type': 'start', 'ts': time.time()})}\n\n"
# Stream events
while not generation_complete['done'] or not event_queue.empty():
try:
event = event_queue.get(timeout=0.1)
if event is None:
break
yield f"data: {json.dumps(event)}\n\n"
except queue.Empty:
continue
gen_thread.join(timeout=2.0)
# Send final result
if generation_complete['result']:
raw_output, response = generation_complete['result']
yield f"data: {json.dumps({'type': 'complete' if raw_output != 'error' else 'error',
'response': response, 'error': response if raw_output == 'error' else None})}\n\n"
return Response(
stream_with_context(generate_events()),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no',
'Connection': 'keep-alive'
}
)
@app.route('/api/status', methods=['GET'])
def status():
"""Get detailed status."""
return jsonify({
'model_loaded': model is not None,
'cpu_cores': os.cpu_count(),
'physical_cores': PHYSICAL_CORES,
'torch_threads': torch.get_num_threads(),
'interop_threads': torch.get_num_interop_threads(),
'optimizations': {
'onnx_runtime': USE_ONNX_RUNTIME,
'ipex': USE_IPEX,
'torch_compile': USE_TORCH_COMPILE,
'quantization': QUANTIZE_MODEL,
'warmup_iterations': WARMUP_ITERATIONS
}
})
if __name__ == '__main__':
print("\n" + "=" * 70)
print("ULTRA-FAST CPU INFERENCE SERVER v2.0")
print("=" * 70)
print(f"CPU Configuration: {PHYSICAL_CORES} physical cores")
print(f"Optimizations: ONNX={USE_ONNX_RUNTIME} | IPEX={USE_IPEX} | Compile={USE_TORCH_COMPILE}")
print("=" * 70 + "\n")
app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)