| """ |
| Extract the language model (text-only) weights from Gemma 4 multimodal safetensors. |
| |
| - Filters keys containing 'language_model' |
| - Renames: model.language_model.X -> model.X |
| - Saves as sharded safetensors (10GB per shard) |
| - Generates model.safetensors.index.json |
| """ |
|
|
| import glob |
| import json |
| import os |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
| SRC_DIR = "/workspace/llm/gemma-4-31B-it" |
| DST_DIR = "/workspace/llm/gemma-4-31B-Text" |
| MAX_SHARD_SIZE = 10 * 1024 * 1024 * 1024 |
|
|
|
|
| def main(): |
| os.makedirs(DST_DIR, exist_ok=True) |
|
|
| src_files = sorted(glob.glob(os.path.join(SRC_DIR, "*.safetensors"))) |
| print(f"Source files: {len(src_files)}") |
|
|
| |
| all_tensors = {} |
| for path in src_files: |
| print(f"Reading {os.path.basename(path)}...") |
| with safe_open(path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| if "language_model" in key: |
| new_key = key.replace("model.language_model.", "model.") |
| all_tensors[new_key] = f.get_tensor(key) |
|
|
| print(f"Extracted {len(all_tensors)} tensors") |
|
|
| |
| shards = [] |
| current_shard = {} |
| current_size = 0 |
|
|
| for key in sorted(all_tensors.keys()): |
| tensor = all_tensors[key] |
| tensor_size = tensor.nelement() * tensor.element_size() |
|
|
| if current_shard and current_size + tensor_size > MAX_SHARD_SIZE: |
| shards.append(current_shard) |
| current_shard = {} |
| current_size = 0 |
|
|
| current_shard[key] = tensor |
| current_size += tensor_size |
|
|
| if current_shard: |
| shards.append(current_shard) |
|
|
| print(f"Splitting into {len(shards)} shards") |
|
|
| |
| total_shards = len(shards) |
| weight_map = {} |
|
|
| for i, shard in enumerate(shards): |
| filename = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors" |
| filepath = os.path.join(DST_DIR, filename) |
| shard_size = sum(t.nelement() * t.element_size() for t in shard.values()) |
| print(f"Saving {filename} ({shard_size / 1e9:.2f} GB, {len(shard)} tensors)...") |
| save_file(shard, filepath) |
|
|
| for key in shard: |
| weight_map[key] = filename |
|
|
| |
| index = { |
| "metadata": {"total_size": sum(t.nelement() * t.element_size() for t in all_tensors.values())}, |
| "weight_map": weight_map, |
| } |
| index_path = os.path.join(DST_DIR, "model.safetensors.index.json") |
| with open(index_path, "w") as f: |
| json.dump(index, f, indent=2) |
|
|
| print(f"Done! Index written to {index_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|