Upload main.py with huggingface_hub
Browse files
main.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import deepspeed
|
| 3 |
+
|
| 4 |
+
parser = argparse.ArgumentParser(description='sp')
|
| 5 |
+
parser.add_argument('--basepath', type=str, default='/home/lyh/weights/hf/llama31chat/8B/')
|
| 6 |
+
parser.add_argument('--trainpath', type=str,
|
| 7 |
+
default="/home/lyh/code/nlp/developing/vllmbase/vllm/gedata/l318b.jsonl")
|
| 8 |
+
parser.add_argument('--testpath', type=str,
|
| 9 |
+
default="/home/lyh/code/nlp/developing/vllmbase/vllm/gedata/0318.json")
|
| 10 |
+
parser.add_argument('--savedir', type=str, default='0')
|
| 11 |
+
parser.add_argument('--model_type', type=str, default='llama', choices=['llama', 'qwen3'],
|
| 12 |
+
help="Model architecture type: 'llama' or 'qwen3'")
|
| 13 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
| 14 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
import json
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
deepspeed_config = args.deepspeed_config
|
| 20 |
+
with open(deepspeed_config) as f:
|
| 21 |
+
ds_config = json.load(f)
|
| 22 |
+
|
| 23 |
+
# [MODIFIED] Select config path based on model_type
|
| 24 |
+
config_path_map = {
|
| 25 |
+
'llama': 'config.json',
|
| 26 |
+
'qwen3': 'config_qwen3.json'
|
| 27 |
+
}
|
| 28 |
+
config_path = config_path_map.get(args.model_type, 'config.json')
|
| 29 |
+
|
| 30 |
+
train_config = {
|
| 31 |
+
"bs": ds_config["train_micro_batch_size_per_gpu"],
|
| 32 |
+
"num_epochs": 15,
|
| 33 |
+
"num_workers": 16,
|
| 34 |
+
"max_len": 1536,
|
| 35 |
+
"config_path": config_path,
|
| 36 |
+
"gradient_checkpointing": False
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
from safetensors import safe_open
|
| 40 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 41 |
+
import os
|
| 42 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
| 43 |
+
import torch
|
| 44 |
+
from cnets import padding
|
| 45 |
+
|
| 46 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 47 |
+
from accelerate.utils import set_seed
|
| 48 |
+
|
| 49 |
+
set_seed(0)
|
| 50 |
+
from cnets import Model
|
| 51 |
+
from configs import EConfig
|
| 52 |
+
from datasets import load_dataset
|
| 53 |
+
from dataclasses import dataclass, field
|
| 54 |
+
from typing import Any, Dict, List, Optional, Union
|
| 55 |
+
|
| 56 |
+
from torch import nn, optim
|
| 57 |
+
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
| 58 |
+
from tqdm import tqdm
|
| 59 |
+
# import accelerate
|
| 60 |
+
import numpy as np
|
| 61 |
+
from transformers import PreTrainedTokenizerBase, get_linear_schedule_with_warmup
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def build_dataset_rank(
|
| 66 |
+
tokenizer, datapath, model_type='llama'
|
| 67 |
+
):
|
| 68 |
+
|
| 69 |
+
ds = load_dataset('json', data_files=datapath)
|
| 70 |
+
ds = ds['train']
|
| 71 |
+
ds = ds.shuffle(seed=42)
|
| 72 |
+
ds1 = ds
|
| 73 |
+
original_columns1 = ds1.column_names
|
| 74 |
+
num_proc = 8
|
| 75 |
+
|
| 76 |
+
# [MODIFIED] Auto-detect chat format from conversation string
|
| 77 |
+
# Will be set dynamically in preprocess_function based on actual format
|
| 78 |
+
|
| 79 |
+
def preprocess_function(examples):
|
| 80 |
+
new_examples = {
|
| 81 |
+
"attention_mask": [],
|
| 82 |
+
"input_ids": [],
|
| 83 |
+
"loss_mask": []
|
| 84 |
+
}
|
| 85 |
+
for i in range(len(examples['id'])):
|
| 86 |
+
messages = [
|
| 87 |
+
{"role": "system",
|
| 88 |
+
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
|
| 89 |
+
]
|
| 90 |
+
convroles = ["user", "assistant"]
|
| 91 |
+
roles = {"human": "user", "gpt": "assistant"}
|
| 92 |
+
source = examples['conversations'][i]
|
| 93 |
+
if not source:
|
| 94 |
+
continue
|
| 95 |
+
if roles[source[0]["from"]] != "user":
|
| 96 |
+
# Skip the first one if it is not from human
|
| 97 |
+
source = source[1:]
|
| 98 |
+
for j, sentence in enumerate(source):
|
| 99 |
+
role = roles[sentence["from"]]
|
| 100 |
+
assert role == convroles[j % 2], f"{i}"
|
| 101 |
+
# if sentence["from"]=="gpt":
|
| 102 |
+
# sentence["value"]=" "+sentence["value"]
|
| 103 |
+
messages.append(
|
| 104 |
+
{"role": role, "content": sentence["value"]}
|
| 105 |
+
)
|
| 106 |
+
# Try to use tokenizer's chat template, fallback to manual ChatML formatting
|
| 107 |
+
try:
|
| 108 |
+
conversation = tokenizer.apply_chat_template(
|
| 109 |
+
messages,
|
| 110 |
+
tokenize=False,
|
| 111 |
+
add_generation_prompt=False,
|
| 112 |
+
)
|
| 113 |
+
except (ValueError, AttributeError):
|
| 114 |
+
# Manually format as ChatML (used by Qwen and many others)
|
| 115 |
+
conversation = ""
|
| 116 |
+
for msg in messages:
|
| 117 |
+
role = msg["role"]
|
| 118 |
+
content = msg["content"]
|
| 119 |
+
conversation += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
| 120 |
+
|
| 121 |
+
if not tokenizer.pad_token_id:
|
| 122 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 123 |
+
|
| 124 |
+
input_ids = tokenizer(
|
| 125 |
+
conversation,
|
| 126 |
+
return_tensors="pt",
|
| 127 |
+
add_special_tokens=False,
|
| 128 |
+
).input_ids[0]
|
| 129 |
+
# filtering out the samples which is longer than max_len
|
| 130 |
+
if len(input_ids) > train_config["max_len"]:
|
| 131 |
+
continue
|
| 132 |
+
loss_mask = torch.ones_like(input_ids)
|
| 133 |
+
# print(i)
|
| 134 |
+
|
| 135 |
+
total_len = len(input_ids)
|
| 136 |
+
|
| 137 |
+
# Auto-detect format and set separators
|
| 138 |
+
if "<|im_start|>" in conversation and "<|im_end|>" in conversation:
|
| 139 |
+
# ChatML format (Qwen, default fallback)
|
| 140 |
+
sep = "<|im_end|>\n<|im_start|>assistant\n"
|
| 141 |
+
sep2 = "<|im_end|>\n<|im_start|>user\n"
|
| 142 |
+
elif "<|eot_id|>" in conversation:
|
| 143 |
+
# LLaMA-3 format
|
| 144 |
+
sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 145 |
+
sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
|
| 146 |
+
else:
|
| 147 |
+
# Unknown format, skip this sample
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
turns = conversation.split(sep2)
|
| 151 |
+
|
| 152 |
+
# [MODIFIED] Skip samples with invalid conversation structure
|
| 153 |
+
if len(turns) < 2:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
turns[1] = turns[0] + sep2 + turns[1]
|
| 157 |
+
turns = turns[1:]
|
| 158 |
+
|
| 159 |
+
cur_len = 1
|
| 160 |
+
loss_mask[:cur_len] = 0
|
| 161 |
+
for i, turn in enumerate(turns):
|
| 162 |
+
if turn == "":
|
| 163 |
+
break
|
| 164 |
+
turn_len = len(tokenizer(turn).input_ids)
|
| 165 |
+
|
| 166 |
+
parts = turn.split(sep)
|
| 167 |
+
if len(parts) != 2:
|
| 168 |
+
break
|
| 169 |
+
parts[0] += sep
|
| 170 |
+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
|
| 171 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
| 172 |
+
|
| 173 |
+
# Ignore the user instructions
|
| 174 |
+
if i == 0:
|
| 175 |
+
loss_mask[cur_len: cur_len + instruction_len - 2] = 0
|
| 176 |
+
else:
|
| 177 |
+
loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
|
| 178 |
+
cur_len += turn_len
|
| 179 |
+
if i != 0:
|
| 180 |
+
cur_len += 3
|
| 181 |
+
# cur_len+=2
|
| 182 |
+
|
| 183 |
+
# if i != 0 and not tokenizer.legacy:
|
| 184 |
+
# # The legacy and non-legacy modes handle special tokens differently
|
| 185 |
+
# cur_len -= 1
|
| 186 |
+
|
| 187 |
+
loss_mask[cur_len:] = 0
|
| 188 |
+
attention_mask = torch.ones_like(loss_mask)
|
| 189 |
+
|
| 190 |
+
# new_examples["conversation"].append(conversation)
|
| 191 |
+
new_examples["input_ids"].append(input_ids[None, :])
|
| 192 |
+
new_examples["loss_mask"].append(loss_mask[None, :])
|
| 193 |
+
new_examples["attention_mask"].append(attention_mask[None, :])
|
| 194 |
+
|
| 195 |
+
return new_examples
|
| 196 |
+
|
| 197 |
+
ds1 = ds1.map(
|
| 198 |
+
preprocess_function,
|
| 199 |
+
batched=True,
|
| 200 |
+
num_proc=num_proc,
|
| 201 |
+
remove_columns=original_columns1,
|
| 202 |
+
load_from_cache_file=False
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
ds1.set_format(type="torch")
|
| 207 |
+
return ds1
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DataCollatorWithPadding:
|
| 211 |
+
|
| 212 |
+
def paddingtensor(self, intensors, N):
|
| 213 |
+
B, n, S = intensors.shape
|
| 214 |
+
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
|
| 215 |
+
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
|
| 216 |
+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
|
| 217 |
+
return outtensors
|
| 218 |
+
|
| 219 |
+
def paddingtensor2D(self, intensors, N):
|
| 220 |
+
B, n = intensors.shape
|
| 221 |
+
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
|
| 222 |
+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
|
| 223 |
+
return outtensors
|
| 224 |
+
|
| 225 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 226 |
+
max_length = max(item['input_ids'].shape[1] for item in features)
|
| 227 |
+
batch_input_ids = torch.cat([self.paddingtensor2D(item['input_ids'], max_length) for item in features])
|
| 228 |
+
batch_attention_mask = torch.cat(
|
| 229 |
+
[self.paddingtensor2D(item['attention_mask'], max_length) for item in features])
|
| 230 |
+
batch_loss_mask = torch.cat(
|
| 231 |
+
[self.paddingtensor2D(item['loss_mask'], max_length) for item in features])
|
| 232 |
+
|
| 233 |
+
batch = {
|
| 234 |
+
"input_ids": batch_input_ids,
|
| 235 |
+
"attention_mask": batch_attention_mask,
|
| 236 |
+
"loss_mask": batch_loss_mask,
|
| 237 |
+
}
|
| 238 |
+
return batch
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
tokenizer = AutoTokenizer.from_pretrained(args.basepath)
|
| 242 |
+
# [MODIFIED] Pass model_type to build_dataset_rank
|
| 243 |
+
traindataset = build_dataset_rank(tokenizer, args.trainpath, model_type=args.model_type)
|
| 244 |
+
testdataset = build_dataset_rank(tokenizer, args.testpath, model_type=args.model_type)
|
| 245 |
+
|
| 246 |
+
config = EConfig.from_pretrained(train_config["config_path"])
|
| 247 |
+
# [MODIFIED] Pass model_type to Model
|
| 248 |
+
model = Model(config, ds_config, train_config, path=args.basepath, load_emb=True, load_head=True, model_type=args.model_type)
|
| 249 |
+
model.scandata(args.trainpath, args.basepath)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
criterion = nn.SmoothL1Loss(reduction="none")
|
| 253 |
+
|
| 254 |
+
num_epochs = train_config["num_epochs"]
|
| 255 |
+
|
| 256 |
+
model_engine, optimizer, _, _ = deepspeed.initialize(args=args,
|
| 257 |
+
model=model,
|
| 258 |
+
model_parameters=model.parameters(),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
global_rank = deepspeed.comm.get_rank()
|
| 262 |
+
rank = deepspeed.comm.get_local_rank()
|
| 263 |
+
world_size = deepspeed.comm.get_world_size()
|
| 264 |
+
if global_rank == 0:
|
| 265 |
+
import wandb
|
| 266 |
+
|
| 267 |
+
wandb.login(key="dcac9b6b99c4203de2a920453357bc8ed55a5baf")
|
| 268 |
+
wandb.init(project="qwen3", entity="model-acceleration", config=ds_config)
|
| 269 |
+
|
| 270 |
+
os.makedirs(args.savedir, exist_ok=True)
|
| 271 |
+
|
| 272 |
+
sampler = DistributedSampler(testdataset, num_replicas=world_size, rank=global_rank, shuffle=False)
|
| 273 |
+
test_loader = DataLoader(testdataset, batch_size=train_config["bs"], sampler=sampler, num_workers=4, pin_memory=True,
|
| 274 |
+
collate_fn=DataCollatorWithPadding())
|
| 275 |
+
|
| 276 |
+
train_sampler = DistributedSampler(traindataset, num_replicas=world_size, rank=global_rank, shuffle=True)
|
| 277 |
+
train_loader = DataLoader(traindataset, batch_size=train_config["bs"], sampler=train_sampler, num_workers=4,
|
| 278 |
+
pin_memory=True,
|
| 279 |
+
collate_fn=DataCollatorWithPadding())
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def find_max_state_with_file(directory, filename="zero_to_fp32.py"):
|
| 283 |
+
max_a = -1
|
| 284 |
+
for subdir in os.listdir(directory):
|
| 285 |
+
match = re.match(r"state_(\d+)", subdir)
|
| 286 |
+
if match:
|
| 287 |
+
a_value = int(match.group(1))
|
| 288 |
+
subdir_path = os.path.join(directory, subdir)
|
| 289 |
+
file_path = os.path.join(subdir_path, filename)
|
| 290 |
+
if os.path.isdir(subdir_path) and os.path.exists(file_path):
|
| 291 |
+
max_a = max(max_a, a_value)
|
| 292 |
+
if max_a == -1:
|
| 293 |
+
return None, 0
|
| 294 |
+
return f"{directory}/state_{max_a}", max_a + 1
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
checkpoint_path, start_epoch = find_max_state_with_file(args.savedir)
|
| 298 |
+
if checkpoint_path:
|
| 299 |
+
print(f"load from {checkpoint_path}")
|
| 300 |
+
model_engine.load_checkpoint(checkpoint_path)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
for epoch in range(start_epoch, num_epochs):
|
| 305 |
+
train_sampler.set_epoch(epoch+1)
|
| 306 |
+
print(f"Now training epoch {epoch}")
|
| 307 |
+
|
| 308 |
+
model.train()
|
| 309 |
+
epoch_acces = [[] for _ in range(model.length)]
|
| 310 |
+
epoch_plosses = [[] for _ in range(model.length)]
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
for batch_idx, data in enumerate(tqdm(train_loader)):
|
| 314 |
+
|
| 315 |
+
model.zero_grad()
|
| 316 |
+
|
| 317 |
+
plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
|
| 318 |
+
attention_mask=data["attention_mask"].to(rank),
|
| 319 |
+
loss_mask=data["loss_mask"],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
ploss_weight = [0.8 ** i for i in range(len(plosses))]
|
| 323 |
+
ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
|
| 324 |
+
loss = ploss
|
| 325 |
+
model_engine.backward(loss)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
model_engine.step()
|
| 329 |
+
|
| 330 |
+
if global_rank == 0:
|
| 331 |
+
logdict = {"train/lr": optimizer.optimizer.param_groups[0]["lr"]}
|
| 332 |
+
for i in range(len(plosses)):
|
| 333 |
+
logdict[f"train/ploss_{i}"] = plosses[i].item()
|
| 334 |
+
for i in range(len(acces)):
|
| 335 |
+
logdict[f"train/acc_{i}"] = acces[i]
|
| 336 |
+
wandb.log(logdict)
|
| 337 |
+
epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
|
| 338 |
+
epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
for i in range(len(epoch_acces)):
|
| 342 |
+
acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
|
| 343 |
+
deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
|
| 344 |
+
acc_i = acc_i.item()
|
| 345 |
+
if global_rank == 0:
|
| 346 |
+
wandb.log({f"train/epochacc_{i}": acc_i})
|
| 347 |
+
print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i}, Acc: {acc_i:.2f}")
|
| 348 |
+
|
| 349 |
+
for i in range(len(epoch_plosses)):
|
| 350 |
+
loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
|
| 351 |
+
deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
|
| 352 |
+
loss_i = loss_i.item()
|
| 353 |
+
if global_rank == 0:
|
| 354 |
+
wandb.log({f"train/epochploss_{i}": loss_i})
|
| 355 |
+
print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")
|
| 356 |
+
|
| 357 |
+
epoch_acces = [[] for _ in range(model.length)]
|
| 358 |
+
epoch_plosses = [[] for _ in range(model.length)]
|
| 359 |
+
|
| 360 |
+
for batch_idx, data in enumerate(tqdm(test_loader)):
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
|
| 363 |
+
attention_mask=data["attention_mask"].to(rank),
|
| 364 |
+
loss_mask=data["loss_mask"],
|
| 365 |
+
)
|
| 366 |
+
epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
|
| 367 |
+
epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]
|
| 368 |
+
|
| 369 |
+
for i in range(len(epoch_acces)):
|
| 370 |
+
acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
|
| 371 |
+
deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
|
| 372 |
+
acc_i = acc_i.item()
|
| 373 |
+
if global_rank == 0:
|
| 374 |
+
wandb.log({f"test/epochacc_{i}": acc_i})
|
| 375 |
+
print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, Acc: {acc_i:.2f}")
|
| 376 |
+
|
| 377 |
+
for i in range(len(epoch_plosses)):
|
| 378 |
+
loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
|
| 379 |
+
deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
|
| 380 |
+
loss_i = loss_i.item()
|
| 381 |
+
if global_rank == 0:
|
| 382 |
+
wandb.log({f"test/epochploss_{i}": loss_i})
|
| 383 |
+
print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")
|
| 384 |
+
# clear out the redundance cahce after each step
|
| 385 |
+
torch.cuda.empty_cache()
|
| 386 |
+
|
| 387 |
+
# 매 epoch마다 체크포인트 저장 (학습 재개 가능하도록)
|
| 388 |
+
model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True)
|
| 389 |
+
deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}")
|
| 390 |
+
|
| 391 |
+
# 디스크 공간 절약: 오래된 체크포인트 삭제 (최근 3개만 유지)
|
| 392 |
+
if global_rank == 0 and epoch > 2:
|
| 393 |
+
old_checkpoint = f"{args.savedir}/state_{epoch - 3}"
|
| 394 |
+
if os.path.exists(old_checkpoint):
|
| 395 |
+
import shutil
|
| 396 |
+
shutil.rmtree(old_checkpoint)
|
| 397 |
+
print(f"Removed old checkpoint: {old_checkpoint}")
|