Spaces:
Sleeping
Sleeping
| # import os | |
| # import sys | |
| # import json | |
| # import time | |
| # import importlib.util | |
| # from pathlib import Path | |
| # from flask import Flask, request, jsonify, Response, stream_with_context | |
| # from flask_cors import CORS | |
| # import torch | |
| # from transformers import AutoTokenizer | |
| # app = Flask(__name__, static_folder='static', static_url_path='/static') | |
| # CORS(app) | |
| # # Global state | |
| # model = None | |
| # tokenizer = None | |
| # config = None | |
| # device = None | |
| # DiffusionLLM = None | |
| # chat_function = None | |
| # def find_file(filename, search_dirs=None): | |
| # """Find a file in current directory or parent directories.""" | |
| # if search_dirs is None: | |
| # search_dirs = [ | |
| # os.path.dirname(__file__), # Current directory | |
| # os.path.dirname(os.path.dirname(__file__)), # Parent directory | |
| # os.getcwd(), # Working directory | |
| # ] | |
| # for directory in search_dirs: | |
| # filepath = os.path.join(directory, filename) | |
| # if os.path.exists(filepath): | |
| # print(f"Found {filename} at: {filepath}") | |
| # return filepath | |
| # return None | |
| # def try_import_module(filepath, module_name): | |
| # """Dynamically import a Python file as a module.""" | |
| # if not filepath or not os.path.exists(filepath): | |
| # return None | |
| # try: | |
| # # Add the directory to sys.path | |
| # module_dir = os.path.dirname(filepath) | |
| # if module_dir not in sys.path: | |
| # sys.path.insert(0, module_dir) | |
| # spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| # if spec is None: | |
| # print(f"Could not create spec for {filepath}") | |
| # return None | |
| # module = importlib.util.module_from_spec(spec) | |
| # sys.modules[module_name] = module | |
| # spec.loader.exec_module(module) | |
| # print(f"Successfully imported {module_name} from {filepath}") | |
| # return module | |
| # except Exception as e: | |
| # print(f"Error importing {filepath}: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # return None | |
| # def load_model_internal(): | |
| # """Load the model and tokenizer.""" | |
| # global model, tokenizer, config, device, DiffusionLLM, chat_function | |
| # if model is not None: | |
| # return True | |
| # try: | |
| # print("=" * 60) | |
| # print("Starting model loading process...") | |
| # print("=" * 60) | |
| # # Find and import infer-base.py | |
| # base_path = find_file("infer-base.py") | |
| # if base_path is None: | |
| # raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.") | |
| # print(f"\nImporting infer-base.py from: {base_path}") | |
| # base_mod = try_import_module(base_path, "infer_base") | |
| # if base_mod is None: | |
| # raise RuntimeError("Failed to import infer-base.py") | |
| # # Check for DiffusionLLM class | |
| # if not hasattr(base_mod, 'DiffusionLLM'): | |
| # print("Available attributes in infer_base:", dir(base_mod)) | |
| # raise RuntimeError("DiffusionLLM class not found in infer-base.py") | |
| # DiffusionLLM = base_mod.DiffusionLLM | |
| # print("β Successfully loaded DiffusionLLM class") | |
| # # Find and import infer-chat.py | |
| # chat_path = find_file("infer-chat.py") | |
| # if chat_path is None: | |
| # raise RuntimeError("Could not find infer-chat.py") | |
| # print(f"\nImporting infer-chat.py from: {chat_path}") | |
| # chat_mod = try_import_module(chat_path, "infer_chat") | |
| # if chat_mod is None or not hasattr(chat_mod, 'chat'): | |
| # raise RuntimeError("Failed to import chat function from infer-chat.py") | |
| # chat_function = chat_mod.chat | |
| # print("β Successfully loaded chat function") | |
| # # Setup pickling workaround for torch.load | |
| # try: | |
| # if hasattr(base_mod, 'ModelConfig'): | |
| # sys.modules['__main__'].ModelConfig = base_mod.ModelConfig | |
| # sys.modules['__main__'].DiffusionLLM = DiffusionLLM | |
| # print("β Configured pickle support for model loading") | |
| # except Exception as e: | |
| # print(f"Warning: Could not setup pickle workaround: {e}") | |
| # # Set device | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"\nβ Using device: {device}") | |
| # # Load tokenizer | |
| # print("\nLoading tokenizer...") | |
| # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| # if tokenizer.pad_token is None: | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # print("β Tokenizer loaded") | |
| # # Find model checkpoint | |
| # checkpoint_dirs = [ | |
| # "checkpoints", | |
| # "../checkpoints", | |
| # "./checkpoints", | |
| # os.path.join(os.path.dirname(__file__), "checkpoints"), | |
| # os.path.join(os.path.dirname(__file__), "../checkpoints"), | |
| # ] | |
| # model_path = None | |
| # for checkpoint_dir in checkpoint_dirs: | |
| # best_path = os.path.join(checkpoint_dir, "best_model.pt") | |
| # fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") | |
| # if os.path.exists(best_path): | |
| # model_path = best_path | |
| # break | |
| # elif os.path.exists(fp32_path): | |
| # model_path = fp32_path | |
| # break | |
| # if model_path is None: | |
| # raise RuntimeError( | |
| # "Could not find model checkpoint. Looking for:\n" | |
| # " - checkpoints/best_model.pt\n" | |
| # " - checkpoints/model_fp32.pt\n" | |
| # f"Searched directories: {checkpoint_dirs}" | |
| # ) | |
| # print(f"\nβ Found model checkpoint: {model_path}") | |
| # print("Loading model weights (this may take a minute)...") | |
| # # Load model | |
| # checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| # config = checkpoint['config'] | |
| # print("Creating model...") | |
| # model = DiffusionLLM(config) | |
| # print("Loading state dict...") | |
| # state_dict = checkpoint['model_state'] | |
| # state_dict = {k: v.float() for k, v in state_dict.items()} | |
| # model.load_state_dict(state_dict) | |
| # model = model.to(device) | |
| # model.eval() | |
| # num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| # print(f"\n{'=' * 60}") | |
| # print(f"βββ MODEL LOADED SUCCESSFULLY βββ") | |
| # print(f"{'=' * 60}") | |
| # print(f"Parameters: {num_params:.1f}M") | |
| # if 'step' in checkpoint: | |
| # print(f"Training steps: {checkpoint['step']}") | |
| # if 'best_val_loss' in checkpoint: | |
| # print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}") | |
| # print(f"{'=' * 60}\n") | |
| # return True | |
| # except Exception as e: | |
| # print("\n" + "=" * 60) | |
| # print("ERROR LOADING MODEL") | |
| # print("=" * 60) | |
| # print(f"Error: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # print("=" * 60 + "\n") | |
| # return False | |
| # def create_streaming_visualizer(): | |
| # """Create a visualizer that yields SSE events instead of printing to terminal.""" | |
| # def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| # # Normalize inputs to lists | |
| # if not isinstance(mask_blocks, list): | |
| # mask_blocks = [mask_blocks] | |
| # is_masked_list = [is_masked_list] | |
| # # Decode context | |
| # try: | |
| # context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| # except Exception: | |
| # context_text = str(context_ids[0].tolist()) | |
| # # Build blocks visualization | |
| # all_blocks = [] | |
| # for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| # block_tokens = mask_block[0].tolist() | |
| # block_data = [] | |
| # for i, token_id in enumerate(block_tokens): | |
| # if is_masked[0, i]: | |
| # block_data.append({ | |
| # 'type': 'masked', | |
| # 'text': 'βββ' | |
| # }) | |
| # else: | |
| # try: | |
| # token_text = tok.decode([token_id], skip_special_tokens=False) | |
| # except Exception: | |
| # token_text = str(int(token_id)) | |
| # block_data.append({ | |
| # 'type': 'revealed', | |
| # 'text': token_text | |
| # }) | |
| # all_blocks.append({ | |
| # 'block_index': block_idx, | |
| # 'tokens': block_data | |
| # }) | |
| # # Return data structure that will be sent as SSE | |
| # return { | |
| # 'context': context_text, | |
| # 'blocks': all_blocks, | |
| # 'num_blocks': len(mask_blocks) | |
| # } | |
| # return visualizer | |
| # @app.route('/') | |
| # def index(): | |
| # """Serve the main HTML page.""" | |
| # return app.send_static_file('index.html') | |
| # @app.route('/api/load', methods=['POST']) | |
| # def load_model_endpoint(): | |
| # """Load the model.""" | |
| # data = request.json or {} | |
| # check_only = data.get('check_only', False) | |
| # global model | |
| # if check_only: | |
| # return jsonify({ | |
| # 'loaded': model is not None, | |
| # 'message': 'Model is loaded' if model is not None else 'Model not loaded' | |
| # }) | |
| # if model is not None: | |
| # return jsonify({ | |
| # 'loaded': True, | |
| # 'message': 'Model already loaded' | |
| # }) | |
| # success = load_model_internal() | |
| # if success: | |
| # return jsonify({ | |
| # 'loaded': True, | |
| # 'message': 'Model loaded successfully' | |
| # }) | |
| # else: | |
| # return jsonify({ | |
| # 'loaded': False, | |
| # 'message': 'Failed to load model. Check server logs for details.' | |
| # }), 500 | |
| # @app.route('/api/generate', methods=['POST']) | |
| # def generate(): | |
| # """Generate response without streaming.""" | |
| # global model, tokenizer, config, device, chat_function | |
| # if model is None: | |
| # return jsonify({'error': 'Model not loaded'}), 400 | |
| # if chat_function is None: | |
| # return jsonify({'error': 'Chat function not available'}), 400 | |
| # data = request.json | |
| # instruction = data.get('instruction', '') | |
| # steps = data.get('steps', 64) | |
| # block_size = data.get('block_size', 128) | |
| # max_new_tokens = data.get('max_new_tokens', 128) | |
| # parallel_blocks = data.get('parallel_blocks', 1) | |
| # if not instruction: | |
| # return jsonify({'error': 'No instruction provided'}), 400 | |
| # try: | |
| # # Generate response | |
| # raw_output, response = chat_function( | |
| # model, | |
| # tokenizer, | |
| # instruction, | |
| # steps=steps, | |
| # block_size=block_size, | |
| # max_new_tokens=max_new_tokens, | |
| # temperature=0.8, | |
| # top_k=50, | |
| # top_p=0.9, | |
| # repetition_penalty=1.2, | |
| # no_repeat_ngram_size=3, | |
| # verbose=False, | |
| # visualize_fn=None, | |
| # parallel_blocks=parallel_blocks, | |
| # ) | |
| # return jsonify({ | |
| # 'response': response, | |
| # 'raw_output': raw_output | |
| # }) | |
| # except Exception as e: | |
| # import traceback | |
| # traceback.print_exc() | |
| # return jsonify({'error': str(e)}), 500 | |
| # @app.route('/api/generate-stream', methods=['POST']) | |
| # def generate_stream(): | |
| # """Generate response with streaming visualization.""" | |
| # global model, tokenizer, config, device, chat_function | |
| # if model is None: | |
| # return jsonify({'error': 'Model not loaded'}), 400 | |
| # if chat_function is None: | |
| # return jsonify({'error': 'Chat function not available'}), 400 | |
| # data = request.json | |
| # instruction = data.get('instruction', '') | |
| # steps = data.get('steps', 64) | |
| # block_size = data.get('block_size', 128) | |
| # max_new_tokens = data.get('max_new_tokens', 128) | |
| # parallel_blocks = data.get('parallel_blocks', 1) | |
| # if not instruction: | |
| # return jsonify({'error': 'No instruction provided'}), 400 | |
| # def generate_events(): | |
| # try: | |
| # # Import threading to allow yielding from callback | |
| # import queue | |
| # event_queue = queue.Queue() | |
| # generation_complete = {'done': False, 'result': None} | |
| # def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| # """This gets called during generation - we need to send events immediately""" | |
| # visualizer = create_streaming_visualizer() | |
| # data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) | |
| # # Put the update in the queue so it can be yielded immediately | |
| # event_queue.put({'type': 'update', 'data': data}) | |
| # # Start generation in a separate thread so we can yield events as they come | |
| # import threading | |
| # def run_generation(): | |
| # try: | |
| # raw_output, response = chat_function( | |
| # model, | |
| # tokenizer, | |
| # instruction, | |
| # steps=steps, | |
| # block_size=block_size, | |
| # max_new_tokens=max_new_tokens, | |
| # temperature=0.8, | |
| # top_k=50, | |
| # top_p=0.9, | |
| # repetition_penalty=1.2, | |
| # no_repeat_ngram_size=3, | |
| # verbose=False, | |
| # visualize_fn=streaming_visualizer, | |
| # parallel_blocks=parallel_blocks, | |
| # ) | |
| # generation_complete['result'] = (raw_output, response) | |
| # except Exception as e: | |
| # generation_complete['result'] = ('error', str(e)) | |
| # finally: | |
| # generation_complete['done'] = True | |
| # event_queue.put(None) # Signal completion | |
| # # Start generation thread | |
| # gen_thread = threading.Thread(target=run_generation) | |
| # gen_thread.daemon = True | |
| # gen_thread.start() | |
| # # Yield start event | |
| # yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n" | |
| # # Yield events as they come from the queue | |
| # while not generation_complete['done'] or not event_queue.empty(): | |
| # try: | |
| # event = event_queue.get(timeout=0.1) | |
| # if event is None: # Completion signal | |
| # break | |
| # yield f"data: {json.dumps(event)}\n\n" | |
| # except queue.Empty: | |
| # continue | |
| # # Wait for thread to finish | |
| # gen_thread.join(timeout=1.0) | |
| # # Send final response | |
| # if generation_complete['result']: | |
| # raw_output, response = generation_complete['result'] | |
| # if raw_output == 'error': | |
| # yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" | |
| # else: | |
| # yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n" | |
| # except Exception as e: | |
| # import traceback | |
| # traceback.print_exc() | |
| # yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" | |
| # return Response( | |
| # stream_with_context(generate_events()), | |
| # mimetype='text/event-stream', | |
| # headers={ | |
| # 'Cache-Control': 'no-cache', | |
| # 'X-Accel-Buffering': 'no' | |
| # } | |
| # ) | |
| # @app.route('/api/test-stream', methods=['GET']) | |
| # def test_stream(): | |
| # """Test streaming endpoint.""" | |
| # def generate(): | |
| # for i in range(10): | |
| # yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n" | |
| # time.sleep(0.5) | |
| # yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n" | |
| # return Response( | |
| # stream_with_context(generate()), | |
| # mimetype='text/event-stream', | |
| # headers={ | |
| # 'Cache-Control': 'no-cache', | |
| # 'X-Accel-Buffering': 'no' | |
| # } | |
| # ) | |
| # if __name__ == '__main__': | |
| # app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import importlib.util | |
| from pathlib import Path | |
| from flask import Flask, request, jsonify, Response, stream_with_context | |
| from flask_cors import CORS | |
| import torch | |
| from transformers import AutoTokenizer | |
| import threading | |
| import queue | |
| app = Flask(__name__, static_folder='static', static_url_path='/static') | |
| CORS(app) | |
| # Global state | |
| model = None | |
| tokenizer = None | |
| config = None | |
| device = None | |
| DiffusionLLM = None | |
| chat_function = None | |
| optimized_pipeline = None # For ONNX/IPEX | |
| # ==================== CONFIGURATION ==================== | |
| USE_ONNX_RUNTIME = False # Set True for ONNX (fastest) | |
| USE_IPEX = False # Set True for Intel CPUs | |
| USE_TORCH_COMPILE = True # Set True for PyTorch 2.0+ (good default) | |
| QUANTIZE_MODEL = True # INT8/BF16 quantization | |
| WARMUP_ITERATIONS = 3 # Warmup for stable performance | |
| # ======================================================= | |
| def find_file(filename, search_dirs=None): | |
| """Find a file in current directory or parent directories.""" | |
| if search_dirs is None: | |
| search_dirs = [ | |
| os.path.dirname(__file__), | |
| os.path.dirname(os.path.dirname(__file__)), | |
| os.getcwd(), | |
| ] | |
| for directory in search_dirs: | |
| filepath = os.path.join(directory, filename) | |
| if os.path.exists(filepath): | |
| print(f"Found {filename} at: {filepath}") | |
| return filepath | |
| return None | |
| def try_import_module(filepath, module_name): | |
| """Dynamically import a Python file as a module.""" | |
| if not filepath or not os.path.exists(filepath): | |
| return None | |
| try: | |
| module_dir = os.path.dirname(filepath) | |
| if module_dir not in sys.path: | |
| sys.path.insert(0, module_dir) | |
| spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| if spec is None: | |
| print(f"Could not create spec for {filepath}") | |
| return None | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| print(f"Successfully imported {module_name} from {filepath}") | |
| return module | |
| except Exception as e: | |
| print(f"Error importing {filepath}: {e}") | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def configure_cpu_optimization(): | |
| """Configure CPU for maximum performance.""" | |
| print("\n" + "=" * 60) | |
| print("CPU OPTIMIZATION CONFIGURATION") | |
| print("=" * 60) | |
| # Get CPU info | |
| cpu_count = os.cpu_count() | |
| physical_cores = cpu_count // 2 if cpu_count else cpu_count | |
| # Optimal thread settings | |
| threads = physical_cores or cpu_count or 1 | |
| torch.set_num_threads(threads) | |
| torch.set_num_interop_threads(threads) | |
| # Environment variables for MKL/OMP | |
| os.environ["OMP_NUM_THREADS"] = str(threads) | |
| os.environ["MKL_NUM_THREADS"] = str(threads) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(threads) | |
| print(f"β CPU Cores: {cpu_count} ({physical_cores} physical)") | |
| print(f"β Threads: {threads}") | |
| print(f"β OMP/MKL threads: {threads}") | |
| # Intel-specific optimizations | |
| if "intel" in torch.__version__.lower() or USE_IPEX: | |
| torch.backends.quantized.engine = 'fbgemm' | |
| print("β Intel FBGEMM backend enabled") | |
| print("=" * 60 + "\n") | |
| def quantize_model(model): | |
| """Apply quantization for faster inference.""" | |
| if not QUANTIZE_MODEL: | |
| return model | |
| print("\nApplying quantization...") | |
| try: | |
| # Use torch.compile with quantization if available | |
| if hasattr(torch, 'ao') and hasattr(torch.ao, 'quantization'): | |
| # Dynamic quantization (fastest, no calibration needed) | |
| model = torch.quantization.quantize_dynamic( | |
| model, | |
| {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d}, | |
| dtype=torch.qint8 | |
| ) | |
| print("β Applied INT8 dynamic quantization") | |
| elif USE_IPEX: | |
| # IPEX BF16 optimization | |
| model = model.to(torch.bfloat16) | |
| print("β Applied BF16 precision (IPEX)") | |
| return model | |
| except Exception as e: | |
| print(f"β Quantization failed: {e}") | |
| return model | |
| def compile_model(model): | |
| """Compile model for maximum speed.""" | |
| global USE_TORCH_COMPILE, USE_ONNX_RUNTIME | |
| print("\nCompiling model...") | |
| # Option 1: ONNX Runtime (BEST performance) | |
| if USE_ONNX_RUNTIME: | |
| try: | |
| import onnxruntime as ort | |
| # Export to ONNX (one-time cost, but worth it) | |
| onnx_path = "model_optimized.onnx" | |
| if not os.path.exists(onnx_path): | |
| print("Exporting to ONNX format...") | |
| dummy_input = torch.randint(0, 100, (1, 128)) # Adjust shape as needed | |
| torch.onnx.export( | |
| model, | |
| dummy_input, | |
| onnx_path, | |
| input_names=['input_ids'], | |
| output_names=['logits'], | |
| dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}}, | |
| opset_version=16 | |
| ) | |
| # Create ONNX Runtime session | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.enable_cpu_mem_arena = True | |
| sess_options.enable_mem_pattern = True | |
| # Use all cores | |
| sess_options.intra_op_num_threads = torch.get_num_threads() | |
| provider = 'CPUExecutionProvider' | |
| compiled_model = ort.InferenceSession(onnx_path, sess_options, providers=[provider]) | |
| print("β ONNX Runtime compilation complete") | |
| return compiled_model | |
| except Exception as e: | |
| print(f"β ONNX Runtime failed: {e}, using torch.compile") | |
| USE_ONNX_RUNTIME = False | |
| # Option 2: Intel IPEX | |
| if USE_IPEX: | |
| try: | |
| import intel_extension_for_pytorch as ipex | |
| model = ipex.optimize(model, dtype=torch.bfloat16, level="O3") | |
| print("β Intel IPEX O3 optimization applied") | |
| return model | |
| except Exception as e: | |
| print(f"β IPEX failed: {e}") | |
| USE_IPEX = False | |
| # Option 3: torch.compile (PyTorch 2.0+) | |
| if USE_TORCH_COMPILE and hasattr(torch, 'compile'): | |
| try: | |
| # Use "max-autotune" for best performance | |
| model = torch.compile( | |
| model, | |
| mode="max-autotune", | |
| fullgraph=True, | |
| backend="inductor" | |
| ) | |
| print("β torch.compile (max-autotune) applied") | |
| except Exception as e: | |
| print(f"β torch.compile failed: {e}, using eager mode") | |
| return model | |
| def warmup_model(model, tokenizer, chat_func): | |
| """Warmup the model for consistent performance.""" | |
| if WARMUP_ITERATIONS == 0: | |
| return | |
| print("\nWarming up model...") | |
| start_time = time.time() | |
| try: | |
| with torch.inference_mode(): | |
| for i in range(WARMUP_ITERATIONS): | |
| _ = chat_func( | |
| model, tokenizer, "Hello world", | |
| steps=8, block_size=32, max_new_tokens=16, | |
| temperature=0.8, top_k=50, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3, | |
| verbose=False, visualize_fn=None, parallel_blocks=2 | |
| ) | |
| print(f" Warmup {i+1}/{WARMUP_ITERATIONS} complete") | |
| except Exception as e: | |
| print(f"β Warmup failed: {e}") | |
| print(f"β Warmup finished in {time.time() - start_time:.2f}s") | |
| def load_model_internal(): | |
| """Load the model with ultra-fast optimizations.""" | |
| global model, tokenizer, config, device, DiffusionLLM, chat_function, optimized_pipeline | |
| if model is not None: | |
| return True | |
| try: | |
| print("=" * 60) | |
| print("ULTRA-FAST CPU MODEL LOADING") | |
| print("=" * 60) | |
| # Configure CPU | |
| configure_cpu_optimization() | |
| # Import modules | |
| base_path = find_file("infer-base.py") | |
| if base_path is None: | |
| raise RuntimeError("Could not find infer-base.py") | |
| base_mod = try_import_module(base_path, "infer_base") | |
| if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'): | |
| raise RuntimeError("DiffusionLLM class not found") | |
| DiffusionLLM = base_mod.DiffusionLLM | |
| chat_path = find_file("infer-chat.py") | |
| if chat_path is None: | |
| raise RuntimeError("Could not find infer-chat.py") | |
| chat_mod = try_import_module(chat_path, "infer_chat") | |
| if chat_mod is None or not hasattr(chat_mod, 'chat'): | |
| raise RuntimeError("Chat function not found") | |
| chat_function = chat_mod.chat | |
| # Device | |
| device = torch.device("cpu") | |
| # Load tokenizer | |
| print("\nLoading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "Qwen/Qwen2.5-0.5B", | |
| use_fast=True, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("β Fast tokenizer loaded") | |
| # Find model | |
| checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"] | |
| model_path = None | |
| for checkpoint_dir in checkpoint_dirs: | |
| best_path = os.path.join(checkpoint_dir, "best_model.pt") | |
| fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") | |
| if os.path.exists(best_path): | |
| model_path = best_path | |
| break | |
| elif os.path.exists(fp32_path): | |
| model_path = fp32_path | |
| if model_path is None: | |
| raise RuntimeError("Model checkpoint not found") | |
| print(f"\nLoading model from: {model_path}") | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| config = checkpoint['config'] | |
| # Create and load model | |
| model = DiffusionLLM(config) | |
| state_dict = checkpoint['model_state'] | |
| # Convert to float32 for CPU | |
| if not USE_IPEX: # Keep FP32 for ONNX, use BF16 for IPEX | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Apply optimizations | |
| model = quantize_model(model) | |
| model = model.to(device) | |
| model = compile_model(model) | |
| # Warmup | |
| warmup_model(model, tokenizer, chat_function) | |
| # Print summary | |
| num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| print(f"\n{'=' * 60}") | |
| print(f"βββ MODEL LOADED & ULTRA-OPTIMIZED FOR CPU βββ") | |
| print(f"{'=' * 60}") | |
| print(f"Framework: {'ONNX Runtime' if USE_ONNX_RUNTIME else 'IPEX' if USE_IPEX else 'PyTorch'}") | |
| print(f"Parameters: {num_params:.1f}M") | |
| print(f"CPU Threads: {torch.get_num_threads()}") | |
| print(f"Quantization: {'INT8' if QUANTIZE_MODEL else 'BF16' if USE_IPEX else 'FP32'}") | |
| if 'step' in checkpoint: | |
| print(f"Training steps: {checkpoint['step']}") | |
| print(f"{'=' * 60}\n") | |
| return True | |
| except Exception as e: | |
| print(f"\nERROR LOADING MODEL: {e}") | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def create_streaming_visualizer(): | |
| """Create optimized visualizer.""" | |
| def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| if not isinstance(mask_blocks, list): | |
| mask_blocks = [mask_blocks] | |
| is_masked_list = [is_masked_list] | |
| try: | |
| # Decode only once for efficiency | |
| context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| except Exception: | |
| context_text = str(context_ids[0].tolist()) | |
| all_blocks = [] | |
| for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| block_tokens = mask_block[0].tolist() | |
| block_data = [] | |
| # Batch decode for speed | |
| token_ids_to_decode = [] | |
| positions = [] | |
| for i, token_id in enumerate(block_tokens): | |
| if not is_masked[0, i]: | |
| token_ids_to_decode.append(token_id) | |
| positions.append(i) | |
| try: | |
| decoded_tokens = tok.batch_decode( | |
| [[tid] for tid in token_ids_to_decode], | |
| skip_special_tokens=False | |
| ) | |
| except Exception: | |
| decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode] | |
| # Reconstruct block | |
| decoded_idx = 0 | |
| for i, token_id in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| block_data.append({'type': 'masked', 'text': 'βββ'}) | |
| else: | |
| block_data.append({ | |
| 'type': 'revealed', | |
| 'text': decoded_tokens[decoded_idx] | |
| }) | |
| decoded_idx += 1 | |
| all_blocks.append({'block_index': block_idx, 'tokens': block_data}) | |
| return { | |
| 'context': context_text, | |
| 'blocks': all_blocks, | |
| 'num_blocks': len(mask_blocks) | |
| } | |
| return visualizer | |
| def index(): | |
| """Serve the main HTML page.""" | |
| return app.send_static_file('index.html') | |
| def load_model_endpoint(): | |
| """Load the model.""" | |
| data = request.json or {} | |
| check_only = data.get('check_only', False) | |
| global model | |
| if check_only: | |
| return jsonify({ | |
| 'loaded': model is not None, | |
| 'message': 'Model is loaded' if model is not None else 'Model not loaded' | |
| }) | |
| if model is not None: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model already loaded' | |
| }) | |
| success = load_model_internal() | |
| if success: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model loaded with ultra-fast CPU optimizations' | |
| }) | |
| else: | |
| return jsonify({ | |
| 'loaded': False, | |
| 'message': 'Failed to load model. Check server logs.' | |
| }), 500 | |
| def generate(): | |
| """Generate response - optimized for minimal latency.""" | |
| global model, tokenizer, device, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| if chat_function is None: | |
| return jsonify({'error': 'Chat function not available'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 16) # Further reduced for speed | |
| block_size = data.get('block_size', 32) | |
| max_new_tokens = data.get('max_new_tokens', 64) | |
| parallel_blocks = data.get('parallel_blocks', torch.get_num_threads()) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| try: | |
| # Fast path: no overhead | |
| with torch.inference_mode(): | |
| raw_output, response = chat_function( | |
| model, | |
| tokenizer, | |
| instruction, | |
| steps=steps, | |
| block_size=block_size, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.7, # Slightly lower for faster sampling | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| verbose=False, | |
| visualize_fn=None, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| return jsonify({ | |
| 'response': response, | |
| 'raw_output': raw_output | |
| }) | |
| except Exception as e: | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def generate_stream(): | |
| """Generate with streaming - optimized with queue.""" | |
| global model, tokenizer, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| if chat_function is None: | |
| return jsonify({'error': 'Chat function not available'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 16) | |
| block_size = data.get('block_size', 32) | |
| max_new_tokens = data.get('max_new_tokens', 64) | |
| parallel_blocks = data.get('parallel_blocks', torch.get_num_threads()) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| def generate_events(): | |
| event_queue = queue.Queue(maxsize=100) # Limit queue size | |
| generation_complete = {'done': False, 'result': None} | |
| def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| try: | |
| visualizer = create_streaming_visualizer() | |
| data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) | |
| event_queue.put({'type': 'update', 'data': data}, block=False) | |
| except queue.Full: | |
| pass # Drop updates if queue full (prevents memory issues) | |
| def run_generation(): | |
| try: | |
| with torch.inference_mode(): | |
| raw_output, response = chat_function( | |
| model, tokenizer, instruction, | |
| steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, | |
| temperature=0.7, top_k=50, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3, | |
| verbose=False, visualize_fn=streaming_visualizer, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| generation_complete['result'] = (raw_output, response) | |
| except Exception as e: | |
| generation_complete['result'] = ('error', str(e)) | |
| finally: | |
| generation_complete['done'] = True | |
| event_queue.put(None) | |
| import threading | |
| gen_thread = threading.Thread(target=run_generation) | |
| gen_thread.daemon = True | |
| gen_thread.start() | |
| yield f"data: {json.dumps({'type': 'start'})}\n\n" | |
| while not generation_complete['done'] or not event_queue.empty(): | |
| try: | |
| event = event_queue.get(timeout=0.1) | |
| if event is None: | |
| break | |
| yield f"data: {json.dumps(event)}\n\n" | |
| except queue.Empty: | |
| continue | |
| gen_thread.join(timeout=2.0) | |
| if generation_complete['result']: | |
| raw_output, response = generation_complete['result'] | |
| if raw_output == 'error': | |
| yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" | |
| else: | |
| yield f"data: {json.dumps({'type': 'complete', 'response': response})}\n\n" | |
| return Response( | |
| stream_with_context(generate_events()), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'X-Accel-Buffering': 'no', | |
| 'Connection': 'keep-alive' | |
| } | |
| ) | |
| def status(): | |
| """Get system status.""" | |
| return jsonify({ | |
| 'model_loaded': model is not None, | |
| 'torch_threads': torch.get_num_threads(), | |
| 'interop_threads': torch.get_num_interop_threads(), | |
| 'cpu_count': os.cpu_count(), | |
| 'optimizations': { | |
| 'onnx_runtime': USE_ONNX_RUNTIME, | |
| 'ipex': USE_IPEX, | |
| 'torch_compile': USE_TORCH_COMPILE, | |
| 'quantization': QUANTIZE_MODEL | |
| } | |
| }) | |
| if __name__ == '__main__': | |
| print("\n" + "=" * 70) | |
| print("ULTRA-FAST CPU INFERENCE SERVER") | |
| print("=" * 70) | |
| print("Available optimizations:") | |
| print(" β ONNX Runtime (best):", USE_ONNX_RUNTIME) | |
| print(" β Intel IPEX:", USE_IPEX) | |
| print(" β torch.compile:", USE_TORCH_COMPILE) | |
| print(" β Quantization:", QUANTIZE_MODEL) | |
| print(" β Multi-threading") | |
| print(" β Inference mode") | |
| print(" β Fast tokenizer") | |
| print(" β Memory layout optimization") | |
| print("\nTo install ONNX Runtime: pip install onnxruntime") | |
| print("To install Intel IPEX: pip install intel-extension-for-pytorch") | |
| print("=" * 70 + "\n") | |
| app.run(debug=False, host='0.0.0.0', port=7860, threaded=True) |