baguette / quantize.py
nbagel's picture
Initial upload: Paris MoE inference code and weights
4dec1ca verified
#!/usr/bin/env python3
"""
╔══════════════════════════════════════════════════════════════════════════════╗
β•‘ β•‘
β•‘ πŸ”§ Paris MoE - Weight Quantization Utility πŸ”§ β•‘
β•‘ β•‘
β•‘ Converts weights between formats: β•‘
β•‘ β€’ Input: .pt (PyTorch) or .safetensors (F32 or BF16) β•‘
β•‘ β€’ Output: BF16 or INT8 safetensors β•‘
β•‘ β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
Usage:
# Convert original .pt files to BF16 safetensors
python quantize.py --input /path/to/weights/ --output ./weights/bf16 --format bf16
# Convert to INT8 safetensors
python quantize.py --input /path/to/weights/ --output ./weights/int8 --format int8
# Convert from existing safetensors (bf16 -> int8)
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
Input Formats Supported:
- PyTorch .pt files (original training checkpoints)
- SafeTensors .safetensors files (F32 or BF16)
Output Formats:
- bf16: BFloat16 safetensors (best quality, ~1.2GB per expert)
- int8: INT8 quantized safetensors (~580MB per expert)
"""
import argparse
import os
import gc
from pathlib import Path
from typing import Dict, Optional, Tuple
import json
import torch
from safetensors.torch import save_file, load_file
from safetensors import safe_open
from tqdm import tqdm
# ═══════════════════════════════════════════════════════════════════════════════
# FILE DETECTION
# ═══════════════════════════════════════════════════════════════════════════════
def detect_input_format(input_dir: Path) -> Tuple[str, Dict[str, Path]]:
"""
Detect input format and locate weight files.
Returns:
format: 'pt' or 'safetensors'
files: Dict mapping 'expert_0'..'expert_7', 'router' to file paths
"""
files = {}
# Check for PyTorch .pt files (original format)
pt_patterns = [
# Pattern 1: Full training checkpoint names
("dit_xl2_multi_expert_pretrained_text_new_dataset_expert_{}_best.pt", "expert_{}"),
("laion_router_preclustered_dit_berthead_b2_improved_router_best.pt", "router"),
# Pattern 2: Simple names
("expert_{}_best.pt", "expert_{}"),
("expert_{}.pt", "expert_{}"),
("router_best.pt", "router"),
("router.pt", "router"),
]
# Check for SafeTensors files
st_patterns = [
("expert_{}.safetensors", "expert_{}"),
("router.safetensors", "router"),
]
# Try PyTorch patterns first
for pattern, key_pattern in pt_patterns:
if "{}" in pattern:
# Expert pattern
for i in range(8):
filename = pattern.format(i)
filepath = input_dir / filename
if filepath.exists():
key = key_pattern.format(i)
files[key] = filepath
else:
# Router pattern
filepath = input_dir / pattern
if filepath.exists():
files[key_pattern] = filepath
if len(files) >= 8: # At least 8 experts found
return 'pt', files
# Try SafeTensors patterns
files = {}
for pattern, key_pattern in st_patterns:
if "{}" in pattern:
for i in range(8):
filename = pattern.format(i)
filepath = input_dir / filename
if filepath.exists():
key = key_pattern.format(i)
files[key] = filepath
else:
filepath = input_dir / pattern
if filepath.exists():
files[key_pattern] = filepath
if len(files) >= 8:
return 'safetensors', files
# List what we found
print(f"Found files in {input_dir}:")
for f in sorted(input_dir.glob("*")):
print(f" {f.name}")
raise ValueError(f"Could not find weight files in {input_dir}")
# ═══════════════════════════════════════════════════════════════════════════════
# LOADING UTILITIES
# ═══════════════════════════════════════════════════════════════════════════════
def load_pt_expert(filepath: Path, expert_id: int) -> Tuple[dict, Optional[object]]:
"""
Load expert weights from PyTorch checkpoint.
Returns:
state_dict: Model weights
config: Config object if available
"""
print(f" Loading {filepath.name}...")
ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
# Try EMA weights first (preferred for inference)
ema_key = f'expert_{expert_id}_ema_state_dict'
regular_key = f'expert_{expert_id}_state_dict'
if ema_key in ckpt:
state_dict = ckpt[ema_key]
print(f" Using EMA weights")
elif regular_key in ckpt:
state_dict = ckpt[regular_key]
print(f" Using regular weights (no EMA)")
else:
# Try to find any state dict key
for k in ckpt.keys():
if 'state_dict' in k and 'optimizer' not in k:
state_dict = ckpt[k]
print(f" Using key: {k}")
break
else:
raise KeyError(f"No state dict found in {filepath}")
config = ckpt.get('config', None)
return state_dict, config
def load_pt_router(filepath: Path) -> Tuple[dict, Optional[object]]:
"""Load router weights from PyTorch checkpoint."""
print(f" Loading {filepath.name}...")
ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
if 'router_state_dict' in ckpt:
state_dict = ckpt['router_state_dict']
else:
raise KeyError(f"router_state_dict not found in {filepath}")
config = ckpt.get('config', None)
return state_dict, config
def load_safetensors_weights(filepath: Path) -> dict:
"""Load weights from SafeTensors file."""
print(f" Loading {filepath.name}...")
return load_file(str(filepath))
# ═══════════════════════════════════════════════════════════════════════════════
# QUANTIZATION
# ═══════════════════════════════════════════════════════════════════════════════
def convert_to_bf16(state_dict: dict) -> dict:
"""Convert all floating point tensors to bfloat16."""
bf16_state = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor) and v.is_floating_point():
bf16_state[k] = v.to(torch.bfloat16)
else:
bf16_state[k] = v
return bf16_state
def is_layernorm_key(key: str) -> bool:
"""Check if a key belongs to a LayerNorm layer."""
ln_patterns = ['norm', 'layernorm', 'layer_norm', 'ln_', 'scale_shift_table']
key_lower = key.lower()
return any(p in key_lower for p in ln_patterns)
def quantize_tensor_int8(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
"""
Quantize a tensor to INT8 with min/max scaling.
Formula: int8 = round((x - min) / (max - min) * 255) - 128
"""
if tensor.numel() == 0:
return tensor.to(torch.int8), 0.0, 0.0
t_float = tensor.float()
t_min = t_float.min().item()
t_max = t_float.max().item()
if t_min == t_max:
return torch.zeros_like(tensor, dtype=torch.int8), t_min, t_max
# Quantize: map [min, max] to [-128, 127]
normalized = (t_float - t_min) / (t_max - t_min)
int8_tensor = (normalized * 255 - 128).round().clamp(-128, 127).to(torch.int8)
return int8_tensor, t_min, t_max
def convert_to_int8(state_dict: dict) -> dict:
"""
Convert state dict to INT8 quantized format.
LayerNorm and small tensors are kept in float32.
Quantization parameters (_min, _max) are stored alongside.
"""
quantized = {}
stats = {'float32': 0, 'int8': 0}
for key, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
continue
# Skip LayerNorm layers - keep as float32
if is_layernorm_key(key):
quantized[key] = tensor.float()
stats['float32'] += tensor.numel()
# Only quantize weight tensors with enough elements
elif tensor.numel() >= 16 and tensor.dtype in [torch.float32, torch.float16, torch.bfloat16]:
int8_tensor, t_min, t_max = quantize_tensor_int8(tensor)
quantized[key] = int8_tensor
quantized[f"{key}._min"] = torch.tensor([t_min], dtype=torch.float32)
quantized[f"{key}._max"] = torch.tensor([t_max], dtype=torch.float32)
stats['int8'] += tensor.numel()
else:
# Keep small tensors as float32
quantized[key] = tensor.float()
stats['float32'] += tensor.numel()
return quantized, stats
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN CONVERSION
# ═══════════════════════════════════════════════════════════════════════════════
def convert_weights(input_dir: Path, output_dir: Path, output_format: str):
"""
Convert weights to specified format.
Args:
input_dir: Directory containing input weights
output_dir: Directory to write output weights
output_format: 'bf16' or 'int8'
"""
print(f"""
╔══════════════════════════════════════════════════════════════════════════════╗
β•‘ πŸ”§ Paris MoE Weight Conversion πŸ”§ β•‘
╠══════════════════════════════════════════════════════════════════════════════╣
β•‘ Input: {str(input_dir):<60} β•‘
β•‘ Output: {str(output_dir):<60} β•‘
β•‘ Format: {output_format.upper():<60} β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
# Detect input format
input_format, files = detect_input_format(input_dir)
print(f"πŸ“‚ Detected input format: {input_format}")
print(f"πŸ“ Found {len(files)} weight files")
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Track sizes
sizes = {'input': 0, 'output': 0}
expert_config = None
router_config = None
# Process experts
print("\n🧠 Converting experts...")
for i in range(8):
key = f"expert_{i}"
if key not in files:
print(f" ⚠️ {key} not found, skipping")
continue
filepath = files[key]
sizes['input'] += filepath.stat().st_size
# Load weights
if input_format == 'pt':
state_dict, config = load_pt_expert(filepath, i)
if config is not None and expert_config is None:
expert_config = config
else:
state_dict = load_safetensors_weights(filepath)
# Convert
if output_format == 'bf16':
converted = convert_to_bf16(state_dict)
else:
converted, stats = convert_to_int8(state_dict)
print(f" INT8: {stats['int8']:,} params, Float32: {stats['float32']:,} params")
# Save
output_path = output_dir / f"expert_{i}.safetensors"
save_file(converted, str(output_path))
sizes['output'] += output_path.stat().st_size
print(f" βœ… Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
# Clean up
del state_dict, converted
gc.collect()
# Process router
if 'router' in files:
print("\nπŸ“‘ Converting router...")
filepath = files['router']
sizes['input'] += filepath.stat().st_size
if input_format == 'pt':
state_dict, config = load_pt_router(filepath)
if config is not None:
router_config = config
else:
state_dict = load_safetensors_weights(filepath)
# Router always kept in bf16/float32 for stability
converted = convert_to_bf16(state_dict)
output_path = output_dir / "router.safetensors"
save_file(converted, str(output_path))
sizes['output'] += output_path.stat().st_size
print(f" βœ… Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)")
del state_dict, converted
gc.collect()
# Save configs if from .pt files
if expert_config is not None:
config_path = output_dir / "config.pt"
torch.save({'config': expert_config}, config_path)
print(f" βœ… Saved: config.pt")
if router_config is not None:
config_path = output_dir / "router_config.pt"
torch.save({'config': router_config}, config_path)
print(f" βœ… Saved: router_config.pt")
# Summary
compression = sizes['input'] / sizes['output'] if sizes['output'] > 0 else 1
print(f"""
╔══════════════════════════════════════════════════════════════════════════════╗
β•‘ πŸ“Š Conversion Summary πŸ“Š β•‘
╠══════════════════════════════════════════════════════════════════════════════╣
β•‘ Input size: {sizes['input']/1e9:>8.2f} GB β•‘
β•‘ Output size: {sizes['output']/1e9:>8.2f} GB β•‘
β•‘ Compression: {compression:>8.1f}x β•‘
╠══════════════════════════════════════════════════════════════════════════════╣
β•‘ βœ… Conversion complete! β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
""")
# List output files
print("πŸ“ Output files:")
for f in sorted(output_dir.glob("*")):
print(f" {f.name}: {f.stat().st_size/1e6:.1f} MB")
# ═══════════════════════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════════════════════
def parse_args():
parser = argparse.ArgumentParser(
description="πŸ”§ Paris MoE - Weight Quantization Utility",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Convert original .pt files to BF16
python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16
# Convert to INT8 from .pt files
python quantize.py --input /path/to/weights --output ./weights/int8 --format int8
# Convert from BF16 safetensors to INT8
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8
"""
)
parser.add_argument("--input", "-i", type=str, required=True,
help="Input directory containing weight files")
parser.add_argument("--output", "-o", type=str, required=True,
help="Output directory for converted weights")
parser.add_argument("--format", "-f", type=str, required=True,
choices=["bf16", "int8"],
help="Output format: bf16 or int8")
return parser.parse_args()
def main():
args = parse_args()
input_dir = Path(args.input)
output_dir = Path(args.output)
if not input_dir.exists():
print(f"❌ Error: Input directory does not exist: {input_dir}")
return 1
convert_weights(input_dir, output_dir, args.format)
return 0
if __name__ == "__main__":
exit(main())