#!/usr/bin/env python3 """ ╔══════════════════════════════════════════════════════════════════════════════╗ ║ ║ ║ 🔧 Paris MoE - Weight Quantization Utility 🔧 ║ ║ ║ ║ Converts weights between formats: ║ ║ • Input: .pt (PyTorch) or .safetensors (F32 or BF16) ║ ║ • Output: BF16 or INT8 safetensors ║ ║ ║ ╚══════════════════════════════════════════════════════════════════════════════╝ Usage: # Convert original .pt files to BF16 safetensors python quantize.py --input /path/to/weights/ --output ./weights/bf16 --format bf16 # Convert to INT8 safetensors python quantize.py --input /path/to/weights/ --output ./weights/int8 --format int8 # Convert from existing safetensors (bf16 -> int8) python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8 Input Formats Supported: - PyTorch .pt files (original training checkpoints) - SafeTensors .safetensors files (F32 or BF16) Output Formats: - bf16: BFloat16 safetensors (best quality, ~1.2GB per expert) - int8: INT8 quantized safetensors (~580MB per expert) """ import argparse import os import gc from pathlib import Path from typing import Dict, Optional, Tuple import json import torch from safetensors.torch import save_file, load_file from safetensors import safe_open from tqdm import tqdm # ═══════════════════════════════════════════════════════════════════════════════ # FILE DETECTION # ═══════════════════════════════════════════════════════════════════════════════ def detect_input_format(input_dir: Path) -> Tuple[str, Dict[str, Path]]: """ Detect input format and locate weight files. Returns: format: 'pt' or 'safetensors' files: Dict mapping 'expert_0'..'expert_7', 'router' to file paths """ files = {} # Check for PyTorch .pt files (original format) pt_patterns = [ # Pattern 1: Full training checkpoint names ("dit_xl2_multi_expert_pretrained_text_new_dataset_expert_{}_best.pt", "expert_{}"), ("laion_router_preclustered_dit_berthead_b2_improved_router_best.pt", "router"), # Pattern 2: Simple names ("expert_{}_best.pt", "expert_{}"), ("expert_{}.pt", "expert_{}"), ("router_best.pt", "router"), ("router.pt", "router"), ] # Check for SafeTensors files st_patterns = [ ("expert_{}.safetensors", "expert_{}"), ("router.safetensors", "router"), ] # Try PyTorch patterns first for pattern, key_pattern in pt_patterns: if "{}" in pattern: # Expert pattern for i in range(8): filename = pattern.format(i) filepath = input_dir / filename if filepath.exists(): key = key_pattern.format(i) files[key] = filepath else: # Router pattern filepath = input_dir / pattern if filepath.exists(): files[key_pattern] = filepath if len(files) >= 8: # At least 8 experts found return 'pt', files # Try SafeTensors patterns files = {} for pattern, key_pattern in st_patterns: if "{}" in pattern: for i in range(8): filename = pattern.format(i) filepath = input_dir / filename if filepath.exists(): key = key_pattern.format(i) files[key] = filepath else: filepath = input_dir / pattern if filepath.exists(): files[key_pattern] = filepath if len(files) >= 8: return 'safetensors', files # List what we found print(f"Found files in {input_dir}:") for f in sorted(input_dir.glob("*")): print(f" {f.name}") raise ValueError(f"Could not find weight files in {input_dir}") # ═══════════════════════════════════════════════════════════════════════════════ # LOADING UTILITIES # ═══════════════════════════════════════════════════════════════════════════════ def load_pt_expert(filepath: Path, expert_id: int) -> Tuple[dict, Optional[object]]: """ Load expert weights from PyTorch checkpoint. Returns: state_dict: Model weights config: Config object if available """ print(f" Loading {filepath.name}...") ckpt = torch.load(filepath, map_location='cpu', weights_only=False) # Try EMA weights first (preferred for inference) ema_key = f'expert_{expert_id}_ema_state_dict' regular_key = f'expert_{expert_id}_state_dict' if ema_key in ckpt: state_dict = ckpt[ema_key] print(f" Using EMA weights") elif regular_key in ckpt: state_dict = ckpt[regular_key] print(f" Using regular weights (no EMA)") else: # Try to find any state dict key for k in ckpt.keys(): if 'state_dict' in k and 'optimizer' not in k: state_dict = ckpt[k] print(f" Using key: {k}") break else: raise KeyError(f"No state dict found in {filepath}") config = ckpt.get('config', None) return state_dict, config def load_pt_router(filepath: Path) -> Tuple[dict, Optional[object]]: """Load router weights from PyTorch checkpoint.""" print(f" Loading {filepath.name}...") ckpt = torch.load(filepath, map_location='cpu', weights_only=False) if 'router_state_dict' in ckpt: state_dict = ckpt['router_state_dict'] else: raise KeyError(f"router_state_dict not found in {filepath}") config = ckpt.get('config', None) return state_dict, config def load_safetensors_weights(filepath: Path) -> dict: """Load weights from SafeTensors file.""" print(f" Loading {filepath.name}...") return load_file(str(filepath)) # ═══════════════════════════════════════════════════════════════════════════════ # QUANTIZATION # ═══════════════════════════════════════════════════════════════════════════════ def convert_to_bf16(state_dict: dict) -> dict: """Convert all floating point tensors to bfloat16.""" bf16_state = {} for k, v in state_dict.items(): if isinstance(v, torch.Tensor) and v.is_floating_point(): bf16_state[k] = v.to(torch.bfloat16) else: bf16_state[k] = v return bf16_state def is_layernorm_key(key: str) -> bool: """Check if a key belongs to a LayerNorm layer.""" ln_patterns = ['norm', 'layernorm', 'layer_norm', 'ln_', 'scale_shift_table'] key_lower = key.lower() return any(p in key_lower for p in ln_patterns) def quantize_tensor_int8(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, float]: """ Quantize a tensor to INT8 with min/max scaling. Formula: int8 = round((x - min) / (max - min) * 255) - 128 """ if tensor.numel() == 0: return tensor.to(torch.int8), 0.0, 0.0 t_float = tensor.float() t_min = t_float.min().item() t_max = t_float.max().item() if t_min == t_max: return torch.zeros_like(tensor, dtype=torch.int8), t_min, t_max # Quantize: map [min, max] to [-128, 127] normalized = (t_float - t_min) / (t_max - t_min) int8_tensor = (normalized * 255 - 128).round().clamp(-128, 127).to(torch.int8) return int8_tensor, t_min, t_max def convert_to_int8(state_dict: dict) -> dict: """ Convert state dict to INT8 quantized format. LayerNorm and small tensors are kept in float32. Quantization parameters (_min, _max) are stored alongside. """ quantized = {} stats = {'float32': 0, 'int8': 0} for key, tensor in state_dict.items(): if not isinstance(tensor, torch.Tensor): continue # Skip LayerNorm layers - keep as float32 if is_layernorm_key(key): quantized[key] = tensor.float() stats['float32'] += tensor.numel() # Only quantize weight tensors with enough elements elif tensor.numel() >= 16 and tensor.dtype in [torch.float32, torch.float16, torch.bfloat16]: int8_tensor, t_min, t_max = quantize_tensor_int8(tensor) quantized[key] = int8_tensor quantized[f"{key}._min"] = torch.tensor([t_min], dtype=torch.float32) quantized[f"{key}._max"] = torch.tensor([t_max], dtype=torch.float32) stats['int8'] += tensor.numel() else: # Keep small tensors as float32 quantized[key] = tensor.float() stats['float32'] += tensor.numel() return quantized, stats # ═══════════════════════════════════════════════════════════════════════════════ # MAIN CONVERSION # ═══════════════════════════════════════════════════════════════════════════════ def convert_weights(input_dir: Path, output_dir: Path, output_format: str): """ Convert weights to specified format. Args: input_dir: Directory containing input weights output_dir: Directory to write output weights output_format: 'bf16' or 'int8' """ print(f""" ╔══════════════════════════════════════════════════════════════════════════════╗ ║ 🔧 Paris MoE Weight Conversion 🔧 ║ ╠══════════════════════════════════════════════════════════════════════════════╣ ║ Input: {str(input_dir):<60} ║ ║ Output: {str(output_dir):<60} ║ ║ Format: {output_format.upper():<60} ║ ╚══════════════════════════════════════════════════════════════════════════════╝ """) # Detect input format input_format, files = detect_input_format(input_dir) print(f"📂 Detected input format: {input_format}") print(f"📁 Found {len(files)} weight files") # Create output directory output_dir.mkdir(parents=True, exist_ok=True) # Track sizes sizes = {'input': 0, 'output': 0} expert_config = None router_config = None # Process experts print("\n🧠 Converting experts...") for i in range(8): key = f"expert_{i}" if key not in files: print(f" ⚠️ {key} not found, skipping") continue filepath = files[key] sizes['input'] += filepath.stat().st_size # Load weights if input_format == 'pt': state_dict, config = load_pt_expert(filepath, i) if config is not None and expert_config is None: expert_config = config else: state_dict = load_safetensors_weights(filepath) # Convert if output_format == 'bf16': converted = convert_to_bf16(state_dict) else: converted, stats = convert_to_int8(state_dict) print(f" INT8: {stats['int8']:,} params, Float32: {stats['float32']:,} params") # Save output_path = output_dir / f"expert_{i}.safetensors" save_file(converted, str(output_path)) sizes['output'] += output_path.stat().st_size print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)") # Clean up del state_dict, converted gc.collect() # Process router if 'router' in files: print("\n📡 Converting router...") filepath = files['router'] sizes['input'] += filepath.stat().st_size if input_format == 'pt': state_dict, config = load_pt_router(filepath) if config is not None: router_config = config else: state_dict = load_safetensors_weights(filepath) # Router always kept in bf16/float32 for stability converted = convert_to_bf16(state_dict) output_path = output_dir / "router.safetensors" save_file(converted, str(output_path)) sizes['output'] += output_path.stat().st_size print(f" ✅ Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)") del state_dict, converted gc.collect() # Save configs if from .pt files if expert_config is not None: config_path = output_dir / "config.pt" torch.save({'config': expert_config}, config_path) print(f" ✅ Saved: config.pt") if router_config is not None: config_path = output_dir / "router_config.pt" torch.save({'config': router_config}, config_path) print(f" ✅ Saved: router_config.pt") # Summary compression = sizes['input'] / sizes['output'] if sizes['output'] > 0 else 1 print(f""" ╔══════════════════════════════════════════════════════════════════════════════╗ ║ 📊 Conversion Summary 📊 ║ ╠══════════════════════════════════════════════════════════════════════════════╣ ║ Input size: {sizes['input']/1e9:>8.2f} GB ║ ║ Output size: {sizes['output']/1e9:>8.2f} GB ║ ║ Compression: {compression:>8.1f}x ║ ╠══════════════════════════════════════════════════════════════════════════════╣ ║ ✅ Conversion complete! ║ ╚══════════════════════════════════════════════════════════════════════════════╝ """) # List output files print("📁 Output files:") for f in sorted(output_dir.glob("*")): print(f" {f.name}: {f.stat().st_size/1e6:.1f} MB") # ═══════════════════════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════════════════════ def parse_args(): parser = argparse.ArgumentParser( description="🔧 Paris MoE - Weight Quantization Utility", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Convert original .pt files to BF16 python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16 # Convert to INT8 from .pt files python quantize.py --input /path/to/weights --output ./weights/int8 --format int8 # Convert from BF16 safetensors to INT8 python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8 """ ) parser.add_argument("--input", "-i", type=str, required=True, help="Input directory containing weight files") parser.add_argument("--output", "-o", type=str, required=True, help="Output directory for converted weights") parser.add_argument("--format", "-f", type=str, required=True, choices=["bf16", "int8"], help="Output format: bf16 or int8") return parser.parse_args() def main(): args = parse_args() input_dir = Path(args.input) output_dir = Path(args.output) if not input_dir.exists(): print(f"❌ Error: Input directory does not exist: {input_dir}") return 1 convert_weights(input_dir, output_dir, args.format) return 0 if __name__ == "__main__": exit(main())