| |
| import os |
| import json |
| import gc |
| import re |
| import torch |
| from safetensors.torch import load_file, save_file |
| from matrix_transform import NatalArchitectureTrimmer |
|
|
| class NatalPipeline: |
| def __init__(self, model_dir, output_dir, target_experts=16): |
| self.model_dir = model_dir |
| self.output_dir = output_dir |
| self.target_experts = target_experts |
| os.makedirs(output_dir, exist_ok=True) |
| |
| self.index_path = os.path.join(model_dir, "model.safetensors.index.json") |
| with open(self.index_path, "r") as f: |
| self.index_data = json.load(f) |
| |
| self.dominant_expert_indices = None |
| self.trimmer = NatalArchitectureTrimmer() |
|
|
| def analyze_experts_globally(self): |
| """ Scannt vorab alle Layer, um die echten dominanten Experten zu finden """ |
| print("[Natal] Starte globale Experten-Analyse über alle Shards...") |
| expert_magnitudes = {} |
| weight_map = self.index_data["weight_map"] |
| unique_shards = sorted(list(set(weight_map.values()))) |
| |
| for shard_name in unique_shards: |
| tensors = load_file(os.path.join(self.model_dir, shard_name)) |
| for name, tensor in tensors.items(): |
| if "mlp.experts" in name and tensor.ndim == 3: |
| |
| mag = tensor.float().abs().mean(dim=(-1, -2)) |
| expert_magnitudes[name] = mag |
| del tensors |
| gc.collect() |
|
|
| if expert_magnitudes: |
| stacked = torch.stack(list(expert_magnitudes.values())) |
| global_avg = stacked.mean(dim=0) |
| _, top_indices = torch.topk(global_avg, self.target_experts) |
| self.dominant_expert_indices = top_indices.sort().values |
| print(f"[Natal] Globale Top-{self.target_experts} Experten fixiert: {self.dominant_expert_indices.tolist()}") |
| else: |
| self.dominant_expert_indices = torch.arange(self.target_experts) |
|
|
| def process_layer_by_layer(self): |
| self.analyze_experts_globally() |
| |
| weight_map = self.index_data["weight_map"] |
| unique_shards = sorted(list(set(weight_map.values()))) |
| |
| for shard_name in unique_shards: |
| current_tensors = load_file(os.path.join(self.model_dir, shard_name)) |
| processed_tensors = {} |
| |
| for tensor_name, tensor_data in current_tensors.items(): |
| |
| if "mlp.experts" in name := tensor_name: |
| if tensor_data.ndim == 3: |
| tensor_data = torch.index_select(tensor_data, 0, self.dominant_expert_indices.to(tensor_data.device)) |
| elif "gate" in tensor_name: |
| |
| tensor_data = tensor_data[self.dominant_expert_indices, :] |
| tensor_data = torch.nn.functional.softmax(tensor_data, dim=-1) |
|
|
| |
| res = self.trimmer.process_layer_weights(tensor_name, tensor_data) |
| if res is not None: |
| processed_tensors.update(res) |
| |
| save_file(processed_tensors, os.path.join(self.output_dir, shard_name)) |
| print(f"[Natal] Shard {shard_name} echt komprimiert gesichert.") |
| del current_tensors, processed_tensors |
| gc.collect() |
|
|
| if __name__ == "__main__": |
| pipeline = NatalPipeline(model_dir="./weights/original", output_dir="./weights/natal_processed") |
| pipeline.process_layer_by_layer() |
|
|