# import os # import sys # import json # import time # import importlib.util # from pathlib import Path # from flask import Flask, request, jsonify, Response, stream_with_context # from flask_cors import CORS # import torch # from transformers import AutoTokenizer # app = Flask(__name__, static_folder='static', static_url_path='/static') # CORS(app) # # Global state # model = None # tokenizer = None # config = None # device = None # DiffusionLLM = None # chat_function = None # def find_file(filename, search_dirs=None): # """Find a file in current directory or parent directories.""" # if search_dirs is None: # search_dirs = [ # os.path.dirname(__file__), # Current directory # os.path.dirname(os.path.dirname(__file__)), # Parent directory # os.getcwd(), # Working directory # ] # for directory in search_dirs: # filepath = os.path.join(directory, filename) # if os.path.exists(filepath): # print(f"Found {filename} at: {filepath}") # return filepath # return None # def try_import_module(filepath, module_name): # """Dynamically import a Python file as a module.""" # if not filepath or not os.path.exists(filepath): # return None # try: # # Add the directory to sys.path # module_dir = os.path.dirname(filepath) # if module_dir not in sys.path: # sys.path.insert(0, module_dir) # spec = importlib.util.spec_from_file_location(module_name, filepath) # if spec is None: # print(f"Could not create spec for {filepath}") # return None # module = importlib.util.module_from_spec(spec) # sys.modules[module_name] = module # spec.loader.exec_module(module) # print(f"Successfully imported {module_name} from {filepath}") # return module # except Exception as e: # print(f"Error importing {filepath}: {e}") # import traceback # traceback.print_exc() # return None # def load_model_internal(): # """Load the model and tokenizer.""" # global model, tokenizer, config, device, DiffusionLLM, chat_function # if model is not None: # return True # try: # print("=" * 60) # print("Starting model loading process...") # print("=" * 60) # # Find and import infer-base.py # base_path = find_file("infer-base.py") # if base_path is None: # raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.") # print(f"\nImporting infer-base.py from: {base_path}") # base_mod = try_import_module(base_path, "infer_base") # if base_mod is None: # raise RuntimeError("Failed to import infer-base.py") # # Check for DiffusionLLM class # if not hasattr(base_mod, 'DiffusionLLM'): # print("Available attributes in infer_base:", dir(base_mod)) # raise RuntimeError("DiffusionLLM class not found in infer-base.py") # DiffusionLLM = base_mod.DiffusionLLM # print("✓ Successfully loaded DiffusionLLM class") # # Find and import infer-chat.py # chat_path = find_file("infer-chat.py") # if chat_path is None: # raise RuntimeError("Could not find infer-chat.py") # print(f"\nImporting infer-chat.py from: {chat_path}") # chat_mod = try_import_module(chat_path, "infer_chat") # if chat_mod is None or not hasattr(chat_mod, 'chat'): # raise RuntimeError("Failed to import chat function from infer-chat.py") # chat_function = chat_mod.chat # print("✓ Successfully loaded chat function") # # Setup pickling workaround for torch.load # try: # if hasattr(base_mod, 'ModelConfig'): # sys.modules['__main__'].ModelConfig = base_mod.ModelConfig # sys.modules['__main__'].DiffusionLLM = DiffusionLLM # print("✓ Configured pickle support for model loading") # except Exception as e: # print(f"Warning: Could not setup pickle workaround: {e}") # # Set device # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"\n✓ Using device: {device}") # # Load tokenizer # print("\nLoading tokenizer...") # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") # if tokenizer.pad_token is None: # tokenizer.pad_token = tokenizer.eos_token # print("✓ Tokenizer loaded") # # Find model checkpoint # checkpoint_dirs = [ # "checkpoints", # "../checkpoints", # "./checkpoints", # os.path.join(os.path.dirname(__file__), "checkpoints"), # os.path.join(os.path.dirname(__file__), "../checkpoints"), # ] # model_path = None # for checkpoint_dir in checkpoint_dirs: # best_path = os.path.join(checkpoint_dir, "best_model.pt") # fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") # if os.path.exists(best_path): # model_path = best_path # break # elif os.path.exists(fp32_path): # model_path = fp32_path # break # if model_path is None: # raise RuntimeError( # "Could not find model checkpoint. Looking for:\n" # " - checkpoints/best_model.pt\n" # " - checkpoints/model_fp32.pt\n" # f"Searched directories: {checkpoint_dirs}" # ) # print(f"\n✓ Found model checkpoint: {model_path}") # print("Loading model weights (this may take a minute)...") # # Load model # checkpoint = torch.load(model_path, map_location=device, weights_only=False) # config = checkpoint['config'] # print("Creating model...") # model = DiffusionLLM(config) # print("Loading state dict...") # state_dict = checkpoint['model_state'] # state_dict = {k: v.float() for k, v in state_dict.items()} # model.load_state_dict(state_dict) # model = model.to(device) # model.eval() # num_params = sum(p.numel() for p in model.parameters()) / 1e6 # print(f"\n{'=' * 60}") # print(f"✓✓✓ MODEL LOADED SUCCESSFULLY ✓✓✓") # print(f"{'=' * 60}") # print(f"Parameters: {num_params:.1f}M") # if 'step' in checkpoint: # print(f"Training steps: {checkpoint['step']}") # if 'best_val_loss' in checkpoint: # print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}") # print(f"{'=' * 60}\n") # return True # except Exception as e: # print("\n" + "=" * 60) # print("ERROR LOADING MODEL") # print("=" * 60) # print(f"Error: {e}") # import traceback # traceback.print_exc() # print("=" * 60 + "\n") # return False # def create_streaming_visualizer(): # """Create a visualizer that yields SSE events instead of printing to terminal.""" # def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): # # Normalize inputs to lists # if not isinstance(mask_blocks, list): # mask_blocks = [mask_blocks] # is_masked_list = [is_masked_list] # # Decode context # try: # context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') # except Exception: # context_text = str(context_ids[0].tolist()) # # Build blocks visualization # all_blocks = [] # for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): # block_tokens = mask_block[0].tolist() # block_data = [] # for i, token_id in enumerate(block_tokens): # if is_masked[0, i]: # block_data.append({ # 'type': 'masked', # 'text': '███' # }) # else: # try: # token_text = tok.decode([token_id], skip_special_tokens=False) # except Exception: # token_text = str(int(token_id)) # block_data.append({ # 'type': 'revealed', # 'text': token_text # }) # all_blocks.append({ # 'block_index': block_idx, # 'tokens': block_data # }) # # Return data structure that will be sent as SSE # return { # 'context': context_text, # 'blocks': all_blocks, # 'num_blocks': len(mask_blocks) # } # return visualizer # @app.route('/') # def index(): # """Serve the main HTML page.""" # return app.send_static_file('index.html') # @app.route('/api/load', methods=['POST']) # def load_model_endpoint(): # """Load the model.""" # data = request.json or {} # check_only = data.get('check_only', False) # global model # if check_only: # return jsonify({ # 'loaded': model is not None, # 'message': 'Model is loaded' if model is not None else 'Model not loaded' # }) # if model is not None: # return jsonify({ # 'loaded': True, # 'message': 'Model already loaded' # }) # success = load_model_internal() # if success: # return jsonify({ # 'loaded': True, # 'message': 'Model loaded successfully' # }) # else: # return jsonify({ # 'loaded': False, # 'message': 'Failed to load model. Check server logs for details.' # }), 500 # @app.route('/api/generate', methods=['POST']) # def generate(): # """Generate response without streaming.""" # global model, tokenizer, config, device, chat_function # if model is None: # return jsonify({'error': 'Model not loaded'}), 400 # if chat_function is None: # return jsonify({'error': 'Chat function not available'}), 400 # data = request.json # instruction = data.get('instruction', '') # steps = data.get('steps', 64) # block_size = data.get('block_size', 128) # max_new_tokens = data.get('max_new_tokens', 128) # parallel_blocks = data.get('parallel_blocks', 1) # if not instruction: # return jsonify({'error': 'No instruction provided'}), 400 # try: # # Generate response # raw_output, response = chat_function( # model, # tokenizer, # instruction, # steps=steps, # block_size=block_size, # max_new_tokens=max_new_tokens, # temperature=0.8, # top_k=50, # top_p=0.9, # repetition_penalty=1.2, # no_repeat_ngram_size=3, # verbose=False, # visualize_fn=None, # parallel_blocks=parallel_blocks, # ) # return jsonify({ # 'response': response, # 'raw_output': raw_output # }) # except Exception as e: # import traceback # traceback.print_exc() # return jsonify({'error': str(e)}), 500 # @app.route('/api/generate-stream', methods=['POST']) # def generate_stream(): # """Generate response with streaming visualization.""" # global model, tokenizer, config, device, chat_function # if model is None: # return jsonify({'error': 'Model not loaded'}), 400 # if chat_function is None: # return jsonify({'error': 'Chat function not available'}), 400 # data = request.json # instruction = data.get('instruction', '') # steps = data.get('steps', 64) # block_size = data.get('block_size', 128) # max_new_tokens = data.get('max_new_tokens', 128) # parallel_blocks = data.get('parallel_blocks', 1) # if not instruction: # return jsonify({'error': 'No instruction provided'}), 400 # def generate_events(): # try: # # Import threading to allow yielding from callback # import queue # event_queue = queue.Queue() # generation_complete = {'done': False, 'result': None} # def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): # """This gets called during generation - we need to send events immediately""" # visualizer = create_streaming_visualizer() # data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) # # Put the update in the queue so it can be yielded immediately # event_queue.put({'type': 'update', 'data': data}) # # Start generation in a separate thread so we can yield events as they come # import threading # def run_generation(): # try: # raw_output, response = chat_function( # model, # tokenizer, # instruction, # steps=steps, # block_size=block_size, # max_new_tokens=max_new_tokens, # temperature=0.8, # top_k=50, # top_p=0.9, # repetition_penalty=1.2, # no_repeat_ngram_size=3, # verbose=False, # visualize_fn=streaming_visualizer, # parallel_blocks=parallel_blocks, # ) # generation_complete['result'] = (raw_output, response) # except Exception as e: # generation_complete['result'] = ('error', str(e)) # finally: # generation_complete['done'] = True # event_queue.put(None) # Signal completion # # Start generation thread # gen_thread = threading.Thread(target=run_generation) # gen_thread.daemon = True # gen_thread.start() # # Yield start event # yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n" # # Yield events as they come from the queue # while not generation_complete['done'] or not event_queue.empty(): # try: # event = event_queue.get(timeout=0.1) # if event is None: # Completion signal # break # yield f"data: {json.dumps(event)}\n\n" # except queue.Empty: # continue # # Wait for thread to finish # gen_thread.join(timeout=1.0) # # Send final response # if generation_complete['result']: # raw_output, response = generation_complete['result'] # if raw_output == 'error': # yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" # else: # yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n" # except Exception as e: # import traceback # traceback.print_exc() # yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" # return Response( # stream_with_context(generate_events()), # mimetype='text/event-stream', # headers={ # 'Cache-Control': 'no-cache', # 'X-Accel-Buffering': 'no' # } # ) # @app.route('/api/test-stream', methods=['GET']) # def test_stream(): # """Test streaming endpoint.""" # def generate(): # for i in range(10): # yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n" # time.sleep(0.5) # yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n" # return Response( # stream_with_context(generate()), # mimetype='text/event-stream', # headers={ # 'Cache-Control': 'no-cache', # 'X-Accel-Buffering': 'no' # } # ) # if __name__ == '__main__': # app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) import os import sys import json import time import importlib.util from pathlib import Path from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import torch from transformers import AutoTokenizer import threading import queue app = Flask(__name__, static_folder='static', static_url_path='/static') CORS(app) # Global state model = None tokenizer = None config = None device = None DiffusionLLM = None chat_function = None optimized_pipeline = None # For ONNX/IPEX # ==================== CONFIGURATION ==================== USE_ONNX_RUNTIME = False # Set True for ONNX (fastest) USE_IPEX = False # Set True for Intel CPUs USE_TORCH_COMPILE = True # Set True for PyTorch 2.0+ (good default) QUANTIZE_MODEL = True # INT8/BF16 quantization WARMUP_ITERATIONS = 3 # Warmup for stable performance # ======================================================= def find_file(filename, search_dirs=None): """Find a file in current directory or parent directories.""" if search_dirs is None: search_dirs = [ os.path.dirname(__file__), os.path.dirname(os.path.dirname(__file__)), os.getcwd(), ] for directory in search_dirs: filepath = os.path.join(directory, filename) if os.path.exists(filepath): print(f"Found {filename} at: {filepath}") return filepath return None def try_import_module(filepath, module_name): """Dynamically import a Python file as a module.""" if not filepath or not os.path.exists(filepath): return None try: module_dir = os.path.dirname(filepath) if module_dir not in sys.path: sys.path.insert(0, module_dir) spec = importlib.util.spec_from_file_location(module_name, filepath) if spec is None: print(f"Could not create spec for {filepath}") return None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) print(f"Successfully imported {module_name} from {filepath}") return module except Exception as e: print(f"Error importing {filepath}: {e}") if __debug__: import traceback traceback.print_exc() return None def configure_cpu_optimization(): """Configure CPU for maximum performance.""" print("\n" + "=" * 60) print("CPU OPTIMIZATION CONFIGURATION") print("=" * 60) # Get CPU info cpu_count = os.cpu_count() physical_cores = cpu_count // 2 if cpu_count else cpu_count # Optimal thread settings threads = physical_cores or cpu_count or 1 torch.set_num_threads(threads) torch.set_num_interop_threads(threads) # Environment variables for MKL/OMP os.environ["OMP_NUM_THREADS"] = str(threads) os.environ["MKL_NUM_THREADS"] = str(threads) os.environ["NUMEXPR_NUM_THREADS"] = str(threads) print(f"✓ CPU Cores: {cpu_count} ({physical_cores} physical)") print(f"✓ Threads: {threads}") print(f"✓ OMP/MKL threads: {threads}") # Intel-specific optimizations if "intel" in torch.__version__.lower() or USE_IPEX: torch.backends.quantized.engine = 'fbgemm' print("✓ Intel FBGEMM backend enabled") print("=" * 60 + "\n") def quantize_model(model): """Apply quantization for faster inference.""" if not QUANTIZE_MODEL: return model print("\nApplying quantization...") try: # Use torch.compile with quantization if available if hasattr(torch, 'ao') and hasattr(torch.ao, 'quantization'): # Dynamic quantization (fastest, no calibration needed) model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d}, dtype=torch.qint8 ) print("✓ Applied INT8 dynamic quantization") elif USE_IPEX: # IPEX BF16 optimization model = model.to(torch.bfloat16) print("✓ Applied BF16 precision (IPEX)") return model except Exception as e: print(f"⚠ Quantization failed: {e}") return model def compile_model(model): """Compile model for maximum speed.""" global USE_TORCH_COMPILE, USE_ONNX_RUNTIME print("\nCompiling model...") # Option 1: ONNX Runtime (BEST performance) if USE_ONNX_RUNTIME: try: import onnxruntime as ort # Export to ONNX (one-time cost, but worth it) onnx_path = "model_optimized.onnx" if not os.path.exists(onnx_path): print("Exporting to ONNX format...") dummy_input = torch.randint(0, 100, (1, 128)) # Adjust shape as needed torch.onnx.export( model, dummy_input, onnx_path, input_names=['input_ids'], output_names=['logits'], dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}}, opset_version=16 ) # Create ONNX Runtime session sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.enable_cpu_mem_arena = True sess_options.enable_mem_pattern = True # Use all cores sess_options.intra_op_num_threads = torch.get_num_threads() provider = 'CPUExecutionProvider' compiled_model = ort.InferenceSession(onnx_path, sess_options, providers=[provider]) print("✓ ONNX Runtime compilation complete") return compiled_model except Exception as e: print(f"⚠ ONNX Runtime failed: {e}, using torch.compile") USE_ONNX_RUNTIME = False # Option 2: Intel IPEX if USE_IPEX: try: import intel_extension_for_pytorch as ipex model = ipex.optimize(model, dtype=torch.bfloat16, level="O3") print("✓ Intel IPEX O3 optimization applied") return model except Exception as e: print(f"⚠ IPEX failed: {e}") USE_IPEX = False # Option 3: torch.compile (PyTorch 2.0+) if USE_TORCH_COMPILE and hasattr(torch, 'compile'): try: # Use "max-autotune" for best performance model = torch.compile( model, mode="max-autotune", fullgraph=True, backend="inductor" ) print("✓ torch.compile (max-autotune) applied") except Exception as e: print(f"⚠ torch.compile failed: {e}, using eager mode") return model def warmup_model(model, tokenizer, chat_func): """Warmup the model for consistent performance.""" if WARMUP_ITERATIONS == 0: return print("\nWarming up model...") start_time = time.time() try: with torch.inference_mode(): for i in range(WARMUP_ITERATIONS): _ = chat_func( model, tokenizer, "Hello world", steps=8, block_size=32, max_new_tokens=16, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, verbose=False, visualize_fn=None, parallel_blocks=2 ) print(f" Warmup {i+1}/{WARMUP_ITERATIONS} complete") except Exception as e: print(f"⚠ Warmup failed: {e}") print(f"✓ Warmup finished in {time.time() - start_time:.2f}s") def load_model_internal(): """Load the model with ultra-fast optimizations.""" global model, tokenizer, config, device, DiffusionLLM, chat_function, optimized_pipeline if model is not None: return True try: print("=" * 60) print("ULTRA-FAST CPU MODEL LOADING") print("=" * 60) # Configure CPU configure_cpu_optimization() # Import modules base_path = find_file("infer-base.py") if base_path is None: raise RuntimeError("Could not find infer-base.py") base_mod = try_import_module(base_path, "infer_base") if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'): raise RuntimeError("DiffusionLLM class not found") DiffusionLLM = base_mod.DiffusionLLM chat_path = find_file("infer-chat.py") if chat_path is None: raise RuntimeError("Could not find infer-chat.py") chat_mod = try_import_module(chat_path, "infer_chat") if chat_mod is None or not hasattr(chat_mod, 'chat'): raise RuntimeError("Chat function not found") chat_function = chat_mod.chat # Device device = torch.device("cpu") # Load tokenizer print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2.5-0.5B", use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("✓ Fast tokenizer loaded") # Find model checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"] model_path = None for checkpoint_dir in checkpoint_dirs: best_path = os.path.join(checkpoint_dir, "best_model.pt") fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") if os.path.exists(best_path): model_path = best_path break elif os.path.exists(fp32_path): model_path = fp32_path if model_path is None: raise RuntimeError("Model checkpoint not found") print(f"\nLoading model from: {model_path}") # Load checkpoint checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) config = checkpoint['config'] # Create and load model model = DiffusionLLM(config) state_dict = checkpoint['model_state'] # Convert to float32 for CPU if not USE_IPEX: # Keep FP32 for ONNX, use BF16 for IPEX state_dict = {k: v.float() for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() # Apply optimizations model = quantize_model(model) model = model.to(device) model = compile_model(model) # Warmup warmup_model(model, tokenizer, chat_function) # Print summary num_params = sum(p.numel() for p in model.parameters()) / 1e6 print(f"\n{'=' * 60}") print(f"✓✓✓ MODEL LOADED & ULTRA-OPTIMIZED FOR CPU ✓✓✓") print(f"{'=' * 60}") print(f"Framework: {'ONNX Runtime' if USE_ONNX_RUNTIME else 'IPEX' if USE_IPEX else 'PyTorch'}") print(f"Parameters: {num_params:.1f}M") print(f"CPU Threads: {torch.get_num_threads()}") print(f"Quantization: {'INT8' if QUANTIZE_MODEL else 'BF16' if USE_IPEX else 'FP32'}") if 'step' in checkpoint: print(f"Training steps: {checkpoint['step']}") print(f"{'=' * 60}\n") return True except Exception as e: print(f"\nERROR LOADING MODEL: {e}") if __debug__: import traceback traceback.print_exc() return False def create_streaming_visualizer(): """Create optimized visualizer.""" def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): if not isinstance(mask_blocks, list): mask_blocks = [mask_blocks] is_masked_list = [is_masked_list] try: # Decode only once for efficiency context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') except Exception: context_text = str(context_ids[0].tolist()) all_blocks = [] for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): block_tokens = mask_block[0].tolist() block_data = [] # Batch decode for speed token_ids_to_decode = [] positions = [] for i, token_id in enumerate(block_tokens): if not is_masked[0, i]: token_ids_to_decode.append(token_id) positions.append(i) try: decoded_tokens = tok.batch_decode( [[tid] for tid in token_ids_to_decode], skip_special_tokens=False ) except Exception: decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode] # Reconstruct block decoded_idx = 0 for i, token_id in enumerate(block_tokens): if is_masked[0, i]: block_data.append({'type': 'masked', 'text': '███'}) else: block_data.append({ 'type': 'revealed', 'text': decoded_tokens[decoded_idx] }) decoded_idx += 1 all_blocks.append({'block_index': block_idx, 'tokens': block_data}) return { 'context': context_text, 'blocks': all_blocks, 'num_blocks': len(mask_blocks) } return visualizer @app.route('/') def index(): """Serve the main HTML page.""" return app.send_static_file('index.html') @app.route('/api/load', methods=['POST']) def load_model_endpoint(): """Load the model.""" data = request.json or {} check_only = data.get('check_only', False) global model if check_only: return jsonify({ 'loaded': model is not None, 'message': 'Model is loaded' if model is not None else 'Model not loaded' }) if model is not None: return jsonify({ 'loaded': True, 'message': 'Model already loaded' }) success = load_model_internal() if success: return jsonify({ 'loaded': True, 'message': 'Model loaded with ultra-fast CPU optimizations' }) else: return jsonify({ 'loaded': False, 'message': 'Failed to load model. Check server logs.' }), 500 @app.route('/api/generate', methods=['POST']) def generate(): """Generate response - optimized for minimal latency.""" global model, tokenizer, device, chat_function if model is None: return jsonify({'error': 'Model not loaded'}), 400 if chat_function is None: return jsonify({'error': 'Chat function not available'}), 400 data = request.json instruction = data.get('instruction', '') steps = data.get('steps', 16) # Further reduced for speed block_size = data.get('block_size', 32) max_new_tokens = data.get('max_new_tokens', 64) parallel_blocks = data.get('parallel_blocks', torch.get_num_threads()) if not instruction: return jsonify({'error': 'No instruction provided'}), 400 try: # Fast path: no overhead with torch.inference_mode(): raw_output, response = chat_function( model, tokenizer, instruction, steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, temperature=0.7, # Slightly lower for faster sampling top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks, ) return jsonify({ 'response': response, 'raw_output': raw_output }) except Exception as e: if __debug__: import traceback traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/generate-stream', methods=['POST']) def generate_stream(): """Generate with streaming - optimized with queue.""" global model, tokenizer, chat_function if model is None: return jsonify({'error': 'Model not loaded'}), 400 if chat_function is None: return jsonify({'error': 'Chat function not available'}), 400 data = request.json instruction = data.get('instruction', '') steps = data.get('steps', 16) block_size = data.get('block_size', 32) max_new_tokens = data.get('max_new_tokens', 64) parallel_blocks = data.get('parallel_blocks', torch.get_num_threads()) if not instruction: return jsonify({'error': 'No instruction provided'}), 400 def generate_events(): event_queue = queue.Queue(maxsize=100) # Limit queue size generation_complete = {'done': False, 'result': None} def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): try: visualizer = create_streaming_visualizer() data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) event_queue.put({'type': 'update', 'data': data}, block=False) except queue.Full: pass # Drop updates if queue full (prevents memory issues) def run_generation(): try: with torch.inference_mode(): raw_output, response = chat_function( model, tokenizer, instruction, steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3, verbose=False, visualize_fn=streaming_visualizer, parallel_blocks=parallel_blocks, ) generation_complete['result'] = (raw_output, response) except Exception as e: generation_complete['result'] = ('error', str(e)) finally: generation_complete['done'] = True event_queue.put(None) import threading gen_thread = threading.Thread(target=run_generation) gen_thread.daemon = True gen_thread.start() yield f"data: {json.dumps({'type': 'start'})}\n\n" while not generation_complete['done'] or not event_queue.empty(): try: event = event_queue.get(timeout=0.1) if event is None: break yield f"data: {json.dumps(event)}\n\n" except queue.Empty: continue gen_thread.join(timeout=2.0) if generation_complete['result']: raw_output, response = generation_complete['result'] if raw_output == 'error': yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" else: yield f"data: {json.dumps({'type': 'complete', 'response': response})}\n\n" return Response( stream_with_context(generate_events()), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no', 'Connection': 'keep-alive' } ) @app.route('/api/status', methods=['GET']) def status(): """Get system status.""" return jsonify({ 'model_loaded': model is not None, 'torch_threads': torch.get_num_threads(), 'interop_threads': torch.get_num_interop_threads(), 'cpu_count': os.cpu_count(), 'optimizations': { 'onnx_runtime': USE_ONNX_RUNTIME, 'ipex': USE_IPEX, 'torch_compile': USE_TORCH_COMPILE, 'quantization': QUANTIZE_MODEL } }) if __name__ == '__main__': print("\n" + "=" * 70) print("ULTRA-FAST CPU INFERENCE SERVER") print("=" * 70) print("Available optimizations:") print(" ✓ ONNX Runtime (best):", USE_ONNX_RUNTIME) print(" ✓ Intel IPEX:", USE_IPEX) print(" ✓ torch.compile:", USE_TORCH_COMPILE) print(" ✓ Quantization:", QUANTIZE_MODEL) print(" ✓ Multi-threading") print(" ✓ Inference mode") print(" ✓ Fast tokenizer") print(" ✓ Memory layout optimization") print("\nTo install ONNX Runtime: pip install onnxruntime") print("To install Intel IPEX: pip install intel-extension-for-pytorch") print("=" * 70 + "\n") app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)