# Copy from https://gist.github.com/Stella2211/10f5bd870387ec1ddb9932235321068e # This is a great work. import json from pathlib import Path import torch from tqdm import tqdm import struct from typing import Dict, Any import sys # input file if(len(sys.argv) < 3): print("Usage: mem_eff_fp8_convert.py {fp16 model path} {output path}") sys.exit(1) path = sys.argv[1] output =sys.argv[2] model_file = Path(path) class MemoryEfficientSafeOpen: # does not support metadata loading def __init__(self, filename): self.filename = filename self.header, self.header_size = self._read_header() self.file = open(filename, "rb") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] if offset_start == offset_end: tensor_bytes = None else: # adjust offset by header size self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): with open(self.filename, "rb") as f: header_size = struct.unpack(" Dict[str, str]: validated = {} for key, value in metadata.items(): if not isinstance(key, str): raise ValueError(f"Metadata key must be a string, got {type(key)}") if not isinstance(value, str): print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") validated[key] = str(value) else: validated[key] = value return validated header = {} offset = 0 if metadata: header["__metadata__"] = validate_metadata(metadata) for k, v in tensors.items(): if v.numel() == 0: # empty tensor header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} else: size = v.numel() * v.element_size() header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} offset += size hjson = json.dumps(header).encode("utf-8") hjson += b" " * (-(len(hjson) + 8) % _ALIGN) with open(filename, "wb") as f: f.write(struct.pack("