| import transformers |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GPTNeoXForCausalLM |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import random |
| from dataclasses import dataclass, field |
| from typing import Optional |
| from peft import ( |
| get_peft_model, |
| PeftModel, |
| PeftConfig |
| ) |
| from torch.nn.functional import gelu |
| import math |
| from safetensors.torch import load_file |
| from transformers.modeling_outputs import ModelOutput |
| import random |
| import copy |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.2") |
| separate_decoder_name: str = field(default="") |
| lora_r: int = field(default=128, metadata={"help": "lora rank"}) |
| lora_dropout: float = field(default=0.05, metadata={"help": "lora dropout"}) |
| full_precision: bool = field(default=True, metadata={"help": "whether use int4 for the base model"}) |
| train: bool = field( |
| default=True, |
| metadata={ |
| "help": "if true, the model ckpt will be initialized for training; else, it's for inference" |
| }, |
| ) |
| lora_init: bool = field( |
| default=False, |
| metadata={"help": "True: Use zero and gaussian initialization; False: Load adapters from LoftQ in HF hub."}, |
| ) |
| token: Optional[str] = field( |
| default=None, |
| metadata={"help": "HF token to access to private models, e.g., meta-llama"}, |
| ) |
| adapter_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to the LoRA adapter. Used in evaluation or resuming from the checkpoint."}, |
| ) |
| lora_alpha: int = field( |
| default=16, |
| metadata={"help": "LoftQ does not require this config. Used for QLoRA."}, |
| ) |
| ckpt_dir: Optional[str] = field(default=None, metadata={"help": "checkpoint dir for inference."}) |
|
|
| @dataclass |
| class DataArguments: |
| data_name: str = field( |
| default=None, metadata={"help": "Path to the training data."} |
| ) |
| debug_data: bool = field( |
| default=False, |
| metadata={ |
| "help": "Enable debug dataset to quickly verify the training process" |
| }, |
| ) |
| batch_size: int = field(default=1, metadata={"help": "batch size during inference"}) |
|
|
| @dataclass |
| class TrainingArguments(transformers.TrainingArguments): |
| cache_dir: Optional[str] = field(default=None) |
| optim: str = field(default="adamw_torch") |
| model_max_length: int = field( |
| default=28000, |
| metadata={ |
| "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." |
| }, |
| ) |
| restore_from: str = field( |
| default="", |
| metadata={ |
| "help": "The checkpoint that should be restored from for fine-tuning" |
| }, |
| ) |
| per_device_train_batch_size: int = field( |
| default=1, |
| ) |
| per_device_eval_batch_size: int = field( |
| default=1, |
| ) |
| expt_name: str = field( |
| default="default", |
| metadata={"help": "Experiment name"}, |
| ) |
| icot_train_path: str = field(default="/users/k24020023/efficient_cot/icae/code/coconut/icot_gsm8k/train.txt", metadata={"help":"The training data path"}) |
| num_latent: int = field(default=5, metadata={"help": "The number of latent for training or inference."}) |
| use_lora: bool = field(default=True, metadata={"help": "Use lora or not."}) |
| greedy: bool = field(default=False, metadata={"help": "Greedy decoding during inference."}) |
| exp_mode: bool = field(default=False, metadata={"help": "Use partial number of data. for debugging."}) |
| exp_data_num: int = field(default=10000, metadata={"help": "The number of data used in exp mode"}) |
| use_prj: bool = field(default=False, metadata={"help": "Use a prj module after the llm for latent generation."}) |
| prj_dim: int = field(default=2048, metadata={"help": "The hidden dim of the projection module."}) |
| prj_dropout: float = field(default=0.0, metadata={"help": "Dropout ratio of the projection module."}) |
| prj_no_ln: bool = field(default=False, metadata={"help": "Remove the Layer Norm layer for the projection module."}) |
| distill_loss_div_std: bool = field(default=False, metadata={"help": "Divide the distillation loss by a std for normallisation."}) |
| distill_loss_type: str = field(default="smooth_l1", metadata={"help": "Specify the distillation loss. Use smoothL1 by default."}) |
| distill_loss_factor: float = field(default=1.0, metadata={"help": "A multiplier of the distillation loss."}) |
| ref_loss_factor: float = field(default=1.0, metadata={"help": "A multiplier of the distillation loss."}) |
| inf_latent_iterations: int = field(default=1, metadata={"help": ""}) |
| inf_num_iterations: int = field(default=5, metadata={"help": "Run multiple times during inference"}) |
| remove_eos: bool = field(default=False, metadata={"help": "Do not add <eos> as a delimiter to split QA."}) |
| print_ref_model_stats: bool = field(default=False, metadata={"help": "Print some stats for the teacher task."}) |
| include_last_cot: bool = field(default=False, metadata={"help": "Include the last CoT step in the training data."}) |
| fix_attn_mask: bool = field(default=False, metadata={"help": "Correct a bug about attention mask."}) |
| log_full: bool = field(default=False, metadata={"help": "Log all losses."}) |
| print_loss: bool = field(default=True) |
| max_token_num: int = field(default=1000, metadata={"help": "Limit the longest data to avoid OOM."}) |
|
|
| def print_trainable_parameters(model): |
| trainable_parameters = 0 |
| all_param = 0 |
| for _, param in model.named_parameters(): |
| all_param += param.numel() |
| if param.requires_grad: |
| trainable_parameters += param.numel() |
| print( |
| f"trainable params: {trainable_parameters} || all params: {all_param} || trainable%: {100 * trainable_parameters / all_param}" |
| ) |
| |
| |
| |
|
|
|
|
| def freeze_model(model): |
| for _, param in model.named_parameters(): |
| param.requires_grad = False |
|
|
| class CODI(torch.nn.Module): |
| def __init__(self, model_args, training_args, lora_config): |
| super().__init__() |
| self.model_args = model_args |
| self.training_args = training_args |
| self.model_name = model_args.model_name_or_path |
| model_wrapper_class = AutoModelForCausalLM |
| if model_args.full_precision: |
| self.codi = model_wrapper_class.from_pretrained( |
| self.model_name, |
| torch_dtype=( |
| torch.float16 if training_args.bf16 is False else torch.bfloat16 |
| ), |
| use_flash_attention_2=False, |
| resume_download=True, |
| ) |
| else: |
| self.codi = model_wrapper_class.from_pretrained( |
| self.model_name, |
| torch_dtype=( |
| torch.float16 if training_args.bf16 is False else torch.bfloat16 |
| ), |
| use_flash_attention_2=False, |
| resume_download=True, |
| quantization_config=transformers.BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=False, |
| bnb_4bit_quant_type='nf4', |
| ) |
| ) |
|
|
|
|
| ori_vocab_size = self.codi.config.vocab_size |
| self.training = self.model_args.train |
|
|
| |
| self.pad_token_id = ori_vocab_size |
| self.bot_id = ori_vocab_size + 1 |
| self.eot_id = ori_vocab_size + 2 |
|
|
| self.codi.resize_token_embeddings( |
| ori_vocab_size + 3 |
| ) |
|
|
| self.dim = self.codi.config.hidden_size |
| self.num_latent = training_args.num_latent |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) |
|
|
| |
| if training_args.use_lora: |
| self.codi = get_peft_model(self.codi, lora_config) |
|
|
| |
| self.use_prj = training_args.use_prj |
| self.prj_no_ln = training_args.prj_no_ln |
| if training_args.use_prj: |
| self.prj = nn.Sequential( |
| nn.Dropout(training_args.prj_dropout), |
| nn.Linear(self.dim, training_args.prj_dim), |
| nn.GELU(), |
| nn.Linear(training_args.prj_dim, self.dim), |
| ) |
| if not self.prj_no_ln: |
| self.prj.add_module("ln", nn.LayerNorm(self.dim)) |
| |
| |
| self.print_loss = training_args.print_loss |
| self.ref_loss_factor = training_args.ref_loss_factor |
|
|
| |
| self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| |
| |
| self.distill_loss_div_std = training_args.distill_loss_div_std |
| self.distill_loss_type = training_args.distill_loss_type |
| self.distill_loss_factor = training_args.distill_loss_factor |
| if self.distill_loss_type == "smooth_l1": |
| self.distill_loss_fct = nn.SmoothL1Loss() |
| elif self.distill_loss_type == "l2": |
| self.distill_loss_fct = nn.MSELoss() |
| else: |
| raise NotImplementedError |
|
|
| |
| self.fix_attn_mask = training_args.fix_attn_mask |
|
|
| if self.tokenizer.pad_token_id is None: |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.tokenizer.pad_token_id = self.pad_token_id |
|
|
| if self.training: |
| self.init() |
|
|
| def get_embd(self, model, model_name): |
| try: |
| if "pythia" in model_name: |
| return model.get_base_model().gpt_neox.embed_in |
| elif "gpt2" in model_name: |
| try: |
| return model.get_base_model().transformer.wte |
| except Exception: |
| return model.transformer.wte |
| else: |
| try: |
| return model.get_base_model().model.embed_tokens |
| except Exception: |
| return model.model.embed_tokens |
| except AttributeError: |
| if "pythia" in model_name: |
| return model.gpt_neox.embed_in |
| raise NotImplementedError |
|
|
| def init(self): |
| print_trainable_parameters(self) |
| if ( |
| self.training_args.restore_from is not None |
| and self.training_args.restore_from != "" |
| ): |
| print( |
| f"Loading from the pretrained checkpoint: {self.training_args.restore_from}..." |
| ) |
| state_dict = load_file(self.training_args.restore_from) |
| self.load_state_dict(state_dict) |
| print(f"Finished loading from {self.training_args.restore_from}") |
|
|
| def forward( |
| self, |
| encoder_input_ids: torch.LongTensor = None, |
| decoder_input_ids: torch.LongTensor = None, |
| ref_input_ids: torch.LongTensor = None, |
| labels: Optional[torch.LongTensor] = None, |
| encoder_attention_mask: Optional[torch.LongTensor] = None, |
| ref_answer_position: Optional[torch.LongTensor] = None, |
| model_answer_position: Optional[torch.LongTensor] = None, |
| ref_attention_mask: Optional[torch.LongTensor] = None, |
| ref_labels: torch.LongTensor = None, |
| step: int = None, |
| step_ratio: float = None |
| ): |
| if not self.fix_attn_mask: |
| ref_attention_mask = None |
| |
| |
| past_key_values = None |
| outputs = self.codi(input_ids=encoder_input_ids, use_cache=True, output_hidden_states=True, past_key_values=past_key_values, attention_mask=encoder_attention_mask) |
| past_key_values = outputs.past_key_values |
| latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1) |
| if self.use_prj: |
| latent_embd = self.prj(latent_embd) |
|
|
| len_pred_loss = 0 |
| dynamic_mask = None |
| if self.fix_attn_mask: |
| dynamic_mask = torch.ones((encoder_attention_mask.size(0), self.num_latent), device=ref_labels.device) |
|
|
| |
| distill_loss_total = 0 |
| ce_loss_total = 0 |
|
|
| with torch.no_grad(): |
| ref_outputs = self.codi(input_ids=ref_input_ids, output_hidden_states=True, attention_mask=ref_attention_mask) |
| ref_outputs_with_grad = self.codi(input_ids=ref_input_ids, output_hidden_states=True, attention_mask=ref_attention_mask) |
| |
| |
| ref_outputs_list = [ref_outputs] |
| ref_input_ids = [ref_input_ids] |
|
|
| |
| |
| if "llama" in self.model_name.lower() or "qwen" in self.model_name.lower(): |
| model_answer_position = model_answer_position + 1 |
| ref_answer_position = ref_answer_position + 1 |
| |
| |
| if self.training_args.print_ref_model_stats: |
| for i, (ref_inputs, ref_outputs) in enumerate(zip(ref_input_ids, ref_outputs_list)): |
| |
| if len(ref_outputs_list) > 1: |
| pos = ref_answer_position[i] |
| else: |
| pos = ref_answer_position |
| ref_probs = torch.nn.functional.softmax(ref_outputs.logits, dim=-1) |
| input_positions = (pos-1).unsqueeze(1).unsqueeze(1).expand(-1, -1, ref_probs.size(2)) |
| ref_probs_at_positions = ref_probs.gather(1, input_positions) |
| probe_positions_positions = pos.unsqueeze(1) |
| probe_positions = ref_inputs.gather(1, probe_positions_positions).unsqueeze(1) |
| ref_probs_of_target = ref_probs_at_positions.gather(2, probe_positions) |
| print(f'stage{i}: mean of the prob of the target token: {ref_probs_of_target.mean()}') |
| |
| |
| model_answer_position = model_answer_position - 1 |
| ref_answer_position = ref_answer_position -1 |
| |
| num_latent = self.num_latent |
| if self.num_latent != 0: |
| for i in range(num_latent): |
| |
| outputs = self.codi(inputs_embeds=latent_embd, use_cache=True, output_hidden_states=True, past_key_values=past_key_values) |
| past_key_values = outputs.past_key_values |
| latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1) |
| if self.use_prj: |
| latent_embd = self.prj(latent_embd) |
|
|
| |
| if i == num_latent - 1: |
| |
| embds = self.get_embd(self.codi, self.model_name)(decoder_input_ids) |
| |
| if dynamic_mask is not None: |
| decoder_mask = torch.ones((embds.size(0), embds.size(1)), dtype=torch.bool).to(dynamic_mask) |
| dynamic_mask = torch.cat((encoder_attention_mask, dynamic_mask, decoder_mask), dim=1) |
| dynamic_mask = dynamic_mask.bool() |
| |
| outputs = self.codi(inputs_embeds=embds, use_cache=True, output_hidden_states=True, past_key_values=past_key_values, attention_mask=dynamic_mask) |
| |
| ref_outputs = ref_outputs_list[0] |
| |
| distill_loss = 0 |
| |
| for j, (out, ref_out) in enumerate(zip(outputs.hidden_states, ref_outputs.hidden_states)): |
| ref_selected = ref_out.gather(1, ref_answer_position.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, ref_out.size(-1))) |
| out_selected = out.gather(1, model_answer_position.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, out.size(-1))) |
|
|
| distill_loss_tmp = self.distill_loss_fct(out_selected, ref_selected.detach()) |
| |
| if self.distill_loss_div_std: |
| if self.distill_loss_type == 'l2': |
| distill_loss_tmp /= ref_selected.std() |
| distill_loss_tmp /= ref_selected.std() |
| distill_loss += distill_loss_tmp |
| |
| distill_loss /= len(outputs.hidden_states) |
| |
| if self.print_loss: |
| print(f'latent{i}: distill_loss={distill_loss}') |
|
|
| distill_loss_total += distill_loss |
|
|
| |
| if i == num_latent - 1: |
| logits = outputs.logits |
| effective_logits = logits[:, :-1, :] |
| effective_logits = effective_logits.reshape(-1, logits.size(-1)) |
| target_ids = labels[:, 1:].reshape(-1) |
| ce_loss = self.loss_fct(effective_logits, target_ids) |
| ce_loss_total += ce_loss |
|
|
| |
| ref_ce_loss = 0 |
| ref_logits = ref_outputs_with_grad.logits |
| effective_ref_logits = ref_logits[:, :-1, :] |
| effective_ref_logits = effective_ref_logits.reshape(-1, ref_logits.size(-1)) |
| ref_target_ids = ref_labels[:, 1:].reshape(-1) |
| ref_ce_loss = self.loss_fct(effective_ref_logits, ref_target_ids) |
| ref_ce_loss *= self.ref_loss_factor |
|
|
| |
| distill_loss *= self.distill_loss_factor |
| distill_loss_total *= self.distill_loss_factor |
|
|
| if self.print_loss: |
| print(f'loss={ce_loss+distill_loss}, ce_loss={ce_loss}, distill_loss={distill_loss}, ce_loss_total={ce_loss_total}, distill_loss_total={distill_loss_total}, ref_ce_loss={ref_ce_loss}') |
|
|
| loss = ce_loss_total + distill_loss_total + ref_ce_loss |
| |
| if ce_loss_total != 0: |
| ce_loss_total = ce_loss_total.detach().item() |
| if distill_loss_total != 0: |
| distill_loss_total = distill_loss_total.detach().item() |
| if ref_ce_loss != 0: |
| ref_ce_loss = ref_ce_loss.detach().item() |
|
|
| return {"loss": loss, "logits": logits, "ce_loss": ce_loss_total, "distill_loss": distill_loss_total, "ref_ce_loss": ref_ce_loss} |