|
|
|
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
import struct |
|
|
from typing import Dict, Any |
|
|
import sys |
|
|
|
|
|
|
|
|
if(len(sys.argv) < 3): |
|
|
print("Usage: mem_eff_fp8_convert.py {fp16 model path} {output path}") |
|
|
sys.exit(1) |
|
|
|
|
|
path = sys.argv[1] |
|
|
output =sys.argv[2] |
|
|
model_file = Path(path) |
|
|
|
|
|
class MemoryEfficientSafeOpen: |
|
|
|
|
|
def __init__(self, filename): |
|
|
self.filename = filename |
|
|
self.header, self.header_size = self._read_header() |
|
|
self.file = open(filename, "rb") |
|
|
|
|
|
def __enter__(self): |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.file.close() |
|
|
|
|
|
def keys(self): |
|
|
return [k for k in self.header.keys() if k != "__metadata__"] |
|
|
|
|
|
def get_tensor(self, key): |
|
|
if key not in self.header: |
|
|
raise KeyError(f"Tensor '{key}' not found in the file") |
|
|
|
|
|
metadata = self.header[key] |
|
|
offset_start, offset_end = metadata["data_offsets"] |
|
|
|
|
|
if offset_start == offset_end: |
|
|
tensor_bytes = None |
|
|
else: |
|
|
|
|
|
self.file.seek(self.header_size + 8 + offset_start) |
|
|
tensor_bytes = self.file.read(offset_end - offset_start) |
|
|
|
|
|
return self._deserialize_tensor(tensor_bytes, metadata) |
|
|
|
|
|
def _read_header(self): |
|
|
with open(self.filename, "rb") as f: |
|
|
header_size = struct.unpack("<Q", f.read(8))[0] |
|
|
header_json = f.read(header_size).decode("utf-8") |
|
|
return json.loads(header_json), header_size |
|
|
|
|
|
def _deserialize_tensor(self, tensor_bytes, metadata): |
|
|
dtype = self._get_torch_dtype(metadata["dtype"]) |
|
|
shape = metadata["shape"] |
|
|
|
|
|
if tensor_bytes is None: |
|
|
byte_tensor = torch.empty(0, dtype=torch.uint8) |
|
|
else: |
|
|
tensor_bytes = bytearray(tensor_bytes) |
|
|
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) |
|
|
|
|
|
|
|
|
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: |
|
|
return self._convert_float8(byte_tensor, metadata["dtype"], shape) |
|
|
|
|
|
|
|
|
return byte_tensor.view(dtype).reshape(shape) |
|
|
|
|
|
@staticmethod |
|
|
def _get_torch_dtype(dtype_str): |
|
|
dtype_map = { |
|
|
"F64": torch.float64, |
|
|
"F32": torch.float32, |
|
|
"F16": torch.float16, |
|
|
"BF16": torch.bfloat16, |
|
|
"I64": torch.int64, |
|
|
"I32": torch.int32, |
|
|
"I16": torch.int16, |
|
|
"I8": torch.int8, |
|
|
"U8": torch.uint8, |
|
|
"BOOL": torch.bool, |
|
|
} |
|
|
|
|
|
if hasattr(torch, "float8_e5m2"): |
|
|
dtype_map["F8_E5M2"] = torch.float8_e5m2 |
|
|
if hasattr(torch, "float8_e4m3fn"): |
|
|
dtype_map["F8_E4M3"] = torch.float8_e4m3fn |
|
|
return dtype_map.get(dtype_str) |
|
|
|
|
|
@staticmethod |
|
|
def _convert_float8(byte_tensor, dtype_str, shape): |
|
|
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): |
|
|
return byte_tensor.view(torch.float8_e5m2).reshape(shape) |
|
|
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): |
|
|
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") |
|
|
|
|
|
|
|
|
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): |
|
|
_TYPES = { |
|
|
torch.float64: "F64", |
|
|
torch.float32: "F32", |
|
|
torch.float16: "F16", |
|
|
torch.bfloat16: "BF16", |
|
|
torch.int64: "I64", |
|
|
torch.int32: "I32", |
|
|
torch.int16: "I16", |
|
|
torch.int8: "I8", |
|
|
torch.uint8: "U8", |
|
|
torch.bool: "BOOL", |
|
|
getattr(torch, "float8_e5m2", None): "F8_E5M2", |
|
|
getattr(torch, "float8_e4m3fn", None): "F8_E4M3", |
|
|
} |
|
|
_ALIGN = 256 |
|
|
|
|
|
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: |
|
|
validated = {} |
|
|
for key, value in metadata.items(): |
|
|
if not isinstance(key, str): |
|
|
raise ValueError(f"Metadata key must be a string, got {type(key)}") |
|
|
if not isinstance(value, str): |
|
|
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") |
|
|
validated[key] = str(value) |
|
|
else: |
|
|
validated[key] = value |
|
|
return validated |
|
|
|
|
|
header = {} |
|
|
offset = 0 |
|
|
if metadata: |
|
|
header["__metadata__"] = validate_metadata(metadata) |
|
|
for k, v in tensors.items(): |
|
|
if v.numel() == 0: |
|
|
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} |
|
|
else: |
|
|
size = v.numel() * v.element_size() |
|
|
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} |
|
|
offset += size |
|
|
|
|
|
hjson = json.dumps(header).encode("utf-8") |
|
|
hjson += b" " * (-(len(hjson) + 8) % _ALIGN) |
|
|
|
|
|
with open(filename, "wb") as f: |
|
|
f.write(struct.pack("<Q", len(hjson))) |
|
|
f.write(hjson) |
|
|
|
|
|
for k, v in tensors.items(): |
|
|
if v.numel() == 0: |
|
|
continue |
|
|
if v.is_cuda: |
|
|
|
|
|
with torch.cuda.device(v.device): |
|
|
if v.dim() == 0: |
|
|
v = v.unsqueeze(0) |
|
|
tensor_bytes = v.contiguous().view(torch.uint8) |
|
|
tensor_bytes.cpu().numpy().tofile(f) |
|
|
else: |
|
|
|
|
|
if v.dim() == 0: |
|
|
v = v.unsqueeze(0) |
|
|
v.contiguous().view(torch.uint8).numpy().tofile(f) |
|
|
|
|
|
|
|
|
|
|
|
def read_safetensors_metadata(path: str): |
|
|
with open(path, 'rb') as f: |
|
|
header_size = int.from_bytes(f.read(8), 'little') |
|
|
header_json = f.read(header_size).decode('utf-8') |
|
|
header = json.loads(header_json) |
|
|
metadata = header.get('__metadata__', {}) |
|
|
return metadata |
|
|
|
|
|
metadata = read_safetensors_metadata(path) |
|
|
print(json.dumps(metadata, indent=4)) |
|
|
|
|
|
sd_pruned = dict() |
|
|
|
|
|
with MemoryEfficientSafeOpen(path) as reader: |
|
|
keys = reader.keys() |
|
|
for key in tqdm(keys): |
|
|
sd_pruned[key] = reader.get_tensor(key).to(torch.float8_e4m3fn) |
|
|
|
|
|
|
|
|
mem_eff_save_file(sd_pruned, output, metadata={"format": "pt", **metadata}) |