File size: 5,703 Bytes
2020344 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import struct
import torch
import numpy as np
class HGSFormat:
MAGIC = b'HGS1'
@staticmethod
def quantize_tensor(tensor: torch.Tensor, bits=8):
tensor_cpu = tensor.detach().cpu()
min_val = tensor_cpu.min()
max_val = tensor_cpu.max()
scale = (max_val - min_val) / (2**bits - 1)
if scale == 0:
scale = 1e-8
q_tensor = ((tensor_cpu - min_val) / scale).round().clamp(0, 2**bits - 1).to(torch.uint8)
return q_tensor.numpy(), float(min_val), float(scale)
@staticmethod
def dequantize_tensor(q_data, min_val, scale):
return torch.tensor(q_data, dtype=torch.float32) * scale + min_val
@staticmethod
def save_hgs(model_state_dict, filepath):
layers = list(model_state_dict.items())
with open(filepath, 'wb') as f:
f.write(HGSFormat.MAGIC)
f.write(struct.pack('<I', len(layers)))
index_pos = f.tell()
index_data = []
for name, tensor in layers:
name_bytes = name.encode('utf-8')
q_data, min_val, scale = HGSFormat.quantize_tensor(tensor, bits=8)
index_data.append((name_bytes, tensor.shape, tensor.numel(), min_val, scale, 0, q_data.nbytes))
f.write(struct.pack('<I', len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack('<I', len(tensor.shape)))
for dim in tensor.shape:
f.write(struct.pack('<Q', dim))
f.write(struct.pack('<f', min_val))
f.write(struct.pack('<f', scale))
f.write(struct.pack('<Q', tensor.numel()))
f.write(struct.pack('<Q', q_data.nbytes))
f.write(struct.pack('<Q', 0))
offsets = []
for _, _, _, _, _, _, size_bytes in index_data:
offsets.append(f.tell())
q_data = index_data.pop(0)[6]
@staticmethod
def save_hgs(model_state_dict, filepath):
layers = list(model_state_dict.items())
with open(filepath, 'wb') as f:
f.write(HGSFormat.MAGIC)
f.write(struct.pack('<I', len(layers)))
index_pos = f.tell()
index_data = []
quantized_datas = []
for name, tensor in layers:
name_bytes = name.encode('utf-8')
q_data, min_val, scale = HGSFormat.quantize_tensor(tensor, bits=8)
index_data.append({
'name_bytes': name_bytes,
'shape': tensor.shape,
'numel': tensor.numel(),
'min_val': min_val,
'scale': scale,
'q_data_len': q_data.nbytes,
'offset': 0
})
quantized_datas.append(q_data)
for entry in index_data:
f.write(struct.pack('<I', len(entry['name_bytes'])))
f.write(entry['name_bytes'])
f.write(struct.pack('<I', len(entry['shape'])))
for dim in entry['shape']:
f.write(struct.pack('<Q', dim))
f.write(struct.pack('<f', entry['min_val']))
f.write(struct.pack('<f', entry['scale']))
f.write(struct.pack('<Q', entry['numel']))
f.write(struct.pack('<Q', entry['q_data_len']))
f.write(struct.pack('<Q', 0))
for i, q_data in enumerate(quantized_datas):
index_data[i]['offset'] = f.tell()
f.write(q_data.tobytes())
f.seek(index_pos)
for entry in index_data:
f.seek(f.tell() + 4 + len(entry['name_bytes']) + 4 + 8 * len(entry['shape']) + 4 + 4 + 8 + 8)
f.write(struct.pack('<Q', entry['offset']))
@staticmethod
def load_hgs(filepath, device='cpu'):
state_dict = {}
with open(filepath, 'rb') as f:
magic = f.read(4)
if magic != HGSFormat.MAGIC:
raise RuntimeError("Неверный формат файла")
num_layers = struct.unpack('<I', f.read(4))[0]
index = []
for _ in range(num_layers):
name_len = struct.unpack('<I', f.read(4))[0]
name = f.read(name_len).decode('utf-8')
shape_len = struct.unpack('<I', f.read(4))[0]
shape = tuple(struct.unpack('<Q', f.read(8))[0] for _ in range(shape_len))
min_val = struct.unpack('<f', f.read(4))[0]
scale = struct.unpack('<f', f.read(4))[0]
numel = struct.unpack('<Q', f.read(8))[0]
q_data_len = struct.unpack('<Q', f.read(8))[0]
offset = struct.unpack('<Q', f.read(8))[0]
index.append({
'name': name,
'shape': shape,
'min_val': min_val,
'scale': scale,
'numel': numel,
'q_data_len': q_data_len,
'offset': offset
})
for entry in index:
f.seek(entry['offset'])
q_data_bytes = f.read(entry['q_data_len'])
q_data = np.frombuffer(q_data_bytes, dtype=np.uint8)
tensor = HGSFormat.dequantize_tensor(q_data, entry['min_val'], entry['scale'])
tensor = tensor.reshape(entry['shape']).to(device)
state_dict[entry['name']] = tensor
return state_dict
|