| | |
| | """ |
| | Grok-2 FP8 Dequantization Script |
| | |
| | Converts FP8 quantized weights back to BF16 for inference. |
| | |
| | Usage: |
| | # Just dequantize and save (for later use) |
| | python dequantize.py --input /path/to/fp8/model --output /path/to/bf16/model |
| | |
| | # Verify dequantization quality (requires original model) |
| | python dequantize.py --input /path/to/fp8/model --verify /path/to/original/model |
| | |
| | The FP8 format reduces storage from ~539GB to ~272GB. |
| | Dequantization restores full BF16 precision for inference. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import shutil |
| | import torch |
| | from pathlib import Path |
| | from safetensors import safe_open |
| | from safetensors.torch import save_file |
| | from tqdm import tqdm |
| | from typing import Dict, Optional |
| | from collections import defaultdict |
| |
|
| |
|
| | def load_fp8_weights(model_path: Path) -> tuple[dict, dict]: |
| | """ |
| | Load FP8 weights and their scales from safetensor files. |
| | |
| | Returns: |
| | Tuple of (weights_dict, scales_dict) |
| | """ |
| | weights = {} |
| | scales = {} |
| | |
| | shard_files = sorted(model_path.glob("*.safetensors")) |
| | print(f"Found {len(shard_files)} shard files") |
| | |
| | for shard_file in tqdm(shard_files, desc="Loading shards"): |
| | with safe_open(str(shard_file), framework="pt") as f: |
| | for key in f.keys(): |
| | tensor = f.get_tensor(key) |
| | if key.endswith('.scale'): |
| | base_key = key[:-6] |
| | scales[base_key] = tensor |
| | else: |
| | weights[key] = tensor |
| | |
| | return weights, scales |
| |
|
| |
|
| | def dequantize_weights(weights: dict, scales: dict, |
| | output_dtype: torch.dtype = torch.bfloat16) -> dict: |
| | """ |
| | Dequantize FP8 weights using their scales. |
| | |
| | Formula: bf16_weight = fp8_weight / scale |
| | Scale is per output channel (dimension 0). |
| | """ |
| | dequantized = {} |
| | |
| | fp8_count = 0 |
| | preserved_count = 0 |
| | |
| | for key, tensor in tqdm(weights.items(), desc="Dequantizing"): |
| | if key in scales: |
| | |
| | scale = scales[key] |
| | |
| | |
| | |
| | |
| | dequant = tensor.to(torch.float32) / scale.unsqueeze(-1) |
| | dequantized[key] = dequant.to(output_dtype) |
| | fp8_count += 1 |
| | else: |
| | |
| | if tensor.is_floating_point(): |
| | dequantized[key] = tensor.to(output_dtype) |
| | else: |
| | dequantized[key] = tensor |
| | preserved_count += 1 |
| | |
| | print(f"Dequantized {fp8_count} FP8 tensors, preserved {preserved_count} tensors") |
| | return dequantized |
| |
|
| |
|
| | def verify_dequantization(dequantized: dict, original_path: Path, |
| | sample_keys: int = 5) -> dict: |
| | """ |
| | Verify dequantization quality against original BF16 weights. |
| | |
| | Returns dict with quality metrics. |
| | """ |
| | print(f"\nVerifying against original: {original_path}") |
| | |
| | |
| | orig_files = sorted(original_path.glob("*.safetensors")) |
| | |
| | metrics = { |
| | 'cosine_similarities': [], |
| | 'mean_abs_errors': [], |
| | 'max_abs_errors': [], |
| | 'relative_errors': [], |
| | } |
| | |
| | checked = 0 |
| | for orig_file in orig_files: |
| | if checked >= sample_keys: |
| | break |
| | |
| | with safe_open(str(orig_file), framework="pt") as f: |
| | for key in f.keys(): |
| | if key in dequantized and checked < sample_keys: |
| | orig = f.get_tensor(key).to(torch.float32) |
| | dequant = dequantized[key].to(torch.float32) |
| | |
| | if orig.shape != dequant.shape: |
| | print(f" Shape mismatch for {key}: {orig.shape} vs {dequant.shape}") |
| | continue |
| | |
| | |
| | diff = (orig - dequant).abs() |
| | mae = diff.mean().item() |
| | max_err = diff.max().item() |
| | rel_err = (diff / (orig.abs() + 1e-8)).mean().item() |
| | |
| | |
| | cos_sim = torch.nn.functional.cosine_similarity( |
| | orig.flatten().unsqueeze(0), |
| | dequant.flatten().unsqueeze(0) |
| | ).item() |
| | |
| | metrics['mean_abs_errors'].append(mae) |
| | metrics['max_abs_errors'].append(max_err) |
| | metrics['relative_errors'].append(rel_err) |
| | metrics['cosine_similarities'].append(cos_sim) |
| | |
| | print(f" {key}:") |
| | print(f" Cosine sim: {cos_sim:.6f}") |
| | print(f" MAE: {mae:.6f}, Max: {max_err:.6f}, Rel: {rel_err*100:.2f}%") |
| | |
| | checked += 1 |
| | |
| | |
| | if metrics['cosine_similarities']: |
| | print(f"\nSummary ({len(metrics['cosine_similarities'])} tensors checked):") |
| | print(f" Avg Cosine Similarity: {sum(metrics['cosine_similarities'])/len(metrics['cosine_similarities']):.6f}") |
| | print(f" Avg MAE: {sum(metrics['mean_abs_errors'])/len(metrics['mean_abs_errors']):.6f}") |
| | print(f" Avg Relative Error: {sum(metrics['relative_errors'])/len(metrics['relative_errors'])*100:.2f}%") |
| | |
| | return metrics |
| |
|
| |
|
| | def save_dequantized(dequantized: dict, output_path: Path, |
| | input_path: Path, max_shard_size: int = 5_000_000_000): |
| | """ |
| | Save dequantized weights to safetensors files. |
| | Also copies config files from input. |
| | """ |
| | output_path.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | total_size = sum(t.numel() * t.element_size() for t in dequantized.values()) |
| | print(f"\nTotal dequantized size: {total_size / 1e9:.2f} GB") |
| | |
| | |
| | current_shard = {} |
| | current_size = 0 |
| | shard_idx = 0 |
| | weight_map = {} |
| | |
| | for key, tensor in tqdm(dequantized.items(), desc="Saving"): |
| | tensor_size = tensor.numel() * tensor.element_size() |
| | |
| | if current_size + tensor_size > max_shard_size and current_shard: |
| | |
| | shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" |
| | save_file(current_shard, output_path / shard_name) |
| | shard_idx += 1 |
| | current_shard = {} |
| | current_size = 0 |
| | |
| | current_shard[key] = tensor |
| | weight_map[key] = f"model-{shard_idx:05d}-of-XXXXX.safetensors" |
| | current_size += tensor_size |
| | |
| | |
| | if current_shard: |
| | shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" |
| | save_file(current_shard, output_path / shard_name) |
| | shard_idx += 1 |
| | |
| | |
| | total_shards = shard_idx |
| | for key in weight_map: |
| | weight_map[key] = weight_map[key].replace("XXXXX", f"{total_shards:05d}") |
| | |
| | |
| | for i in range(total_shards): |
| | old_name = output_path / f"model-{i:05d}-of-XXXXX.safetensors" |
| | new_name = output_path / f"model-{i:05d}-of-{total_shards:05d}.safetensors" |
| | if old_name.exists(): |
| | old_name.rename(new_name) |
| | |
| | |
| | index = { |
| | "metadata": {"total_size": total_size}, |
| | "weight_map": weight_map |
| | } |
| | |
| | index["weight_map"] = {k: v.replace("XXXXX", f"{total_shards:05d}") |
| | for k, v in weight_map.items()} |
| | |
| | with open(output_path / "model.safetensors.index.json", "w") as f: |
| | json.dump(index, f, indent=2) |
| | |
| | print(f"Saved {total_shards} shards to {output_path}") |
| | |
| | |
| | config_files = [ |
| | "config.json", |
| | "tokenizer_config.json", |
| | "tokenizer.tok.json", |
| | "configuration_grok2.py", |
| | "modeling_grok2.py", |
| | "tokenization_grok2.py", |
| | "__init__.py", |
| | ] |
| | |
| | for cfg in config_files: |
| | src = input_path / cfg |
| | if src.exists(): |
| | shutil.copy(src, output_path / cfg) |
| | print(f"Copied {cfg}") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Dequantize Grok-2 FP8 weights to BF16", |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | # Dequantize and save |
| | python dequantize.py --input ./Grok-2-FP8 --output ./Grok-2-BF16 |
| | |
| | # Verify quality against original |
| | python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original |
| | |
| | # Memory-efficient: process without saving (just verify) |
| | python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original --no-save |
| | """ |
| | ) |
| | |
| | parser.add_argument("--input", type=str, required=True, |
| | help="Path to FP8 quantized model") |
| | parser.add_argument("--output", type=str, |
| | help="Path to save dequantized BF16 model") |
| | parser.add_argument("--verify", type=str, |
| | help="Path to original BF16 model for quality verification") |
| | parser.add_argument("--dtype", type=str, default="bfloat16", |
| | choices=["bfloat16", "float16", "float32"], |
| | help="Output dtype (default: bfloat16)") |
| | parser.add_argument("--no-save", action="store_true", |
| | help="Don't save output (useful with --verify)") |
| | |
| | args = parser.parse_args() |
| | |
| | input_path = Path(args.input) |
| | |
| | dtype_map = { |
| | "bfloat16": torch.bfloat16, |
| | "float16": torch.float16, |
| | "float32": torch.float32, |
| | } |
| | output_dtype = dtype_map[args.dtype] |
| | |
| | print(f"Loading FP8 weights from: {input_path}") |
| | weights, scales = load_fp8_weights(input_path) |
| | print(f"Loaded {len(weights)} weights, {len(scales)} scales") |
| | |
| | print(f"\nDequantizing to {args.dtype}...") |
| | dequantized = dequantize_weights(weights, scales, output_dtype) |
| | |
| | if args.verify: |
| | verify_dequantization(dequantized, Path(args.verify)) |
| | |
| | if args.output and not args.no_save: |
| | output_path = Path(args.output) |
| | save_dequantized(dequantized, output_path, input_path) |
| | print(f"\nDequantized model saved to: {output_path}") |
| | elif not args.verify: |
| | print("\nNo output path specified. Use --output to save dequantized weights.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|