Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -487,9 +487,43 @@ import torch
|
|
| 487 |
from transformers import AutoTokenizer
|
| 488 |
import threading
|
| 489 |
import queue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
app = Flask(__name__, static_folder='static', static_url_path='/static')
|
| 492 |
-
CORS(app)
|
| 493 |
|
| 494 |
# Global state
|
| 495 |
model = None
|
|
@@ -498,15 +532,7 @@ config = None
|
|
| 498 |
device = None
|
| 499 |
DiffusionLLM = None
|
| 500 |
chat_function = None
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
# ==================== CONFIGURATION ====================
|
| 504 |
-
USE_ONNX_RUNTIME = False # Set True for ONNX (fastest)
|
| 505 |
-
USE_IPEX = False # Set True for Intel CPUs
|
| 506 |
-
USE_TORCH_COMPILE = True # Set True for PyTorch 2.0+ (good default)
|
| 507 |
-
QUANTIZE_MODEL = True # INT8/BF16 quantization
|
| 508 |
-
WARMUP_ITERATIONS = 3 # Warmup for stable performance
|
| 509 |
-
# =======================================================
|
| 510 |
|
| 511 |
def find_file(filename, search_dirs=None):
|
| 512 |
"""Find a file in current directory or parent directories."""
|
|
@@ -531,53 +557,25 @@ def try_import_module(filepath, module_name):
|
|
| 531 |
module_dir = os.path.dirname(filepath)
|
| 532 |
if module_dir not in sys.path:
|
| 533 |
sys.path.insert(0, module_dir)
|
|
|
|
| 534 |
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
| 535 |
if spec is None:
|
| 536 |
print(f"Could not create spec for {filepath}")
|
| 537 |
return None
|
|
|
|
| 538 |
module = importlib.util.module_from_spec(spec)
|
| 539 |
sys.modules[module_name] = module
|
| 540 |
spec.loader.exec_module(module)
|
| 541 |
-
|
|
|
|
| 542 |
return module
|
| 543 |
except Exception as e:
|
| 544 |
-
print(f"Error importing {filepath}: {e}")
|
| 545 |
if __debug__:
|
| 546 |
import traceback
|
| 547 |
traceback.print_exc()
|
| 548 |
return None
|
| 549 |
|
| 550 |
-
def configure_cpu_optimization():
|
| 551 |
-
"""Configure CPU for maximum performance."""
|
| 552 |
-
print("\n" + "=" * 60)
|
| 553 |
-
print("CPU OPTIMIZATION CONFIGURATION")
|
| 554 |
-
print("=" * 60)
|
| 555 |
-
|
| 556 |
-
# Get CPU info
|
| 557 |
-
cpu_count = os.cpu_count()
|
| 558 |
-
physical_cores = cpu_count // 2 if cpu_count else cpu_count
|
| 559 |
-
|
| 560 |
-
# Optimal thread settings
|
| 561 |
-
threads = physical_cores or cpu_count or 1
|
| 562 |
-
torch.set_num_threads(threads)
|
| 563 |
-
torch.set_num_interop_threads(threads)
|
| 564 |
-
|
| 565 |
-
# Environment variables for MKL/OMP
|
| 566 |
-
os.environ["OMP_NUM_THREADS"] = str(threads)
|
| 567 |
-
os.environ["MKL_NUM_THREADS"] = str(threads)
|
| 568 |
-
os.environ["NUMEXPR_NUM_THREADS"] = str(threads)
|
| 569 |
-
|
| 570 |
-
print(f"✓ CPU Cores: {cpu_count} ({physical_cores} physical)")
|
| 571 |
-
print(f"✓ Threads: {threads}")
|
| 572 |
-
print(f"✓ OMP/MKL threads: {threads}")
|
| 573 |
-
|
| 574 |
-
# Intel-specific optimizations
|
| 575 |
-
if "intel" in torch.__version__.lower() or USE_IPEX:
|
| 576 |
-
torch.backends.quantized.engine = 'fbgemm'
|
| 577 |
-
print("✓ Intel FBGEMM backend enabled")
|
| 578 |
-
|
| 579 |
-
print("=" * 60 + "\n")
|
| 580 |
-
|
| 581 |
def quantize_model(model):
|
| 582 |
"""Apply quantization for faster inference."""
|
| 583 |
if not QUANTIZE_MODEL:
|
|
@@ -585,20 +583,13 @@ def quantize_model(model):
|
|
| 585 |
|
| 586 |
print("\nApplying quantization...")
|
| 587 |
try:
|
| 588 |
-
#
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
)
|
| 596 |
-
print("✓ Applied INT8 dynamic quantization")
|
| 597 |
-
elif USE_IPEX:
|
| 598 |
-
# IPEX BF16 optimization
|
| 599 |
-
model = model.to(torch.bfloat16)
|
| 600 |
-
print("✓ Applied BF16 precision (IPEX)")
|
| 601 |
-
|
| 602 |
return model
|
| 603 |
except Exception as e:
|
| 604 |
print(f"⚠ Quantization failed: {e}")
|
|
@@ -606,124 +597,114 @@ def quantize_model(model):
|
|
| 606 |
|
| 607 |
def compile_model(model):
|
| 608 |
"""Compile model for maximum speed."""
|
| 609 |
-
global USE_TORCH_COMPILE, USE_ONNX_RUNTIME
|
| 610 |
-
|
| 611 |
print("\nCompiling model...")
|
| 612 |
|
| 613 |
-
#
|
| 614 |
if USE_ONNX_RUNTIME:
|
| 615 |
try:
|
| 616 |
import onnxruntime as ort
|
| 617 |
-
# Export to ONNX (one-time cost, but worth it)
|
| 618 |
onnx_path = "model_optimized.onnx"
|
|
|
|
|
|
|
| 619 |
if not os.path.exists(onnx_path):
|
| 620 |
-
print("Exporting to ONNX format...")
|
| 621 |
-
dummy_input = torch.randint(0, 100, (1,
|
| 622 |
torch.onnx.export(
|
| 623 |
-
model,
|
| 624 |
-
dummy_input,
|
| 625 |
-
onnx_path,
|
| 626 |
input_names=['input_ids'],
|
| 627 |
output_names=['logits'],
|
| 628 |
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
|
| 629 |
opset_version=16
|
| 630 |
)
|
| 631 |
|
| 632 |
-
# Create
|
| 633 |
sess_options = ort.SessionOptions()
|
| 634 |
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 635 |
-
sess_options.
|
| 636 |
-
sess_options.enable_mem_pattern = True
|
| 637 |
|
| 638 |
-
|
| 639 |
-
sess_options.intra_op_num_threads = torch.get_num_threads()
|
| 640 |
-
|
| 641 |
-
provider = 'CPUExecutionProvider'
|
| 642 |
-
compiled_model = ort.InferenceSession(onnx_path, sess_options, providers=[provider])
|
| 643 |
-
|
| 644 |
-
print("✓ ONNX Runtime compilation complete")
|
| 645 |
-
return compiled_model
|
| 646 |
except Exception as e:
|
| 647 |
-
print(f"⚠ ONNX Runtime failed: {e}
|
| 648 |
-
USE_ONNX_RUNTIME = False
|
| 649 |
|
| 650 |
-
#
|
| 651 |
if USE_IPEX:
|
| 652 |
try:
|
| 653 |
import intel_extension_for_pytorch as ipex
|
| 654 |
model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
|
| 655 |
-
print("✓ Intel IPEX
|
| 656 |
return model
|
| 657 |
except Exception as e:
|
| 658 |
print(f"⚠ IPEX failed: {e}")
|
| 659 |
-
USE_IPEX = False
|
| 660 |
|
| 661 |
-
#
|
| 662 |
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
| 663 |
try:
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
model,
|
| 667 |
-
mode="max-autotune",
|
| 668 |
-
fullgraph=True,
|
| 669 |
-
backend="inductor"
|
| 670 |
-
)
|
| 671 |
-
print("✓ torch.compile (max-autotune) applied")
|
| 672 |
except Exception as e:
|
| 673 |
-
print(f"⚠ torch.compile failed: {e}
|
| 674 |
|
| 675 |
return model
|
| 676 |
|
| 677 |
def warmup_model(model, tokenizer, chat_func):
|
| 678 |
-
"""Warmup the model
|
| 679 |
if WARMUP_ITERATIONS == 0:
|
| 680 |
return
|
| 681 |
|
| 682 |
print("\nWarming up model...")
|
| 683 |
-
|
| 684 |
|
| 685 |
try:
|
| 686 |
with torch.inference_mode():
|
| 687 |
for i in range(WARMUP_ITERATIONS):
|
| 688 |
-
|
| 689 |
-
model, tokenizer, "Hello
|
| 690 |
-
steps=
|
| 691 |
-
temperature=0.
|
| 692 |
repetition_penalty=1.2, no_repeat_ngram_size=3,
|
| 693 |
-
verbose=False, visualize_fn=None, parallel_blocks=
|
| 694 |
)
|
| 695 |
-
print(f" Warmup {i+1}/{WARMUP_ITERATIONS}
|
| 696 |
except Exception as e:
|
| 697 |
print(f"⚠ Warmup failed: {e}")
|
| 698 |
|
| 699 |
-
print(f"✓ Warmup
|
| 700 |
|
| 701 |
def load_model_internal():
|
| 702 |
-
"""Load
|
| 703 |
-
global model, tokenizer, config, device, DiffusionLLM, chat_function,
|
| 704 |
|
| 705 |
if model is not None:
|
| 706 |
return True
|
| 707 |
|
| 708 |
try:
|
| 709 |
-
print("=" *
|
| 710 |
print("ULTRA-FAST CPU MODEL LOADING")
|
| 711 |
-
print("=" *
|
| 712 |
|
| 713 |
-
#
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
# Import modules
|
| 717 |
base_path = find_file("infer-base.py")
|
| 718 |
if base_path is None:
|
| 719 |
raise RuntimeError("Could not find infer-base.py")
|
| 720 |
|
| 721 |
base_mod = try_import_module(base_path, "infer_base")
|
| 722 |
-
if base_mod is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
raise RuntimeError("DiffusionLLM class not found")
|
| 724 |
|
| 725 |
DiffusionLLM = base_mod.DiffusionLLM
|
|
|
|
| 726 |
|
|
|
|
| 727 |
chat_path = find_file("infer-chat.py")
|
| 728 |
if chat_path is None:
|
| 729 |
raise RuntimeError("Could not find infer-chat.py")
|
|
@@ -733,12 +714,13 @@ def load_model_internal():
|
|
| 733 |
raise RuntimeError("Chat function not found")
|
| 734 |
|
| 735 |
chat_function = chat_mod.chat
|
|
|
|
| 736 |
|
| 737 |
# Device
|
| 738 |
device = torch.device("cpu")
|
| 739 |
|
| 740 |
# Load tokenizer
|
| 741 |
-
print("\
|
| 742 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 743 |
"Qwen/Qwen2.5-0.5B",
|
| 744 |
use_fast=True,
|
|
@@ -746,7 +728,7 @@ def load_model_internal():
|
|
| 746 |
)
|
| 747 |
if tokenizer.pad_token is None:
|
| 748 |
tokenizer.pad_token = tokenizer.eos_token
|
| 749 |
-
print("✓ Fast tokenizer
|
| 750 |
|
| 751 |
# Find model
|
| 752 |
checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
|
|
@@ -763,26 +745,27 @@ def load_model_internal():
|
|
| 763 |
if model_path is None:
|
| 764 |
raise RuntimeError("Model checkpoint not found")
|
| 765 |
|
| 766 |
-
print(f"\
|
| 767 |
|
| 768 |
-
# Load checkpoint
|
| 769 |
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 770 |
config = checkpoint['config']
|
| 771 |
|
| 772 |
-
# Create
|
|
|
|
| 773 |
model = DiffusionLLM(config)
|
| 774 |
-
state_dict = checkpoint['model_state']
|
| 775 |
|
| 776 |
-
#
|
| 777 |
-
|
|
|
|
| 778 |
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 779 |
|
| 780 |
model.load_state_dict(state_dict)
|
| 781 |
model.eval()
|
|
|
|
| 782 |
|
| 783 |
# Apply optimizations
|
| 784 |
model = quantize_model(model)
|
| 785 |
-
model = model.to(device)
|
| 786 |
model = compile_model(model)
|
| 787 |
|
| 788 |
# Warmup
|
|
@@ -790,24 +773,28 @@ def load_model_internal():
|
|
| 790 |
|
| 791 |
# Print summary
|
| 792 |
num_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
print(
|
|
|
|
|
|
|
| 797 |
print(f"Parameters: {num_params:.1f}M")
|
| 798 |
-
print(f"CPU Threads: {
|
| 799 |
-
print(f"Quantization: {'INT8' if QUANTIZE_MODEL else 'BF16' if USE_IPEX else 'FP32'}")
|
| 800 |
if 'step' in checkpoint:
|
| 801 |
print(f"Training steps: {checkpoint['step']}")
|
| 802 |
-
|
|
|
|
|
|
|
| 803 |
|
| 804 |
return True
|
| 805 |
|
| 806 |
except Exception as e:
|
| 807 |
-
print(f"\
|
| 808 |
if __debug__:
|
| 809 |
import traceback
|
| 810 |
traceback.print_exc()
|
|
|
|
| 811 |
return False
|
| 812 |
|
| 813 |
def create_streaming_visualizer():
|
|
@@ -818,7 +805,6 @@ def create_streaming_visualizer():
|
|
| 818 |
is_masked_list = [is_masked_list]
|
| 819 |
|
| 820 |
try:
|
| 821 |
-
# Decode only once for efficiency
|
| 822 |
context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
|
| 823 |
except Exception:
|
| 824 |
context_text = str(context_ids[0].tolist())
|
|
@@ -828,10 +814,9 @@ def create_streaming_visualizer():
|
|
| 828 |
block_tokens = mask_block[0].tolist()
|
| 829 |
block_data = []
|
| 830 |
|
| 831 |
-
#
|
| 832 |
token_ids_to_decode = []
|
| 833 |
positions = []
|
| 834 |
-
|
| 835 |
for i, token_id in enumerate(block_tokens):
|
| 836 |
if not is_masked[0, i]:
|
| 837 |
token_ids_to_decode.append(token_id)
|
|
@@ -845,7 +830,7 @@ def create_streaming_visualizer():
|
|
| 845 |
except Exception:
|
| 846 |
decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
|
| 847 |
|
| 848 |
-
# Reconstruct
|
| 849 |
decoded_idx = 0
|
| 850 |
for i, token_id in enumerate(block_tokens):
|
| 851 |
if is_masked[0, i]:
|
|
@@ -875,11 +860,11 @@ def index():
|
|
| 875 |
@app.route('/api/load', methods=['POST'])
|
| 876 |
def load_model_endpoint():
|
| 877 |
"""Load the model."""
|
|
|
|
|
|
|
| 878 |
data = request.json or {}
|
| 879 |
check_only = data.get('check_only', False)
|
| 880 |
|
| 881 |
-
global model
|
| 882 |
-
|
| 883 |
if check_only:
|
| 884 |
return jsonify({
|
| 885 |
'loaded': model is not None,
|
|
@@ -906,8 +891,8 @@ def load_model_endpoint():
|
|
| 906 |
|
| 907 |
@app.route('/api/generate', methods=['POST'])
|
| 908 |
def generate():
|
| 909 |
-
"""Generate response -
|
| 910 |
-
global model, tokenizer,
|
| 911 |
|
| 912 |
if model is None:
|
| 913 |
return jsonify({'error': 'Model not loaded'}), 400
|
|
@@ -917,38 +902,25 @@ def generate():
|
|
| 917 |
|
| 918 |
data = request.json
|
| 919 |
instruction = data.get('instruction', '')
|
| 920 |
-
steps = data.get('steps', 16) #
|
| 921 |
block_size = data.get('block_size', 32)
|
| 922 |
max_new_tokens = data.get('max_new_tokens', 64)
|
| 923 |
-
parallel_blocks = data.get('parallel_blocks',
|
| 924 |
|
| 925 |
if not instruction:
|
| 926 |
return jsonify({'error': 'No instruction provided'}), 400
|
| 927 |
|
| 928 |
try:
|
| 929 |
-
# Fast path: no overhead
|
| 930 |
with torch.inference_mode():
|
| 931 |
raw_output, response = chat_function(
|
| 932 |
-
model,
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
max_new_tokens=max_new_tokens,
|
| 938 |
-
temperature=0.7, # Slightly lower for faster sampling
|
| 939 |
-
top_k=50,
|
| 940 |
-
top_p=0.9,
|
| 941 |
-
repetition_penalty=1.2,
|
| 942 |
-
no_repeat_ngram_size=3,
|
| 943 |
-
verbose=False,
|
| 944 |
-
visualize_fn=None,
|
| 945 |
-
parallel_blocks=parallel_blocks,
|
| 946 |
)
|
| 947 |
|
| 948 |
-
return jsonify({
|
| 949 |
-
'response': response,
|
| 950 |
-
'raw_output': raw_output
|
| 951 |
-
})
|
| 952 |
|
| 953 |
except Exception as e:
|
| 954 |
if __debug__:
|
|
@@ -958,36 +930,33 @@ def generate():
|
|
| 958 |
|
| 959 |
@app.route('/api/generate-stream', methods=['POST'])
|
| 960 |
def generate_stream():
|
| 961 |
-
"""Generate with streaming - optimized
|
| 962 |
global model, tokenizer, chat_function
|
| 963 |
|
| 964 |
if model is None:
|
| 965 |
return jsonify({'error': 'Model not loaded'}), 400
|
| 966 |
|
| 967 |
-
if chat_function is None:
|
| 968 |
-
return jsonify({'error': 'Chat function not available'}), 400
|
| 969 |
-
|
| 970 |
data = request.json
|
| 971 |
instruction = data.get('instruction', '')
|
| 972 |
steps = data.get('steps', 16)
|
| 973 |
block_size = data.get('block_size', 32)
|
| 974 |
max_new_tokens = data.get('max_new_tokens', 64)
|
| 975 |
-
parallel_blocks = data.get('parallel_blocks',
|
| 976 |
|
| 977 |
if not instruction:
|
| 978 |
return jsonify({'error': 'No instruction provided'}), 400
|
| 979 |
|
| 980 |
def generate_events():
|
| 981 |
-
event_queue = queue.Queue(maxsize=
|
| 982 |
generation_complete = {'done': False, 'result': None}
|
| 983 |
|
| 984 |
def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
|
| 985 |
try:
|
| 986 |
visualizer = create_streaming_visualizer()
|
| 987 |
data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
|
| 988 |
-
event_queue.put({'type': 'update', 'data': data}, block=False)
|
| 989 |
except queue.Full:
|
| 990 |
-
pass # Drop
|
| 991 |
|
| 992 |
def run_generation():
|
| 993 |
try:
|
|
@@ -1007,13 +976,13 @@ def generate_stream():
|
|
| 1007 |
generation_complete['done'] = True
|
| 1008 |
event_queue.put(None)
|
| 1009 |
|
| 1010 |
-
|
| 1011 |
-
gen_thread = threading.Thread(target=run_generation)
|
| 1012 |
-
gen_thread.daemon = True
|
| 1013 |
gen_thread.start()
|
| 1014 |
|
| 1015 |
-
yield f"data: {json.dumps({'type': 'start'})}\n\n"
|
| 1016 |
|
|
|
|
| 1017 |
while not generation_complete['done'] or not event_queue.empty():
|
| 1018 |
try:
|
| 1019 |
event = event_queue.get(timeout=0.1)
|
|
@@ -1025,12 +994,11 @@ def generate_stream():
|
|
| 1025 |
|
| 1026 |
gen_thread.join(timeout=2.0)
|
| 1027 |
|
|
|
|
| 1028 |
if generation_complete['result']:
|
| 1029 |
raw_output, response = generation_complete['result']
|
| 1030 |
-
if raw_output
|
| 1031 |
-
|
| 1032 |
-
else:
|
| 1033 |
-
yield f"data: {json.dumps({'type': 'complete', 'response': response})}\n\n"
|
| 1034 |
|
| 1035 |
return Response(
|
| 1036 |
stream_with_context(generate_events()),
|
|
@@ -1044,35 +1012,28 @@ def generate_stream():
|
|
| 1044 |
|
| 1045 |
@app.route('/api/status', methods=['GET'])
|
| 1046 |
def status():
|
| 1047 |
-
"""Get
|
| 1048 |
return jsonify({
|
| 1049 |
'model_loaded': model is not None,
|
|
|
|
|
|
|
| 1050 |
'torch_threads': torch.get_num_threads(),
|
| 1051 |
'interop_threads': torch.get_num_interop_threads(),
|
| 1052 |
-
'cpu_count': os.cpu_count(),
|
| 1053 |
'optimizations': {
|
| 1054 |
'onnx_runtime': USE_ONNX_RUNTIME,
|
| 1055 |
'ipex': USE_IPEX,
|
| 1056 |
'torch_compile': USE_TORCH_COMPILE,
|
| 1057 |
-
'quantization': QUANTIZE_MODEL
|
|
|
|
| 1058 |
}
|
| 1059 |
})
|
| 1060 |
|
| 1061 |
if __name__ == '__main__':
|
| 1062 |
print("\n" + "=" * 70)
|
| 1063 |
-
print("ULTRA-FAST CPU INFERENCE SERVER")
|
| 1064 |
print("=" * 70)
|
| 1065 |
-
print("
|
| 1066 |
-
print("
|
| 1067 |
-
print(" ✓ Intel IPEX:", USE_IPEX)
|
| 1068 |
-
print(" ✓ torch.compile:", USE_TORCH_COMPILE)
|
| 1069 |
-
print(" ✓ Quantization:", QUANTIZE_MODEL)
|
| 1070 |
-
print(" ✓ Multi-threading")
|
| 1071 |
-
print(" ✓ Inference mode")
|
| 1072 |
-
print(" ✓ Fast tokenizer")
|
| 1073 |
-
print(" ✓ Memory layout optimization")
|
| 1074 |
-
print("\nTo install ONNX Runtime: pip install onnxruntime")
|
| 1075 |
-
print("To install Intel IPEX: pip install intel-extension-for-pytorch")
|
| 1076 |
print("=" * 70 + "\n")
|
| 1077 |
|
| 1078 |
app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)
|
|
|
|
| 487 |
from transformers import AutoTokenizer
|
| 488 |
import threading
|
| 489 |
import queue
|
| 490 |
+
import warnings
|
| 491 |
+
|
| 492 |
+
# ============ CRITICAL: CONFIGURE THREADS BEFORE TORCH OPERATIONS ============
|
| 493 |
+
# Must be set IMMEDIATELY at module import time
|
| 494 |
+
def setup_cpu_threads():
|
| 495 |
+
"""Configure CPU threads BEFORE any PyTorch parallel work starts."""
|
| 496 |
+
cpu_count = os.cpu_count() or 1
|
| 497 |
+
physical_cores = cpu_count // 2 if cpu_count > 1 else 1
|
| 498 |
+
|
| 499 |
+
# Set environment variables FIRST
|
| 500 |
+
os.environ["OMP_NUM_THREADS"] = str(physical_cores)
|
| 501 |
+
os.environ["MKL_NUM_THREADS"] = str(physical_cores)
|
| 502 |
+
os.environ["NUMEXPR_NUM_THREADS"] = str(physical_cores)
|
| 503 |
+
|
| 504 |
+
# Set PyTorch threads BEFORE any operations
|
| 505 |
+
try:
|
| 506 |
+
torch.set_num_threads(physical_cores)
|
| 507 |
+
torch.set_num_interop_threads(physical_cores)
|
| 508 |
+
except RuntimeError as e:
|
| 509 |
+
warnings.warn(f"Could not set threads: {e} (already initialized)")
|
| 510 |
+
|
| 511 |
+
print(f"✓ CPU threads configured: {physical_cores} physical cores")
|
| 512 |
+
return physical_cores
|
| 513 |
+
|
| 514 |
+
# Call immediately
|
| 515 |
+
PHYSICAL_CORES = setup_cpu_threads()
|
| 516 |
+
# ============================================================================
|
| 517 |
+
|
| 518 |
+
# Configuration flags
|
| 519 |
+
USE_ONNX_RUNTIME = False
|
| 520 |
+
USE_IPEX = False
|
| 521 |
+
USE_TORCH_COMPILE = True
|
| 522 |
+
QUANTIZE_MODEL = True
|
| 523 |
+
WARMUP_ITERATIONS = 3
|
| 524 |
|
| 525 |
app = Flask(__name__, static_folder='static', static_url_path='/static')
|
| 526 |
+
CORS(app, resources={r"/api/*": {"origins": "*"}}) # More permissive for testing
|
| 527 |
|
| 528 |
# Global state
|
| 529 |
model = None
|
|
|
|
| 532 |
device = None
|
| 533 |
DiffusionLLM = None
|
| 534 |
chat_function = None
|
| 535 |
+
ModelConfig = None # Will be imported from infer-base.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
def find_file(filename, search_dirs=None):
|
| 538 |
"""Find a file in current directory or parent directories."""
|
|
|
|
| 557 |
module_dir = os.path.dirname(filepath)
|
| 558 |
if module_dir not in sys.path:
|
| 559 |
sys.path.insert(0, module_dir)
|
| 560 |
+
|
| 561 |
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
| 562 |
if spec is None:
|
| 563 |
print(f"Could not create spec for {filepath}")
|
| 564 |
return None
|
| 565 |
+
|
| 566 |
module = importlib.util.module_from_spec(spec)
|
| 567 |
sys.modules[module_name] = module
|
| 568 |
spec.loader.exec_module(module)
|
| 569 |
+
|
| 570 |
+
print(f"✓ Successfully imported {module_name}")
|
| 571 |
return module
|
| 572 |
except Exception as e:
|
| 573 |
+
print(f"✗ Error importing {filepath}: {e}")
|
| 574 |
if __debug__:
|
| 575 |
import traceback
|
| 576 |
traceback.print_exc()
|
| 577 |
return None
|
| 578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
def quantize_model(model):
|
| 580 |
"""Apply quantization for faster inference."""
|
| 581 |
if not QUANTIZE_MODEL:
|
|
|
|
| 583 |
|
| 584 |
print("\nApplying quantization...")
|
| 585 |
try:
|
| 586 |
+
# Dynamic quantization - no calibration needed, works on any model
|
| 587 |
+
model = torch.quantization.quantize_dynamic(
|
| 588 |
+
model,
|
| 589 |
+
{torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding},
|
| 590 |
+
dtype=torch.qint8
|
| 591 |
+
)
|
| 592 |
+
print("✓ INT8 dynamic quantization applied")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
return model
|
| 594 |
except Exception as e:
|
| 595 |
print(f"⚠ Quantization failed: {e}")
|
|
|
|
| 597 |
|
| 598 |
def compile_model(model):
|
| 599 |
"""Compile model for maximum speed."""
|
|
|
|
|
|
|
| 600 |
print("\nCompiling model...")
|
| 601 |
|
| 602 |
+
# ONNX Runtime (BEST performance)
|
| 603 |
if USE_ONNX_RUNTIME:
|
| 604 |
try:
|
| 605 |
import onnxruntime as ort
|
|
|
|
| 606 |
onnx_path = "model_optimized.onnx"
|
| 607 |
+
|
| 608 |
+
# Export if not exists
|
| 609 |
if not os.path.exists(onnx_path):
|
| 610 |
+
print("Exporting model to ONNX format...")
|
| 611 |
+
dummy_input = torch.randint(0, 100, (1, 64))
|
| 612 |
torch.onnx.export(
|
| 613 |
+
model, dummy_input, onnx_path,
|
|
|
|
|
|
|
| 614 |
input_names=['input_ids'],
|
| 615 |
output_names=['logits'],
|
| 616 |
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}},
|
| 617 |
opset_version=16
|
| 618 |
)
|
| 619 |
|
| 620 |
+
# Create optimized session
|
| 621 |
sess_options = ort.SessionOptions()
|
| 622 |
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 623 |
+
sess_options.intra_op_num_threads = PHYSICAL_CORES
|
|
|
|
| 624 |
|
| 625 |
+
return ort.InferenceSession(onnx_path, sess_options)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
except Exception as e:
|
| 627 |
+
print(f"⚠ ONNX Runtime failed: {e}")
|
|
|
|
| 628 |
|
| 629 |
+
# Intel IPEX
|
| 630 |
if USE_IPEX:
|
| 631 |
try:
|
| 632 |
import intel_extension_for_pytorch as ipex
|
| 633 |
model = ipex.optimize(model, dtype=torch.bfloat16, level="O3")
|
| 634 |
+
print("✓ Intel IPEX optimization applied")
|
| 635 |
return model
|
| 636 |
except Exception as e:
|
| 637 |
print(f"⚠ IPEX failed: {e}")
|
|
|
|
| 638 |
|
| 639 |
+
# torch.compile
|
| 640 |
if USE_TORCH_COMPILE and hasattr(torch, 'compile'):
|
| 641 |
try:
|
| 642 |
+
model = torch.compile(model, mode="max-autotune")
|
| 643 |
+
print("✓ torch.compile applied")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
except Exception as e:
|
| 645 |
+
print(f"⚠ torch.compile failed: {e}")
|
| 646 |
|
| 647 |
return model
|
| 648 |
|
| 649 |
def warmup_model(model, tokenizer, chat_func):
|
| 650 |
+
"""Warmup the model."""
|
| 651 |
if WARMUP_ITERATIONS == 0:
|
| 652 |
return
|
| 653 |
|
| 654 |
print("\nWarming up model...")
|
| 655 |
+
start = time.time()
|
| 656 |
|
| 657 |
try:
|
| 658 |
with torch.inference_mode():
|
| 659 |
for i in range(WARMUP_ITERATIONS):
|
| 660 |
+
chat_func(
|
| 661 |
+
model, tokenizer, "Hello",
|
| 662 |
+
steps=4, block_size=16, max_new_tokens=8,
|
| 663 |
+
temperature=0.7, top_k=50, top_p=0.9,
|
| 664 |
repetition_penalty=1.2, no_repeat_ngram_size=3,
|
| 665 |
+
verbose=False, visualize_fn=None, parallel_blocks=PHYSICAL_CORES
|
| 666 |
)
|
| 667 |
+
print(f" Warmup {i+1}/{WARMUP_ITERATIONS}...")
|
| 668 |
except Exception as e:
|
| 669 |
print(f"⚠ Warmup failed: {e}")
|
| 670 |
|
| 671 |
+
print(f"✓ Warmup complete ({time.time() - start:.2f}s)")
|
| 672 |
|
| 673 |
def load_model_internal():
|
| 674 |
+
"""Load model with ultra-fast optimizations."""
|
| 675 |
+
global model, tokenizer, config, device, DiffusionLLM, chat_function, ModelConfig
|
| 676 |
|
| 677 |
if model is not None:
|
| 678 |
return True
|
| 679 |
|
| 680 |
try:
|
| 681 |
+
print("\n" + "=" * 70)
|
| 682 |
print("ULTRA-FAST CPU MODEL LOADING")
|
| 683 |
+
print("=" * 70)
|
| 684 |
|
| 685 |
+
# FIRST: Import modules to get ModelConfig
|
| 686 |
+
print("\n1. Loading modules...")
|
|
|
|
|
|
|
| 687 |
base_path = find_file("infer-base.py")
|
| 688 |
if base_path is None:
|
| 689 |
raise RuntimeError("Could not find infer-base.py")
|
| 690 |
|
| 691 |
base_mod = try_import_module(base_path, "infer_base")
|
| 692 |
+
if base_mod is None:
|
| 693 |
+
raise RuntimeError("Failed to import infer-base.py")
|
| 694 |
+
|
| 695 |
+
# CRITICAL: Register ModelConfig for pickle
|
| 696 |
+
if hasattr(base_mod, 'ModelConfig'):
|
| 697 |
+
ModelConfig = base_mod.ModelConfig
|
| 698 |
+
sys.modules['__main__'].ModelConfig = ModelConfig
|
| 699 |
+
print("✓ ModelConfig registered for pickle")
|
| 700 |
+
|
| 701 |
+
if not hasattr(base_mod, 'DiffusionLLM'):
|
| 702 |
raise RuntimeError("DiffusionLLM class not found")
|
| 703 |
|
| 704 |
DiffusionLLM = base_mod.DiffusionLLM
|
| 705 |
+
print("✓ DiffusionLLM loaded")
|
| 706 |
|
| 707 |
+
# Import chat function
|
| 708 |
chat_path = find_file("infer-chat.py")
|
| 709 |
if chat_path is None:
|
| 710 |
raise RuntimeError("Could not find infer-chat.py")
|
|
|
|
| 714 |
raise RuntimeError("Chat function not found")
|
| 715 |
|
| 716 |
chat_function = chat_mod.chat
|
| 717 |
+
print("✓ Chat function loaded")
|
| 718 |
|
| 719 |
# Device
|
| 720 |
device = torch.device("cpu")
|
| 721 |
|
| 722 |
# Load tokenizer
|
| 723 |
+
print("\n2. Loading tokenizer...")
|
| 724 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 725 |
"Qwen/Qwen2.5-0.5B",
|
| 726 |
use_fast=True,
|
|
|
|
| 728 |
)
|
| 729 |
if tokenizer.pad_token is None:
|
| 730 |
tokenizer.pad_token = tokenizer.eos_token
|
| 731 |
+
print("✓ Fast tokenizer ready")
|
| 732 |
|
| 733 |
# Find model
|
| 734 |
checkpoint_dirs = ["checkpoints", "../checkpoints", "./checkpoints"]
|
|
|
|
| 745 |
if model_path is None:
|
| 746 |
raise RuntimeError("Model checkpoint not found")
|
| 747 |
|
| 748 |
+
print(f"\n3. Loading checkpoint: {model_path}")
|
| 749 |
|
| 750 |
+
# CRITICAL: Load checkpoint AFTER ModelConfig is registered
|
| 751 |
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 752 |
config = checkpoint['config']
|
| 753 |
|
| 754 |
+
# Create model
|
| 755 |
+
print("4. Building model...")
|
| 756 |
model = DiffusionLLM(config)
|
|
|
|
| 757 |
|
| 758 |
+
# Load weights
|
| 759 |
+
state_dict = checkpoint['model_state']
|
| 760 |
+
if not USE_IPEX:
|
| 761 |
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 762 |
|
| 763 |
model.load_state_dict(state_dict)
|
| 764 |
model.eval()
|
| 765 |
+
model = model.to(device)
|
| 766 |
|
| 767 |
# Apply optimizations
|
| 768 |
model = quantize_model(model)
|
|
|
|
| 769 |
model = compile_model(model)
|
| 770 |
|
| 771 |
# Warmup
|
|
|
|
| 773 |
|
| 774 |
# Print summary
|
| 775 |
num_params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 776 |
+
framework = "ONNX Runtime" if USE_ONNX_RUNTIME else "IPEX" if USE_IPEX else "PyTorch"
|
| 777 |
+
precision = "INT8" if QUANTIZE_MODEL and not USE_IPEX else "BF16" if USE_IPEX else "FP32"
|
| 778 |
+
|
| 779 |
+
print("\n" + "=" * 70)
|
| 780 |
+
print(f"✓✓✓ MODEL LOADED & ULTRA-OPTIMIZED ({framework} + {precision}) ✓✓✓")
|
| 781 |
+
print("=" * 70)
|
| 782 |
print(f"Parameters: {num_params:.1f}M")
|
| 783 |
+
print(f"CPU Threads: {PHYSICAL_CORES}")
|
|
|
|
| 784 |
if 'step' in checkpoint:
|
| 785 |
print(f"Training steps: {checkpoint['step']}")
|
| 786 |
+
if 'best_val_loss' in checkpoint:
|
| 787 |
+
print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")
|
| 788 |
+
print("=" * 70 + "\n")
|
| 789 |
|
| 790 |
return True
|
| 791 |
|
| 792 |
except Exception as e:
|
| 793 |
+
print(f"\n✗ ERROR LOADING MODEL: {e}")
|
| 794 |
if __debug__:
|
| 795 |
import traceback
|
| 796 |
traceback.print_exc()
|
| 797 |
+
print("=" * 70 + "\n")
|
| 798 |
return False
|
| 799 |
|
| 800 |
def create_streaming_visualizer():
|
|
|
|
| 805 |
is_masked_list = [is_masked_list]
|
| 806 |
|
| 807 |
try:
|
|
|
|
| 808 |
context_text = tok.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
|
| 809 |
except Exception:
|
| 810 |
context_text = str(context_ids[0].tolist())
|
|
|
|
| 814 |
block_tokens = mask_block[0].tolist()
|
| 815 |
block_data = []
|
| 816 |
|
| 817 |
+
# Efficient batch decoding
|
| 818 |
token_ids_to_decode = []
|
| 819 |
positions = []
|
|
|
|
| 820 |
for i, token_id in enumerate(block_tokens):
|
| 821 |
if not is_masked[0, i]:
|
| 822 |
token_ids_to_decode.append(token_id)
|
|
|
|
| 830 |
except Exception:
|
| 831 |
decoded_tokens = [str(int(tid)) for tid in token_ids_to_decode]
|
| 832 |
|
| 833 |
+
# Reconstruct
|
| 834 |
decoded_idx = 0
|
| 835 |
for i, token_id in enumerate(block_tokens):
|
| 836 |
if is_masked[0, i]:
|
|
|
|
| 860 |
@app.route('/api/load', methods=['POST'])
|
| 861 |
def load_model_endpoint():
|
| 862 |
"""Load the model."""
|
| 863 |
+
global model
|
| 864 |
+
|
| 865 |
data = request.json or {}
|
| 866 |
check_only = data.get('check_only', False)
|
| 867 |
|
|
|
|
|
|
|
| 868 |
if check_only:
|
| 869 |
return jsonify({
|
| 870 |
'loaded': model is not None,
|
|
|
|
| 891 |
|
| 892 |
@app.route('/api/generate', methods=['POST'])
|
| 893 |
def generate():
|
| 894 |
+
"""Generate response - ultra-fast path."""
|
| 895 |
+
global model, tokenizer, chat_function
|
| 896 |
|
| 897 |
if model is None:
|
| 898 |
return jsonify({'error': 'Model not loaded'}), 400
|
|
|
|
| 902 |
|
| 903 |
data = request.json
|
| 904 |
instruction = data.get('instruction', '')
|
| 905 |
+
steps = data.get('steps', 16) # Minimal for speed
|
| 906 |
block_size = data.get('block_size', 32)
|
| 907 |
max_new_tokens = data.get('max_new_tokens', 64)
|
| 908 |
+
parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
|
| 909 |
|
| 910 |
if not instruction:
|
| 911 |
return jsonify({'error': 'No instruction provided'}), 400
|
| 912 |
|
| 913 |
try:
|
|
|
|
| 914 |
with torch.inference_mode():
|
| 915 |
raw_output, response = chat_function(
|
| 916 |
+
model, tokenizer, instruction,
|
| 917 |
+
steps=steps, block_size=block_size, max_new_tokens=max_new_tokens,
|
| 918 |
+
temperature=0.7, top_k=50, top_p=0.9,
|
| 919 |
+
repetition_penalty=1.2, no_repeat_ngram_size=3,
|
| 920 |
+
verbose=False, visualize_fn=None, parallel_blocks=parallel_blocks,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
)
|
| 922 |
|
| 923 |
+
return jsonify({'response': response, 'raw_output': raw_output})
|
|
|
|
|
|
|
|
|
|
| 924 |
|
| 925 |
except Exception as e:
|
| 926 |
if __debug__:
|
|
|
|
| 930 |
|
| 931 |
@app.route('/api/generate-stream', methods=['POST'])
|
| 932 |
def generate_stream():
|
| 933 |
+
"""Generate with streaming - optimized."""
|
| 934 |
global model, tokenizer, chat_function
|
| 935 |
|
| 936 |
if model is None:
|
| 937 |
return jsonify({'error': 'Model not loaded'}), 400
|
| 938 |
|
|
|
|
|
|
|
|
|
|
| 939 |
data = request.json
|
| 940 |
instruction = data.get('instruction', '')
|
| 941 |
steps = data.get('steps', 16)
|
| 942 |
block_size = data.get('block_size', 32)
|
| 943 |
max_new_tokens = data.get('max_new_tokens', 64)
|
| 944 |
+
parallel_blocks = data.get('parallel_blocks', PHYSICAL_CORES)
|
| 945 |
|
| 946 |
if not instruction:
|
| 947 |
return jsonify({'error': 'No instruction provided'}), 400
|
| 948 |
|
| 949 |
def generate_events():
|
| 950 |
+
event_queue = queue.Queue(maxsize=50) # Limited queue
|
| 951 |
generation_complete = {'done': False, 'result': None}
|
| 952 |
|
| 953 |
def streaming_visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear=True):
|
| 954 |
try:
|
| 955 |
visualizer = create_streaming_visualizer()
|
| 956 |
data = visualizer(tok, context_ids, mask_blocks, is_masked_list, cfg, clear)
|
| 957 |
+
event_queue.put({'type': 'update', 'data': data}, block=False, timeout=0.1)
|
| 958 |
except queue.Full:
|
| 959 |
+
pass # Drop frames if too slow
|
| 960 |
|
| 961 |
def run_generation():
|
| 962 |
try:
|
|
|
|
| 976 |
generation_complete['done'] = True
|
| 977 |
event_queue.put(None)
|
| 978 |
|
| 979 |
+
# Start generation thread
|
| 980 |
+
gen_thread = threading.Thread(target=run_generation, daemon=True)
|
|
|
|
| 981 |
gen_thread.start()
|
| 982 |
|
| 983 |
+
yield f"data: {json.dumps({'type': 'start', 'ts': time.time()})}\n\n"
|
| 984 |
|
| 985 |
+
# Stream events
|
| 986 |
while not generation_complete['done'] or not event_queue.empty():
|
| 987 |
try:
|
| 988 |
event = event_queue.get(timeout=0.1)
|
|
|
|
| 994 |
|
| 995 |
gen_thread.join(timeout=2.0)
|
| 996 |
|
| 997 |
+
# Send final result
|
| 998 |
if generation_complete['result']:
|
| 999 |
raw_output, response = generation_complete['result']
|
| 1000 |
+
yield f"data: {json.dumps({'type': 'complete' if raw_output != 'error' else 'error',
|
| 1001 |
+
'response': response, 'error': response if raw_output == 'error' else None})}\n\n"
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
return Response(
|
| 1004 |
stream_with_context(generate_events()),
|
|
|
|
| 1012 |
|
| 1013 |
@app.route('/api/status', methods=['GET'])
|
| 1014 |
def status():
|
| 1015 |
+
"""Get detailed status."""
|
| 1016 |
return jsonify({
|
| 1017 |
'model_loaded': model is not None,
|
| 1018 |
+
'cpu_cores': os.cpu_count(),
|
| 1019 |
+
'physical_cores': PHYSICAL_CORES,
|
| 1020 |
'torch_threads': torch.get_num_threads(),
|
| 1021 |
'interop_threads': torch.get_num_interop_threads(),
|
|
|
|
| 1022 |
'optimizations': {
|
| 1023 |
'onnx_runtime': USE_ONNX_RUNTIME,
|
| 1024 |
'ipex': USE_IPEX,
|
| 1025 |
'torch_compile': USE_TORCH_COMPILE,
|
| 1026 |
+
'quantization': QUANTIZE_MODEL,
|
| 1027 |
+
'warmup_iterations': WARMUP_ITERATIONS
|
| 1028 |
}
|
| 1029 |
})
|
| 1030 |
|
| 1031 |
if __name__ == '__main__':
|
| 1032 |
print("\n" + "=" * 70)
|
| 1033 |
+
print("ULTRA-FAST CPU INFERENCE SERVER v2.0")
|
| 1034 |
print("=" * 70)
|
| 1035 |
+
print(f"CPU Configuration: {PHYSICAL_CORES} physical cores")
|
| 1036 |
+
print(f"Optimizations: ONNX={USE_ONNX_RUNTIME} | IPEX={USE_IPEX} | Compile={USE_TORCH_COMPILE}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
print("=" * 70 + "\n")
|
| 1038 |
|
| 1039 |
app.run(debug=False, host='0.0.0.0', port=7860, threaded=True)
|