grok-2-FP8 / dequantize.py
rockylynnstein's picture
Upload 6 files
27e5920 verified
#!/usr/bin/env python3
"""
Grok-2 FP8 Dequantization Script
Converts FP8 quantized weights back to BF16 for inference.
Usage:
# Just dequantize and save (for later use)
python dequantize.py --input /path/to/fp8/model --output /path/to/bf16/model
# Verify dequantization quality (requires original model)
python dequantize.py --input /path/to/fp8/model --verify /path/to/original/model
The FP8 format reduces storage from ~539GB to ~272GB.
Dequantization restores full BF16 precision for inference.
"""
import argparse
import json
import shutil
import torch
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from typing import Dict, Optional
from collections import defaultdict
def load_fp8_weights(model_path: Path) -> tuple[dict, dict]:
"""
Load FP8 weights and their scales from safetensor files.
Returns:
Tuple of (weights_dict, scales_dict)
"""
weights = {}
scales = {}
shard_files = sorted(model_path.glob("*.safetensors"))
print(f"Found {len(shard_files)} shard files")
for shard_file in tqdm(shard_files, desc="Loading shards"):
with safe_open(str(shard_file), framework="pt") as f:
for key in f.keys():
tensor = f.get_tensor(key)
if key.endswith('.scale'):
base_key = key[:-6] # Remove '.scale'
scales[base_key] = tensor
else:
weights[key] = tensor
return weights, scales
def dequantize_weights(weights: dict, scales: dict,
output_dtype: torch.dtype = torch.bfloat16) -> dict:
"""
Dequantize FP8 weights using their scales.
Formula: bf16_weight = fp8_weight / scale
Scale is per output channel (dimension 0).
"""
dequantized = {}
fp8_count = 0
preserved_count = 0
for key, tensor in tqdm(weights.items(), desc="Dequantizing"):
if key in scales:
# FP8 quantized tensor
scale = scales[key]
# Dequantize: divide by scale (per output channel)
# Weight shape: [out_features, in_features]
# Scale shape: [out_features]
dequant = tensor.to(torch.float32) / scale.unsqueeze(-1)
dequantized[key] = dequant.to(output_dtype)
fp8_count += 1
else:
# Not quantized - preserve as-is
if tensor.is_floating_point():
dequantized[key] = tensor.to(output_dtype)
else:
dequantized[key] = tensor
preserved_count += 1
print(f"Dequantized {fp8_count} FP8 tensors, preserved {preserved_count} tensors")
return dequantized
def verify_dequantization(dequantized: dict, original_path: Path,
sample_keys: int = 5) -> dict:
"""
Verify dequantization quality against original BF16 weights.
Returns dict with quality metrics.
"""
print(f"\nVerifying against original: {original_path}")
# Load some original weights for comparison
orig_files = sorted(original_path.glob("*.safetensors"))
metrics = {
'cosine_similarities': [],
'mean_abs_errors': [],
'max_abs_errors': [],
'relative_errors': [],
}
checked = 0
for orig_file in orig_files:
if checked >= sample_keys:
break
with safe_open(str(orig_file), framework="pt") as f:
for key in f.keys():
if key in dequantized and checked < sample_keys:
orig = f.get_tensor(key).to(torch.float32)
dequant = dequantized[key].to(torch.float32)
if orig.shape != dequant.shape:
print(f" Shape mismatch for {key}: {orig.shape} vs {dequant.shape}")
continue
# Compute metrics
diff = (orig - dequant).abs()
mae = diff.mean().item()
max_err = diff.max().item()
rel_err = (diff / (orig.abs() + 1e-8)).mean().item()
# Cosine similarity
cos_sim = torch.nn.functional.cosine_similarity(
orig.flatten().unsqueeze(0),
dequant.flatten().unsqueeze(0)
).item()
metrics['mean_abs_errors'].append(mae)
metrics['max_abs_errors'].append(max_err)
metrics['relative_errors'].append(rel_err)
metrics['cosine_similarities'].append(cos_sim)
print(f" {key}:")
print(f" Cosine sim: {cos_sim:.6f}")
print(f" MAE: {mae:.6f}, Max: {max_err:.6f}, Rel: {rel_err*100:.2f}%")
checked += 1
# Summary
if metrics['cosine_similarities']:
print(f"\nSummary ({len(metrics['cosine_similarities'])} tensors checked):")
print(f" Avg Cosine Similarity: {sum(metrics['cosine_similarities'])/len(metrics['cosine_similarities']):.6f}")
print(f" Avg MAE: {sum(metrics['mean_abs_errors'])/len(metrics['mean_abs_errors']):.6f}")
print(f" Avg Relative Error: {sum(metrics['relative_errors'])/len(metrics['relative_errors'])*100:.2f}%")
return metrics
def save_dequantized(dequantized: dict, output_path: Path,
input_path: Path, max_shard_size: int = 5_000_000_000):
"""
Save dequantized weights to safetensors files.
Also copies config files from input.
"""
output_path.mkdir(parents=True, exist_ok=True)
# Calculate total size and plan shards
total_size = sum(t.numel() * t.element_size() for t in dequantized.values())
print(f"\nTotal dequantized size: {total_size / 1e9:.2f} GB")
# Save in shards
current_shard = {}
current_size = 0
shard_idx = 0
weight_map = {}
for key, tensor in tqdm(dequantized.items(), desc="Saving"):
tensor_size = tensor.numel() * tensor.element_size()
if current_size + tensor_size > max_shard_size and current_shard:
# Save current shard
shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors"
save_file(current_shard, output_path / shard_name)
shard_idx += 1
current_shard = {}
current_size = 0
current_shard[key] = tensor
weight_map[key] = f"model-{shard_idx:05d}-of-XXXXX.safetensors"
current_size += tensor_size
# Save last shard
if current_shard:
shard_name = f"model-{shard_idx:05d}-of-XXXXX.safetensors"
save_file(current_shard, output_path / shard_name)
shard_idx += 1
# Fix shard names in index
total_shards = shard_idx
for key in weight_map:
weight_map[key] = weight_map[key].replace("XXXXX", f"{total_shards:05d}")
# Rename files
for i in range(total_shards):
old_name = output_path / f"model-{i:05d}-of-XXXXX.safetensors"
new_name = output_path / f"model-{i:05d}-of-{total_shards:05d}.safetensors"
if old_name.exists():
old_name.rename(new_name)
# Save index
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map
}
# Fix weight_map filenames
index["weight_map"] = {k: v.replace("XXXXX", f"{total_shards:05d}")
for k, v in weight_map.items()}
with open(output_path / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
print(f"Saved {total_shards} shards to {output_path}")
# Copy config files
config_files = [
"config.json",
"tokenizer_config.json",
"tokenizer.tok.json",
"configuration_grok2.py",
"modeling_grok2.py",
"tokenization_grok2.py",
"__init__.py",
]
for cfg in config_files:
src = input_path / cfg
if src.exists():
shutil.copy(src, output_path / cfg)
print(f"Copied {cfg}")
def main():
parser = argparse.ArgumentParser(
description="Dequantize Grok-2 FP8 weights to BF16",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Dequantize and save
python dequantize.py --input ./Grok-2-FP8 --output ./Grok-2-BF16
# Verify quality against original
python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original
# Memory-efficient: process without saving (just verify)
python dequantize.py --input ./Grok-2-FP8 --verify ./grok-2-original --no-save
"""
)
parser.add_argument("--input", type=str, required=True,
help="Path to FP8 quantized model")
parser.add_argument("--output", type=str,
help="Path to save dequantized BF16 model")
parser.add_argument("--verify", type=str,
help="Path to original BF16 model for quality verification")
parser.add_argument("--dtype", type=str, default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Output dtype (default: bfloat16)")
parser.add_argument("--no-save", action="store_true",
help="Don't save output (useful with --verify)")
args = parser.parse_args()
input_path = Path(args.input)
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
output_dtype = dtype_map[args.dtype]
print(f"Loading FP8 weights from: {input_path}")
weights, scales = load_fp8_weights(input_path)
print(f"Loaded {len(weights)} weights, {len(scales)} scales")
print(f"\nDequantizing to {args.dtype}...")
dequantized = dequantize_weights(weights, scales, output_dtype)
if args.verify:
verify_dequantization(dequantized, Path(args.verify))
if args.output and not args.no_save:
output_path = Path(args.output)
save_dequantized(dequantized, output_path, input_path)
print(f"\nDequantized model saved to: {output_path}")
elif not args.verify:
print("\nNo output path specified. Use --output to save dequantized weights.")
if __name__ == "__main__":
main()