import torch from safetensors.torch import load_file, save_file def fp16_to_fp8(tensor): """Convert a tensor from FP16 to an approximate FP8 using quantization.""" # Simulate FP8 by scaling FP16 values and converting to INT8 scale = tensor.abs().max() / 127.0 # Maximum value for INT8 is 127 tensor_fp8 = (tensor / scale).clamp(-127, 127).to(torch.int8) return tensor_fp8, scale def fp8_to_fp16(tensor_fp8, scale): """Convert FP8 tensor back to FP16 using the stored scale.""" return tensor_fp8.to(torch.float16) * scale def convert_model_to_fp8(model_path, output_path): # Load the model (FP16) using safetensors tensors = load_file(model_path) # Dictionary to store the converted FP8 tensors and scales converted_tensors = {} scales = {} for key, tensor in tensors.items(): if tensor.dtype == torch.float16: # Convert FP16 to FP8 tensor_fp8, scale = fp16_to_fp8(tensor) converted_tensors[key] = tensor_fp8 scales[key] = scale # Store the scale used for this tensor print(f"Converted tensor {key} from FP16 to FP8.") else: # Keep non-FP16 tensors as is converted_tensors[key] = tensor # Save the converted tensors (FP8) and their scales using safetensors save_file(converted_tensors, output_path) print(f"Model saved in FP8 format to {output_path}") # Optionally, you could also save the scales used for each tensor for later recovery return scales # Example usage convert_model_to_fp8( 'flowgram.safetensors', # Input FP16 model 'flowgram_fp8.safetensors' # Output FP8 model )