Upload setup_fine_tune_mps.ipynb
Browse files- setup_fine_tune_mps.ipynb +107 -0
setup_fine_tune_mps.ipynb
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "5aa6bfeb",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# === 1. 在所有导入之前设置环境变量 ===\n",
|
| 11 |
+
"import os\n",
|
| 12 |
+
"from pathlib import Path\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"# 强制指定模板路径\n",
|
| 15 |
+
"custom_template_dir = Path(\"./custom_templates\").resolve()\n",
|
| 16 |
+
"os.environ[\"HUGGINGFACE_HUB_TEMPLATES_PATH\"] = str(custom_template_dir)\n",
|
| 17 |
+
"os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"# === 2. 初始化模板 ===\n",
|
| 20 |
+
"custom_template_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 21 |
+
"template_file = custom_template_dir / \"modelcard_template.md\"\n",
|
| 22 |
+
"template_content = \"\"\"---\n",
|
| 23 |
+
"language: en\n",
|
| 24 |
+
"tags:\n",
|
| 25 |
+
"- generated_from_trainer\n",
|
| 26 |
+
"model-index:\n",
|
| 27 |
+
"- name: fine-tuned\n",
|
| 28 |
+
" results: []\n",
|
| 29 |
+
"---\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# Model Card\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"This model was fine-tuned using PEFT/LoRA.\"\"\"\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"if not template_file.exists():\n",
|
| 36 |
+
" with open(template_file, \"w\") as f:\n",
|
| 37 |
+
" f.write(template_content)\n",
|
| 38 |
+
" print(\"✅ 自定义模板已创建\")\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"# === 3. 强制禁用默认模板 ===\n",
|
| 41 |
+
"import huggingface_hub.repocard as hf_card\n",
|
| 42 |
+
"def _disable_default_template():\n",
|
| 43 |
+
" # 禁用默认模板加载\n",
|
| 44 |
+
" hf_card.ModelCard._default_template_path = None\n",
|
| 45 |
+
" # 重定向模板加载方法\n",
|
| 46 |
+
" hf_card.ModelCard.from_template = classmethod(\n",
|
| 47 |
+
" lambda cls, template_path=None, **kwargs: cls(template_content))\n",
|
| 48 |
+
" \n",
|
| 49 |
+
"_disable_default_template()\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# === 4. 强化猴子补丁 ===\n",
|
| 52 |
+
"# === 正确应用猴子补丁的方法 ===\n",
|
| 53 |
+
"from functools import wraps\n",
|
| 54 |
+
"from transformers import Trainer\n",
|
| 55 |
+
"import shutil\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"# 1. 保存原始方法到不同变量名\n",
|
| 58 |
+
"_original_trainer_save = Trainer.save_model # 使用唯一变量名\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"@wraps(_original_trainer_save)\n",
|
| 61 |
+
"def _patched_save_model(self, output_dir, **kwargs):\n",
|
| 62 |
+
" # 调用原始保存方法(使用保存的引用)\n",
|
| 63 |
+
" _original_trainer_save(self, output_dir, **kwargs)\n",
|
| 64 |
+
" \n",
|
| 65 |
+
" # 原子化写入模板\n",
|
| 66 |
+
" output_path = Path(output_dir)\n",
|
| 67 |
+
" temp_file = output_path / \"README.tmp\"\n",
|
| 68 |
+
" target_file = output_path / \"README.md\"\n",
|
| 69 |
+
" \n",
|
| 70 |
+
" try:\n",
|
| 71 |
+
" # 写入临时文件\n",
|
| 72 |
+
" with open(temp_file, \"w\") as f:\n",
|
| 73 |
+
" f.write(template_content)\n",
|
| 74 |
+
" # 原子替换\n",
|
| 75 |
+
" temp_file.rename(target_file)\n",
|
| 76 |
+
" except Exception as e:\n",
|
| 77 |
+
" print(f\"⚠️ 模板写入失败: {str(e)}\")\n",
|
| 78 |
+
" \n",
|
| 79 |
+
" print(f\"✅ 检查点 {output_path.name} 保存完成\")\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# 2. 应用补丁前确保变量未被覆盖\n",
|
| 82 |
+
"if not hasattr(Trainer, '_hf_original_save_model'):\n",
|
| 83 |
+
" Trainer._hf_original_save_model = Trainer.save_model\n",
|
| 84 |
+
" Trainer.save_model = _patched_save_model\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# === 5. 验证 Hugging Face 配置 ===\n",
|
| 87 |
+
"import huggingface_hub\n",
|
| 88 |
+
"print(f\"当前模板路径: {os.environ.get('HUGGINGFACE_HUB_TEMPLATES_PATH', '未设置')}\")\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# === 剩余代码保持不变,但需确保以下顺序 ===\n",
|
| 92 |
+
"import torch\n",
|
| 93 |
+
"from huggingface_hub import login\n",
|
| 94 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments\n",
|
| 95 |
+
"from peft import LoraConfig\n",
|
| 96 |
+
"from trl import SFTTrainer"
|
| 97 |
+
]
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"metadata": {
|
| 101 |
+
"language_info": {
|
| 102 |
+
"name": "python"
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
"nbformat": 4,
|
| 106 |
+
"nbformat_minor": 5
|
| 107 |
+
}
|