import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GPTNeoXForCausalLM import os import torch import torch.nn as nn import torch.nn.functional as F import random from dataclasses import dataclass, field from typing import Optional from peft import ( get_peft_model, PeftModel, PeftConfig ) from torch.nn.functional import gelu import math from safetensors.torch import load_file from transformers.modeling_outputs import ModelOutput import random import copy device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @dataclass class ModelArguments: model_name_or_path: str = field(default="mistralai/Mistral-7B-Instruct-v0.2") separate_decoder_name: str = field(default="") lora_r: int = field(default=128, metadata={"help": "lora rank"}) lora_dropout: float = field(default=0.05, metadata={"help": "lora dropout"}) full_precision: bool = field(default=True, metadata={"help": "whether use int4 for the base model"}) train: bool = field( default=True, metadata={ "help": "if true, the model ckpt will be initialized for training; else, it's for inference" }, ) lora_init: bool = field( default=False, metadata={"help": "True: Use zero and gaussian initialization; False: Load adapters from LoftQ in HF hub."}, ) token: Optional[str] = field( default=None, metadata={"help": "HF token to access to private models, e.g., meta-llama"}, ) adapter_name_or_path: Optional[str] = field( default=None, metadata={"help": "Path to the LoRA adapter. Used in evaluation or resuming from the checkpoint."}, ) lora_alpha: int = field( default=16, metadata={"help": "LoftQ does not require this config. Used for QLoRA."}, ) ckpt_dir: Optional[str] = field(default=None, metadata={"help": "checkpoint dir for inference."}) @dataclass class DataArguments: data_name: str = field( default=None, metadata={"help": "Path to the training data."} ) debug_data: bool = field( default=False, metadata={ "help": "Enable debug dataset to quickly verify the training process" }, ) batch_size: int = field(default=1, metadata={"help": "batch size during inference"}) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=28000, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) restore_from: str = field( default="", metadata={ "help": "The checkpoint that should be restored from for fine-tuning" }, ) per_device_train_batch_size: int = field( default=1, ) per_device_eval_batch_size: int = field( default=1, ) expt_name: str = field( default="default", metadata={"help": "Experiment name"}, ) icot_train_path: str = field(default="/users/k24020023/efficient_cot/icae/code/coconut/icot_gsm8k/train.txt", metadata={"help":"The training data path"}) num_latent: int = field(default=5, metadata={"help": "The number of latent for training or inference."}) use_lora: bool = field(default=True, metadata={"help": "Use lora or not."}) greedy: bool = field(default=False, metadata={"help": "Greedy decoding during inference."}) exp_mode: bool = field(default=False, metadata={"help": "Use partial number of data. for debugging."}) exp_data_num: int = field(default=10000, metadata={"help": "The number of data used in exp mode"}) use_prj: bool = field(default=False, metadata={"help": "Use a prj module after the llm for latent generation."}) prj_dim: int = field(default=2048, metadata={"help": "The hidden dim of the projection module."}) prj_dropout: float = field(default=0.0, metadata={"help": "Dropout ratio of the projection module."}) prj_no_ln: bool = field(default=False, metadata={"help": "Remove the Layer Norm layer for the projection module."}) distill_loss_div_std: bool = field(default=False, metadata={"help": "Divide the distillation loss by a std for normallisation."}) distill_loss_type: str = field(default="smooth_l1", metadata={"help": "Specify the distillation loss. Use smoothL1 by default."}) distill_loss_factor: float = field(default=1.0, metadata={"help": "A multiplier of the distillation loss."}) ref_loss_factor: float = field(default=1.0, metadata={"help": "A multiplier of the distillation loss."}) inf_latent_iterations: int = field(default=1, metadata={"help": ""}) inf_num_iterations: int = field(default=5, metadata={"help": "Run multiple times during inference"}) remove_eos: bool = field(default=False, metadata={"help": "Do not add as a delimiter to split QA."}) print_ref_model_stats: bool = field(default=False, metadata={"help": "Print some stats for the teacher task."}) include_last_cot: bool = field(default=False, metadata={"help": "Include the last CoT step in the training data."}) fix_attn_mask: bool = field(default=False, metadata={"help": "Correct a bug about attention mask."}) log_full: bool = field(default=False, metadata={"help": "Log all losses."}) print_loss: bool = field(default=True) max_token_num: int = field(default=1000, metadata={"help": "Limit the longest data to avoid OOM."}) def print_trainable_parameters(model): trainable_parameters = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_parameters += param.numel() print( f"trainable params: {trainable_parameters} || all params: {all_param} || trainable%: {100 * trainable_parameters / all_param}" ) # for name, param in model.named_parameters(): # if param.requires_grad: # print(name, param.shape) def freeze_model(model): for _, param in model.named_parameters(): param.requires_grad = False class CODI(torch.nn.Module): def __init__(self, model_args, training_args, lora_config): super().__init__() self.model_args = model_args self.training_args = training_args self.model_name = model_args.model_name_or_path model_wrapper_class = AutoModelForCausalLM if model_args.full_precision: self.codi = model_wrapper_class.from_pretrained( self.model_name, torch_dtype=( torch.float16 if training_args.bf16 is False else torch.bfloat16 ), use_flash_attention_2=False, resume_download=True, ) else: self.codi = model_wrapper_class.from_pretrained( self.model_name, torch_dtype=( torch.float16 if training_args.bf16 is False else torch.bfloat16 ), use_flash_attention_2=False, resume_download=True, quantization_config=transformers.BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type='nf4', ) ) ori_vocab_size = self.codi.config.vocab_size self.training = self.model_args.train # special tokens to enclose the latent embeddings self.pad_token_id = ori_vocab_size self.bot_id = ori_vocab_size + 1 self.eot_id = ori_vocab_size + 2 self.codi.resize_token_embeddings( ori_vocab_size + 3 ) # dummy values for mem tokens self.dim = self.codi.config.hidden_size self.num_latent = training_args.num_latent self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) # LoRA if training_args.use_lora: self.codi = get_peft_model(self.codi, lora_config) # Projection Layer self.use_prj = training_args.use_prj self.prj_no_ln = training_args.prj_no_ln if training_args.use_prj: self.prj = nn.Sequential( nn.Dropout(training_args.prj_dropout), nn.Linear(self.dim, training_args.prj_dim), nn.GELU(), nn.Linear(training_args.prj_dim, self.dim), ) if not self.prj_no_ln: self.prj.add_module("ln", nn.LayerNorm(self.dim)) # Losses self.print_loss = training_args.print_loss self.ref_loss_factor = training_args.ref_loss_factor # Cross Entropy Loss self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # Distillation Loss self.distill_loss_div_std = training_args.distill_loss_div_std self.distill_loss_type = training_args.distill_loss_type self.distill_loss_factor = training_args.distill_loss_factor if self.distill_loss_type == "smooth_l1": self.distill_loss_fct = nn.SmoothL1Loss() elif self.distill_loss_type == "l2": self.distill_loss_fct = nn.MSELoss() else: raise NotImplementedError # general self.fix_attn_mask = training_args.fix_attn_mask if self.tokenizer.pad_token_id is None: self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.pad_token_id = self.pad_token_id if self.training: self.init() def get_embd(self, model, model_name): try: if "pythia" in model_name: return model.get_base_model().gpt_neox.embed_in elif "gpt2" in model_name: try: return model.get_base_model().transformer.wte except Exception: # no lora return model.transformer.wte else: try: return model.get_base_model().model.embed_tokens except Exception: # no lora return model.model.embed_tokens except AttributeError: if "pythia" in model_name: return model.gpt_neox.embed_in raise NotImplementedError def init(self): print_trainable_parameters(self) if ( self.training_args.restore_from is not None and self.training_args.restore_from != "" ): print( f"Loading from the pretrained checkpoint: {self.training_args.restore_from}..." ) state_dict = load_file(self.training_args.restore_from) self.load_state_dict(state_dict) print(f"Finished loading from {self.training_args.restore_from}") def forward( self, encoder_input_ids: torch.LongTensor = None, decoder_input_ids: torch.LongTensor = None, ref_input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, ref_answer_position: Optional[torch.LongTensor] = None, model_answer_position: Optional[torch.LongTensor] = None, ref_attention_mask: Optional[torch.LongTensor] = None, ref_labels: torch.LongTensor = None, step: int = None, step_ratio: float = None ): if not self.fix_attn_mask: ref_attention_mask = None # Encode the question past_key_values = None outputs = self.codi(input_ids=encoder_input_ids, use_cache=True, output_hidden_states=True, past_key_values=past_key_values, attention_mask=encoder_attention_mask) past_key_values = outputs.past_key_values latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1) # as the next input if self.use_prj: latent_embd = self.prj(latent_embd) len_pred_loss = 0 dynamic_mask = None if self.fix_attn_mask: dynamic_mask = torch.ones((encoder_attention_mask.size(0), self.num_latent), device=ref_labels.device) # Iterate over the latent embeddings distill_loss_total = 0 ce_loss_total = 0 with torch.no_grad(): ref_outputs = self.codi(input_ids=ref_input_ids, output_hidden_states=True, attention_mask=ref_attention_mask) ref_outputs_with_grad = self.codi(input_ids=ref_input_ids, output_hidden_states=True, attention_mask=ref_attention_mask) # Formatting for deprecated exps ref_outputs_list = [ref_outputs] ref_input_ids = [ref_input_ids] # Process the position tensor # Normalise the position definition if "llama" in self.model_name.lower() or "qwen" in self.model_name.lower(): # there is one more token standing for " " model_answer_position = model_answer_position + 1 ref_answer_position = ref_answer_position + 1 # For DEBUG: Print the probability of the teacher task to predict the correct answer if self.training_args.print_ref_model_stats: for i, (ref_inputs, ref_outputs) in enumerate(zip(ref_input_ids, ref_outputs_list)): # evalutae the reference model if len(ref_outputs_list) > 1: pos = ref_answer_position[i] else: pos = ref_answer_position ref_probs = torch.nn.functional.softmax(ref_outputs.logits, dim=-1) input_positions = (pos-1).unsqueeze(1).unsqueeze(1).expand(-1, -1, ref_probs.size(2)) ref_probs_at_positions = ref_probs.gather(1, input_positions) probe_positions_positions = pos.unsqueeze(1) probe_positions = ref_inputs.gather(1, probe_positions_positions).unsqueeze(1) ref_probs_of_target = ref_probs_at_positions.gather(2, probe_positions) print(f'stage{i}: mean of the prob of the target token: {ref_probs_of_target.mean()}') # the model answer position is the position of the eot token to predict the first token of the response model_answer_position = model_answer_position - 1 ref_answer_position = ref_answer_position -1 num_latent = self.num_latent if self.num_latent != 0: for i in range(num_latent): # Implicit CoT generation outputs = self.codi(inputs_embeds=latent_embd, use_cache=True, output_hidden_states=True, past_key_values=past_key_values) past_key_values = outputs.past_key_values latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1) if self.use_prj: latent_embd = self.prj(latent_embd) # Calculate the distillation loss if i == num_latent - 1: # the last latent embedding # Decode the final answer in natural language embds = self.get_embd(self.codi, self.model_name)(decoder_input_ids) if dynamic_mask is not None: # Prevent attending the paddings decoder_mask = torch.ones((embds.size(0), embds.size(1)), dtype=torch.bool).to(dynamic_mask) dynamic_mask = torch.cat((encoder_attention_mask, dynamic_mask, decoder_mask), dim=1) dynamic_mask = dynamic_mask.bool() # Student task's output outputs = self.codi(inputs_embeds=embds, use_cache=True, output_hidden_states=True, past_key_values=past_key_values, attention_mask=dynamic_mask) # Teacher task's output ref_outputs = ref_outputs_list[0] distill_loss = 0 # Calculate distillation loss between the teacher's logits and the student's logits for every layer for j, (out, ref_out) in enumerate(zip(outputs.hidden_states, ref_outputs.hidden_states)): ref_selected = ref_out.gather(1, ref_answer_position.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, ref_out.size(-1))) out_selected = out.gather(1, model_answer_position.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, out.size(-1))) distill_loss_tmp = self.distill_loss_fct(out_selected, ref_selected.detach()) if self.distill_loss_div_std: if self.distill_loss_type == 'l2': distill_loss_tmp /= ref_selected.std() distill_loss_tmp /= ref_selected.std() distill_loss += distill_loss_tmp distill_loss /= len(outputs.hidden_states) if self.print_loss: print(f'latent{i}: distill_loss={distill_loss}') distill_loss_total += distill_loss # Calculate the CE loss for the student task if i == num_latent - 1: logits = outputs.logits effective_logits = logits[:, :-1, :] effective_logits = effective_logits.reshape(-1, logits.size(-1)) target_ids = labels[:, 1:].reshape(-1) ce_loss = self.loss_fct(effective_logits, target_ids) ce_loss_total += ce_loss # Calculate the CE loss for the teacher task ref_ce_loss = 0 ref_logits = ref_outputs_with_grad.logits effective_ref_logits = ref_logits[:, :-1, :] effective_ref_logits = effective_ref_logits.reshape(-1, ref_logits.size(-1)) ref_target_ids = ref_labels[:, 1:].reshape(-1) ref_ce_loss = self.loss_fct(effective_ref_logits, ref_target_ids) ref_ce_loss *= self.ref_loss_factor # Weigh the distillation loss distill_loss *= self.distill_loss_factor distill_loss_total *= self.distill_loss_factor if self.print_loss: print(f'loss={ce_loss+distill_loss}, ce_loss={ce_loss}, distill_loss={distill_loss}, ce_loss_total={ce_loss_total}, distill_loss_total={distill_loss_total}, ref_ce_loss={ref_ce_loss}') loss = ce_loss_total + distill_loss_total + ref_ce_loss if ce_loss_total != 0: ce_loss_total = ce_loss_total.detach().item() if distill_loss_total != 0: distill_loss_total = distill_loss_total.detach().item() if ref_ce_loss != 0: ref_ce_loss = ref_ce_loss.detach().item() return {"loss": loss, "logits": logits, "ce_loss": ce_loss_total, "distill_loss": distill_loss_total, "ref_ce_loss": ref_ce_loss}