File size: 7,115 Bytes
c5a22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# Copy from https://gist.github.com/Stella2211/10f5bd870387ec1ddb9932235321068e
# This is a great work.
import json
from pathlib import Path
import torch
from tqdm import tqdm
import struct
from typing import Dict, Any
import sys
# input file
if(len(sys.argv) < 3):
print("Usage: mem_eff_fp8_convert.py {fp16 model path} {output path}")
sys.exit(1)
path = sys.argv[1]
output =sys.argv[2]
model_file = Path(path)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
# read safetensors metadata
def read_safetensors_metadata(path: str):
with open(path, 'rb') as f:
header_size = int.from_bytes(f.read(8), 'little')
header_json = f.read(header_size).decode('utf-8')
header = json.loads(header_json)
metadata = header.get('__metadata__', {})
return metadata
metadata = read_safetensors_metadata(path)
print(json.dumps(metadata, indent=4)) #show metadata
sd_pruned = dict() #initialize empty dict
with MemoryEfficientSafeOpen(path) as reader:
keys = reader.keys()
for key in tqdm(keys): #for each key in the safetensors file
sd_pruned[key] = reader.get_tensor(key).to(torch.float8_e4m3fn) #convert to fp8
# save the pruned safetensors file
mem_eff_save_file(sd_pruned, output, metadata={"format": "pt", **metadata}) |