Spaces:
Sleeping
Sleeping
| # import os | |
| # import sys | |
| # import json | |
| # import time | |
| # import importlib.util | |
| # from pathlib import Path | |
| # from flask import Flask, request, jsonify, Response, stream_with_context | |
| # from flask_cors import CORS | |
| # import torch | |
| # from transformers import AutoTokenizer | |
| # app = Flask(__name__, static_folder='static', static_url_path='/static') | |
| # CORS(app) | |
| # # Global state | |
| # model = None | |
| # tokenizer = None | |
| # config = None | |
| # device = None | |
| # DiffusionLLM = None | |
| # chat_function = None | |
| # def find_file(filename, search_dirs=None): | |
| # """Find a file in current directory or parent directories.""" | |
| # if search_dirs is None: | |
| # search_dirs = [ | |
| # os.path.dirname(__file__), # Current directory | |
| # os.path.dirname(os.path.dirname(__file__)), # Parent directory | |
| # os.getcwd(), # Working directory | |
| # ] | |
| # for directory in search_dirs: | |
| # filepath = os.path.join(directory, filename) | |
| # if os.path.exists(filepath): | |
| # print(f"Found {filename} at: {filepath}") | |
| # return filepath | |
| # return None | |
| # def try_import_module(filepath, module_name): | |
| # """Dynamically import a Python file as a module.""" | |
| # if not filepath or not os.path.exists(filepath): | |
| # return None | |
| # try: | |
| # # Add the directory to sys.path | |
| # module_dir = os.path.dirname(filepath) | |
| # if module_dir not in sys.path: | |
| # sys.path.insert(0, module_dir) | |
| # spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| # if spec is None: | |
| # print(f"Could not create spec for {filepath}") | |
| # return None | |
| # module = importlib.util.module_from_spec(spec) | |
| # sys.modules[module_name] = module | |
| # spec.loader.exec_module(module) | |
| # print(f"Successfully imported {module_name} from {filepath}") | |
| # return module | |
| # except Exception as e: | |
| # print(f"Error importing {filepath}: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # return None | |
| # def load_model_internal(): | |
| # """Load the model and tokenizer.""" | |
| # global model, tokenizer, config, device, DiffusionLLM, chat_function | |
| # if model is not None: | |
| # return True | |
| # try: | |
| # print("=" * 60) | |
| # print("Starting model loading process...") | |
| # print("=" * 60) | |
| # # Find and import infer-base.py | |
| # base_path = find_file("infer-base.py") | |
| # if base_path is None: | |
| # raise RuntimeError("Could not find infer-base.py. Make sure it's in the same directory as app.py or parent directory.") | |
| # print(f"\nImporting infer-base.py from: {base_path}") | |
| # base_mod = try_import_module(base_path, "infer_base") | |
| # if base_mod is None: | |
| # raise RuntimeError("Failed to import infer-base.py") | |
| # # Check for DiffusionLLM class | |
| # if not hasattr(base_mod, 'DiffusionLLM'): | |
| # print("Available attributes in infer_base:", dir(base_mod)) | |
| # raise RuntimeError("DiffusionLLM class not found in infer-base.py") | |
| # DiffusionLLM = base_mod.DiffusionLLM | |
| # print("β Successfully loaded DiffusionLLM class") | |
| # # Find and import infer-chat.py | |
| # chat_path = find_file("infer-chat.py") | |
| # if chat_path is None: | |
| # raise RuntimeError("Could not find infer-chat.py") | |
| # print(f"\nImporting infer-chat.py from: {chat_path}") | |
| # chat_mod = try_import_module(chat_path, "infer_chat") | |
| # if chat_mod is None or not hasattr(chat_mod, 'chat'): | |
| # raise RuntimeError("Failed to import chat function from infer-chat.py") | |
| # chat_function = chat_mod.chat | |
| # print("β Successfully loaded chat function") | |
| # # Setup pickling workaround for torch.load | |
| # try: | |
| # if hasattr(base_mod, 'ModelConfig'): | |
| # sys.modules['__main__'].ModelConfig = base_mod.ModelConfig | |
| # sys.modules['__main__'].DiffusionLLM = DiffusionLLM | |
| # print("β Configured pickle support for model loading") | |
| # except Exception as e: | |
| # print(f"Warning: Could not setup pickle workaround: {e}") | |
| # # Set device | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"\nβ Using device: {device}") | |
| # # Load tokenizer | |
| # print("\nLoading tokenizer...") | |
| # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| # if tokenizer.pad_token is None: | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # print("β Tokenizer loaded") | |
| # # Find model checkpoint | |
| # checkpoint_dirs = [ | |
| # "checkpoints", | |
| # "../checkpoints", | |
| # "./checkpoints", | |
| # os.path.join(os.path.dirname(__file__), "checkpoints"), | |
| # os.path.join(os.path.dirname(__file__), "../checkpoints"), | |
| # ] | |
| # model_path = None | |
| # for checkpoint_dir in checkpoint_dirs: | |
| # best_path = os.path.join(checkpoint_dir, "best_model.pt") | |
| # fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") | |
| # if os.path.exists(best_path): | |
| # model_path = best_path | |
| # break | |
| # elif os.path.exists(fp32_path): | |
| # model_path = fp32_path | |
| # break | |
| # if model_path is None: | |
| # raise RuntimeError( | |
| # "Could not find model checkpoint. Looking for:\n" | |
| # " - checkpoints/best_model.pt\n" | |
| # " - checkpoints/model_fp32.pt\n" | |
| # f"Searched directories: {checkpoint_dirs}" | |
| # ) | |
| # print(f"\nβ Found model checkpoint: {model_path}") | |
| # print("Loading model weights (this may take a minute)...") | |
| # # Load model | |
| # checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| # config = checkpoint['config'] | |
| # print("Creating model...") | |
| # model = DiffusionLLM(config) | |
| # print("Loading state dict...") | |
| # state_dict = checkpoint['model_state'] | |
| # state_dict = {k: v.float() for k, v in state_dict.items()} | |
| # model.load_state_dict(state_dict) | |
| # model = model.to(device) | |
| # model.eval() | |
| # num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| # print(f"\n{'=' * 60}") | |
| # print(f"βββ MODEL LOADED SUCCESSFULLY βββ") | |
| # print(f"{'=' * 60}") | |
| # print(f"Parameters: {num_params:.1f}M") | |
| # if 'step' in checkpoint: | |
| # print(f"Training steps: {checkpoint['step']}") | |
| # if 'best_val_loss' in checkpoint: | |
| # print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}") | |
| # print(f"{'=' * 60}\n") | |
| # return True | |
| # except Exception as e: | |
| # print("\n" + "=" * 60) | |
| # print("ERROR LOADING MODEL") | |
| # print("=" * 60) | |
| # print(f"Error: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # print("=" * 60 + "\n") | |
| # return False | |
| # def create_streaming_visualizer(): | |
| # """Create a visualizer that yields SSE events instead of printing to terminal.""" | |
| # def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| # # Normalize inputs to lists | |
| # if not isinstance(mask_blocks, list): | |
| # mask_blocks = [mask_blocks] | |
| # is_masked_list = [is_masked_list] | |
| # # Decode context | |
| # try: | |
| # context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| # except Exception: | |
| # context_text = str(context_ids[0].tolist()) | |
| # # Build blocks visualization | |
| # all_blocks = [] | |
| # for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| # block_tokens = mask_block[0].tolist() | |
| # block_data = [] | |
| # for i, token_id in enumerate(block_tokens): | |
| # if is_masked[0, i]: | |
| # block_data.append({ | |
| # 'type': 'masked', | |
| # 'text': 'βββ' | |
| # }) | |
| # else: | |
| # try: | |
| # token_text = tok.decode([token_id], skip_special_tokens=False) | |
| # except Exception: | |
| # token_text = str(int(token_id)) | |
| # block_data.append({ | |
| # 'type': 'revealed', | |
| # 'text': token_text | |
| # }) | |
| # all_blocks.append({ | |
| # 'block_index': block_idx, | |
| # 'tokens': block_data | |
| # }) | |
| # # Return data structure that will be sent as SSE | |
| # return { | |
| # 'context': context_text, | |
| # 'blocks': all_blocks, | |
| # 'num_blocks': len(mask_blocks) | |
| # } | |
| # return visualizer | |
| # @app.route('/') | |
| # def index(): | |
| # """Serve the main HTML page.""" | |
| # return app.send_static_file('index.html') | |
| # @app.route('/api/load', methods=['POST']) | |
| # def load_model_endpoint(): | |
| # """Load the model.""" | |
| # data = request.json or {} | |
| # check_only = data.get('check_only', False) | |
| # global model | |
| # if check_only: | |
| # return jsonify({ | |
| # 'loaded': model is not None, | |
| # 'message': 'Model is loaded' if model is not None else 'Model not loaded' | |
| # }) | |
| # if model is not None: | |
| # return jsonify({ | |
| # 'loaded': True, | |
| # 'message': 'Model already loaded' | |
| # }) | |
| # success = load_model_internal() | |
| # if success: | |
| # return jsonify({ | |
| # 'loaded': True, | |
| # 'message': 'Model loaded successfully' | |
| # }) | |
| # else: | |
| # return jsonify({ | |
| # 'loaded': False, | |
| # 'message': 'Failed to load model. Check server logs for details.' | |
| # }), 500 | |
| # @app.route('/api/generate', methods=['POST']) | |
| # def generate(): | |
| # """Generate response without streaming.""" | |
| # global model, tokenizer, config, device, chat_function | |
| # if model is None: | |
| # return jsonify({'error': 'Model not loaded'}), 400 | |
| # if chat_function is None: | |
| # return jsonify({'error': 'Chat function not available'}), 400 | |
| # data = request.json | |
| # instruction = data.get('instruction', '') | |
| # steps = data.get('steps', 64) | |
| # block_size = data.get('block_size', 128) | |
| # max_new_tokens = data.get('max_new_tokens', 128) | |
| # parallel_blocks = data.get('parallel_blocks', 1) | |
| # if not instruction: | |
| # return jsonify({'error': 'No instruction provided'}), 400 | |
| # try: | |
| # # Generate response | |
| # raw_output, response = chat_function( | |
| # model, | |
| # tokenizer, | |
| # instruction, | |
| # steps=steps, | |
| # block_size=block_size, | |
| # max_new_tokens=max_new_tokens, | |
| # temperature=0.8, | |
| # top_k=50, | |
| # top_p=0.9, | |
| # repetition_penalty=1.2, | |
| # no_repeat_ngram_size=3, | |
| # verbose=False, | |
| # visualize_fn=None, | |
| # parallel_blocks=parallel_blocks, | |
| # ) | |
| # return jsonify({ | |
| # 'response': response, | |
| # 'raw_output': raw_output | |
| # }) | |
| # except Exception as e: | |
| # import traceback | |
| # traceback.print_exc() | |
| # return jsonify({'error': str(e)}), 500 | |
| # @app.route('/api/generate-stream', methods=['POST']) | |
| # def generate_stream(): | |
| # """Generate response with streaming visualization.""" | |
| # global model, tokenizer, config, device, chat_function | |
| # if model is None: | |
| # return jsonify({'error': 'Model not loaded'}), 400 | |
| # if chat_function is None: | |
| # return jsonify({'error': 'Chat function not available'}), 400 | |
| # data = request.json | |
| # instruction = data.get('instruction', '') | |
| # steps = data.get('steps', 64) | |
| # block_size = data.get('block_size', 128) | |
| # max_new_tokens = data.get('max_new_tokens', 128) | |
| # parallel_blocks = data.get('parallel_blocks', 1) | |
| # if not instruction: | |
| # return jsonify({'error': 'No instruction provided'}), 400 | |
| # def generate_events(): | |
| # try: | |
| # # Import threading to allow yielding from callback | |
| # import queue | |
| # event_queue = queue.Queue() | |
| # generation_complete = {'done': False, 'result': None} | |
| # def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| # """This gets called during generation - we need to send events immediately""" | |
| # visualizer = create_streaming_visualizer() | |
| # data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) | |
| # # Put the update in the queue so it can be yielded immediately | |
| # event_queue.put({'type': 'update', 'data': data}) | |
| # # Start generation in a separate thread so we can yield events as they come | |
| # import threading | |
| # def run_generation(): | |
| # try: | |
| # raw_output, response = chat_function( | |
| # model, | |
| # tokenizer, | |
| # instruction, | |
| # steps=steps, | |
| # block_size=block_size, | |
| # max_new_tokens=max_new_tokens, | |
| # temperature=0.8, | |
| # top_k=50, | |
| # top_p=0.9, | |
| # repetition_penalty=1.2, | |
| # no_repeat_ngram_size=3, | |
| # verbose=False, | |
| # visualize_fn=streaming_visualizer, | |
| # parallel_blocks=parallel_blocks, | |
| # ) | |
| # generation_complete['result'] = (raw_output, response) | |
| # except Exception as e: | |
| # generation_complete['result'] = ('error', str(e)) | |
| # finally: | |
| # generation_complete['done'] = True | |
| # event_queue.put(None) # Signal completion | |
| # # Start generation thread | |
| # gen_thread = threading.Thread(target=run_generation) | |
| # gen_thread.daemon = True | |
| # gen_thread.start() | |
| # # Yield start event | |
| # yield f"data: {json.dumps({'type': 'start', 'message': 'Generation started'})}\n\n" | |
| # # Yield events as they come from the queue | |
| # while not generation_complete['done'] or not event_queue.empty(): | |
| # try: | |
| # event = event_queue.get(timeout=0.1) | |
| # if event is None: # Completion signal | |
| # break | |
| # yield f"data: {json.dumps(event)}\n\n" | |
| # except queue.Empty: | |
| # continue | |
| # # Wait for thread to finish | |
| # gen_thread.join(timeout=1.0) | |
| # # Send final response | |
| # if generation_complete['result']: | |
| # raw_output, response = generation_complete['result'] | |
| # if raw_output == 'error': | |
| # yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n" | |
| # else: | |
| # yield f"data: {json.dumps({'type': 'complete', 'response': response, 'raw_output': raw_output})}\n\n" | |
| # except Exception as e: | |
| # import traceback | |
| # traceback.print_exc() | |
| # yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" | |
| # return Response( | |
| # stream_with_context(generate_events()), | |
| # mimetype='text/event-stream', | |
| # headers={ | |
| # 'Cache-Control': 'no-cache', | |
| # 'X-Accel-Buffering': 'no' | |
| # } | |
| # ) | |
| # @app.route('/api/test-stream', methods=['GET']) | |
| # def test_stream(): | |
| # """Test streaming endpoint.""" | |
| # def generate(): | |
| # for i in range(10): | |
| # yield f"data: {json.dumps({'message': f'Test message {i+1}'})}\n\n" | |
| # time.sleep(0.5) | |
| # yield f"data: {json.dumps({'message': 'Stream complete'})}\n\n" | |
| # return Response( | |
| # stream_with_context(generate()), | |
| # mimetype='text/event-stream', | |
| # headers={ | |
| # 'Cache-Control': 'no-cache', | |
| # 'X-Accel-Buffering': 'no' | |
| # } | |
| # ) | |
| # if __name__ == '__main__': | |
| # app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import importlib.util | |
| from pathlib import Path | |
| from flask import Flask, request, jsonify, Response, stream_with_context | |
| from flask_cors import CORS | |
| import torch | |
| from transformers import AutoTokenizer | |
| import threading | |
| import queue | |
| import warnings | |
| # ============ CRITICAL: CONFIGURE THREADS BEFORE TORCH OPERATIONS ============ | |
| # Must be set IMMEDIATELY at module import time | |
| def setup_cpu_threads(): | |
| """Configure CPU threads BEFORE any PyTorch parallel work starts.""" | |
| cpu_count = os.cpu_count() or 1 | |
| physical_cores = cpu_count // 2 if cpu_count > 1 else 1 | |
| # Set environment variables FIRST | |
| os.environ["OMP_NUM_THREADS"] = str(physical_cores) | |
| os.environ["MKL_NUM_THREADS"] = str(physical_cores) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(physical_cores) | |
| # Set PyTorch threads BEFORE any operations | |
| try: | |
| torch.set_num_threads(physical_cores) | |
| torch.set_num_interop_threads(physical_cores) | |
| except RuntimeError as e: | |
| warnings.warn(f"Could not set threads: {e} (already initialized)") | |
| print(f"β CPU threads configured: {physical_cores} physical cores") | |
| return physical_cores | |
| # Call immediately | |
| PHYSICAL_CORES = setup_cpu_threads() | |
| # ============================================================================ | |
| # Configuration flags | |
| USE_ONNX_RUNTIME = False | |
| USE_IPEX = False | |
| USE_TORCH_COMPILE = True | |
| QUANTIZE_MODEL = True | |
| WARMUP_ITERATIONS = 3 | |
| app = Flask(__name__, static_folder='static', static_url_path='/static') | |
| CORS(app, resources={r"/api/*": {"origins": "*"}}) # More permissive for testing | |
| # Global state | |
| model = None | |
| tokenizer = None | |
| config = None | |
| device = None | |
| DiffusionLLM = None | |
| chat_function = None | |
| ModelConfig = None # Will be imported from infer-base.py | |
| def find_file(filename, search_dirs=None): | |
| """Find a file in current directory or parent directories.""" | |
| if search_dirs is None: | |
| search_dirs = [ | |
| os.path.dirname(__file__), | |
| os.path.dirname(os.path.dirname(__file__)), | |
| os.getcwd(), | |
| ] | |
| for directory in search_dirs: | |
| filepath = os.path.join(directory, filename) | |
| if os.path.exists(filepath): | |
| print(f"Found {filename} at: {filepath}") | |
| return filepath | |
| return None | |
| def try_import_module(filepath, module_name): | |
| """Dynamically import a Python file as a module.""" | |
| if not filepath or not os.path.exists(filepath): | |
| return None | |
| try: | |
| module_dir = os.path.dirname(filepath) | |
| if module_dir not in sys.path: | |
| sys.path.insert(0, module_dir) | |
| spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| if spec is None: | |
| print(f"Could not create spec for {filepath}") | |
| return None | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[module_name] = module | |
| spec.loader.exec_module(module) | |
| print(f"β Successfully imported {module_name}") | |
| return module | |
| except Exception as e: | |
| print(f"β Error importing {filepath}: {e}") | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def quantize_model(model): | |
| """Apply quantization for faster inference.""" | |
| if not QUANTIZE_MODEL: | |
| return model | |
| print("\nApplying quantization...") | |
| try: | |
| # Dynamic quantization - no calibration needed, works on any model | |
| model = torch.quantization.quantize_dynamic( | |
| model, | |
| {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding}, | |
| dtype=torch.qint8 | |
| ) | |
| print("β INT8 dynamic quantization applied") | |
| return model | |
| except Exception as e: | |
| print(f"β Quantization failed: {e}") | |
| return model | |
| def compile_model(model): | |
| """Compile model for maximum speed.""" | |
| print("\nCompiling model...") | |
| # ONNX Runtime (BEST performance) | |
| if USE_ONNX_RUNTIME: | |
| try: | |
| import onnxruntime as ort | |
| onnx_path = "model_optimized.onnx" | |
| # Export if not exists | |
| if not os.path.exists(onnx_path): | |
| print("Exporting model to ONNX format...") | |
| dummy_input = torch.randint(0, 100, (1, 64)) | |
| torch.onnx.export( | |
| model, dummy_input, onnx_path, | |
| input_names=['input_ids'], | |
| output_names=['logits'], | |
| dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}}, | |
| opset_version=16 | |
| ) | |
| # Create optimized session | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.intra_op_num_threads = PHYSICAL_CORES | |
| return ort.InferenceSession(onnx_path, sess_options) | |
| except Exception as e: | |
| print(f"β ONNX Runtime failed: {e}") | |
| # Intel IPEX | |
| if USE_IPEX: | |
| try: | |
| import intel_extension_for_pytorch as ipex | |
| model = ipex.optimize(model, dtype=torch.bfloat16, level="O3") | |
| print("β Intel IPEX optimization applied") | |
| return model | |
| except Exception as e: | |
| print(f"β IPEX failed: {e}") | |
| # torch.compile | |
| if USE_TORCH_COMPILE and hasattr(torch, 'compile'): | |
| try: | |
| model = torch.compile(model, mode="max-autotune") | |
| print("β torch.compile applied") | |
| except Exception as e: | |
| print(f"β torch.compile failed: {e}") | |
| return model | |
| def warmup_model(model, tokenizer, chat_func): | |
| """Warmup the model.""" | |
| if WARMUP_ITERATIONS == 0: | |
| return | |
| print("\nWarming up model...") | |
| start = time.time() | |
| try: | |
| with torch.inference_mode(): | |
| for i in range(WARMUP_ITERATIONS): | |
| chat_func( | |
| model, tokenizer, "Hello", | |
| steps=4, block_size=16, max_new_tokens=8, | |
| temperature=0.7, top_k=50, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3, | |
| verbose=False, visualize_fn=None, parallel_blocks=PHYSICAL_CORES | |
| ) | |
| print(f" Warmup {i+1}/{WARMUP_ITERATIONS}...") | |
| except Exception as e: | |
| print(f"β Warmup failed: {e}") | |
| print(f"β Warmup complete ({time.time() - start:.2f}s)") | |
| def load_model_internal(): | |
| """Load model with ultra-fast optimizations.""" | |
| global model, tokenizer, config, device, DiffusionLLM, chat_function, ModelConfig | |
| if model is not None: | |
| return True | |
| try: | |
| print("\n" + "=" * 70) | |
| print("ULTRA-FAST CPU MODEL LOADING") | |
| print("=" * 70) | |
| # FIRST: Import modules to get ModelConfig | |
| print("\n1. Loading modules...") | |
| base_path = find_file("infer-base.py") | |
| if base_path is None: | |
| raise RuntimeError("Could not find infer-base.py") | |
| base_mod = try_import_module(base_path, "infer_base") | |
| if base_mod is None: | |
| raise RuntimeError("Failed to import infer-base.py") | |
| # CRITICAL: Register ModelConfig for pickle | |
| if hasattr(base_mod, 'ModelConfig'): | |
| ModelConfig = base_mod.ModelConfig | |
| sys.modules['__main__'].ModelConfig = ModelConfig | |
| print("β ModelConfig registered for pickle") | |
| if not hasattr(base_mod, 'DiffusionLLM'): | |
| raise RuntimeError("DiffusionLLM class not found") | |
| DiffusionLLM = base_mod.DiffusionLLM | |
| print("β DiffusionLLM loaded") | |
| # Import chat function | |
| chat_path = find_file("infer-chat.py") | |
| if chat_path is None: | |
| raise RuntimeError("Could not find infer-chat.py") | |
| chat_mod = try_import_module(chat_path, "infer_chat") | |
| if chat_mod is None or not hasattr(chat_mod, 'chat'): | |
| raise RuntimeError("Chat function not found") | |
| chat_function = chat_mod.chat | |
| print("β Chat function loaded") | |
| # Device | |
| device = torch.device("cpu") | |
| # Load tokenizer | |
| print("\n2. Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "Qwen/Qwen2.5-0.5B", | |
| use_fast=True, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("β Fast tokenizer ready") | |
| # Find model | |
| checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"] | |
| model_path = None | |
| for checkpoint_dir in checkpoint_dirs: | |
| best_path = os.path.join(checkpoint_dir, "best_model.pt") | |
| fp32_path = os.path.join(checkpoint_dir, "model_fp32.pt") | |
| if os.path.exists(best_path): | |
| model_path = best_path | |
| break | |
| elif os.path.exists(fp32_path): | |
| model_path = fp32_path | |
| if model_path is None: | |
| raise RuntimeError("Model checkpoint not found") | |
| print(f"\n3. Loading checkpoint: {model_path}") | |
| # CRITICAL: Load checkpoint AFTER ModelConfig is registered | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| config = checkpoint['config'] | |
| # Create model | |
| print("4. Building model...") | |
| model = DiffusionLLM(config) | |
| # Load weights | |
| state_dict = checkpoint['model_state'] | |
| if not USE_IPEX: | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model = model.to(device) | |
| # Apply optimizations | |
| model = quantize_model(model) | |
| model = compile_model(model) | |
| # Warmup | |
| warmup_model(model, tokenizer, chat_function) | |
| # Print summary | |
| num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| framework = "ONNX Runtime" if USE_ONNX_RUNTIME else "IPEX" if USE_IPEX else "PyTorch" | |
| precision = "INT8" if QUANTIZE_MODEL and not USE_IPEX else "BF16" if USE_IPEX else "FP32" | |
| print("\n" + "=" * 70) | |
| print(f"βββ MODEL LOADED & ULTRA-OPTIMIZED ({framework} + {precision}) βββ") | |
| print("=" * 70) | |
| print(f"Parameters: {num_params:.1f}M") | |
| print(f"CPU Threads: {PHYSICAL_CORES}") | |
| if 'step' in checkpoint: | |
| print(f"Training steps: {checkpoint['step']}") | |
| if 'best_val_loss' in checkpoint: | |
| print(f"Best val loss: {checkpoint['best_val_loss']:.4f}") | |
| print("=" * 70 + "\n") | |
| return True | |
| except Exception as e: | |
| print(f"\nβ ERROR LOADING MODEL: {e}") | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| print("=" * 70 + "\n") | |
| return False | |
| def create_streaming_visualizer(): | |
| """Create optimized visualizer.""" | |
| def visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| if not isinstance(mask_blocks, list): | |
| mask_blocks = [mask_blocks] | |
| is_masked_list = [is_masked_list] | |
| try: | |
| context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| except Exception: | |
| context_text = str(context_ids[0].tolist()) | |
| all_blocks = [] | |
| for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| block_tokens = mask_block[0].tolist() | |
| block_data = [] | |
| # Efficient batch decoding | |
| token_ids_to_decode = [] | |
| positions = [] | |
| for i, token_id in enumerate(block_tokens): | |
| if not is_masked[0, i]: | |
| token_ids_to_decode.append(token_id) | |
| positions.append(i) | |
| try: | |
| decoded_tokens = tok.batch_decode( | |
| [[tid] for tid in token_ids_to_decode], | |
| skip_special_tokens=False | |
| ) | |
| except Exception: | |
| decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode] | |
| # Reconstruct | |
| decoded_idx = 0 | |
| for i, token_id in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| block_data.append({'type': 'masked', 'text': 'βββ'}) | |
| else: | |
| block_data.append({ | |
| 'type': 'revealed', | |
| 'text': decoded_tokens[decoded_idx] | |
| }) | |
| decoded_idx += 1 | |
| all_blocks.append({'block_index': block_idx, 'tokens': block_data}) | |
| return { | |
| 'context': context_text, | |
| 'blocks': all_blocks, | |
| 'num_blocks': len(mask_blocks) | |
| } | |
| return visualizer | |
| def index(): | |
| """Serve the main HTML page.""" | |
| return app.send_static_file('index.html') | |
| def load_model_endpoint(): | |
| """Load the model.""" | |
| global model | |
| data = request.json or {} | |
| check_only = data.get('check_only', False) | |
| if check_only: | |
| return jsonify({ | |
| 'loaded': model is not None, | |
| 'message': 'Model is loaded' if model is not None else 'Model not loaded' | |
| }) | |
| if model is not None: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model already loaded' | |
| }) | |
| success = load_model_internal() | |
| if success: | |
| return jsonify({ | |
| 'loaded': True, | |
| 'message': 'Model loaded with ultra-fast CPU optimizations' | |
| }) | |
| else: | |
| return jsonify({ | |
| 'loaded': False, | |
| 'message': 'Failed to load model. Check server logs.' | |
| }), 500 | |
| def generate(): | |
| """Generate response - ultra-fast path.""" | |
| global model, tokenizer, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| if chat_function is None: | |
| return jsonify({'error': 'Chat function not available'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 16) # Minimal for speed | |
| block_size = data.get('block_size', 32) | |
| max_new_tokens = data.get('max_new_tokens', 64) | |
| parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| try: | |
| with torch.inference_mode(): | |
| raw_output, response = chat_function( | |
| model, tokenizer, instruction, | |
| steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, | |
| temperature=0.7, top_k=50, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3, | |
| verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks, | |
| ) | |
| return jsonify({'response': response, 'raw_output': raw_output}) | |
| except Exception as e: | |
| if __debug__: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def generate_stream(): | |
| """Generate with streaming - optimized.""" | |
| global model, tokenizer, chat_function | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 400 | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| steps = data.get('steps', 16) | |
| block_size = data.get('block_size', 32) | |
| max_new_tokens = data.get('max_new_tokens', 64) | |
| parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES) | |
| if not instruction: | |
| return jsonify({'error': 'No instruction provided'}), 400 | |
| def generate_events(): | |
| event_queue = queue.Queue(maxsize=50) # Limited queue | |
| generation_complete = {'done': False, 'result': None} | |
| def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True): | |
| try: | |
| visualizer = create_streaming_visualizer() | |
| data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear) | |
| event_queue.put({'type': 'update', 'data': data}, block=False, timeout=0.1) | |
| except queue.Full: | |
| pass # Drop frames if too slow | |
| def run_generation(): | |
| try: | |
| with torch.inference_mode(): | |
| raw_output, response = chat_function( | |
| model, tokenizer, instruction, | |
| steps=steps, block_size=block_size, max_new_tokens=max_new_tokens, | |
| temperature=0.7, top_k=50, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3, | |
| verbose=False, visualize_fn=streaming_visualizer, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| generation_complete['result'] = (raw_output, response) | |
| except Exception as e: | |
| generation_complete['result'] = ('error', str(e)) | |
| finally: | |
| generation_complete['done'] = True | |
| event_queue.put(None) | |
| # Start generation thread | |
| gen_thread = threading.Thread(target=run_generation, daemon=True) | |
| gen_thread.start() | |
| yield f"data: {json.dumps({'type': 'start', 'ts': time.time()})}\n\n" | |
| # Stream events | |
| while not generation_complete['done'] or not event_queue.empty(): | |
| try: | |
| event = event_queue.get(timeout=0.1) | |
| if event is None: | |
| break | |
| yield f"data: {json.dumps(event)}\n\n" | |
| except queue.Empty: | |
| continue | |
| gen_thread.join(timeout=2.0) | |
| # Send final result | |
| if generation_complete['result']: | |
| raw_output, response = generation_complete['result'] | |
| yield f"data: {json.dumps({'type': 'complete' if raw_output != 'error' else 'error', | |
| 'response': response, 'error': response if raw_output == 'error' else None})}\n\n" | |
| return Response( | |
| stream_with_context(generate_events()), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'X-Accel-Buffering': 'no', | |
| 'Connection': 'keep-alive' | |
| } | |
| ) | |
| def status(): | |
| """Get detailed status.""" | |
| return jsonify({ | |
| 'model_loaded': model is not None, | |
| 'cpu_cores': os.cpu_count(), | |
| 'physical_cores': PHYSICAL_CORES, | |
| 'torch_threads': torch.get_num_threads(), | |
| 'interop_threads': torch.get_num_interop_threads(), | |
| 'optimizations': { | |
| 'onnx_runtime': USE_ONNX_RUNTIME, | |
| 'ipex': USE_IPEX, | |
| 'torch_compile': USE_TORCH_COMPILE, | |
| 'quantization': QUANTIZE_MODEL, | |
| 'warmup_iterations': WARMUP_ITERATIONS | |
| } | |
| }) | |
| if __name__ == '__main__': | |
| print("\n" + "=" * 70) | |
| print("ULTRA-FAST CPU INFERENCE SERVER v2.0") | |
| print("=" * 70) | |
| print(f"CPU Configuration: {PHYSICAL_CORES} physical cores") | |
| print(f"Optimizations: ONNX={USE_ONNX_RUNTIME} | IPEX={USE_IPEX} | Compile={USE_TORCH_COMPILE}") | |
| print("=" * 70 + "\n") | |
| app.run(debug=False, host='0.0.0.0', port=7860, threaded=True) |