thejagstudio commited on
Commit
527cb39
·
verified ·
1 Parent(s): 01cf381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -187
app.py CHANGED
@@ -487,9 +487,43 @@ import torch
487
  from transformers import AutoTokenizer
488
  import threading
489
  import queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  app = Flask(__name__, static_folder='static', static_url_path='/static')
492
- CORS(app)
493
 
494
  # Global state
495
  model = None
@@ -498,15 +532,7 @@ config = None
498
  device = None
499
  DiffusionLLM = None
500
  chat_function = None
501
- optimized_pipeline = None # For ONNX/IPEX
502
-
503
- # ==================== CONFIGURATION ====================
504
- USE_ONNX_RUNTIME = False # Set True for ONNX (fastest)
505
- USE_IPEX = False # Set True for Intel CPUs
506
- USE_TORCH_COMPILE = True # Set True for PyTorch 2.0+ (good default)
507
- QUANTIZE_MODEL = True # INT8/BF16 quantization
508
- WARMUP_ITERATIONS = 3 # Warmup for stable performance
509
- # =======================================================
510
 
511
  def find_file(filename, search_dirs=None):
512
  """Find a file in current directory or parent directories."""
@@ -531,53 +557,25 @@ def try_import_module(filepath, module_name):
531
  module_dir = os.path.dirname(filepath)
532
  if module_dir not in sys.path:
533
  sys.path.insert(0, module_dir)
 
534
  spec = importlib.util.spec_from_file_location(module_name, filepath)
535
  if spec is None:
536
  print(f"Could not create spec for {filepath}")
537
  return None
 
538
  module = importlib.util.module_from_spec(spec)
539
  sys.modules[module_name] = module
540
  spec.loader.exec_module(module)
541
- print(f"Successfully imported {module_name} from {filepath}")
 
542
  return module
543
  except Exception as e:
544
- print(f"Error importing {filepath}: {e}")
545
  if __debug__:
546
  import traceback
547
  traceback.print_exc()
548
  return None
549
 
550
- def configure_cpu_optimization():
551
- """Configure CPU for maximum performance."""
552
- print("\n" + "=" * 60)
553
- print("CPU OPTIMIZATION CONFIGURATION")
554
- print("=" * 60)
555
-
556
- # Get CPU info
557
- cpu_count = os.cpu_count()
558
- physical_cores = cpu_count // 2 if cpu_count else cpu_count
559
-
560
- # Optimal thread settings
561
- threads = physical_cores or cpu_count or 1
562
- torch.set_num_threads(threads)
563
- torch.set_num_interop_threads(threads)
564
-
565
- # Environment variables for MKL/OMP
566
- os.environ["OMP_NUM_THREADS"] = str(threads)
567
- os.environ["MKL_NUM_THREADS"] = str(threads)
568
- os.environ["NUMEXPR_NUM_THREADS"] = str(threads)
569
-
570
- print(f"✓ CPU Cores: {cpu_count} ({physical_cores} physical)")
571
- print(f"✓ Threads: {threads}")
572
- print(f"✓ OMP/MKL threads: {threads}")
573
-
574
- # Intel-specific optimizations
575
- if "intel" in torch.__version__.lower() or USE_IPEX:
576
- torch.backends.quantized.engine = 'fbgemm'
577
- print("✓ Intel FBGEMM backend enabled")
578
-
579
- print("=" * 60 + "\n")
580
-
581
  def quantize_model(model):
582
  """Apply quantization for faster inference."""
583
  if not QUANTIZE_MODEL:
@@ -585,20 +583,13 @@ def quantize_model(model):
585
 
586
  print("\nApplying quantization...")
587
  try:
588
- # Use torch.compile with quantization if available
589
- if hasattr(torch, 'ao') and hasattr(torch.ao, 'quantization'):
590
- # Dynamic quantization (fastest, no calibration needed)
591
- model = torch.quantization.quantize_dynamic(
592
- model,
593
- {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d},
594
- dtype=torch.qint8
595
- )
596
- print("✓ Applied INT8 dynamic quantization")
597
- elif USE_IPEX:
598
- # IPEX BF16 optimization
599
- model = model.to(torch.bfloat16)
600
- print("✓ Applied BF16 precision (IPEX)")
601
-
602
  return model
603
  except Exception as e:
604
  print(f"⚠ Quantization failed: {e}")
@@ -606,124 +597,114 @@ def quantize_model(model):
606
 
607
  def compile_model(model):
608
  """Compile model for maximum speed."""
609
- global USE_TORCH_COMPILE, USE_ONNX_RUNTIME
610
-
611
  print("\nCompiling model...")
612
 
613
- # Option 1: ONNX Runtime (BEST performance)
614
  if USE_ONNX_RUNTIME:
615
  try:
616
  import onnxruntime as ort
617
- # Export to ONNX (one-time cost, but worth it)
618
  onnx_path = "model_optimized.onnx"
 
 
619
  if not os.path.exists(onnx_path):
620
- print("Exporting to ONNX format...")
621
- dummy_input = torch.randint(0, 100, (1, 128)) # Adjust shape as needed
622
  torch.onnx.export(
623
- model,
624
- dummy_input,
625
- onnx_path,
626
  input_names=['input_ids'],
627
  output_names=['logits'],
628
  dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
629
  opset_version=16
630
  )
631
 
632
- # Create ONNX Runtime session
633
  sess_options = ort.SessionOptions()
634
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
635
- sess_options.enable_cpu_mem_arena = True
636
- sess_options.enable_mem_pattern = True
637
 
638
- # Use all cores
639
- sess_options.intra_op_num_threads = torch.get_num_threads()
640
-
641
- provider = 'CPUExecutionProvider'
642
- compiled_model = ort.InferenceSession(onnx_path, sess_options, providers=[provider])
643
-
644
- print("✓ ONNX Runtime compilation complete")
645
- return compiled_model
646
  except Exception as e:
647
- print(f"⚠ ONNX Runtime failed: {e}, using torch.compile")
648
- USE_ONNX_RUNTIME = False
649
 
650
- # Option 2: Intel IPEX
651
  if USE_IPEX:
652
  try:
653
  import intel_extension_for_pytorch as ipex
654
  model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
655
- print("✓ Intel IPEX O3 optimization applied")
656
  return model
657
  except Exception as e:
658
  print(f"⚠ IPEX failed: {e}")
659
- USE_IPEX = False
660
 
661
- # Option 3: torch.compile (PyTorch 2.0+)
662
  if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
663
  try:
664
- # Use "max-autotune" for best performance
665
- model = torch.compile(
666
- model,
667
- mode="max-autotune",
668
- fullgraph=True,
669
- backend="inductor"
670
- )
671
- print("✓ torch.compile (max-autotune) applied")
672
  except Exception as e:
673
- print(f"⚠ torch.compile failed: {e}, using eager mode")
674
 
675
  return model
676
 
677
  def warmup_model(model, tokenizer, chat_func):
678
- """Warmup the model for consistent performance."""
679
  if WARMUP_ITERATIONS == 0:
680
  return
681
 
682
  print("\nWarming up model...")
683
- start_time = time.time()
684
 
685
  try:
686
  with torch.inference_mode():
687
  for i in range(WARMUP_ITERATIONS):
688
- _ = chat_func(
689
- model, tokenizer, "Hello world",
690
- steps=8, block_size=32, max_new_tokens=16,
691
- temperature=0.8, top_k=50, top_p=0.9,
692
  repetition_penalty=1.2, no_repeat_ngram_size=3,
693
- verbose=False, visualize_fn=None, parallel_blocks=2
694
  )
695
- print(f" Warmup {i+1}/{WARMUP_ITERATIONS} complete")
696
  except Exception as e:
697
  print(f"⚠ Warmup failed: {e}")
698
 
699
- print(f"✓ Warmup finished in {time.time() - start_time:.2f}s")
700
 
701
  def load_model_internal():
702
- """Load the model with ultra-fast optimizations."""
703
- global model, tokenizer, config, device, DiffusionLLM, chat_function, optimized_pipeline
704
 
705
  if model is not None:
706
  return True
707
 
708
  try:
709
- print("=" * 60)
710
  print("ULTRA-FAST CPU MODEL LOADING")
711
- print("=" * 60)
712
 
713
- # Configure CPU
714
- configure_cpu_optimization()
715
-
716
- # Import modules
717
  base_path = find_file("infer-base.py")
718
  if base_path is None:
719
  raise RuntimeError("Could not find infer-base.py")
720
 
721
  base_mod = try_import_module(base_path, "infer_base")
722
- if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'):
 
 
 
 
 
 
 
 
 
723
  raise RuntimeError("DiffusionLLM class not found")
724
 
725
  DiffusionLLM = base_mod.DiffusionLLM
 
726
 
 
727
  chat_path = find_file("infer-chat.py")
728
  if chat_path is None:
729
  raise RuntimeError("Could not find infer-chat.py")
@@ -733,12 +714,13 @@ def load_model_internal():
733
  raise RuntimeError("Chat function not found")
734
 
735
  chat_function = chat_mod.chat
 
736
 
737
  # Device
738
  device = torch.device("cpu")
739
 
740
  # Load tokenizer
741
- print("\nLoading tokenizer...")
742
  tokenizer = AutoTokenizer.from_pretrained(
743
  "Qwen/Qwen2.5-0.5B",
744
  use_fast=True,
@@ -746,7 +728,7 @@ def load_model_internal():
746
  )
747
  if tokenizer.pad_token is None:
748
  tokenizer.pad_token = tokenizer.eos_token
749
- print("✓ Fast tokenizer loaded")
750
 
751
  # Find model
752
  checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
@@ -763,26 +745,27 @@ def load_model_internal():
763
  if model_path is None:
764
  raise RuntimeError("Model checkpoint not found")
765
 
766
- print(f"\nLoading model from: {model_path}")
767
 
768
- # Load checkpoint
769
  checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
770
  config = checkpoint['config']
771
 
772
- # Create and load model
 
773
  model = DiffusionLLM(config)
774
- state_dict = checkpoint['model_state']
775
 
776
- # Convert to float32 for CPU
777
- if not USE_IPEX: # Keep FP32 for ONNX, use BF16 for IPEX
 
778
  state_dict = {k: v.float() for k, v in state_dict.items()}
779
 
780
  model.load_state_dict(state_dict)
781
  model.eval()
 
782
 
783
  # Apply optimizations
784
  model = quantize_model(model)
785
- model = model.to(device)
786
  model = compile_model(model)
787
 
788
  # Warmup
@@ -790,24 +773,28 @@ def load_model_internal():
790
 
791
  # Print summary
792
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
793
- print(f"\n{'=' * 60}")
794
- print(f"✓✓✓ MODEL LOADED & ULTRA-OPTIMIZED FOR CPU ✓✓✓")
795
- print(f"{'=' * 60}")
796
- print(f"Framework: {'ONNX Runtime' if USE_ONNX_RUNTIME else 'IPEX' if USE_IPEX else 'PyTorch'}")
 
 
797
  print(f"Parameters: {num_params:.1f}M")
798
- print(f"CPU Threads: {torch.get_num_threads()}")
799
- print(f"Quantization: {'INT8' if QUANTIZE_MODEL else 'BF16' if USE_IPEX else 'FP32'}")
800
  if 'step' in checkpoint:
801
  print(f"Training steps: {checkpoint['step']}")
802
- print(f"{'=' * 60}\n")
 
 
803
 
804
  return True
805
 
806
  except Exception as e:
807
- print(f"\nERROR LOADING MODEL: {e}")
808
  if __debug__:
809
  import traceback
810
  traceback.print_exc()
 
811
  return False
812
 
813
  def create_streaming_visualizer():
@@ -818,7 +805,6 @@ def create_streaming_visualizer():
818
  is_masked_list = [is_masked_list]
819
 
820
  try:
821
- # Decode only once for efficiency
822
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
823
  except Exception:
824
  context_text = str(context_ids[0].tolist())
@@ -828,10 +814,9 @@ def create_streaming_visualizer():
828
  block_tokens = mask_block[0].tolist()
829
  block_data = []
830
 
831
- # Batch decode for speed
832
  token_ids_to_decode = []
833
  positions = []
834
-
835
  for i, token_id in enumerate(block_tokens):
836
  if not is_masked[0, i]:
837
  token_ids_to_decode.append(token_id)
@@ -845,7 +830,7 @@ def create_streaming_visualizer():
845
  except Exception:
846
  decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
847
 
848
- # Reconstruct block
849
  decoded_idx = 0
850
  for i, token_id in enumerate(block_tokens):
851
  if is_masked[0, i]:
@@ -875,11 +860,11 @@ def index():
875
  @app.route('/api/load', methods=['POST'])
876
  def load_model_endpoint():
877
  """Load the model."""
 
 
878
  data = request.json or {}
879
  check_only = data.get('check_only', False)
880
 
881
- global model
882
-
883
  if check_only:
884
  return jsonify({
885
  'loaded': model is not None,
@@ -906,8 +891,8 @@ def load_model_endpoint():
906
 
907
  @app.route('/api/generate', methods=['POST'])
908
  def generate():
909
- """Generate response - optimized for minimal latency."""
910
- global model, tokenizer, device, chat_function
911
 
912
  if model is None:
913
  return jsonify({'error': 'Model not loaded'}), 400
@@ -917,38 +902,25 @@ def generate():
917
 
918
  data = request.json
919
  instruction = data.get('instruction', '')
920
- steps = data.get('steps', 16) # Further reduced for speed
921
  block_size = data.get('block_size', 32)
922
  max_new_tokens = data.get('max_new_tokens', 64)
923
- parallel_blocks = data.get('parallel_blocks', torch.get_num_threads())
924
 
925
  if not instruction:
926
  return jsonify({'error': 'No instruction provided'}), 400
927
 
928
  try:
929
- # Fast path: no overhead
930
  with torch.inference_mode():
931
  raw_output, response = chat_function(
932
- model,
933
- tokenizer,
934
- instruction,
935
- steps=steps,
936
- block_size=block_size,
937
- max_new_tokens=max_new_tokens,
938
- temperature=0.7, # Slightly lower for faster sampling
939
- top_k=50,
940
- top_p=0.9,
941
- repetition_penalty=1.2,
942
- no_repeat_ngram_size=3,
943
- verbose=False,
944
- visualize_fn=None,
945
- parallel_blocks=parallel_blocks,
946
  )
947
 
948
- return jsonify({
949
- 'response': response,
950
- 'raw_output': raw_output
951
- })
952
 
953
  except Exception as e:
954
  if __debug__:
@@ -958,36 +930,33 @@ def generate():
958
 
959
  @app.route('/api/generate-stream', methods=['POST'])
960
  def generate_stream():
961
- """Generate with streaming - optimized with queue."""
962
  global model, tokenizer, chat_function
963
 
964
  if model is None:
965
  return jsonify({'error': 'Model not loaded'}), 400
966
 
967
- if chat_function is None:
968
- return jsonify({'error': 'Chat function not available'}), 400
969
-
970
  data = request.json
971
  instruction = data.get('instruction', '')
972
  steps = data.get('steps', 16)
973
  block_size = data.get('block_size', 32)
974
  max_new_tokens = data.get('max_new_tokens', 64)
975
- parallel_blocks = data.get('parallel_blocks', torch.get_num_threads())
976
 
977
  if not instruction:
978
  return jsonify({'error': 'No instruction provided'}), 400
979
 
980
  def generate_events():
981
- event_queue = queue.Queue(maxsize=100) # Limit queue size
982
  generation_complete = {'done': False, 'result': None}
983
 
984
  def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
985
  try:
986
  visualizer = create_streaming_visualizer()
987
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
988
- event_queue.put({'type': 'update', 'data': data}, block=False)
989
  except queue.Full:
990
- pass # Drop updates if queue full (prevents memory issues)
991
 
992
  def run_generation():
993
  try:
@@ -1007,13 +976,13 @@ def generate_stream():
1007
  generation_complete['done'] = True
1008
  event_queue.put(None)
1009
 
1010
- import threading
1011
- gen_thread = threading.Thread(target=run_generation)
1012
- gen_thread.daemon = True
1013
  gen_thread.start()
1014
 
1015
- yield f"data: {json.dumps({'type': 'start'})}\n\n"
1016
 
 
1017
  while not generation_complete['done'] or not event_queue.empty():
1018
  try:
1019
  event = event_queue.get(timeout=0.1)
@@ -1025,12 +994,11 @@ def generate_stream():
1025
 
1026
  gen_thread.join(timeout=2.0)
1027
 
 
1028
  if generation_complete['result']:
1029
  raw_output, response = generation_complete['result']
1030
- if raw_output == 'error':
1031
- yield f"data: {json.dumps({'type': 'error', 'error': response})}\n\n"
1032
- else:
1033
- yield f"data: {json.dumps({'type': 'complete', 'response': response})}\n\n"
1034
 
1035
  return Response(
1036
  stream_with_context(generate_events()),
@@ -1044,35 +1012,28 @@ def generate_stream():
1044
 
1045
  @app.route('/api/status', methods=['GET'])
1046
  def status():
1047
- """Get system status."""
1048
  return jsonify({
1049
  'model_loaded': model is not None,
 
 
1050
  'torch_threads': torch.get_num_threads(),
1051
  'interop_threads': torch.get_num_interop_threads(),
1052
- 'cpu_count': os.cpu_count(),
1053
  'optimizations': {
1054
  'onnx_runtime': USE_ONNX_RUNTIME,
1055
  'ipex': USE_IPEX,
1056
  'torch_compile': USE_TORCH_COMPILE,
1057
- 'quantization': QUANTIZE_MODEL
 
1058
  }
1059
  })
1060
 
1061
  if __name__ == '__main__':
1062
  print("\n" + "=" * 70)
1063
- print("ULTRA-FAST CPU INFERENCE SERVER")
1064
  print("=" * 70)
1065
- print("Available optimizations:")
1066
- print(" ONNX Runtime (best):", USE_ONNX_RUNTIME)
1067
- print(" ✓ Intel IPEX:", USE_IPEX)
1068
- print(" ✓ torch.compile:", USE_TORCH_COMPILE)
1069
- print(" ✓ Quantization:", QUANTIZE_MODEL)
1070
- print(" ✓ Multi-threading")
1071
- print(" ✓ Inference mode")
1072
- print(" ✓ Fast tokenizer")
1073
- print(" ✓ Memory layout optimization")
1074
- print("\nTo install ONNX Runtime: pip install onnxruntime")
1075
- print("To install Intel IPEX: pip install intel-extension-for-pytorch")
1076
  print("=" * 70 + "\n")
1077
 
1078
  app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)
 
487
  from transformers import AutoTokenizer
488
  import threading
489
  import queue
490
+ import warnings
491
+
492
+ # ============ CRITICAL: CONFIGURE THREADS BEFORE TORCH OPERATIONS ============
493
+ # Must be set IMMEDIATELY at module import time
494
+ def setup_cpu_threads():
495
+ """Configure CPU threads BEFORE any PyTorch parallel work starts."""
496
+ cpu_count = os.cpu_count() or 1
497
+ physical_cores = cpu_count // 2 if cpu_count > 1 else 1
498
+
499
+ # Set environment variables FIRST
500
+ os.environ["OMP_NUM_THREADS"] = str(physical_cores)
501
+ os.environ["MKL_NUM_THREADS"] = str(physical_cores)
502
+ os.environ["NUMEXPR_NUM_THREADS"] = str(physical_cores)
503
+
504
+ # Set PyTorch threads BEFORE any operations
505
+ try:
506
+ torch.set_num_threads(physical_cores)
507
+ torch.set_num_interop_threads(physical_cores)
508
+ except RuntimeError as e:
509
+ warnings.warn(f"Could not set threads: {e} (already initialized)")
510
+
511
+ print(f"✓ CPU threads configured: {physical_cores} physical cores")
512
+ return physical_cores
513
+
514
+ # Call immediately
515
+ PHYSICAL_CORES = setup_cpu_threads()
516
+ # ============================================================================
517
+
518
+ # Configuration flags
519
+ USE_ONNX_RUNTIME = False
520
+ USE_IPEX = False
521
+ USE_TORCH_COMPILE = True
522
+ QUANTIZE_MODEL = True
523
+ WARMUP_ITERATIONS = 3
524
 
525
  app = Flask(__name__, static_folder='static', static_url_path='/static')
526
+ CORS(app, resources={r"/api/*": {"origins": "*"}}) # More permissive for testing
527
 
528
  # Global state
529
  model = None
 
532
  device = None
533
  DiffusionLLM = None
534
  chat_function = None
535
+ ModelConfig = None # Will be imported from infer-base.py
 
 
 
 
 
 
 
 
536
 
537
  def find_file(filename, search_dirs=None):
538
  """Find a file in current directory or parent directories."""
 
557
  module_dir = os.path.dirname(filepath)
558
  if module_dir not in sys.path:
559
  sys.path.insert(0, module_dir)
560
+
561
  spec = importlib.util.spec_from_file_location(module_name, filepath)
562
  if spec is None:
563
  print(f"Could not create spec for {filepath}")
564
  return None
565
+
566
  module = importlib.util.module_from_spec(spec)
567
  sys.modules[module_name] = module
568
  spec.loader.exec_module(module)
569
+
570
+ print(f"✓ Successfully imported {module_name}")
571
  return module
572
  except Exception as e:
573
+ print(f"Error importing {filepath}: {e}")
574
  if __debug__:
575
  import traceback
576
  traceback.print_exc()
577
  return None
578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  def quantize_model(model):
580
  """Apply quantization for faster inference."""
581
  if not QUANTIZE_MODEL:
 
583
 
584
  print("\nApplying quantization...")
585
  try:
586
+ # Dynamic quantization - no calibration needed, works on any model
587
+ model = torch.quantization.quantize_dynamic(
588
+ model,
589
+ {torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding},
590
+ dtype=torch.qint8
591
+ )
592
+ print("✓ INT8 dynamic quantization applied")
 
 
 
 
 
 
 
593
  return model
594
  except Exception as e:
595
  print(f"⚠ Quantization failed: {e}")
 
597
 
598
  def compile_model(model):
599
  """Compile model for maximum speed."""
 
 
600
  print("\nCompiling model...")
601
 
602
+ # ONNX Runtime (BEST performance)
603
  if USE_ONNX_RUNTIME:
604
  try:
605
  import onnxruntime as ort
 
606
  onnx_path = "model_optimized.onnx"
607
+
608
+ # Export if not exists
609
  if not os.path.exists(onnx_path):
610
+ print("Exporting model to ONNX format...")
611
+ dummy_input = torch.randint(0, 100, (1, 64))
612
  torch.onnx.export(
613
+ model, dummy_input, onnx_path,
 
 
614
  input_names=['input_ids'],
615
  output_names=['logits'],
616
  dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
617
  opset_version=16
618
  )
619
 
620
+ # Create optimized session
621
  sess_options = ort.SessionOptions()
622
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
623
+ sess_options.intra_op_num_threads = PHYSICAL_CORES
 
624
 
625
+ return ort.InferenceSession(onnx_path, sess_options)
 
 
 
 
 
 
 
626
  except Exception as e:
627
+ print(f"⚠ ONNX Runtime failed: {e}")
 
628
 
629
+ # Intel IPEX
630
  if USE_IPEX:
631
  try:
632
  import intel_extension_for_pytorch as ipex
633
  model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
634
+ print("✓ Intel IPEX optimization applied")
635
  return model
636
  except Exception as e:
637
  print(f"⚠ IPEX failed: {e}")
 
638
 
639
+ # torch.compile
640
  if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
641
  try:
642
+ model = torch.compile(model, mode="max-autotune")
643
+ print("✓ torch.compile applied")
 
 
 
 
 
 
644
  except Exception as e:
645
+ print(f"⚠ torch.compile failed: {e}")
646
 
647
  return model
648
 
649
  def warmup_model(model, tokenizer, chat_func):
650
+ """Warmup the model."""
651
  if WARMUP_ITERATIONS == 0:
652
  return
653
 
654
  print("\nWarming up model...")
655
+ start = time.time()
656
 
657
  try:
658
  with torch.inference_mode():
659
  for i in range(WARMUP_ITERATIONS):
660
+ chat_func(
661
+ model, tokenizer, "Hello",
662
+ steps=4, block_size=16, max_new_tokens=8,
663
+ temperature=0.7, top_k=50, top_p=0.9,
664
  repetition_penalty=1.2, no_repeat_ngram_size=3,
665
+ verbose=False, visualize_fn=None, parallel_blocks=PHYSICAL_CORES
666
  )
667
+ print(f" Warmup {i+1}/{WARMUP_ITERATIONS}...")
668
  except Exception as e:
669
  print(f"⚠ Warmup failed: {e}")
670
 
671
+ print(f"✓ Warmup complete ({time.time() - start:.2f}s)")
672
 
673
  def load_model_internal():
674
+ """Load model with ultra-fast optimizations."""
675
+ global model, tokenizer, config, device, DiffusionLLM, chat_function, ModelConfig
676
 
677
  if model is not None:
678
  return True
679
 
680
  try:
681
+ print("\n" + "=" * 70)
682
  print("ULTRA-FAST CPU MODEL LOADING")
683
+ print("=" * 70)
684
 
685
+ # FIRST: Import modules to get ModelConfig
686
+ print("\n1. Loading modules...")
 
 
687
  base_path = find_file("infer-base.py")
688
  if base_path is None:
689
  raise RuntimeError("Could not find infer-base.py")
690
 
691
  base_mod = try_import_module(base_path, "infer_base")
692
+ if base_mod is None:
693
+ raise RuntimeError("Failed to import infer-base.py")
694
+
695
+ # CRITICAL: Register ModelConfig for pickle
696
+ if hasattr(base_mod, 'ModelConfig'):
697
+ ModelConfig = base_mod.ModelConfig
698
+ sys.modules['__main__'].ModelConfig = ModelConfig
699
+ print("✓ ModelConfig registered for pickle")
700
+
701
+ if not hasattr(base_mod, 'DiffusionLLM'):
702
  raise RuntimeError("DiffusionLLM class not found")
703
 
704
  DiffusionLLM = base_mod.DiffusionLLM
705
+ print("✓ DiffusionLLM loaded")
706
 
707
+ # Import chat function
708
  chat_path = find_file("infer-chat.py")
709
  if chat_path is None:
710
  raise RuntimeError("Could not find infer-chat.py")
 
714
  raise RuntimeError("Chat function not found")
715
 
716
  chat_function = chat_mod.chat
717
+ print("✓ Chat function loaded")
718
 
719
  # Device
720
  device = torch.device("cpu")
721
 
722
  # Load tokenizer
723
+ print("\n2. Loading tokenizer...")
724
  tokenizer = AutoTokenizer.from_pretrained(
725
  "Qwen/Qwen2.5-0.5B",
726
  use_fast=True,
 
728
  )
729
  if tokenizer.pad_token is None:
730
  tokenizer.pad_token = tokenizer.eos_token
731
+ print("✓ Fast tokenizer ready")
732
 
733
  # Find model
734
  checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
 
745
  if model_path is None:
746
  raise RuntimeError("Model checkpoint not found")
747
 
748
+ print(f"\n3. Loading checkpoint: {model_path}")
749
 
750
+ # CRITICAL: Load checkpoint AFTER ModelConfig is registered
751
  checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
752
  config = checkpoint['config']
753
 
754
+ # Create model
755
+ print("4. Building model...")
756
  model = DiffusionLLM(config)
 
757
 
758
+ # Load weights
759
+ state_dict = checkpoint['model_state']
760
+ if not USE_IPEX:
761
  state_dict = {k: v.float() for k, v in state_dict.items()}
762
 
763
  model.load_state_dict(state_dict)
764
  model.eval()
765
+ model = model.to(device)
766
 
767
  # Apply optimizations
768
  model = quantize_model(model)
 
769
  model = compile_model(model)
770
 
771
  # Warmup
 
773
 
774
  # Print summary
775
  num_params = sum(p.numel() for p in model.parameters()) / 1e6
776
+ framework = "ONNX Runtime" if USE_ONNX_RUNTIME else "IPEX" if USE_IPEX else "PyTorch"
777
+ precision = "INT8" if QUANTIZE_MODEL and not USE_IPEX else "BF16" if USE_IPEX else "FP32"
778
+
779
+ print("\n" + "=" * 70)
780
+ print(f"✓✓✓ MODEL LOADED & ULTRA-OPTIMIZED ({framework} + {precision}) ✓✓✓")
781
+ print("=" * 70)
782
  print(f"Parameters: {num_params:.1f}M")
783
+ print(f"CPU Threads: {PHYSICAL_CORES}")
 
784
  if 'step' in checkpoint:
785
  print(f"Training steps: {checkpoint['step']}")
786
+ if 'best_val_loss' in checkpoint:
787
+ print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")
788
+ print("=" * 70 + "\n")
789
 
790
  return True
791
 
792
  except Exception as e:
793
+ print(f"\n✗ ERROR LOADING MODEL: {e}")
794
  if __debug__:
795
  import traceback
796
  traceback.print_exc()
797
+ print("=" * 70 + "\n")
798
  return False
799
 
800
  def create_streaming_visualizer():
 
805
  is_masked_list = [is_masked_list]
806
 
807
  try:
 
808
  context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
809
  except Exception:
810
  context_text = str(context_ids[0].tolist())
 
814
  block_tokens = mask_block[0].tolist()
815
  block_data = []
816
 
817
+ # Efficient batch decoding
818
  token_ids_to_decode = []
819
  positions = []
 
820
  for i, token_id in enumerate(block_tokens):
821
  if not is_masked[0, i]:
822
  token_ids_to_decode.append(token_id)
 
830
  except Exception:
831
  decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
832
 
833
+ # Reconstruct
834
  decoded_idx = 0
835
  for i, token_id in enumerate(block_tokens):
836
  if is_masked[0, i]:
 
860
  @app.route('/api/load', methods=['POST'])
861
  def load_model_endpoint():
862
  """Load the model."""
863
+ global model
864
+
865
  data = request.json or {}
866
  check_only = data.get('check_only', False)
867
 
 
 
868
  if check_only:
869
  return jsonify({
870
  'loaded': model is not None,
 
891
 
892
  @app.route('/api/generate', methods=['POST'])
893
  def generate():
894
+ """Generate response - ultra-fast path."""
895
+ global model, tokenizer, chat_function
896
 
897
  if model is None:
898
  return jsonify({'error': 'Model not loaded'}), 400
 
902
 
903
  data = request.json
904
  instruction = data.get('instruction', '')
905
+ steps = data.get('steps', 16) # Minimal for speed
906
  block_size = data.get('block_size', 32)
907
  max_new_tokens = data.get('max_new_tokens', 64)
908
+ parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
909
 
910
  if not instruction:
911
  return jsonify({'error': 'No instruction provided'}), 400
912
 
913
  try:
 
914
  with torch.inference_mode():
915
  raw_output, response = chat_function(
916
+ model, tokenizer, instruction,
917
+ steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
918
+ temperature=0.7, top_k=50, top_p=0.9,
919
+ repetition_penalty=1.2, no_repeat_ngram_size=3,
920
+ verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks,
 
 
 
 
 
 
 
 
 
921
  )
922
 
923
+ return jsonify({'response': response, 'raw_output': raw_output})
 
 
 
924
 
925
  except Exception as e:
926
  if __debug__:
 
930
 
931
  @app.route('/api/generate-stream', methods=['POST'])
932
  def generate_stream():
933
+ """Generate with streaming - optimized."""
934
  global model, tokenizer, chat_function
935
 
936
  if model is None:
937
  return jsonify({'error': 'Model not loaded'}), 400
938
 
 
 
 
939
  data = request.json
940
  instruction = data.get('instruction', '')
941
  steps = data.get('steps', 16)
942
  block_size = data.get('block_size', 32)
943
  max_new_tokens = data.get('max_new_tokens', 64)
944
+ parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
945
 
946
  if not instruction:
947
  return jsonify({'error': 'No instruction provided'}), 400
948
 
949
  def generate_events():
950
+ event_queue = queue.Queue(maxsize=50) # Limited queue
951
  generation_complete = {'done': False, 'result': None}
952
 
953
  def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
954
  try:
955
  visualizer = create_streaming_visualizer()
956
  data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
957
+ event_queue.put({'type': 'update', 'data': data}, block=False, timeout=0.1)
958
  except queue.Full:
959
+ pass # Drop frames if too slow
960
 
961
  def run_generation():
962
  try:
 
976
  generation_complete['done'] = True
977
  event_queue.put(None)
978
 
979
+ # Start generation thread
980
+ gen_thread = threading.Thread(target=run_generation, daemon=True)
 
981
  gen_thread.start()
982
 
983
+ yield f"data: {json.dumps({'type': 'start', 'ts': time.time()})}\n\n"
984
 
985
+ # Stream events
986
  while not generation_complete['done'] or not event_queue.empty():
987
  try:
988
  event = event_queue.get(timeout=0.1)
 
994
 
995
  gen_thread.join(timeout=2.0)
996
 
997
+ # Send final result
998
  if generation_complete['result']:
999
  raw_output, response = generation_complete['result']
1000
+ yield f"data: {json.dumps({'type': 'complete' if raw_output != 'error' else 'error',
1001
+ 'response': response, 'error': response if raw_output == 'error' else None})}\n\n"
 
 
1002
 
1003
  return Response(
1004
  stream_with_context(generate_events()),
 
1012
 
1013
  @app.route('/api/status', methods=['GET'])
1014
  def status():
1015
+ """Get detailed status."""
1016
  return jsonify({
1017
  'model_loaded': model is not None,
1018
+ 'cpu_cores': os.cpu_count(),
1019
+ 'physical_cores': PHYSICAL_CORES,
1020
  'torch_threads': torch.get_num_threads(),
1021
  'interop_threads': torch.get_num_interop_threads(),
 
1022
  'optimizations': {
1023
  'onnx_runtime': USE_ONNX_RUNTIME,
1024
  'ipex': USE_IPEX,
1025
  'torch_compile': USE_TORCH_COMPILE,
1026
+ 'quantization': QUANTIZE_MODEL,
1027
+ 'warmup_iterations': WARMUP_ITERATIONS
1028
  }
1029
  })
1030
 
1031
  if __name__ == '__main__':
1032
  print("\n" + "=" * 70)
1033
+ print("ULTRA-FAST CPU INFERENCE SERVER v2.0")
1034
  print("=" * 70)
1035
+ print(f"CPU Configuration: {PHYSICAL_CORES} physical cores")
1036
+ print(f"Optimizations: ONNX={USE_ONNX_RUNTIME} | IPEX={USE_IPEX} | Compile={USE_TORCH_COMPILE}")
 
 
 
 
 
 
 
 
 
1037
  print("=" * 70 + "\n")
1038
 
1039
  app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)