| |
| import torch |
| import re |
|
|
| class NatalArchitectureTrimmer: |
| def __init__(self): |
| self.qkv_buffer = {} |
|
|
| def extract_layer_idx(self, name: str) -> int or None: |
| """ Extrahiert die Layer-Nummer robust über Regex. """ |
| match = re.search(r'(?:layers|blocks|h)\.(\d+)\.', name) |
| return int(match.group(1)) if match else None |
|
|
| def pack_ternary_weights(self, tensor): |
| """ |
| Packt 4 ternäre Werte {-1, 0, 1} in ein einzelnes int8-Byte. |
| Sorgt für die echte Kompression von 32-Bit auf 2-Bit! |
| """ |
| |
| mapped = (tensor + 1).to(torch.int8) |
| flat = mapped.view(-1) |
| |
| |
| remainder = (4 - len(flat) % 4) % 4 |
| if remainder > 0: |
| flat = torch.nn.functional.pad(flat, (0, remainder), value=1) |
| |
| |
| packed = (flat[0::4] | |
| (flat[1::4] << 2) | |
| (flat[2::4] << 4) | |
| (flat[3::4] << 6)) |
| return packed |
|
|
| def process_layer_weights(self, name, tensor): |
| """ |
| Führt SubNorm-Terminierung, QKV-Fusion und echtes Bit-Packing aus. |
| """ |
| |
| if "sub_norm" in name: |
| return None |
|
|
| |
| layer_idx = self.extract_layer_idx(name) |
| if layer_idx is not None and any(k in name for k in ["wq", "wk", "wv"]): |
| if layer_idx not in self.qkv_buffer: |
| self.qkv_buffer[layer_idx] = {} |
| |
| key = "wq" if "wq" in name else ("wk" if "wk" in name else "wv") |
| self.qkv_buffer[layer_idx][key] = tensor |
| |
| if len(self.qkv_buffer[layer_idx]) == 3: |
| wq = self.qkv_buffer[layer_idx]["wq"] |
| wk = self.qkv_buffer[layer_idx]["wk"] |
| wv = self.qkv_buffer[layer_idx]["wv"] |
| |
| fused = torch.cat([wq, wk, wv], dim=0) |
| del self.qkv_buffer[layer_idx] |
| |
| |
| scale = fused.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5) |
| ternary = torch.clamp(torch.round(fused / scale), -1.0, 1.0) |
| |
| |
| packed_weights = self.pack_ternary_weights(ternary) |
| |
| fused_name = f"model.layers.{layer_idx}.attention.qkv_proj.packed" |
| return {fused_name: packed_weights, fused_name + ".scale"]: scale.half()} |
| return None |
|
|
| |
| if "weight" in name and tensor.ndim >= 2: |
| scale = tensor.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5) |
| ternary = torch.clamp(torch.round(tensor / scale), -1.0, 1.0) |
| packed_weights = self.pack_ternary_weights(ternary) |
| |
| packed_name = name.replace(".weight", ".packed") |
| return {packed_name: packed_weights, packed_name + ".scale": scale.half()} |
|
|
| |
| return {name: tensor} |
|
|