| import sys | |
| import logging | |
| import os | |
| import datasets | |
| from datasets import load_dataset | |
| import torch | |
| import transformers | |
| from trl import SFTTrainer, SFTConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig | |
| from typing import Dict, List | |
| ########################## | |
| # Convert to BF16 format | |
| ########################## | |
| # 重新以bf16格式加载最后的checkpoint | |
| model_bf16 = AutoModelForCausalLM.from_pretrained( | |
| './checkpoint-3457', | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cpu" | |
| ) | |
| # 保存为bf16格式 | |
| model_bf16.save_pretrained('./', torch_dtype=torch.bfloat16) | |
| # 清理内存 | |
| del model_bf16 | |
| torch.cuda.empty_cache() |