appleeji commited on
Commit
dfd3147
·
verified ·
1 Parent(s): 38849f5

Upload main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +397 -0
main.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import deepspeed
3
+
4
+ parser = argparse.ArgumentParser(description='sp')
5
+ parser.add_argument('--basepath', type=str, default='/home/lyh/weights/hf/llama31chat/8B/')
6
+ parser.add_argument('--trainpath', type=str,
7
+ default="/home/lyh/code/nlp/developing/vllmbase/vllm/gedata/l318b.jsonl")
8
+ parser.add_argument('--testpath', type=str,
9
+ default="/home/lyh/code/nlp/developing/vllmbase/vllm/gedata/0318.json")
10
+ parser.add_argument('--savedir', type=str, default='0')
11
+ parser.add_argument('--model_type', type=str, default='llama', choices=['llama', 'qwen3'],
12
+ help="Model architecture type: 'llama' or 'qwen3'")
13
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
14
+ parser = deepspeed.add_config_arguments(parser)
15
+ args = parser.parse_args()
16
+ import json
17
+ import re
18
+
19
+ deepspeed_config = args.deepspeed_config
20
+ with open(deepspeed_config) as f:
21
+ ds_config = json.load(f)
22
+
23
+ # [MODIFIED] Select config path based on model_type
24
+ config_path_map = {
25
+ 'llama': 'config.json',
26
+ 'qwen3': 'config_qwen3.json'
27
+ }
28
+ config_path = config_path_map.get(args.model_type, 'config.json')
29
+
30
+ train_config = {
31
+ "bs": ds_config["train_micro_batch_size_per_gpu"],
32
+ "num_epochs": 15,
33
+ "num_workers": 16,
34
+ "max_len": 1536,
35
+ "config_path": config_path,
36
+ "gradient_checkpointing": False
37
+ }
38
+
39
+ from safetensors import safe_open
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer
41
+ import os
42
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
43
+ import torch
44
+ from cnets import padding
45
+
46
+ torch.backends.cuda.matmul.allow_tf32 = True
47
+ from accelerate.utils import set_seed
48
+
49
+ set_seed(0)
50
+ from cnets import Model
51
+ from configs import EConfig
52
+ from datasets import load_dataset
53
+ from dataclasses import dataclass, field
54
+ from typing import Any, Dict, List, Optional, Union
55
+
56
+ from torch import nn, optim
57
+ from torch.utils.data import Dataset, DataLoader, DistributedSampler
58
+ from tqdm import tqdm
59
+ # import accelerate
60
+ import numpy as np
61
+ from transformers import PreTrainedTokenizerBase, get_linear_schedule_with_warmup
62
+
63
+
64
+
65
+ def build_dataset_rank(
66
+ tokenizer, datapath, model_type='llama'
67
+ ):
68
+
69
+ ds = load_dataset('json', data_files=datapath)
70
+ ds = ds['train']
71
+ ds = ds.shuffle(seed=42)
72
+ ds1 = ds
73
+ original_columns1 = ds1.column_names
74
+ num_proc = 8
75
+
76
+ # [MODIFIED] Auto-detect chat format from conversation string
77
+ # Will be set dynamically in preprocess_function based on actual format
78
+
79
+ def preprocess_function(examples):
80
+ new_examples = {
81
+ "attention_mask": [],
82
+ "input_ids": [],
83
+ "loss_mask": []
84
+ }
85
+ for i in range(len(examples['id'])):
86
+ messages = [
87
+ {"role": "system",
88
+ "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
89
+ ]
90
+ convroles = ["user", "assistant"]
91
+ roles = {"human": "user", "gpt": "assistant"}
92
+ source = examples['conversations'][i]
93
+ if not source:
94
+ continue
95
+ if roles[source[0]["from"]] != "user":
96
+ # Skip the first one if it is not from human
97
+ source = source[1:]
98
+ for j, sentence in enumerate(source):
99
+ role = roles[sentence["from"]]
100
+ assert role == convroles[j % 2], f"{i}"
101
+ # if sentence["from"]=="gpt":
102
+ # sentence["value"]=" "+sentence["value"]
103
+ messages.append(
104
+ {"role": role, "content": sentence["value"]}
105
+ )
106
+ # Try to use tokenizer's chat template, fallback to manual ChatML formatting
107
+ try:
108
+ conversation = tokenizer.apply_chat_template(
109
+ messages,
110
+ tokenize=False,
111
+ add_generation_prompt=False,
112
+ )
113
+ except (ValueError, AttributeError):
114
+ # Manually format as ChatML (used by Qwen and many others)
115
+ conversation = ""
116
+ for msg in messages:
117
+ role = msg["role"]
118
+ content = msg["content"]
119
+ conversation += f"<|im_start|>{role}\n{content}<|im_end|>\n"
120
+
121
+ if not tokenizer.pad_token_id:
122
+ tokenizer.pad_token_id = tokenizer.unk_token_id
123
+
124
+ input_ids = tokenizer(
125
+ conversation,
126
+ return_tensors="pt",
127
+ add_special_tokens=False,
128
+ ).input_ids[0]
129
+ # filtering out the samples which is longer than max_len
130
+ if len(input_ids) > train_config["max_len"]:
131
+ continue
132
+ loss_mask = torch.ones_like(input_ids)
133
+ # print(i)
134
+
135
+ total_len = len(input_ids)
136
+
137
+ # Auto-detect format and set separators
138
+ if "<|im_start|>" in conversation and "<|im_end|>" in conversation:
139
+ # ChatML format (Qwen, default fallback)
140
+ sep = "<|im_end|>\n<|im_start|>assistant\n"
141
+ sep2 = "<|im_end|>\n<|im_start|>user\n"
142
+ elif "<|eot_id|>" in conversation:
143
+ # LLaMA-3 format
144
+ sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
145
+ sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
146
+ else:
147
+ # Unknown format, skip this sample
148
+ continue
149
+
150
+ turns = conversation.split(sep2)
151
+
152
+ # [MODIFIED] Skip samples with invalid conversation structure
153
+ if len(turns) < 2:
154
+ continue
155
+
156
+ turns[1] = turns[0] + sep2 + turns[1]
157
+ turns = turns[1:]
158
+
159
+ cur_len = 1
160
+ loss_mask[:cur_len] = 0
161
+ for i, turn in enumerate(turns):
162
+ if turn == "":
163
+ break
164
+ turn_len = len(tokenizer(turn).input_ids)
165
+
166
+ parts = turn.split(sep)
167
+ if len(parts) != 2:
168
+ break
169
+ parts[0] += sep
170
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
171
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
172
+
173
+ # Ignore the user instructions
174
+ if i == 0:
175
+ loss_mask[cur_len: cur_len + instruction_len - 2] = 0
176
+ else:
177
+ loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
178
+ cur_len += turn_len
179
+ if i != 0:
180
+ cur_len += 3
181
+ # cur_len+=2
182
+
183
+ # if i != 0 and not tokenizer.legacy:
184
+ # # The legacy and non-legacy modes handle special tokens differently
185
+ # cur_len -= 1
186
+
187
+ loss_mask[cur_len:] = 0
188
+ attention_mask = torch.ones_like(loss_mask)
189
+
190
+ # new_examples["conversation"].append(conversation)
191
+ new_examples["input_ids"].append(input_ids[None, :])
192
+ new_examples["loss_mask"].append(loss_mask[None, :])
193
+ new_examples["attention_mask"].append(attention_mask[None, :])
194
+
195
+ return new_examples
196
+
197
+ ds1 = ds1.map(
198
+ preprocess_function,
199
+ batched=True,
200
+ num_proc=num_proc,
201
+ remove_columns=original_columns1,
202
+ load_from_cache_file=False
203
+ )
204
+
205
+
206
+ ds1.set_format(type="torch")
207
+ return ds1
208
+
209
+
210
+ class DataCollatorWithPadding:
211
+
212
+ def paddingtensor(self, intensors, N):
213
+ B, n, S = intensors.shape
214
+ # padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
215
+ padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
216
+ outtensors = torch.cat((intensors, padding_tensor), dim=1)
217
+ return outtensors
218
+
219
+ def paddingtensor2D(self, intensors, N):
220
+ B, n = intensors.shape
221
+ padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
222
+ outtensors = torch.cat((intensors, padding_tensor), dim=1)
223
+ return outtensors
224
+
225
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
226
+ max_length = max(item['input_ids'].shape[1] for item in features)
227
+ batch_input_ids = torch.cat([self.paddingtensor2D(item['input_ids'], max_length) for item in features])
228
+ batch_attention_mask = torch.cat(
229
+ [self.paddingtensor2D(item['attention_mask'], max_length) for item in features])
230
+ batch_loss_mask = torch.cat(
231
+ [self.paddingtensor2D(item['loss_mask'], max_length) for item in features])
232
+
233
+ batch = {
234
+ "input_ids": batch_input_ids,
235
+ "attention_mask": batch_attention_mask,
236
+ "loss_mask": batch_loss_mask,
237
+ }
238
+ return batch
239
+
240
+
241
+ tokenizer = AutoTokenizer.from_pretrained(args.basepath)
242
+ # [MODIFIED] Pass model_type to build_dataset_rank
243
+ traindataset = build_dataset_rank(tokenizer, args.trainpath, model_type=args.model_type)
244
+ testdataset = build_dataset_rank(tokenizer, args.testpath, model_type=args.model_type)
245
+
246
+ config = EConfig.from_pretrained(train_config["config_path"])
247
+ # [MODIFIED] Pass model_type to Model
248
+ model = Model(config, ds_config, train_config, path=args.basepath, load_emb=True, load_head=True, model_type=args.model_type)
249
+ model.scandata(args.trainpath, args.basepath)
250
+
251
+
252
+ criterion = nn.SmoothL1Loss(reduction="none")
253
+
254
+ num_epochs = train_config["num_epochs"]
255
+
256
+ model_engine, optimizer, _, _ = deepspeed.initialize(args=args,
257
+ model=model,
258
+ model_parameters=model.parameters(),
259
+ )
260
+
261
+ global_rank = deepspeed.comm.get_rank()
262
+ rank = deepspeed.comm.get_local_rank()
263
+ world_size = deepspeed.comm.get_world_size()
264
+ if global_rank == 0:
265
+ import wandb
266
+
267
+ wandb.login(key="dcac9b6b99c4203de2a920453357bc8ed55a5baf")
268
+ wandb.init(project="qwen3", entity="model-acceleration", config=ds_config)
269
+
270
+ os.makedirs(args.savedir, exist_ok=True)
271
+
272
+ sampler = DistributedSampler(testdataset, num_replicas=world_size, rank=global_rank, shuffle=False)
273
+ test_loader = DataLoader(testdataset, batch_size=train_config["bs"], sampler=sampler, num_workers=4, pin_memory=True,
274
+ collate_fn=DataCollatorWithPadding())
275
+
276
+ train_sampler = DistributedSampler(traindataset, num_replicas=world_size, rank=global_rank, shuffle=True)
277
+ train_loader = DataLoader(traindataset, batch_size=train_config["bs"], sampler=train_sampler, num_workers=4,
278
+ pin_memory=True,
279
+ collate_fn=DataCollatorWithPadding())
280
+
281
+
282
+ def find_max_state_with_file(directory, filename="zero_to_fp32.py"):
283
+ max_a = -1
284
+ for subdir in os.listdir(directory):
285
+ match = re.match(r"state_(\d+)", subdir)
286
+ if match:
287
+ a_value = int(match.group(1))
288
+ subdir_path = os.path.join(directory, subdir)
289
+ file_path = os.path.join(subdir_path, filename)
290
+ if os.path.isdir(subdir_path) and os.path.exists(file_path):
291
+ max_a = max(max_a, a_value)
292
+ if max_a == -1:
293
+ return None, 0
294
+ return f"{directory}/state_{max_a}", max_a + 1
295
+
296
+
297
+ checkpoint_path, start_epoch = find_max_state_with_file(args.savedir)
298
+ if checkpoint_path:
299
+ print(f"load from {checkpoint_path}")
300
+ model_engine.load_checkpoint(checkpoint_path)
301
+
302
+
303
+
304
+ for epoch in range(start_epoch, num_epochs):
305
+ train_sampler.set_epoch(epoch+1)
306
+ print(f"Now training epoch {epoch}")
307
+
308
+ model.train()
309
+ epoch_acces = [[] for _ in range(model.length)]
310
+ epoch_plosses = [[] for _ in range(model.length)]
311
+
312
+
313
+ for batch_idx, data in enumerate(tqdm(train_loader)):
314
+
315
+ model.zero_grad()
316
+
317
+ plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
318
+ attention_mask=data["attention_mask"].to(rank),
319
+ loss_mask=data["loss_mask"],
320
+ )
321
+
322
+ ploss_weight = [0.8 ** i for i in range(len(plosses))]
323
+ ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
324
+ loss = ploss
325
+ model_engine.backward(loss)
326
+
327
+
328
+ model_engine.step()
329
+
330
+ if global_rank == 0:
331
+ logdict = {"train/lr": optimizer.optimizer.param_groups[0]["lr"]}
332
+ for i in range(len(plosses)):
333
+ logdict[f"train/ploss_{i}"] = plosses[i].item()
334
+ for i in range(len(acces)):
335
+ logdict[f"train/acc_{i}"] = acces[i]
336
+ wandb.log(logdict)
337
+ epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
338
+ epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]
339
+
340
+
341
+ for i in range(len(epoch_acces)):
342
+ acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
343
+ deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
344
+ acc_i = acc_i.item()
345
+ if global_rank == 0:
346
+ wandb.log({f"train/epochacc_{i}": acc_i})
347
+ print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i}, Acc: {acc_i:.2f}")
348
+
349
+ for i in range(len(epoch_plosses)):
350
+ loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
351
+ deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
352
+ loss_i = loss_i.item()
353
+ if global_rank == 0:
354
+ wandb.log({f"train/epochploss_{i}": loss_i})
355
+ print(f"Train Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")
356
+
357
+ epoch_acces = [[] for _ in range(model.length)]
358
+ epoch_plosses = [[] for _ in range(model.length)]
359
+
360
+ for batch_idx, data in enumerate(tqdm(test_loader)):
361
+ with torch.no_grad():
362
+ plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
363
+ attention_mask=data["attention_mask"].to(rank),
364
+ loss_mask=data["loss_mask"],
365
+ )
366
+ epoch_acces = [epoch_acces[i] + [acces[i]] for i in range(len(acces))]
367
+ epoch_plosses = [epoch_plosses[i] + [plosses[i].item()] for i in range(len(plosses))]
368
+
369
+ for i in range(len(epoch_acces)):
370
+ acc_i = torch.tensor(epoch_acces[i]).cuda().mean()
371
+ deepspeed.comm.all_reduce(acc_i, op=deepspeed.comm.ReduceOp.AVG)
372
+ acc_i = acc_i.item()
373
+ if global_rank == 0:
374
+ wandb.log({f"test/epochacc_{i}": acc_i})
375
+ print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, Acc: {acc_i:.2f}")
376
+
377
+ for i in range(len(epoch_plosses)):
378
+ loss_i = torch.tensor(epoch_plosses[i]).cuda().mean()
379
+ deepspeed.comm.all_reduce(loss_i, op=deepspeed.comm.ReduceOp.AVG)
380
+ loss_i = loss_i.item()
381
+ if global_rank == 0:
382
+ wandb.log({f"test/epochploss_{i}": loss_i})
383
+ print(f"Test Epoch [{epoch + 1}/{num_epochs}], position {i}, pLoss: {loss_i:.2f}")
384
+ # clear out the redundance cahce after each step
385
+ torch.cuda.empty_cache()
386
+
387
+ # 매 epoch마다 체크포인트 저장 (학습 재개 가능하도록)
388
+ model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True)
389
+ deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}")
390
+
391
+ # 디스크 공간 절약: 오래된 체크포인트 삭제 (최근 3개만 유지)
392
+ if global_rank == 0 and epoch > 2:
393
+ old_checkpoint = f"{args.savedir}/state_{epoch - 3}"
394
+ if os.path.exists(old_checkpoint):
395
+ import shutil
396
+ shutil.rmtree(old_checkpoint)
397
+ print(f"Removed old checkpoint: {old_checkpoint}")