Update train.py
Browse files
train.py
CHANGED
|
@@ -9,6 +9,9 @@ import torch
|
|
| 9 |
|
| 10 |
class ModelTrainer:
|
| 11 |
def __init__(self, model_id, system_prompts_path):
|
|
|
|
|
|
|
|
|
|
| 12 |
self.model_id = model_id
|
| 13 |
|
| 14 |
# 加载系统提示词
|
|
@@ -27,7 +30,9 @@ class ModelTrainer:
|
|
| 27 |
trust_remote_code=True,
|
| 28 |
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
|
| 29 |
device_map='auto', # 自动选择设备
|
| 30 |
-
low_cpu_mem_usage=True
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
# 使用更轻量的LoRA配置
|
|
|
|
| 9 |
|
| 10 |
class ModelTrainer:
|
| 11 |
def __init__(self, model_id, system_prompts_path):
|
| 12 |
+
# 确保临时文件夹存在
|
| 13 |
+
os.makedirs("temp_model_dir", exist_ok=True)
|
| 14 |
+
|
| 15 |
self.model_id = model_id
|
| 16 |
|
| 17 |
# 加载系统提示词
|
|
|
|
| 30 |
trust_remote_code=True,
|
| 31 |
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
|
| 32 |
device_map='auto', # 自动选择设备
|
| 33 |
+
low_cpu_mem_usage=True,
|
| 34 |
+
offload_folder="temp_model_dir", # 添加临时文件夹
|
| 35 |
+
use_safetensors=True # 使用 safetensors
|
| 36 |
)
|
| 37 |
|
| 38 |
# 使用更轻量的LoRA配置
|