| | import deepspeed |
| | import torch |
| | from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
| |
|
| | from ovis.util.constants import END_LINE, BEGIN_LINE |
| | from ovis.util.utils import rank0_print |
| |
|
| |
|
| | class TuneTauCallback(TrainerCallback): |
| | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| | visual_tokenizer = kwargs['model'].get_visual_tokenizer() |
| | current_step = state.global_step |
| | max_step = state.max_steps |
| | ratio = current_step / max_step |
| | visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio |
| |
|
| |
|
| | class MonitorCallback(TrainerCallback): |
| | def _monitoring(self, model, step): |
| | with torch.no_grad(): |
| | with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()): |
| | for k, v in model.get_monitor_tensors().items(): |
| | rank0_print(BEGIN_LINE) |
| | rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ') |
| | rank0_print(v) |
| | rank0_print(END_LINE) |
| |
|
| | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| | model = kwargs['model'] |
| | step = state.global_step |
| | if step % args.monitor_step == 0 or step == 10: |
| | self._monitoring(model, step) |
| |
|
| | def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| | model = kwargs['model'] |
| | step = state.global_step |
| | self._monitoring(model, step) |
| |
|