from peft import get_peft_model, LoraConfig, AdaLoraConfig, TaskType import os import hydra from omegaconf import DictConfig, OmegaConf from utils import ( train_text_to_text_model, model_inference, initialize_text_to_text_model, transform_dataset, merge_llama, ) import json import math from datasets import load_dataset import wandb from data import * from typing import List import torch from copy import deepcopy import logging from tqdm import tqdm, trange from typing import Tuple, List, Dict from peft.tuners.lora.layer import Linear as LoraLinear from split import rebuild import re import itertools import matplotlib.pyplot as plt from commonsense_evaluate import common_evaluate from eval_humaneval import humaneval # from eval_mtbench import evaluate_mtbench_from_model log = logging.getLogger(__name__) s = 0 def kron(A, B): return (A[:, None, :, None] * B[None, :, None, :]).reshape(A.shape[0] * B.shape[0], A.shape[1] * B.shape[1]) def modified_gram_schmidt(W, eps=1e-12): """ Modified Gram–Schmidt QR W: (m, n) Returns: Q: (m, n) R: (n, n) """ m, n = W.shape Q = W.clone() R = torch.zeros(n, n, device=W.device, dtype=W.dtype) for i in range(n): R[i, i] = torch.norm(Q[:, i]) if R[i, i] < eps: raise RuntimeError("Linearly dependent columns") Q[:, i] = Q[:, i] / R[i, i] for j in range(i + 1, n): R[i, j] = torch.dot(Q[:, i], Q[:, j]) Q[:, j] = Q[:, j] - R[i, j] * Q[:, i] return Q, R def seed_everything(seed: int): import random, os import numpy as np import torch random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def find_all_linear_modules(model) -> List[str]: r""" Finds all available modules to apply lora. """ linear_cls = torch.nn.Linear output_layer_names = ["lm_head", "embed_tokens"] module_names = set() for name, module in model.named_modules(): if isinstance(module, linear_cls) and not any( [output_layer in name for output_layer in output_layer_names] ): module_names.add(name.split(".")[-1]) return list(module_names) def find_hidden_state_size(model): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): return min(module.weight.shape) return None def calculate_gain( nonlinearity, param ) -> float: linear_fns = [ "linear", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", ] if nonlinearity in linear_fns or nonlinearity == "sigmoid": return 1 elif nonlinearity == "tanh": return 5.0 / 3 elif nonlinearity == "relu": return math.sqrt(2.0) elif nonlinearity == "leaky_relu": if param is None: negative_slope = 0.01 elif ( not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float) ): # True/False are instances of int, hence check above negative_slope = param else: raise ValueError(f"negative_slope {param} not a valid number") return math.sqrt(2.0 / (1 + negative_slope**2)) elif nonlinearity == "selu": return ( 3.0 / 4 ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) else: raise ValueError(f"Unsupported nonlinearity {nonlinearity}") def kaimings(weight, a=math.sqrt(5), fan=4096): nonlinearity = "leaky_relu" generator = None gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation with torch.no_grad(): return weight.uniform_(-bound, bound, generator=generator) @torch.no_grad() def reinit_lora_modules(name, module, init_config, **kwargs): r""" Reinitialize the lora model with the given configuration. """ lora_r1 = kwargs["lora_r1"] lora_r2 = kwargs["lora_r2"] lora_r = kwargs["lora_r"] # lora_r1 = min(module.lora_A.default.weight.shape) # lora_r2 = min(module.lora_B.default.weight.shape) a_dim = max(module.lora_A.default.weight.shape) b_dim = max(module.lora_B.default.weight.shape) if init_config.mode == "simple": match init_config.lora_A: case "gaussian": torch.nn.init.normal_( module.lora_A.default.weight, mean=0.0, std=init_config.lora_A_std ) case "kaiming": # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 torch.nn.init.kaiming_uniform_(module.lora_A.default.weight, a=math.sqrt(5)) case "kaimings": kaimings(module.lora_A.default.weight, a=math.sqrt(5), fan=module.weight.size(1)) case "fan_out_kaiming": torch.nn.init.kaiming_normal_( module.lora_A.default.weight, mode="fan_out" ) case "xavier": torch.nn.init.xavier_normal_(module.lora_A.default.weight) case "zeros": torch.nn.init.zeros_(module.lora_A.default.weight) case "unit": torch.nn.init.normal_( module.lora_A.default.weight, mean=0.0, std=1.0 / (a_dim**0.5) ) case "orthogonal": torch.nn.init.orthogonal_(module.lora_A.default.weight) case _: raise ValueError(f"Unknown lora_A initialization: {init_config.lora_A}") match init_config.lora_B: case "gaussian": torch.nn.init.normal_( module.lora_B.default.weight, mean=0.0, std=init_config.lora_B_std ) case "kaiming": torch.nn.init.kaiming_normal_(module.lora_B.default.weight.T, a=math.sqrt(5)) case "fan_out_kaiming": torch.nn.init.kaiming_normal_( module.lora_B.default.weight, mode="fan_out" ) case "xavier": torch.nn.init.xavier_normal_(module.lora_B.default.weight) case "zeros": torch.nn.init.zeros_(module.lora_B.default.weight) case "unit": torch.nn.init.normal_( module.lora_B.default.weight, mean=0.0, std=1.0 / (b_dim**0.5) ) case "orthogonal": torch.nn.init.orthogonal_(module.lora_B.default.weight) case _: raise ValueError(f"Unknown lora_B initialization: {init_config.lora_B}") if init_config.get("scale", "") == "stable": gamma = init_config.stable_gamma #module.lora_B.default.weight.data *= (m**0.25) / gamma**0.5 #module.lora_A.default.weight.data *= (n**0.25) / gamma**0.5 #module.lora_B.default.weight.data *= (m**0.25) #module.lora_A.default.weight.data *= (n**0.25) module.lora_B.default.weight.data *= 1 module.lora_A.default.weight.data *= 1 elif init_config.mode == "svd": U, S, V = torch.svd_lowrank(module.weight.float(), q=4 * lora_r, niter=4) V = V.T m, n = module.weight.shape if init_config.scale == "default": S = S / module.scaling["default"] module.lora_B.default.weight = torch.nn.Parameter( (U[:, :lora_r] * torch.sqrt(S[:lora_r])).contiguous() ) module.lora_A.default.weight = torch.nn.Parameter( (V[:lora_r, :].T * torch.sqrt(S[:lora_r])).T.contiguous() ) elif init_config.scale == "stable": gamma = init_config.stable_gamma module.lora_B.default.weight = torch.nn.Parameter( (U[:, :lora_r] * (m**0.25) / gamma**0.5).contiguous() ) module.lora_A.default.weight = torch.nn.Parameter( (V[:lora_r, :] * (n**0.25) / gamma**0.5).contiguous() ) elif init_config.scale == "unit": module.lora_B.default.weight = torch.nn.Parameter( (U[:, :lora_r]).contiguous() ) module.lora_A.default.weight = torch.nn.Parameter( (V[:lora_r, :]).contiguous() ) elif init_config.scale == "normalized": S_sum = S[:lora_r].sum() module.lora_B.default.weight = torch.nn.Parameter( (U[:, :lora_r] * torch.sqrt(S[:lora_r])/torch.sqrt(S_sum)*lora_r**0.5).contiguous() ) module.lora_A.default.weight = torch.nn.Parameter( (V[:lora_r, :].T * torch.sqrt(S[:lora_r])/torch.sqrt(S_sum)*lora_r**0.5).T.contiguous() ) elif init_config.mode == "qr": W = module.weight.float() k,d = W.shape Q, R = torch.linalg.qr(W, mode="reduced") diag = torch.sign(torch.diag(R)) diag[diag == 0] = 1.0 D = torch.diag(diag) Q = Q @ D R = D @ R print(torch.min(torch.diag(R))) lambda_vals = torch.abs(torch.diag(R)) perm = torch.argsort(lambda_vals, descending=True) I1 = perm[:lora_r2] I2 = perm[lora_r2:lora_r1+lora_r2] Q1 = Q[:, I1] # (m, r_high) R1 = R[I1] Q2 = Q[:, I2] R2 = R[I2] B = Q1[:k // lora_r1] @ R1[:, :lora_r2] A = (Q2[:d // lora_r2] @ R2[:, :lora_r1]).T module.lora_B.default.weight = torch.nn.Parameter(B.contiguous().to(module.lora_B.default.weight.device)) module.lora_A.default.weight = torch.nn.Parameter(A.contiguous().to(module.lora_A.default.weight.device)) elif init_config.mode == "gradient": named_grad = kwargs["named_grads"] grad_name = ".".join(name.split(".")[2:]) + ".weight" grads = named_grad[grad_name] # print(grads.shape) if lora_r1 == 1 and lora_r2 == 1: U, S, V = torch.svd_lowrank(-grads.cuda().float(), q=512, niter=16) else: U, S, V = torch.svd_lowrank(rebuild(-grads.float(),lora_r1, lora_r2), q=4*lora_r, niter=16) V = V.T # set direction if init_config.direction == "ArBr": if lora_r1 == 1 and lora_r2 == 1: B = U[:, :lora_r] @ torch.diag(torch.sqrt(S[:lora_r])) / torch.sqrt(S[0]) / 128.0 **0.5 A = torch.diag(torch.sqrt(S[:lora_r])) @ V[:lora_r, :] / torch.sqrt(S[0]) / 128.0 **0.5 module.lora_B.default.weight = torch.nn.Parameter(B.contiguous().to(module.lora_B.default.weight.device)) module.lora_A.default.weight = torch.nn.Parameter(A.contiguous().to(module.lora_A.default.weight.device)) else: for i in range(lora_r): B = (S[i] / S[0] / 1024)**0.5 * V[i, :].reshape([lora_r2, grads.shape[0]//lora_r1]).T A = (S[i] / S[0] / 1024)**0.5 * U[:, i].reshape([grads.shape[1]//lora_r2,lora_r1]).T module.lora_A.default.weight[i::lora_r] = torch.nn.Parameter(A.contiguous().to(module.lora_A.default.weight.device)) module.lora_B.default.weight[:,i::lora_r] = torch.nn.Parameter(B.contiguous().to(module.lora_B.default.weight.device)) elif init_config.direction == "A2rBr": B = U[:, :lora_r] A = V[lora_r : 2 * lora_r, :] elif init_config.direction == "ArB2r": B = U[:, lora_r : 2 * lora_r] A = V[:lora_r, :] scaling_factor = module.scaling["default"] if init_config.scale == "gd": A = A / scaling_factor B = B / scaling_factor elif init_config.scale == "unit": # Because A,B is orthogonal, do not need to scale pass elif init_config.scale == "stable": m, n = grads.shape # m: feature_out, n: feature_in # the scale of output is only related to the feature_out gamma = init_config.stable_gamma elif init_config.scale == "weightS": _, S, _ = torch.svd_lowrank(module.weight.float(), q=4 * lora_r, niter=4) S = S / module.scaling["default"] avg_s = torch.sqrt(S[:lora_r]).mean().to(A.device) B = B * avg_s A = A * avg_s # module.lora_B.default.weight = torch.nn.Parameter(B.contiguous().to(module.lora_B.default.weight.device)) # module.lora_A.default.weight = torch.nn.Parameter(A.contiguous().to(module.lora_A.default.weight.device)) with torch.no_grad(): # consider dtype not in init_config if "dtype" not in init_config: pass elif init_config.dtype == "bf16": module.lora_A.default.weight.data = module.lora_A.default.weight.data.to( torch.bfloat16 ) module.lora_B.default.weight.data = module.lora_B.default.weight.data.to( torch.bfloat16 ) elif init_config.dtype == "fp32": module.lora_A.default.weight.data = module.lora_A.default.weight.data.to( torch.float32 ) module.lora_B.default.weight.data = module.lora_B.default.weight.data.to( torch.float32 ) # If lora_A@lora_B is not zero, then we need to subtract lora_A@lora_B from the original weight matrix if init_config.mode == "qr": offset = (kron(module.lora_B.default.weight.contiguous(),module.lora_A.default.weight.contiguous())).to( module.weight.data.device ) else: offset = 0 # offset = (module.lora_B.default.weight @ module.lora_A.default.weight).to( # module.weight.data.device # ) scaling_factor = module.scaling["default"] offset *= scaling_factor if "norm_clip" in init_config and init_config.norm_clip: # for numerical stability, offset's largest value must be less then weight's largest value ratio = torch.max(torch.abs(module.weight.data)) / torch.max( torch.abs(offset) ) if ratio < 1: offset *= ratio module.lora_A.default.weight.data *= ratio**0.5 module.lora_B.default.weight.data *= ratio**0.5 log.warning(f"Clipping offset by {ratio}") try: module.weight.data -= offset except: breakpoint() def reinit_lora(model, init_config, **kwargs): r""" Reinitialize the lora model with the given configuration. """ for name, module in tqdm( model.named_modules(), desc="Reinitializing Lora", total=len(list(model.named_modules())), ): if isinstance(module, LoraLinear): reinit_lora_modules(name, module, init_config, **kwargs) return model def get_record_gradient_hook(model, record_dict): def record_gradient_hook(grad): for n, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if n not in record_dict: record_dict[n] = p.grad.cpu() else: record_dict[n] += p.grad.cpu() p.grad = None return grad return record_gradient_hook def estimate_gradient( model, dataset, batch_size: int = 4 ) -> Dict[str, List[torch.Tensor]]: r""" Estimate the gradient of the model on the given dataset """ log.info("Estimating gradient") model.train() named_grads = {} hooks = [] for name, param in model.named_parameters(): hook = param.register_hook(get_record_gradient_hook(model, named_grads)) hooks.append(hook) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) num = 0 for batch in tqdm(dataloader, desc="Estimating gradient"): num += 1 batch = {k: v.to(model.device) for k, v in batch.items()} outputs = model(**batch) outputs.loss.backward() get_record_gradient_hook(model, named_grads)(None) # get gradient of last layer # make sure the gradient is cleared for n, p in model.named_parameters(): if p.grad is not None: p.grad = None for n, g in named_grads.items(): named_grads[n] /= num for hook in hooks: hook.remove() torch.cuda.empty_cache() return named_grads def extract_num(text): # Regex pattern to find the number following '####' pattern = r'####\s*(\d+)' # Using re.search to find the first match match = re.search(pattern, text) if match: result = match.group(1) print(text) else: print(text) result = "" try: return int(result.replace(",", "")) except: print(f"'{result}' can't be converted") return 0 def eval_gsm8k(model,tokenizer,model_type, test_set): all = 0 correct = 0 t = tqdm(test_set) for example in t: # print(example['x']) pred_text = model_inference(model, tokenizer, example['x'], model_type, max_target_length=512) gt = extract_num(example["y"]) pred = extract_num(pred_text) correct += int(gt == pred) all += 1 t.set_description(f"Accuracy: {correct / all * 100:02f}%") print("Acc:", correct / all) # append to gsm8k_results.txt (create if not exists) if not os.path.exists("gsm8k_results.txt"): with open("gsm8k_results.txt", "w") as f: f.write("Model Acc\n") with open("gsm8k_results.txt", "a") as f: f.write(f"{model_name} {correct / all}\n") @hydra.main(version_base="1.2", config_path="conf", config_name="config") def run_exp(cfg: DictConfig): log.info(OmegaConf.to_yaml(cfg)) seed_everything(cfg.seed) model_name = cfg.model.name model_type = cfg.model.type dataset_name = cfg.dataset_name dataset_func = DATASET_MAP[dataset_name] use_peft = cfg.peft.use_peft if_use_rslora = cfg.peft.use_rslora lora_r = cfg.peft.lora_r lora_r1 = cfg.peft.lora_r1 lora_r2 = cfg.peft.lora_r2 lora_relative_r = cfg.peft.lora_relative_r lora_target_modules = cfg.peft.lora_target_modules train_embeddings = cfg.peft.train_embeddings if cfg.dry_run: return if use_peft: lora_r = cfg.peft.lora_r lora_r1 = cfg.peft.lora_r1 lora_r2 = cfg.peft.lora_r2 lora_alpha = cfg.peft.lora_alpha lora_relative_r = None init = cfg.init.mode else: lora_r = None lora_target_modules = None lora_relative_r = None train_embeddings = True config = { "model_name": model_name, "dataset_name": dataset_name, "use_peft": use_peft, "lora_r1": lora_r1, "lora_r2": lora_r2, "lora_r": lora_r, "lora_alpha": lora_alpha, "init": init, "lora_target_modules": str(lora_target_modules), "lora_relative_r": lora_relative_r, "train_embeddings": train_embeddings, } if cfg.wandb.name: name = cfg.wandb.name else: name = "_".join([f"{k}={v}" for k, v in config.items()]) cfg.wandb.project += "_" + cfg.dataset_name wandb.init( project=cfg.wandb.project, name=name, config=config, ) train_set, val_set, eval_set = dataset_func() model, tokenizer = initialize_text_to_text_model( model_name, model_type, cfg.model.bf16, cfg.peft.use_peft, flash_attention=True ) additional_kwargs = {} if use_peft and cfg.init.mode == "gradient": if isinstance(train_set, list): temp_set = train_set[: cfg.init.bsz * cfg.init.iters] else: temp_set = train_set.select(range(cfg.init.bsz * cfg.init.iters)) transform_dataset( model_type=model_type, dataset=temp_set, tokenizer=tokenizer, max_length=cfg.init.max_length, ) # named_grads = estimate_layer_inputs(model, temp_set, cfg.init.bsz) named_grads = estimate_gradient(model, temp_set, cfg.init.bsz) additional_kwargs["named_grads"] = named_grads additional_kwargs["lora_r1"] = lora_r1 additional_kwargs["lora_r"] = lora_r additional_kwargs["lora_r2"] = lora_r2 if lora_target_modules == "all": lora_target_modules = find_all_linear_modules(model) else: lora_target_modules = list(lora_target_modules) if lora_target_modules else [] if lora_relative_r is not None: hidden_size = find_hidden_state_size(model) lora_r = int(hidden_size * lora_relative_r) log.info(f"lora_r is set to {hidden_size} * {lora_relative_r} = {lora_r}") if use_peft and cfg.peft.get("dora", False): log.info("Using Dora") peft_config = LoraConfig( r1=lora_r1, r2=lora_r2, lora_alpha=cfg.peft.lora_alpha, target_modules=lora_target_modules, use_rslora=if_use_rslora, use_dora=True, ) orig_model_params = sum(p.numel() for p in model.parameters()) model = get_peft_model(model, peft_config) trainable_params, all_param = model.get_nb_trainable_parameters() rate = { "trainable_params": trainable_params, "orig_params": orig_model_params, "all_params": all_param, "trainable_ratio": trainable_params / all_param, "param_ratio": trainable_params / orig_model_params, } elif use_peft and cfg.peft.get("adalora", False): log.info("Using AdaLora") peft_config = AdaLoraConfig( task_type=TaskType.CAUSAL_LM, target_r=lora_r, lora_alpha=cfg.peft.lora_alpha, target_modules=lora_target_modules, total_step=int(len(train_set)/cfg.model.real_batch_size)*cfg.model.epochs, ) orig_model_params = sum(p.numel() for p in model.parameters()) model = get_peft_model(model, peft_config) trainable_params, all_param = model.get_nb_trainable_parameters() rate = { "trainable_params": trainable_params, "orig_params": orig_model_params, "all_params": all_param, "trainable_ratio": trainable_params / all_param, "param_ratio": trainable_params / orig_model_params, } elif use_peft: peft_config = LoraConfig( r1=lora_r1, r2=lora_r2, r= lora_r, lora_alpha=cfg.peft.lora_alpha, target_modules=lora_target_modules, use_rslora=if_use_rslora, ) orig_model_params = sum(p.numel() for p in model.parameters()) model = get_peft_model(model, peft_config) reinit_lora(model, cfg.init, **additional_kwargs) if train_embeddings: model.lm_head.weight.requires_grad = True trainable_params, all_param = model.get_nb_trainable_parameters() rate = { "trainable_params": trainable_params, "orig_params": orig_model_params, "all_params": all_param, "trainable_ratio": trainable_params / all_param, "param_ratio": trainable_params / orig_model_params, } save_dir = os.path.join( "results", f"{cfg.wandb.project}/{name}/{cfg.seed}", "orig_checkpoint" ) model.save_pretrained(save_dir) adapter_config = json.load(open(os.path.join(save_dir, "adapter_config.json"))) adapter_config["lora_alpha"] = -adapter_config["lora_alpha"] json.dump( adapter_config, open(os.path.join(save_dir, "adapter_config.json"), "w") ) else: # full finetune all_param = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) rate = { "trainable_params": trainable_params, "orig_params": all_param, "all_params": all_param, "trainable_ratio": trainable_params / all_param, "param_ratio": 1, } log.info(rate) wandb.summary.update(rate) training_loop = train_text_to_text_model global s print(s) model = training_loop( f"{cfg.wandb.project}/{name}", train_set, val_set, model, tokenizer, model_type, num_train_epochs=cfg.model.epochs, per_device_batch_size=cfg.model.per_device_batch_size, real_batch_size=cfg.model.real_batch_size, bf16=cfg.model.bf16, eval_epochs=cfg.model.eval_epochs, early_stopping_patience=cfg.model.early_stopping_patience, max_length=cfg.model.max_length, logging_steps=cfg.model.logging_steps, use_loraplus=cfg.peft.use_loraplus, loraplus_lr_ratio=cfg.peft.loraplus_lr_ratio, learning_rate=cfg.model.learning_rate, # deepspeed=( # "z3_offload_all_bf16.json" if cfg.peft == False else None # ), gradient_checkpointing=cfg.get("gradient_checkpointing", False), seed=cfg.seed, ) save_dir = os.path.join( "results", f"{cfg.wandb.project}/{name}/{cfg.seed}" ) if not use_peft: model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) else: # merge_llama(os.path.join("results", f"{cfg.wandb.project}/{name}/{cfg.seed}")) pass log.info(f"Saving model to {save_dir}") if dataset_name == 'meta_math': train_set, val_set, eval_set = load_gsm8k() model.generation_config.pad_token_id = tokenizer.pad_token_id eval_gsm8k(model,tokenizer,model_type,eval_set) if dataset_name == 'codefeedback': model.generation_config.pad_token_id = tokenizer.pad_token_id humaneval(model,tokenizer,save_dir, model_type) wandb.finish() if __name__ == "__main__": run_exp()