| |
| |
|
|
| """ |
| 从 Llama 3.2-1B 权重初始化 LoopLlama 模型 |
| """ |
|
|
| import os |
| import torch |
| from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer |
| from configuration_llama import LlamaConfig |
| from modeling_llama import LoopLlamaForCausalLM |
|
|
| def setup_loopllama_from_pretrained( |
| source_model_path="meta-llama/Llama-3.2-1B", |
| target_path="./", |
| loop_times=2 |
| ): |
| """ |
| 从预训练的 Llama 模型创建 LoopLlama 模型 |
| |
| Args: |
| source_model_path: 源 Llama 模型路径 |
| target_path: 目标保存路径 |
| loop_times: 循环次数 |
| """ |
| print(f"Loading original Llama model from {source_model_path}...") |
| |
| |
| original_model = LlamaForCausalLM.from_pretrained( |
| source_model_path, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True |
| ) |
| original_config = original_model.config |
| |
| |
| try: |
| tokenizer = LlamaTokenizer.from_pretrained(source_model_path) |
| except: |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(source_model_path) |
| |
| print("Creating LoopLlama configuration...") |
| |
| |
| loop_config = LlamaConfig( |
| **original_config.to_dict(), |
| loop_times=loop_times |
| ) |
| |
| print(f"Creating LoopLlama model with {loop_times} loop times...") |
| |
| |
| loop_model = LoopLlamaForCausalLM(loop_config) |
| |
| print("Copying weights from original model...") |
| |
| |
| original_state_dict = original_model.state_dict() |
| loop_state_dict = loop_model.state_dict() |
| |
| |
| for key in loop_state_dict.keys(): |
| if key in original_state_dict: |
| print(f"Copying {key}") |
| loop_state_dict[key].copy_(original_state_dict[key]) |
| else: |
| print(f"Warning: {key} not found in original model") |
| |
| print(f"Saving LoopLlama model to {target_path}...") |
| |
| |
| loop_model.save_pretrained(target_path) |
| loop_config.save_pretrained(target_path) |
| tokenizer.save_pretrained(target_path) |
| |
| print("Setup completed!") |
| |
| |
| print("Verifying model loading...") |
| test_model = LoopLlamaForCausalLM.from_pretrained( |
| target_path, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 |
| ) |
| print(f"Model loaded successfully. Loop times: {test_model.config.loop_times}") |
| |
| return loop_model, tokenizer |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--source", default="/9950backfile/zjy_2/loopllama_cpt/loopllama-cpt/models/llama3_2-1B", help="Source Llama model") |
| parser.add_argument("--target", default="./", help="Target directory") |
| parser.add_argument("--loop_times", type=int, default=3, help="Number of loop times") |
| |
| args = parser.parse_args() |
| |
| setup_loopllama_from_pretrained( |
| source_model_path=args.source, |
| target_path=args.target, |
| loop_times=args.loop_times |
| ) |