#!/usr/bin/env python3 """ Grok-2 FP8 Dequantization Script Converts FP8 quantized weights back to BF16 for inference. Usage: # Just dequantize and save (for later use) python dequantize.py --input /path/to/fp8/model --output /path/to/bf16/model # Verify dequantization quality (requires original model) python dequantize.py --input /path/to/fp8/model --verify /path/to/original/model The FP8 format reduces storage from ~539GB to ~272GB. Dequantization restores full BF16 precision for inference. """ import argparse import json import shutil import torch from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file from tqdm import tqdm from typing import Dict, Optional from collections import defaultdict def load_fp8_weights(model_path: Path) -> tuple[dict, dict]: """ Load FP8 weights and their scales from safetensor files. Returns: Tuple of (weights_dict, scales_dict) """ weights = {} scales = {} shard_files = sorted(model_path.glob("*.safetensors")) print(f"Found {len(shard_files)} shard files") for shard_file in tqdm(shard_files, desc="Loading shards"): with safe_open(str(shard_file), framework="pt") as f: for key in f.keys(): tensor = f.get_tensor(key) if key.endswith('.scale'): base_key = key[:-6] # Remove '.scale' scales[base_key] = tensor else: weights[key] = tensor return weights, scales def dequantize_weights(weights: dict, scales: dict, output_dtype: torch.dtype = torch.bfloat16) -> dict: """ Dequantize FP8 weights using their scales. Formula: bf16_weight = fp8_weight / scale Scale is per output channel (dimension 0). """ dequantized = {} fp8_count = 0 preserved_count = 0 for key, tensor in tqdm(weights.items(), desc="Dequantizing"): if key in scales: # FP8 quantized tensor scale = scales[key] # Dequantize: divide by scale (per output channel) # Weight shape: [out_features, in_features] # Scale shape: [out_features] dequant = tensor.to(torch.float32) / scale.unsqueeze(-1) dequantized[key] = dequant.to(output_dtype) fp8_count += 1 else: # Not quantized - preserve as-is if tensor.is_floating_point(): dequantized[key] = tensor.to(output_dtype) else: dequantized[key] = tensor preserved_count += 1 print(f"Dequantized {fp8_count} FP8 tensors, preserved {preserved_count} tensors") return dequantized def verify_dequantization(dequantized: dict, original_path: Path, sample_keys: int = 5) -> dict: """ Verify dequantization quality against original BF16 weights. Returns dict with quality metrics. """ print(f"\nVerifying against original: {original_path}") # Load some original weights for comparison orig_files = sorted(original_path.glob("*.safetensors")) metrics = { 'cosine_similarities': [], 'mean_abs_errors': [], 'max_abs_errors': [], 'relative_errors': [], } checked = 0 for orig_file in orig_files: if checked >= sample_keys: break with safe_open(str(orig_file), framework="pt") as f: for key in f.keys(): if key in dequantized and checked < sample_keys: orig = f.get_tensor(key).to(torch.float32) dequant = dequantized[key].to(torch.float32) if orig.shape != dequant.shape: print(f" Shape mismatch for {key}: {orig.shape} vs {dequant.shape}") continue # Compute metrics diff = (orig - dequant).abs() mae = diff.mean().item() max_err = diff.max().item() rel_err = (diff / (orig.abs() + 1e-8)).mean().item() # Cosine similarity cos_sim = torch.nn.functional.cosine_similarity( orig.flatten().unsqueeze(0), dequant.flatten().unsqueeze(0) ).item() metrics['mean_abs_errors'].append(mae) metrics['max_abs_errors'].append(max_err) metrics['relative_errors'].append(rel_err) metrics['cosine_similarities'].append(cos_sim) print(f" {key}:") print(f" Cosine sim: {cos_sim:.6f}") print(f" MAE: {mae:.6f}, Max: {max_err:.6f}, Rel: {rel_err*100:.2f}%") checked += 1 # Summary if metrics['cosine_similarities']: print(f"\nSummary ({len(metrics['cosine_similarities'])} tensors checked):") print(f" Avg Cosine Similarity: {sum(metrics['cosine_similarities'])/len(metrics['cosine_similarities']):.6f}") print(f" Avg MAE: {sum(metrics['mean_abs_errors'])/len(metrics['mean_abs_errors']):.6f}") print(f" Avg Relative Error: {sum(metrics['relative_errors'])/len(metrics['relative_errors'])*100:.2f}%") return metrics def save_dequantized(dequantized: dict, output_path: Path, input_path: Path, max_shard_size: int = 5_000_000_000): """ Save dequantized weights to safetensors files. Also copies config files from input. """ output_path.mkdir(parents=True, exist_ok=True) # Calculate total size and plan shards total_size = sum(t.numel() * t.element_size() for t in dequantized.values()) print(f"\nTotal dequantized size: {total_size / 1e9:.2f} GB") # Save in shards current_shard = {} current_size = 0 shard_idx = 0 weight_map = {} for key, tensor in tqdm(dequantized.items(), desc="Saving"): tensor_size = tensor.numel() * tensor.element_size() if current_size + tensor_size > max_shard_size and current_shard: # Save current shard shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" save_file(current_shard, output_path / shard_name) shard_idx += 1 current_shard = {} current_size = 0 current_shard[key] = tensor weight_map[key] = f"model-{shard_idx:05d}-of-XXXXX.safetensors" current_size += tensor_size # Save last shard if current_shard: shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors" save_file(current_shard, output_path / shard_name) shard_idx += 1 # Fix shard names in index total_shards = shard_idx for key in weight_map: weight_map[key] = weight_map[key].replace("XXXXX", f"{total_shards:05d}") # Rename files for i in range(total_shards): old_name = output_path / f"model-{i:05d}-of-XXXXX.safetensors" new_name = output_path / f"model-{i:05d}-of-{total_shards:05d}.safetensors" if old_name.exists(): old_name.rename(new_name) # Save index index = { "metadata": {"total_size": total_size}, "weight_map": weight_map } # Fix weight_map filenames index["weight_map"] = {k: v.replace("XXXXX", f"{total_shards:05d}") for k, v in weight_map.items()} with open(output_path / "model.safetensors.index.json", "w") as f: json.dump(index, f, indent=2) print(f"Saved {total_shards} shards to {output_path}") # Copy config files config_files = [ "config.json", "tokenizer_config.json", "tokenizer.tok.json", "configuration_grok2.py", "modeling_grok2.py", "tokenization_grok2.py", "__init__.py", ] for cfg in config_files: src = input_path / cfg if src.exists(): shutil.copy(src, output_path / cfg) print(f"Copied {cfg}") def main(): parser = argparse.ArgumentParser( description="Dequantize Grok-2 FP8 weights to BF16", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Dequantize and save python dequantize.py --input ./Grok-2-FP8 --output ./Grok-2-BF16 # Verify quality against original python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original # Memory-efficient: process without saving (just verify) python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original --no-save """ ) parser.add_argument("--input", type=str, required=True, help="Path to FP8 quantized model") parser.add_argument("--output", type=str, help="Path to save dequantized BF16 model") parser.add_argument("--verify", type=str, help="Path to original BF16 model for quality verification") parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], help="Output dtype (default: bfloat16)") parser.add_argument("--no-save", action="store_true", help="Don't save output (useful with --verify)") args = parser.parse_args() input_path = Path(args.input) dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } output_dtype = dtype_map[args.dtype] print(f"Loading FP8 weights from: {input_path}") weights, scales = load_fp8_weights(input_path) print(f"Loaded {len(weights)} weights, {len(scales)} scales") print(f"\nDequantizing to {args.dtype}...") dequantized = dequantize_weights(weights, scales, output_dtype) if args.verify: verify_dequantization(dequantized, Path(args.verify)) if args.output and not args.no_save: output_path = Path(args.output) save_dequantized(dequantized, output_path, input_path) print(f"\nDequantized model saved to: {output_path}") elif not args.verify: print("\nNo output path specified. Use --output to save dequantized weights.") if __name__ == "__main__": main()