#!/usr/bin/env python3 """ Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM. Converts to LLaMA-compatible format for standard inference engines. Usage: python extract_llm.py --input ./HyperCLOVAX-SEED-Think-32B --output ./HyperCLOVAX-SEED-Text-Think-32B Requirements: pip install safetensors torch tqdm """ import argparse import json import os import shutil from pathlib import Path from collections import defaultdict from safetensors import safe_open from safetensors.torch import save_file import torch from tqdm import tqdm def load_weight_index(model_path: Path) -> dict: """Load the safetensors weight index file.""" index_path = model_path / "model.safetensors.index.json" with open(index_path, "r") as f: return json.load(f) def extract_llm_weights(model_path: Path, output_path: Path): """ Extract LLM weights from VLM. Key mapping: - model.language_model.model.* → model.* - model.language_model.lm_head.* → lm_head.* All vision encoder and MM projector weights are excluded. """ output_path.mkdir(parents=True, exist_ok=True) weight_index = load_weight_index(model_path) weight_map = weight_index["weight_map"] # Filter and remap LLM weights llm_weights = {} for key, shard_file in weight_map.items(): if key.startswith("model.language_model."): if key.startswith("model.language_model.model."): new_key = key.replace("model.language_model.model.", "model.") elif key.startswith("model.language_model.lm_head."): new_key = key.replace("model.language_model.", "") else: new_key = key.replace("model.language_model.", "") llm_weights[new_key] = (key, shard_file) print(f"Found {len(llm_weights)} LLM weight tensors") print(f"Excluded {len(weight_map) - len(llm_weights)} vision/projector tensors") # Group by source shard for efficient loading shard_to_weights = defaultdict(list) for new_key, (old_key, shard_file) in llm_weights.items(): shard_to_weights[shard_file].append((old_key, new_key)) # Load all LLM tensors all_tensors = {} shard_files = sorted(set(shard_to_weights.keys())) print(f"\nLoading weights from {len(shard_files)} shards...") for shard_file in tqdm(shard_files, desc="Loading shards"): shard_path = model_path / shard_file with safe_open(shard_path, framework="pt", device="cpu") as f: for old_key, new_key in shard_to_weights[shard_file]: tensor = f.get_tensor(old_key) all_tensors[new_key] = tensor print(f"\nTotal tensors extracted: {len(all_tensors)}") total_size = sum(t.numel() * t.element_size() for t in all_tensors.values()) print(f"Total size: {total_size / 1e9:.2f} GB") # Save as sharded safetensors (~5GB per shard) max_shard_size = 5 * 1024 * 1024 * 1024 print("\nSaving extracted weights...") save_sharded_safetensors(all_tensors, output_path, max_shard_size) return list(all_tensors.keys()) def save_sharded_safetensors(tensors: dict, output_path: Path, max_shard_size: int): """Save tensors as sharded safetensors files with index.""" sorted_keys = sorted(tensors.keys()) shards = [] current_shard = {} current_size = 0 shard_idx = 1 weight_map = {} for key in sorted_keys: tensor = tensors[key] tensor_size = tensor.numel() * tensor.element_size() if current_size + tensor_size > max_shard_size and current_shard: shards.append((shard_idx, current_shard)) shard_idx += 1 current_shard = {} current_size = 0 current_shard[key] = tensor current_size += tensor_size if current_shard: shards.append((shard_idx, current_shard)) total_shards = len(shards) total_size = sum(t.numel() * t.element_size() for t in tensors.values()) for shard_idx, shard_tensors in tqdm(shards, desc="Saving shards"): shard_name = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors" shard_path = output_path / shard_name save_file(shard_tensors, shard_path) for key in shard_tensors.keys(): weight_map[key] = shard_name # Create index file index = { "metadata": {"total_size": total_size}, "weight_map": weight_map } index_path = output_path / "model.safetensors.index.json" with open(index_path, "w") as f: json.dump(index, f, indent=2) print(f"Saved {total_shards} shards to {output_path}") def create_llama_config(original_config_path: Path, output_path: Path): """ Create LLaMA-compatible config from VLM config. Note: HyperCLOVAX uses attention_multiplier ≈ 1/sqrt(head_dim) which matches standard LLaMA scaled dot-product attention. """ with open(original_config_path, "r") as f: vlm_config = json.load(f) text_config = vlm_config["text_config"] llama_config = { "architectures": ["LlamaForCausalLM"], "attention_bias": text_config.get("attention_bias", False), "attention_dropout": text_config.get("attention_dropout", 0.0), "bos_token_id": text_config.get("bos_token_id", 128000), "eos_token_id": text_config.get("eos_token_id", 128001), "head_dim": text_config.get("head_dim", 128), "hidden_act": text_config.get("hidden_act", "silu"), "hidden_size": text_config.get("hidden_size", 5120), "initializer_range": text_config.get("initializer_range", 0.006), "intermediate_size": text_config.get("intermediate_size", 24192), "max_position_embeddings": text_config.get("max_position_embeddings", 131072), "mlp_bias": text_config.get("mlp_bias", False), "model_type": "llama", "num_attention_heads": text_config.get("num_attention_heads", 40), "num_hidden_layers": text_config.get("num_hidden_layers", 72), "num_key_value_heads": text_config.get("num_key_value_heads", 8), "pad_token_id": text_config.get("pad_token_id", 0), "pretraining_tp": 1, "rms_norm_eps": text_config.get("rms_norm_eps", 1e-05), "rope_scaling": text_config.get("rope_scaling", None), "rope_theta": text_config.get("rope_theta", 50000000), "tie_word_embeddings": text_config.get("tie_word_embeddings", False), "torch_dtype": "bfloat16", "transformers_version": "4.52.4", "use_cache": True, "vocab_size": text_config.get("vocab_size", 128256), } config_path = output_path / "config.json" with open(config_path, "w") as f: json.dump(llama_config, f, indent=2) print(f"Saved LLaMA config to {config_path}") # Generation config gen_config = { "bos_token_id": llama_config["bos_token_id"], "eos_token_id": llama_config["eos_token_id"], "pad_token_id": llama_config["pad_token_id"], "do_sample": True, "temperature": 0.7, "top_p": 0.9, "max_length": 4096 } gen_config_path = output_path / "generation_config.json" with open(gen_config_path, "w") as f: json.dump(gen_config, f, indent=2) return llama_config def copy_tokenizer_files(original_path: Path, output_path: Path): """Copy tokenizer files from original model.""" tokenizer_files = [ "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "added_tokens.json", "vocab.json", "merges.txt", "chat_template.jinja" ] copied = [] for fname in tokenizer_files: src = original_path / fname if src.exists(): dst = output_path / fname shutil.copy2(src, dst) copied.append(fname) print(f"Copied tokenizer files: {copied}") def main(): parser = argparse.ArgumentParser( description="Extract text-only LLM from HyperCLOVAX-SEED-Think-32B VLM", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Example: # Download original VLM huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Think-32B \\ --local-dir ./HyperCLOVAX-SEED-Think-32B # Extract text-only LLM python extract_llm.py \\ --input ./HyperCLOVAX-SEED-Think-32B \\ --output ./HyperCLOVAX-SEED-Text-Think-32B """ ) parser.add_argument( "--input", "-i", type=Path, required=True, help="Path to original HyperCLOVAX-SEED-Think-32B VLM" ) parser.add_argument( "--output", "-o", type=Path, required=True, help="Output path for extracted text-only LLM" ) args = parser.parse_args() if not args.input.exists(): print(f"Error: Input path does not exist: {args.input}") return 1 if not (args.input / "model.safetensors.index.json").exists(): print(f"Error: model.safetensors.index.json not found in {args.input}") return 1 print("=" * 60) print("HyperCLOVAX VLM → Text-only LLM Extraction") print("=" * 60) print(f"Input: {args.input}") print(f"Output: {args.output}") print("\n[Step 1] Extracting LLM weights...") extracted_keys = extract_llm_weights(args.input, args.output) print("\n[Step 2] Creating LLaMA-compatible config...") config = create_llama_config(args.input / "config.json", args.output) print("\n[Step 3] Copying tokenizer files...") copy_tokenizer_files(args.input, args.output) print("\n" + "=" * 60) print("Extraction complete!") print(f"Output: {args.output}") print("=" * 60) print(f"\nModel summary:") print(f" - Architecture: LlamaForCausalLM") print(f" - Hidden size: {config['hidden_size']}") print(f" - Layers: {config['num_hidden_layers']}") print(f" - Attention heads: {config['num_attention_heads']}") print(f" - KV heads: {config['num_key_value_heads']}") print(f" - Vocab size: {config['vocab_size']}") print(f" - Max context: {config['max_position_embeddings']}") print(f"\nYou can now use the model with vLLM, transformers, or other LLaMA-compatible frameworks.") return 0 if __name__ == "__main__": exit(main())