MiMo-V2.5-QEdge / map_compressor.py
Qapdex's picture
Update map_compressor.py
1632b11 verified
Raw
History Blame Contribute Delete
3.71 kB
# project_natal/src/map_compressor.py
import os
import json
import gc
import re
import torch
from safetensors.torch import load_file, save_file
from matrix_transform import NatalArchitectureTrimmer
class NatalPipeline:
def __init__(self, model_dir, output_dir, target_experts=16):
self.model_dir = model_dir
self.output_dir = output_dir
self.target_experts = target_experts
os.makedirs(output_dir, exist_ok=True)
self.index_path = os.path.join(model_dir, "model.safetensors.index.json")
with open(self.index_path, "r") as f:
self.index_data = json.load(f)
self.dominant_expert_indices = None
self.trimmer = NatalArchitectureTrimmer()
def analyze_experts_globally(self):
""" Scannt vorab alle Layer, um die echten dominanten Experten zu finden """
print("[Natal] Starte globale Experten-Analyse über alle Shards...")
expert_magnitudes = {}
weight_map = self.index_data["weight_map"]
unique_shards = sorted(list(set(weight_map.values())))
for shard_name in unique_shards:
tensors = load_file(os.path.join(self.model_dir, shard_name))
for name, tensor in tensors.items():
if "mlp.experts" in name and tensor.ndim == 3:
# Signalstärke der 384 Experten summieren
mag = tensor.float().abs().mean(dim=(-1, -2))
expert_magnitudes[name] = mag
del tensors
gc.collect()
if expert_magnitudes:
stacked = torch.stack(list(expert_magnitudes.values()))
global_avg = stacked.mean(dim=0)
_, top_indices = torch.topk(global_avg, self.target_experts)
self.dominant_expert_indices = top_indices.sort().values
print(f"[Natal] Globale Top-{self.target_experts} Experten fixiert: {self.dominant_expert_indices.tolist()}")
else:
self.dominant_expert_indices = torch.arange(self.target_experts)
def process_layer_by_layer(self):
self.analyze_experts_globally()
weight_map = self.index_data["weight_map"]
unique_shards = sorted(list(set(weight_map.values())))
for shard_name in unique_shards:
current_tensors = load_file(os.path.join(self.model_dir, shard_name))
processed_tensors = {}
for tensor_name, tensor_data in current_tensors.items():
# MoE Experten Slicing vorab
if "mlp.experts" in name := tensor_name:
if tensor_data.ndim == 3:
tensor_data = torch.index_select(tensor_data, 0, self.dominant_expert_indices.to(tensor_data.device))
elif "gate" in tensor_name:
# Routing-Gate neu normalisieren, wie von MiMo empfohlen
tensor_data = tensor_data[self.dominant_expert_indices, :]
tensor_data = torch.nn.functional.softmax(tensor_data, dim=-1)
# An den Bit-Packer übergeben
res = self.trimmer.process_layer_weights(tensor_name, tensor_data)
if res is not None:
processed_tensors.update(res)
save_file(processed_tensors, os.path.join(self.output_dir, shard_name))
print(f"[Natal] Shard {shard_name} echt komprimiert gesichert.")
del current_tensors, processed_tensors
gc.collect()
if __name__ == "__main__":
pipeline = NatalPipeline(model_dir="./weights/original", output_dir="./weights/natal_processed")
pipeline.process_layer_by_layer()