jkralev commited on
Commit
c0b77ea
·
verified ·
1 Parent(s): c1f583c

Upload folder using huggingface_hub

Browse files
DistributedTrainer.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.trainer import *
2
+
3
+ class DistributedTrainer(Trainer):
4
+ def _inner_training_loop(
5
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
6
+ ):
7
+ self.accelerator.free_memory()
8
+ self._train_batch_size = batch_size
9
+ if self.args.auto_find_batch_size:
10
+ if self.state.train_batch_size != self._train_batch_size:
11
+ from accelerate.utils import release_memory
12
+
13
+ (self.model_wrapped,) = release_memory(self.model_wrapped)
14
+ self.model_wrapped = self.model
15
+
16
+ # Check for DeepSpeed *after* the initial pass and modify the config
17
+ if self.is_deepspeed_enabled:
18
+ # Temporarily unset `self.args.train_batch_size`
19
+ original_bs = self.args.per_device_train_batch_size
20
+ self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
21
+ self.propagate_args_to_deepspeed(True)
22
+ self.args.per_device_train_batch_size = original_bs
23
+ self.state.train_batch_size = self._train_batch_size
24
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
25
+ # Data loader and number of training steps
26
+ train_dataloader = self.get_train_dataloader()
27
+ if self.is_fsdp_xla_v2_enabled:
28
+ train_dataloader = tpu_spmd_dataloader(train_dataloader)
29
+
30
+ # Setting up training control variables:
31
+ # number of training epochs: num_train_epochs
32
+ # number of training steps per epoch: num_update_steps_per_epoch
33
+ # total number of training steps to execute: max_steps
34
+ total_train_batch_size = self.get_total_train_batch_size(args)
35
+
36
+ (
37
+ num_train_epochs,
38
+ num_update_steps_per_epoch,
39
+ num_examples,
40
+ num_train_samples,
41
+ epoch_based,
42
+ len_dataloader,
43
+ max_steps,
44
+ ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
45
+
46
+ num_train_tokens = None
47
+ if self.args.include_tokens_per_second:
48
+ num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
49
+ # If going by epochs, multiply tokens linearly
50
+ if len_dataloader is not None and epoch_based:
51
+ num_train_tokens *= args.num_train_epochs
52
+ # Otherwise since its steps, we just multiply by grad accum
53
+ else:
54
+ num_train_tokens *= args.gradient_accumulation_steps
55
+
56
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
57
+ if self.args.n_gpu > 1:
58
+ # nn.DataParallel(model) replicates the model, creating new variables and module
59
+ # references registered here no longer work on other gpus, breaking the module
60
+ raise ValueError(
61
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
62
+ " (torchrun or torch.distributed.launch (deprecated))."
63
+ )
64
+ else:
65
+ DebugUnderflowOverflow(self.model)
66
+
67
+ delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
68
+
69
+ # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
70
+ is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
71
+ if is_fsdp2:
72
+ delay_optimizer_creation = False
73
+
74
+ # We need to reset the scheduler, as its parameters may be different on subsequent calls
75
+ if self._created_lr_scheduler:
76
+ self.lr_scheduler = None
77
+ self._created_lr_scheduler = False
78
+
79
+ if self.is_deepspeed_enabled:
80
+ self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
81
+
82
+ if not delay_optimizer_creation:
83
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
84
+
85
+ self.state = TrainerState(
86
+ stateful_callbacks=[
87
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
88
+ ]
89
+ )
90
+ self.state.is_hyper_param_search = trial is not None
91
+ self.state.train_batch_size = self._train_batch_size
92
+
93
+ # Compute absolute values for logging, eval, and save if given as ratio
94
+ self.state.compute_steps(args, max_steps)
95
+
96
+ # Activate gradient checkpointing if needed
97
+ if args.gradient_checkpointing:
98
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
99
+
100
+ model = self._wrap_model(self.model_wrapped)
101
+
102
+ # as the model is wrapped, don't use `accelerator.prepare`
103
+ # this is for unhandled cases such as
104
+ # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
105
+ use_accelerator_prepare = model is self.model
106
+
107
+ if use_accelerator_prepare and self.is_fsdp_enabled:
108
+ # In case of auto_find_batch_size=True
109
+ # Remove FSDP wrapping from sub-models.
110
+ self.model = unwrap_model(self.model, recursive=True)
111
+
112
+ if delay_optimizer_creation:
113
+ if use_accelerator_prepare:
114
+ # configure fsdp plugin for qlora if any
115
+ self._fsdp_qlora_plugin_updates()
116
+ if self.accelerator.mixed_precision != "fp8":
117
+ self.model = self.accelerator.prepare(self.model)
118
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
119
+
120
+ # prepare using `accelerator` prepare
121
+ use_accelerator_prepare = False
122
+ if use_accelerator_prepare:
123
+ self.model.train()
124
+ if hasattr(self.lr_scheduler, "step"):
125
+ if self.use_apex:
126
+ model = self.accelerator.prepare(self.model)
127
+ else:
128
+ # We should avoid accelerate preparing the model in TP case since we dont need it as it is handled by transformers from_pretrained and also it goes into DDP based preparation.
129
+ if self.is_tp_enabled:
130
+ self.optimizer = self.accelerator.prepare(self.optimizer)
131
+ else:
132
+ model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
133
+ else:
134
+ # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
135
+ model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
136
+ self.model, self.optimizer, self.lr_scheduler
137
+ )
138
+ else:
139
+ self.optimizer = self.accelerator.prepare(self.optimizer)
140
+
141
+ if self.is_fsdp_enabled:
142
+ self.model = self.model_wrapped = model
143
+
144
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
145
+ if model is not self.model:
146
+ self.model_wrapped = model
147
+
148
+ # backward compatibility
149
+ if self.is_deepspeed_enabled:
150
+ self.deepspeed = self.model_wrapped
151
+
152
+ # ckpt loading
153
+ if resume_from_checkpoint is not None:
154
+ if self.is_deepspeed_enabled:
155
+ deepspeed_load_checkpoint(
156
+ self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
157
+ )
158
+ elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
159
+ self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
160
+
161
+ # Check if saved optimizer or scheduler states exist
162
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
163
+ self._load_scaler(resume_from_checkpoint)
164
+
165
+ # important: at this point:
166
+ # self.model is the Transformers Model
167
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
168
+ # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
169
+
170
+ # Train!
171
+ logger.info("***** Running training *****")
172
+ logger.info(f" Num examples = {num_examples:,}")
173
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
174
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
175
+ if self.args.per_device_train_batch_size != self._train_batch_size:
176
+ logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
177
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
178
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
179
+ logger.info(f" Total optimization steps = {max_steps:,}")
180
+ logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
181
+
182
+ self.state.epoch = 0
183
+ start_time = time.time()
184
+ epochs_trained = 0
185
+ steps_trained_in_current_epoch = 0
186
+
187
+ # Check if continuing training from a checkpoint
188
+ if resume_from_checkpoint is not None and os.path.isfile(
189
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
190
+ ):
191
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
192
+ self.compare_trainer_and_checkpoint_args(self.args, self.state)
193
+ self._load_callback_state()
194
+ epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
195
+ if not args.ignore_data_skip:
196
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
197
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
198
+ else:
199
+ steps_trained_in_current_epoch = 0
200
+
201
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
202
+ logger.info(f" Continuing training from epoch {epochs_trained}")
203
+ logger.info(f" Continuing training from global step {self.state.global_step}")
204
+ if not args.ignore_data_skip:
205
+ logger.info(
206
+ f" Will skip the first {epochs_trained} epochs then the first"
207
+ f" {steps_trained_in_current_epoch} batches in the first epoch."
208
+ )
209
+
210
+ # Update the references
211
+ for attr in ("model", "optimizer", "lr_scheduler"):
212
+ setattr(self.callback_handler, attr, getattr(self, attr))
213
+ self.callback_handler.train_dataloader = train_dataloader
214
+
215
+ self.state.init_training_references(self, max_steps, num_train_epochs, trial)
216
+
217
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
218
+ tr_loss = torch.tensor(0.0, device=model.out_device)
219
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
220
+ self._total_loss_scalar = 0.0
221
+ self._globalstep_last_logged = self.state.global_step
222
+ model.zero_grad()
223
+ grad_norm: Optional[float] = None
224
+ learning_rate = None
225
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
226
+
227
+ if args.eval_on_start:
228
+ self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
229
+
230
+ for epoch in range(epochs_trained, num_train_epochs):
231
+ epoch_dataloader = train_dataloader
232
+ if hasattr(epoch_dataloader, "set_epoch"):
233
+ epoch_dataloader.set_epoch(epoch)
234
+
235
+ # Reset the past mems state at the beginning of each epoch if necessary.
236
+ if args.past_index >= 0:
237
+ self._past = None
238
+
239
+ steps_in_epoch = (
240
+ len(epoch_dataloader)
241
+ if len_dataloader is not None
242
+ else args.max_steps * args.gradient_accumulation_steps
243
+ )
244
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
245
+
246
+ step = -1
247
+ rng_to_sync = False
248
+
249
+ # Handle resumption from checkpoint
250
+ if epoch == epochs_trained and resume_from_checkpoint is not None:
251
+ if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
252
+ epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
253
+ step = steps_trained_in_current_epoch - 1
254
+ rng_to_sync = True
255
+ elif steps_trained_in_current_epoch == 0:
256
+ self._load_rng_state(resume_from_checkpoint)
257
+
258
+ epoch_iterator = iter(epoch_dataloader)
259
+ # We chunkify the epoch iterator into gradient accumulation steps `n` batches
260
+ remainder = steps_in_epoch % args.gradient_accumulation_steps
261
+ if remainder == 0:
262
+ remainder = args.gradient_accumulation_steps
263
+ update_step = -1
264
+ total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
265
+ remainder < args.gradient_accumulation_steps
266
+ )
267
+ for _ in range(total_updates):
268
+ update_step += 1
269
+ num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
270
+ batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
271
+ # Store the number of batches for current gradient accumulation
272
+ # This is used to correctly scale the loss when the last accumulation step has fewer batches
273
+ self.current_gradient_accumulation_steps = len(batch_samples)
274
+ for i, inputs in enumerate(batch_samples):
275
+ step += 1
276
+ do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
277
+ # Since we perform prefetching, we need to manually set sync_gradients
278
+ self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
279
+
280
+ if self.args.include_num_input_tokens_seen not in ["no", False]:
281
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
282
+ if main_input_name not in inputs:
283
+ logger.warning(
284
+ "Tried to track the number of tokens seen, however the current model is "
285
+ "not configured properly to know what item is the input. To fix this, add "
286
+ "a `main_input_name` attribute to the model class you are using."
287
+ )
288
+ else:
289
+ if self.args.include_num_input_tokens_seen == "non_padding":
290
+ if "attention_mask" in inputs:
291
+ input_tokens = inputs["attention_mask"].sum()
292
+ elif (
293
+ self.processing_class is not None
294
+ and hasattr(self.processing_class, "pad_token_id")
295
+ and self.processing_class.pad_token_id is not None
296
+ ):
297
+ input_tokens = (
298
+ inputs[main_input_name] != self.processing_class.pad_token_id
299
+ ).sum()
300
+ else:
301
+ logger.warning(
302
+ "Could not determine method to count non-padding tokens, falling back to counting all tokens."
303
+ )
304
+ input_tokens = inputs[main_input_name].numel()
305
+ else:
306
+ input_tokens = inputs[main_input_name].numel()
307
+
308
+ input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
309
+ self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
310
+
311
+ if rng_to_sync:
312
+ self._load_rng_state(resume_from_checkpoint)
313
+ rng_to_sync = False
314
+
315
+ if step % args.gradient_accumulation_steps == 0:
316
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
317
+
318
+ # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
319
+ context = (
320
+ functools.partial(self.accelerator.no_sync, model=model)
321
+ if i != len(batch_samples) - 1
322
+ and self.accelerator.distributed_type != DistributedType.DEEPSPEED
323
+ else contextlib.nullcontext
324
+ )
325
+ with context():
326
+ tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
327
+
328
+ if (
329
+ args.logging_nan_inf_filter
330
+ and not is_torch_xla_available()
331
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
332
+ ):
333
+ # if loss is nan or inf simply add the average of previous logged losses
334
+ tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
335
+ else:
336
+ if tr_loss.device != tr_loss_step.device:
337
+ raise ValueError(
338
+ f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
339
+ )
340
+ tr_loss = tr_loss + tr_loss_step
341
+
342
+ self.current_flos += float(self.floating_point_ops(inputs))
343
+
344
+ if do_sync_step:
345
+ # Since we perform prefetching, we need to manually set sync_gradients to True
346
+ self.accelerator.gradient_state._set_sync_gradients(True)
347
+
348
+ # Gradient clipping
349
+ if args.max_grad_norm is not None and args.max_grad_norm > 0:
350
+ if is_sagemaker_mp_enabled() and args.fp16:
351
+ _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
352
+ elif self.use_apex:
353
+ from apex import amp
354
+
355
+ # Revert to normal clipping otherwise, handling Apex or full precision
356
+ _grad_norm = nn.utils.clip_grad_norm_(
357
+ amp.master_params(self.optimizer),
358
+ args.max_grad_norm,
359
+ )
360
+ else:
361
+ grad_norm_context = contextlib.nullcontext
362
+ if self.is_tp_enabled:
363
+ from torch.distributed._tensor.experimental import implicit_replication
364
+
365
+ grad_norm_context = implicit_replication
366
+ with grad_norm_context():
367
+ _grad_norm = self.accelerator.clip_grad_norm_(
368
+ model.parameters(),
369
+ args.max_grad_norm,
370
+ )
371
+
372
+ if (
373
+ is_accelerate_available()
374
+ and self.accelerator.distributed_type == DistributedType.DEEPSPEED
375
+ ):
376
+ grad_norm = model.get_global_grad_norm()
377
+ # In some cases the grad norm may not return a float
378
+ if hasattr(grad_norm, "item"):
379
+ grad_norm = grad_norm.item()
380
+ else:
381
+ grad_norm = _grad_norm
382
+
383
+ self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
384
+
385
+ context = contextlib.nullcontext
386
+ if self.is_tp_enabled:
387
+ from torch.distributed._tensor.experimental import implicit_replication
388
+
389
+ context = implicit_replication
390
+
391
+ with context():
392
+ self.optimizer.step()
393
+
394
+ self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
395
+
396
+ # get leaning rate before update
397
+ learning_rate = self._get_learning_rate()
398
+
399
+ if not self.accelerator.optimizer_step_was_skipped:
400
+ # Delay optimizer scheduling until metrics are generated
401
+ if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
402
+ self.lr_scheduler.step()
403
+
404
+ model.zero_grad()
405
+ self.state.global_step += 1
406
+ self.state.epoch = epoch + (step + 1) / steps_in_epoch
407
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
408
+ self._maybe_log_save_evaluate(
409
+ tr_loss,
410
+ grad_norm,
411
+ model,
412
+ trial,
413
+ epoch,
414
+ ignore_keys_for_eval,
415
+ start_time,
416
+ learning_rate=learning_rate,
417
+ )
418
+ else:
419
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
420
+
421
+ # PyTorch/XLA relies on the data loader to insert the mark_step for
422
+ # each step. Since we are breaking the loop early, we need to manually
423
+ # insert the mark_step here.
424
+ if self.control.should_epoch_stop or self.control.should_training_stop:
425
+ if is_torch_xla_available():
426
+ xm.mark_step()
427
+ break
428
+ # We also need to break out of the nested loop
429
+ if self.control.should_epoch_stop or self.control.should_training_stop:
430
+ if is_torch_xla_available():
431
+ xm.mark_step()
432
+ break
433
+ if step < 0:
434
+ logger.warning(
435
+ "There seems not to be a single sample in your epoch_iterator, stopping training at step"
436
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
437
+ f" num_steps ({max_steps}) higher than the number of available samples."
438
+ )
439
+ self.control.should_training_stop = True
440
+
441
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
442
+ self._maybe_log_save_evaluate(
443
+ tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
444
+ )
445
+
446
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
447
+ if is_torch_xla_available():
448
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
449
+ xm.master_print(met.metrics_report())
450
+ else:
451
+ logger.warning(
452
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
453
+ "configured. Check your training configuration if this is unexpected."
454
+ )
455
+ if self.control.should_training_stop:
456
+ break
457
+
458
+ if args.past_index and hasattr(self, "_past"):
459
+ # Clean the state at the end of training
460
+ delattr(self, "_past")
461
+
462
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
463
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
464
+ self._load_best_model()
465
+
466
+ # add remaining tr_loss
467
+ self._total_loss_scalar += tr_loss.item()
468
+ effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
469
+ train_loss = self._total_loss_scalar / effective_global_step
470
+
471
+ metrics = speed_metrics(
472
+ "train",
473
+ start_time,
474
+ num_samples=num_train_samples,
475
+ num_steps=self.state.max_steps,
476
+ num_tokens=num_train_tokens,
477
+ )
478
+ self.store_flos()
479
+ metrics["total_flos"] = self.state.total_flos
480
+ metrics["train_loss"] = train_loss
481
+
482
+ self.is_in_train = False
483
+
484
+ self._memory_tracker.stop_and_update_metrics(metrics)
485
+
486
+ self.log(metrics)
487
+
488
+ run_dir = self._get_output_dir(trial)
489
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
490
+
491
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
492
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
493
+ for checkpoint in checkpoints_sorted:
494
+ if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
495
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
496
+ shutil.rmtree(checkpoint, ignore_errors=True)
497
+
498
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
499
+
500
+ # Wait for the checkpoint to be uploaded.
501
+ self._finish_current_push()
502
+
503
+ # After training we make sure to retrieve back the original forward pass method
504
+ # for the embedding layer by removing the forward post hook.
505
+ if self.neftune_noise_alpha is not None:
506
+ self._deactivate_neftune(self.model)
507
+
508
+ return TrainOutput(self.state.global_step, train_loss, metrics)
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from .configuration_mic21 import MIC21SummarizerConfig
3
+ from .modeling_mic21 import MIC21SummarizerModel
4
+
5
+ AutoConfig.register("mic21_summarizer", MIC21SummarizerConfig)
6
+ AutoModel.register(MIC21SummarizerConfig, MIC21SummarizerModel)
configuration_mic21.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class MIC21SummarizerConfig(PretrainedConfig):
5
+ model_type = "mic21_summarizer"
6
+
7
+ def __init__(
8
+ self,
9
+ hf_text_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
10
+ hf_image_model = "microsoft/resnet-50",
11
+ im_model_cuda_id = 0,
12
+ device_map = "auto",
13
+ memory_map = {},
14
+ #text_model_dtype = torch.float16,
15
+ attn_implementation = "eager",
16
+ in_device = 0,
17
+ out_device = 0,
18
+ output_length = 40,
19
+ **kwargs,
20
+ ):
21
+ self.hf_text_model = hf_text_model
22
+ self.hf_image_model = hf_image_model
23
+ self.im_model_cuda_id = im_model_cuda_id
24
+ self.device_map = device_map
25
+ self.memory_map = memory_map
26
+ #self.text_model_dtype = text_model_dtype
27
+ self.attn_implementation = attn_implementation
28
+ self.in_device = in_device
29
+ self.out_device = out_device
30
+ self.output_length = output_length
31
+ self.auto_map = {
32
+ "AutoConfig": "jkralev/mic21_model--configuration_mic21.MIC21SummarizerConfig",
33
+ "AutoModel": "jkralev/mic21_model--modeling_mic21.MIC21SummarizerModel"}
34
+ super().__init__(**kwargs)
mic21_preprocess.py ADDED
File without changes
modeling_mic21.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_utils import PreTrainedModel
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import pdb
5
+ from transformers import OffloadedCache,DynamicCache
6
+ from .configuration_mic21 import MIC21SummarizerConfig
7
+ import numpy as np
8
+ from transformers import AutoImageProcessor, ResNetForImageClassification
9
+
10
+ class MIC21SummarizerModel(PreTrainedModel):
11
+ config_class = MIC21SummarizerConfig
12
+ is_parallelizable = True
13
+ model_parallel = True
14
+ place_model_on_device = False
15
+ model_wrapped = {}
16
+
17
+ def __init__(self,config):
18
+ super().__init__(config)
19
+ #Init Image Processing Model
20
+ self.components = {"image_model":None,"llm":None,"tokenizer":None,"image_processor":None}
21
+ #self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model,device_map=f"cuda:{config.im_model_cuda_id}")
22
+ self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model)
23
+ self.components["image_processor"] = AutoImageProcessor.from_pretrained(config.hf_image_model)
24
+
25
+ self.components["llm"] = AutoModelForCausalLM.from_pretrained(config.hf_text_model,torch_dtype=torch.float16)
26
+ #self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16)
27
+ #self.components["llm"] = AutoModelForCausalLM.from_pretrained(
28
+ # config.hf_text_model,
29
+ # device_map=config.device_map,
30
+ # max_memory=config.memory_map,
31
+ # torch_dtype=torch.float16,#config.text_model_dtype,
32
+ # attn_implementation=config.attn_implementation,
33
+ # #quantization_config=self.quantization_config
34
+ #)
35
+ self.components["tokenizer"] = AutoTokenizer.from_pretrained(config.hf_text_model)
36
+
37
+ #self.in_device = config.in_device
38
+ #self.out_device = config.out_device
39
+
40
+ #self.projection_layer = torch.nn.Linear(49, self.components["llm"].config.hidden_size, dtype=torch.float, device=f"cuda:{self.in_device}")
41
+ self.projection_layer = torch.nn.Linear(49, self.components["llm"].config.hidden_size, dtype=torch.float)
42
+
43
+ #self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True, device=f"cuda:{self.in_device}")
44
+ self.projection_layer = torch.nn.Linear(49, self.components["llm"].config.hidden_size, dtype=torch.float)
45
+ self.projection_dropout = torch.nn.Dropout(0.1)
46
+
47
+ for param in self.components["image_model"].parameters():
48
+ param.requires_grad = False
49
+
50
+ for param in self.components["llm"].parameters():
51
+ param.requires_grad = False
52
+
53
+ self.im_model_cuda_id = config.im_model_cuda_id
54
+ self.output_length = config.output_length
55
+
56
+ def forward(self, images, titles):
57
+ prepared_images = self.components["image_processor"](images,return_tensors="pt")
58
+ #prepared_images = prepared_images.to(f"cuda:{self.im_model_cuda_id}")
59
+
60
+ img_features = self.components["image_model"](**prepared_images,output_hidden_states=True)
61
+ img_features = img_features["hidden_states"][-1]
62
+ (batch_size,nfilter,nx,ny)=img_features.shape
63
+ img_features = img_features.view(batch_size,nfilter,nx*ny)
64
+
65
+ messages = [
66
+ {"role":"system","content":"Generate title and description for the provided image. The image features are: "},
67
+ {"role":"user","content":"Generate a title:"}]
68
+
69
+ tokenized_messages = self.components["tokenizer"].apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
70
+ #.to(self.in_device)
71
+ vectorized_messages = self.components["llm"].model.embed_tokens(tokenized_messages[0]).unsqueeze(0)
72
+ vectorized_messages = vectorized_messages.repeat(batch_size,1,1)
73
+ #.to(self.in_device)
74
+ first_eos_index = (tokenized_messages[0]==self.components["tokenizer"].eos_token_id).nonzero()[0].item()
75
+
76
+ #img_features = img_features.to(f"cuda:{self.in_device}")
77
+ visual_embeddings = self.projection_layer(self.projection_dropout(self.projection_norm(img_features[:,0:256,:])))
78
+
79
+ #visual_embeddings.half().to(self.in_device)
80
+ combined_embeds = torch.cat([
81
+ vectorized_messages[:,:first_eos_index-1,:],
82
+ visual_embeddings.half(),
83
+ vectorized_messages[:,first_eos_index:,:]],dim=1)
84
+
85
+ #combined_embeds = torch.cat([self.input_emb, self.eot_emb],dim=1)
86
+ self.cache = OffloadedCache()
87
+ #self.cache = DynamicCache()
88
+
89
+ outputs = self.components["llm"](inputs_embeds=combined_embeds,past_key_values=self.cache,use_cache=True)
90
+ logits = outputs.logits[:,-1]
91
+ out_logits = logits.unsqueeze(1)
92
+ new_tok = torch.argmax(logits,dim=-1)
93
+
94
+ if self.output_length is None:
95
+ max_len = 64
96
+ else:
97
+ max_len = self.output_length
98
+
99
+ for k in range(0,max_len):
100
+ outputs = self.components["llm"](input_ids=new_tok.unsqueeze(0).permute(1,0),past_key_values=self.cache,use_cache=True)
101
+ logits = outputs.logits[:,-1]
102
+ if out_logits is None:
103
+ out_logits = logits.unsqueeze(1)
104
+ else:
105
+ out_logits = torch.cat([out_logits,logits.unsqueeze(1)],dim=1)
106
+ new_tok = torch.argmax(logits,dim=-1)
107
+ if max_len is None and new_tok.item() == self.components["tokenizer"].eos_token_id:
108
+ break
109
+ if titles is not None:
110
+ target_tok = self.components["tokenizer"](titles, add_special_tokens=False, max_length=max_len+1, padding='max_length')
111
+ loss = torch.nn.CrossEntropyLoss()(out_logits.permute((0,2,1)), torch.LongTensor(target_tok["input_ids"]))
112
+ #.cuda(self.out_device))
113
+ return {"loss": loss, "logits": logits}
114
+
115
+ return {"logits":out_logits}
116
+