MiMo-V2.5-QEdge / matrix_transform.py
Qapdex's picture
Update matrix_transform.py
796be71 verified
Raw
History Blame Contribute Delete
3.46 kB
# project_natal/src/matrix_transform.py
import torch
import re
class NatalArchitectureTrimmer:
def __init__(self):
self.qkv_buffer = {}
def extract_layer_idx(self, name: str) -> int or None:
""" Extrahiert die Layer-Nummer robust über Regex. """
match = re.search(r'(?:layers|blocks|h)\.(\d+)\.', name)
return int(match.group(1)) if match else None
def pack_ternary_weights(self, tensor):
"""
Packt 4 ternäre Werte {-1, 0, 1} in ein einzelnes int8-Byte.
Sorgt für die echte Kompression von 32-Bit auf 2-Bit!
"""
# Werte von {-1, 0, 1} auf {0, 1, 2} mappen (braucht 2 Bit)
mapped = (tensor + 1).to(torch.int8)
flat = mapped.view(-1)
# Padding hinzufügen, falls die Anzahl nicht durch 4 teilbar ist
remainder = (4 - len(flat) % 4) % 4
if remainder > 0:
flat = torch.nn.functional.pad(flat, (0, remainder), value=1) # 1 entspricht der Null
# 4 Werte in die Bit-Positionen eines Bytes schieben
packed = (flat[0::4] |
(flat[1::4] << 2) |
(flat[2::4] << 4) |
(flat[3::4] << 6))
return packed
def process_layer_weights(self, name, tensor):
"""
Führt SubNorm-Terminierung, QKV-Fusion und echtes Bit-Packing aus.
"""
# 1. SubNorms löschen
if "sub_norm" in name:
return None
# 2. QKV-Attention-Fusion (Bugfrei via Namens-Mapping)
layer_idx = self.extract_layer_idx(name)
if layer_idx is not None and any(k in name for k in ["wq", "wk", "wv"]):
if layer_idx not in self.qkv_buffer:
self.qkv_buffer[layer_idx] = {}
key = "wq" if "wq" in name else ("wk" if "wk" in name else "wv")
self.qkv_buffer[layer_idx][key] = tensor
if len(self.qkv_buffer[layer_idx]) == 3:
wq = self.qkv_buffer[layer_idx]["wq"]
wk = self.qkv_buffer[layer_idx]["wk"]
wv = self.qkv_buffer[layer_idx]["wv"]
fused = torch.cat([wq, wk, wv], dim=0)
del self.qkv_buffer[layer_idx]
# Ternäre Skalierung berechnen
scale = fused.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5)
ternary = torch.clamp(torch.round(fused / scale), -1.0, 1.0)
# Gewichte komprimieren & packen
packed_weights = self.pack_ternary_weights(ternary)
fused_name = f"model.layers.{layer_idx}.attention.qkv_proj.packed"
return {fused_name: packed_weights, fused_name + ".scale"]: scale.half()} # float16 für Speicherplatz
return None
# 3. Standard-Layer quantisieren und packen
if "weight" in name and tensor.ndim >= 2:
scale = tensor.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5)
ternary = torch.clamp(torch.round(tensor / scale), -1.0, 1.0)
packed_weights = self.pack_ternary_weights(ternary)
packed_name = name.replace(".weight", ".packed")
return {packed_name: packed_weights, packed_name + ".scale": scale.half()}
# Embeddings und LayerNorms unberührt lassen (wichtig für die minimale Logik-Stabilität)
return {name: tensor}