|
|
|
|
|
""" |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β β |
|
|
β π§ Paris MoE - Weight Quantization Utility π§ β |
|
|
β β |
|
|
β Converts weights between formats: β |
|
|
β β’ Input: .pt (PyTorch) or .safetensors (F32 or BF16) β |
|
|
β β’ Output: BF16 or INT8 safetensors β |
|
|
β β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
|
|
|
Usage: |
|
|
# Convert original .pt files to BF16 safetensors |
|
|
python quantize.py --input /path/to/weights/ --output ./weights/bf16 --format bf16 |
|
|
|
|
|
# Convert to INT8 safetensors |
|
|
python quantize.py --input /path/to/weights/ --output ./weights/int8 --format int8 |
|
|
|
|
|
# Convert from existing safetensors (bf16 -> int8) |
|
|
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8 |
|
|
|
|
|
Input Formats Supported: |
|
|
- PyTorch .pt files (original training checkpoints) |
|
|
- SafeTensors .safetensors files (F32 or BF16) |
|
|
|
|
|
Output Formats: |
|
|
- bf16: BFloat16 safetensors (best quality, ~1.2GB per expert) |
|
|
- int8: INT8 quantized safetensors (~580MB per expert) |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import gc |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Tuple |
|
|
import json |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import save_file, load_file |
|
|
from safetensors import safe_open |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_input_format(input_dir: Path) -> Tuple[str, Dict[str, Path]]: |
|
|
""" |
|
|
Detect input format and locate weight files. |
|
|
|
|
|
Returns: |
|
|
format: 'pt' or 'safetensors' |
|
|
files: Dict mapping 'expert_0'..'expert_7', 'router' to file paths |
|
|
""" |
|
|
files = {} |
|
|
|
|
|
|
|
|
pt_patterns = [ |
|
|
|
|
|
("dit_xl2_multi_expert_pretrained_text_new_dataset_expert_{}_best.pt", "expert_{}"), |
|
|
("laion_router_preclustered_dit_berthead_b2_improved_router_best.pt", "router"), |
|
|
|
|
|
("expert_{}_best.pt", "expert_{}"), |
|
|
("expert_{}.pt", "expert_{}"), |
|
|
("router_best.pt", "router"), |
|
|
("router.pt", "router"), |
|
|
] |
|
|
|
|
|
|
|
|
st_patterns = [ |
|
|
("expert_{}.safetensors", "expert_{}"), |
|
|
("router.safetensors", "router"), |
|
|
] |
|
|
|
|
|
|
|
|
for pattern, key_pattern in pt_patterns: |
|
|
if "{}" in pattern: |
|
|
|
|
|
for i in range(8): |
|
|
filename = pattern.format(i) |
|
|
filepath = input_dir / filename |
|
|
if filepath.exists(): |
|
|
key = key_pattern.format(i) |
|
|
files[key] = filepath |
|
|
else: |
|
|
|
|
|
filepath = input_dir / pattern |
|
|
if filepath.exists(): |
|
|
files[key_pattern] = filepath |
|
|
|
|
|
if len(files) >= 8: |
|
|
return 'pt', files |
|
|
|
|
|
|
|
|
files = {} |
|
|
for pattern, key_pattern in st_patterns: |
|
|
if "{}" in pattern: |
|
|
for i in range(8): |
|
|
filename = pattern.format(i) |
|
|
filepath = input_dir / filename |
|
|
if filepath.exists(): |
|
|
key = key_pattern.format(i) |
|
|
files[key] = filepath |
|
|
else: |
|
|
filepath = input_dir / pattern |
|
|
if filepath.exists(): |
|
|
files[key_pattern] = filepath |
|
|
|
|
|
if len(files) >= 8: |
|
|
return 'safetensors', files |
|
|
|
|
|
|
|
|
print(f"Found files in {input_dir}:") |
|
|
for f in sorted(input_dir.glob("*")): |
|
|
print(f" {f.name}") |
|
|
|
|
|
raise ValueError(f"Could not find weight files in {input_dir}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pt_expert(filepath: Path, expert_id: int) -> Tuple[dict, Optional[object]]: |
|
|
""" |
|
|
Load expert weights from PyTorch checkpoint. |
|
|
|
|
|
Returns: |
|
|
state_dict: Model weights |
|
|
config: Config object if available |
|
|
""" |
|
|
print(f" Loading {filepath.name}...") |
|
|
ckpt = torch.load(filepath, map_location='cpu', weights_only=False) |
|
|
|
|
|
|
|
|
ema_key = f'expert_{expert_id}_ema_state_dict' |
|
|
regular_key = f'expert_{expert_id}_state_dict' |
|
|
|
|
|
if ema_key in ckpt: |
|
|
state_dict = ckpt[ema_key] |
|
|
print(f" Using EMA weights") |
|
|
elif regular_key in ckpt: |
|
|
state_dict = ckpt[regular_key] |
|
|
print(f" Using regular weights (no EMA)") |
|
|
else: |
|
|
|
|
|
for k in ckpt.keys(): |
|
|
if 'state_dict' in k and 'optimizer' not in k: |
|
|
state_dict = ckpt[k] |
|
|
print(f" Using key: {k}") |
|
|
break |
|
|
else: |
|
|
raise KeyError(f"No state dict found in {filepath}") |
|
|
|
|
|
config = ckpt.get('config', None) |
|
|
return state_dict, config |
|
|
|
|
|
|
|
|
def load_pt_router(filepath: Path) -> Tuple[dict, Optional[object]]: |
|
|
"""Load router weights from PyTorch checkpoint.""" |
|
|
print(f" Loading {filepath.name}...") |
|
|
ckpt = torch.load(filepath, map_location='cpu', weights_only=False) |
|
|
|
|
|
if 'router_state_dict' in ckpt: |
|
|
state_dict = ckpt['router_state_dict'] |
|
|
else: |
|
|
raise KeyError(f"router_state_dict not found in {filepath}") |
|
|
|
|
|
config = ckpt.get('config', None) |
|
|
return state_dict, config |
|
|
|
|
|
|
|
|
def load_safetensors_weights(filepath: Path) -> dict: |
|
|
"""Load weights from SafeTensors file.""" |
|
|
print(f" Loading {filepath.name}...") |
|
|
return load_file(str(filepath)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_to_bf16(state_dict: dict) -> dict: |
|
|
"""Convert all floating point tensors to bfloat16.""" |
|
|
bf16_state = {} |
|
|
for k, v in state_dict.items(): |
|
|
if isinstance(v, torch.Tensor) and v.is_floating_point(): |
|
|
bf16_state[k] = v.to(torch.bfloat16) |
|
|
else: |
|
|
bf16_state[k] = v |
|
|
return bf16_state |
|
|
|
|
|
|
|
|
def is_layernorm_key(key: str) -> bool: |
|
|
"""Check if a key belongs to a LayerNorm layer.""" |
|
|
ln_patterns = ['norm', 'layernorm', 'layer_norm', 'ln_', 'scale_shift_table'] |
|
|
key_lower = key.lower() |
|
|
return any(p in key_lower for p in ln_patterns) |
|
|
|
|
|
|
|
|
def quantize_tensor_int8(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, float]: |
|
|
""" |
|
|
Quantize a tensor to INT8 with min/max scaling. |
|
|
|
|
|
Formula: int8 = round((x - min) / (max - min) * 255) - 128 |
|
|
""" |
|
|
if tensor.numel() == 0: |
|
|
return tensor.to(torch.int8), 0.0, 0.0 |
|
|
|
|
|
t_float = tensor.float() |
|
|
t_min = t_float.min().item() |
|
|
t_max = t_float.max().item() |
|
|
|
|
|
if t_min == t_max: |
|
|
return torch.zeros_like(tensor, dtype=torch.int8), t_min, t_max |
|
|
|
|
|
|
|
|
normalized = (t_float - t_min) / (t_max - t_min) |
|
|
int8_tensor = (normalized * 255 - 128).round().clamp(-128, 127).to(torch.int8) |
|
|
|
|
|
return int8_tensor, t_min, t_max |
|
|
|
|
|
|
|
|
def convert_to_int8(state_dict: dict) -> dict: |
|
|
""" |
|
|
Convert state dict to INT8 quantized format. |
|
|
|
|
|
LayerNorm and small tensors are kept in float32. |
|
|
Quantization parameters (_min, _max) are stored alongside. |
|
|
""" |
|
|
quantized = {} |
|
|
stats = {'float32': 0, 'int8': 0} |
|
|
|
|
|
for key, tensor in state_dict.items(): |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
continue |
|
|
|
|
|
|
|
|
if is_layernorm_key(key): |
|
|
quantized[key] = tensor.float() |
|
|
stats['float32'] += tensor.numel() |
|
|
|
|
|
elif tensor.numel() >= 16 and tensor.dtype in [torch.float32, torch.float16, torch.bfloat16]: |
|
|
int8_tensor, t_min, t_max = quantize_tensor_int8(tensor) |
|
|
quantized[key] = int8_tensor |
|
|
quantized[f"{key}._min"] = torch.tensor([t_min], dtype=torch.float32) |
|
|
quantized[f"{key}._max"] = torch.tensor([t_max], dtype=torch.float32) |
|
|
stats['int8'] += tensor.numel() |
|
|
else: |
|
|
|
|
|
quantized[key] = tensor.float() |
|
|
stats['float32'] += tensor.numel() |
|
|
|
|
|
return quantized, stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_weights(input_dir: Path, output_dir: Path, output_format: str): |
|
|
""" |
|
|
Convert weights to specified format. |
|
|
|
|
|
Args: |
|
|
input_dir: Directory containing input weights |
|
|
output_dir: Directory to write output weights |
|
|
output_format: 'bf16' or 'int8' |
|
|
""" |
|
|
print(f""" |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β π§ Paris MoE Weight Conversion π§ β |
|
|
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£ |
|
|
β Input: {str(input_dir):<60} β |
|
|
β Output: {str(output_dir):<60} β |
|
|
β Format: {output_format.upper():<60} β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
""") |
|
|
|
|
|
|
|
|
input_format, files = detect_input_format(input_dir) |
|
|
print(f"π Detected input format: {input_format}") |
|
|
print(f"π Found {len(files)} weight files") |
|
|
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
sizes = {'input': 0, 'output': 0} |
|
|
expert_config = None |
|
|
router_config = None |
|
|
|
|
|
|
|
|
print("\nπ§ Converting experts...") |
|
|
for i in range(8): |
|
|
key = f"expert_{i}" |
|
|
if key not in files: |
|
|
print(f" β οΈ {key} not found, skipping") |
|
|
continue |
|
|
|
|
|
filepath = files[key] |
|
|
sizes['input'] += filepath.stat().st_size |
|
|
|
|
|
|
|
|
if input_format == 'pt': |
|
|
state_dict, config = load_pt_expert(filepath, i) |
|
|
if config is not None and expert_config is None: |
|
|
expert_config = config |
|
|
else: |
|
|
state_dict = load_safetensors_weights(filepath) |
|
|
|
|
|
|
|
|
if output_format == 'bf16': |
|
|
converted = convert_to_bf16(state_dict) |
|
|
else: |
|
|
converted, stats = convert_to_int8(state_dict) |
|
|
print(f" INT8: {stats['int8']:,} params, Float32: {stats['float32']:,} params") |
|
|
|
|
|
|
|
|
output_path = output_dir / f"expert_{i}.safetensors" |
|
|
save_file(converted, str(output_path)) |
|
|
sizes['output'] += output_path.stat().st_size |
|
|
|
|
|
print(f" β
Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)") |
|
|
|
|
|
|
|
|
del state_dict, converted |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if 'router' in files: |
|
|
print("\nπ‘ Converting router...") |
|
|
filepath = files['router'] |
|
|
sizes['input'] += filepath.stat().st_size |
|
|
|
|
|
if input_format == 'pt': |
|
|
state_dict, config = load_pt_router(filepath) |
|
|
if config is not None: |
|
|
router_config = config |
|
|
else: |
|
|
state_dict = load_safetensors_weights(filepath) |
|
|
|
|
|
|
|
|
converted = convert_to_bf16(state_dict) |
|
|
|
|
|
output_path = output_dir / "router.safetensors" |
|
|
save_file(converted, str(output_path)) |
|
|
sizes['output'] += output_path.stat().st_size |
|
|
|
|
|
print(f" β
Saved: {output_path.name} ({output_path.stat().st_size / 1e6:.1f} MB)") |
|
|
|
|
|
del state_dict, converted |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if expert_config is not None: |
|
|
config_path = output_dir / "config.pt" |
|
|
torch.save({'config': expert_config}, config_path) |
|
|
print(f" β
Saved: config.pt") |
|
|
|
|
|
if router_config is not None: |
|
|
config_path = output_dir / "router_config.pt" |
|
|
torch.save({'config': router_config}, config_path) |
|
|
print(f" β
Saved: router_config.pt") |
|
|
|
|
|
|
|
|
compression = sizes['input'] / sizes['output'] if sizes['output'] > 0 else 1 |
|
|
print(f""" |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β π Conversion Summary π β |
|
|
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£ |
|
|
β Input size: {sizes['input']/1e9:>8.2f} GB β |
|
|
β Output size: {sizes['output']/1e9:>8.2f} GB β |
|
|
β Compression: {compression:>8.1f}x β |
|
|
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£ |
|
|
β β
Conversion complete! β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
""") |
|
|
|
|
|
|
|
|
print("π Output files:") |
|
|
for f in sorted(output_dir.glob("*")): |
|
|
print(f" {f.name}: {f.stat().st_size/1e6:.1f} MB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="π§ Paris MoE - Weight Quantization Utility", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Convert original .pt files to BF16 |
|
|
python quantize.py --input /path/to/weights --output ./weights/bf16 --format bf16 |
|
|
|
|
|
# Convert to INT8 from .pt files |
|
|
python quantize.py --input /path/to/weights --output ./weights/int8 --format int8 |
|
|
|
|
|
# Convert from BF16 safetensors to INT8 |
|
|
python quantize.py --input ./weights/bf16 --output ./weights/int8 --format int8 |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument("--input", "-i", type=str, required=True, |
|
|
help="Input directory containing weight files") |
|
|
parser.add_argument("--output", "-o", type=str, required=True, |
|
|
help="Output directory for converted weights") |
|
|
parser.add_argument("--format", "-f", type=str, required=True, |
|
|
choices=["bf16", "int8"], |
|
|
help="Output format: bf16 or int8") |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
input_dir = Path(args.input) |
|
|
output_dir = Path(args.output) |
|
|
|
|
|
if not input_dir.exists(): |
|
|
print(f"β Error: Input directory does not exist: {input_dir}") |
|
|
return 1 |
|
|
|
|
|
convert_weights(input_dir, output_dir, args.format) |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|