| import torch |
| from safetensors.torch import load_file, save_file |
|
|
| def fp16_to_fp8(tensor): |
| """Convert a tensor from FP16 to an approximate FP8 using quantization.""" |
| |
| scale = tensor.abs().max() / 127.0 |
| tensor_fp8 = (tensor / scale).clamp(-127, 127).to(torch.int8) |
| return tensor_fp8, scale |
|
|
| def fp8_to_fp16(tensor_fp8, scale): |
| """Convert FP8 tensor back to FP16 using the stored scale.""" |
| return tensor_fp8.to(torch.float16) * scale |
|
|
| def convert_model_to_fp8(model_path, output_path): |
| |
| tensors = load_file(model_path) |
| |
| |
| converted_tensors = {} |
| scales = {} |
|
|
| for key, tensor in tensors.items(): |
| if tensor.dtype == torch.float16: |
| |
| tensor_fp8, scale = fp16_to_fp8(tensor) |
| converted_tensors[key] = tensor_fp8 |
| scales[key] = scale |
| print(f"Converted tensor {key} from FP16 to FP8.") |
| else: |
| |
| converted_tensors[key] = tensor |
| |
| |
| save_file(converted_tensors, output_path) |
| print(f"Model saved in FP8 format to {output_path}") |
|
|
| |
| return scales |
|
|
| |
| convert_model_to_fp8( |
| 'flowgram.safetensors', |
| 'flowgram_fp8.safetensors' |
| ) |