# project_natal/src/matrix_transform.py import torch import re class NatalArchitectureTrimmer: def __init__(self): self.qkv_buffer = {} def extract_layer_idx(self, name: str) -> int or None: """ Extrahiert die Layer-Nummer robust über Regex. """ match = re.search(r'(?:layers|blocks|h)\.(\d+)\.', name) return int(match.group(1)) if match else None def pack_ternary_weights(self, tensor): """ Packt 4 ternäre Werte {-1, 0, 1} in ein einzelnes int8-Byte. Sorgt für die echte Kompression von 32-Bit auf 2-Bit! """ # Werte von {-1, 0, 1} auf {0, 1, 2} mappen (braucht 2 Bit) mapped = (tensor + 1).to(torch.int8) flat = mapped.view(-1) # Padding hinzufügen, falls die Anzahl nicht durch 4 teilbar ist remainder = (4 - len(flat) % 4) % 4 if remainder > 0: flat = torch.nn.functional.pad(flat, (0, remainder), value=1) # 1 entspricht der Null # 4 Werte in die Bit-Positionen eines Bytes schieben packed = (flat[0::4] | (flat[1::4] << 2) | (flat[2::4] << 4) | (flat[3::4] << 6)) return packed def process_layer_weights(self, name, tensor): """ Führt SubNorm-Terminierung, QKV-Fusion und echtes Bit-Packing aus. """ # 1. SubNorms löschen if "sub_norm" in name: return None # 2. QKV-Attention-Fusion (Bugfrei via Namens-Mapping) layer_idx = self.extract_layer_idx(name) if layer_idx is not None and any(k in name for k in ["wq", "wk", "wv"]): if layer_idx not in self.qkv_buffer: self.qkv_buffer[layer_idx] = {} key = "wq" if "wq" in name else ("wk" if "wk" in name else "wv") self.qkv_buffer[layer_idx][key] = tensor if len(self.qkv_buffer[layer_idx]) == 3: wq = self.qkv_buffer[layer_idx]["wq"] wk = self.qkv_buffer[layer_idx]["wk"] wv = self.qkv_buffer[layer_idx]["wv"] fused = torch.cat([wq, wk, wv], dim=0) del self.qkv_buffer[layer_idx] # Ternäre Skalierung berechnen scale = fused.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5) ternary = torch.clamp(torch.round(fused / scale), -1.0, 1.0) # Gewichte komprimieren & packen packed_weights = self.pack_ternary_weights(ternary) fused_name = f"model.layers.{layer_idx}.attention.qkv_proj.packed" return {fused_name: packed_weights, fused_name + ".scale"]: scale.half()} # float16 für Speicherplatz return None # 3. Standard-Layer quantisieren und packen if "weight" in name and tensor.ndim >= 2: scale = tensor.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5) ternary = torch.clamp(torch.round(tensor / scale), -1.0, 1.0) packed_weights = self.pack_ternary_weights(ternary) packed_name = name.replace(".weight", ".packed") return {packed_name: packed_weights, packed_name + ".scale": scale.half()} # Embeddings und LayerNorms unberührt lassen (wichtig für die minimale Logik-Stabilität) return {name: tensor}