# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import os import gguf import torch import argparse from tqdm import tqdm from safetensors.torch import load_file def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--src", required=True) parser.add_argument("--dst", required=True) parser.add_argument("--fix", required=False, help="Defaults to ./fix_5d_tensors_[arch].pt") parser.add_argument("--overwrite", action="store_true") args = parser.parse_args() if not os.path.isfile(args.src): parser.error(f"Invalid source file '{args.src}'") if not args.overwrite and os.path.exists(args.dst): parser.error(f"Output exists, use '--overwrite' ({args.dst})") return args def get_arch_str(reader): field = reader.get_field("general.architecture") return str(field.parts[field.data[-1]], encoding="utf-8") def get_file_type(reader): field = reader.get_field("general.file_type") ft = int(field.parts[field.data[-1]]) return gguf.LlamaFileType(ft) if __name__ == "__main__": args = get_args() # read existing reader = gguf.GGUFReader(args.src) arch = get_arch_str(reader) file_type = get_file_type(reader) print(f"Detected arch: '{arch}' (ftype: {str(file_type)})") # prep fix if args.fix is None: args.fix = f"./fix_5d_tensors_{arch}.safetensors" if not os.path.isfile(args.fix): raise OSError(f"No 5D tensor fix file: {args.fix}") sd5d = load_file(args.fix) sd5d = {k:v.numpy() for k,v in sd5d.items()} print("5D tensors:", sd5d.keys()) # prep output writer = gguf.GGUFWriter(path=None, arch=arch) writer.add_quantization_version(gguf.GGML_QUANT_VERSION) writer.add_file_type(file_type) added = [] def add_extra_key(writer, key, data): global added data_qtype = gguf.GGMLQuantizationType.F32 data = gguf.quants.quantize(data, data_qtype) tqdm.write(f"Adding key {key} ({data.shape})") writer.add_tensor(key, data, raw_dtype=data_qtype) added.append(key) # main loop to add missing 5D tensor(s) for tensor in tqdm(reader.tensors): writer.add_tensor(tensor.name, tensor.data, raw_dtype=tensor.tensor_type) key5d = tensor.name.replace(".bias", ".weight") if key5d in sd5d.keys(): add_extra_key(writer, key5d, sd5d[key5d]) # brute force for any missed for key, data in sd5d.items(): if key not in added: add_extra_key(writer, key, data) writer.write_header_to_file(path=args.dst) writer.write_kv_data_to_file() writer.write_tensors_to_file(progress=True) writer.close()