AHA-OLMO2 / tobf16.py
xuan-luo's picture
Upload folder using huggingface_hub
36e3d1b verified
import sys
import logging
import os
import datasets
from datasets import load_dataset
import torch
import transformers
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from typing import Dict, List
##########################
# Convert to BF16 format
##########################
# 重新以bf16格式加载最后的checkpoint
model_bf16 = AutoModelForCausalLM.from_pretrained(
'./checkpoint-3457',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
# 保存为bf16格式
model_bf16.save_pretrained('./', torch_dtype=torch.bfloat16)
# 清理内存
del model_bf16
torch.cuda.empty_cache()