psychologyphd commited on
Commit
68278de
·
verified ·
1 Parent(s): 898b1ed

Upload setup_fine_tune_mps.ipynb

Browse files
Files changed (1) hide show
  1. setup_fine_tune_mps.ipynb +107 -0
setup_fine_tune_mps.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "5aa6bfeb",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# === 1. 在所有导入之前设置环境变量 ===\n",
11
+ "import os\n",
12
+ "from pathlib import Path\n",
13
+ "\n",
14
+ "# 强制指定模板路径\n",
15
+ "custom_template_dir = Path(\"./custom_templates\").resolve()\n",
16
+ "os.environ[\"HUGGINGFACE_HUB_TEMPLATES_PATH\"] = str(custom_template_dir)\n",
17
+ "os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n",
18
+ "\n",
19
+ "# === 2. 初始化模板 ===\n",
20
+ "custom_template_dir.mkdir(parents=True, exist_ok=True)\n",
21
+ "template_file = custom_template_dir / \"modelcard_template.md\"\n",
22
+ "template_content = \"\"\"---\n",
23
+ "language: en\n",
24
+ "tags:\n",
25
+ "- generated_from_trainer\n",
26
+ "model-index:\n",
27
+ "- name: fine-tuned\n",
28
+ " results: []\n",
29
+ "---\n",
30
+ "\n",
31
+ "# Model Card\n",
32
+ "\n",
33
+ "This model was fine-tuned using PEFT/LoRA.\"\"\"\n",
34
+ "\n",
35
+ "if not template_file.exists():\n",
36
+ " with open(template_file, \"w\") as f:\n",
37
+ " f.write(template_content)\n",
38
+ " print(\"✅ 自定义模板已创建\")\n",
39
+ "\n",
40
+ "# === 3. 强制禁用默认模板 ===\n",
41
+ "import huggingface_hub.repocard as hf_card\n",
42
+ "def _disable_default_template():\n",
43
+ " # 禁用默认模板加载\n",
44
+ " hf_card.ModelCard._default_template_path = None\n",
45
+ " # 重定向模板加载方法\n",
46
+ " hf_card.ModelCard.from_template = classmethod(\n",
47
+ " lambda cls, template_path=None, **kwargs: cls(template_content))\n",
48
+ " \n",
49
+ "_disable_default_template()\n",
50
+ "\n",
51
+ "# === 4. 强化猴子补丁 ===\n",
52
+ "# === 正确应用猴子补丁的方法 ===\n",
53
+ "from functools import wraps\n",
54
+ "from transformers import Trainer\n",
55
+ "import shutil\n",
56
+ "\n",
57
+ "# 1. 保存原始方法到不同变量名\n",
58
+ "_original_trainer_save = Trainer.save_model # 使用唯一变量名\n",
59
+ "\n",
60
+ "@wraps(_original_trainer_save)\n",
61
+ "def _patched_save_model(self, output_dir, **kwargs):\n",
62
+ " # 调用原始保存方法(使用保存的引用)\n",
63
+ " _original_trainer_save(self, output_dir, **kwargs)\n",
64
+ " \n",
65
+ " # 原子化写入模板\n",
66
+ " output_path = Path(output_dir)\n",
67
+ " temp_file = output_path / \"README.tmp\"\n",
68
+ " target_file = output_path / \"README.md\"\n",
69
+ " \n",
70
+ " try:\n",
71
+ " # 写入临时文件\n",
72
+ " with open(temp_file, \"w\") as f:\n",
73
+ " f.write(template_content)\n",
74
+ " # 原子替换\n",
75
+ " temp_file.rename(target_file)\n",
76
+ " except Exception as e:\n",
77
+ " print(f\"⚠️ 模板写入失败: {str(e)}\")\n",
78
+ " \n",
79
+ " print(f\"✅ 检查点 {output_path.name} 保存完成\")\n",
80
+ "\n",
81
+ "# 2. 应用补丁前确保变量未被覆盖\n",
82
+ "if not hasattr(Trainer, '_hf_original_save_model'):\n",
83
+ " Trainer._hf_original_save_model = Trainer.save_model\n",
84
+ " Trainer.save_model = _patched_save_model\n",
85
+ "\n",
86
+ "# === 5. 验证 Hugging Face 配置 ===\n",
87
+ "import huggingface_hub\n",
88
+ "print(f\"当前模板路径: {os.environ.get('HUGGINGFACE_HUB_TEMPLATES_PATH', '未设置')}\")\n",
89
+ "\n",
90
+ "\n",
91
+ "# === 剩余代码保持不变,但需确保以下顺序 ===\n",
92
+ "import torch\n",
93
+ "from huggingface_hub import login\n",
94
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments\n",
95
+ "from peft import LoraConfig\n",
96
+ "from trl import SFTTrainer"
97
+ ]
98
+ }
99
+ ],
100
+ "metadata": {
101
+ "language_info": {
102
+ "name": "python"
103
+ }
104
+ },
105
+ "nbformat": 4,
106
+ "nbformat_minor": 5
107
+ }