from src.probes import ProbeClassification, ProbeClassificationMixScaler, LinearProbeClassification, LinearProbeClassificationMixScaler import os import torch.nn.functional as F import torch from tqdm.auto import tqdm from src.dataset import llama_v2_prompt import numpy as np from torch import nn device = "cuda" torch_device = "cuda" def load_probe_classifier(model_func, input_dim, num_classes, weight_path, **kwargs): """ Instantiate a ProbeClassification model and load its pretrained weights. Args: - input_dim (int): Input dimension for the classifier. - num_classes (int): Number of classes for classification. - weight_path (str): Path to the pretrained weights. Returns: - model: The ProbeClassification model with loaded weights. """ # Instantiate the model model = model_func(device, num_classes, input_dim, **kwargs) # Load the pretrained weights into the model model.load_state_dict(torch.load(weight_path)) return model num_classes = {"age": 4, "gender": 2, "education": 3, "socioeco": 3,} def return_classifier_dict(directory, model_func, chosen_layer=None, mix_scaler=False, sklearn=False, **kwargs): checkpoint_paths = os.listdir(directory) # file_paths = [os.path.join(directory, file) for file in checkpoint_paths if file.endswith("pth")] classifier_dict = {} for i in range(len(checkpoint_paths)): category = checkpoint_paths[i][:checkpoint_paths[i].find("_")] weight_path = os.path.join(directory, checkpoint_paths[i]) num_class = num_classes[category] if category == "gender" and sklearn: num_class = 1 if category not in classifier_dict.keys(): classifier_dict[category] = {} if mix_scaler: classifier_dict[category]["all"] = load_probe_classifier(model_func, 5120, num_classes=num_class, weight_path=weight_path, **kwargs) else: layer_num = int(checkpoint_paths[i][checkpoint_paths[i].rfind("_") + 1: checkpoint_paths[i].rfind(".pth")]) if chosen_layer is None or layer_num == chosen_layer: try: classifier_dict[category][layer_num] = load_probe_classifier(model_func, 5120, num_classes=num_class, weight_path=weight_path, **kwargs) except Exception as e: print(category) # print(e) return classifier_dict def split_into_messages(text: str) -> list[str]: # Constants used for splitting B_INST, E_INST = "[INST]", "[/INST]" # Use the tokens to split the text parts = [] current_message = "" for word in text.split(): # If we encounter a start or end token, and there's a current message, store it if word in [B_INST, E_INST] and current_message: parts.append(current_message.strip()) current_message = "" # If the word is not a token, add it to the current message elif word not in [B_INST, E_INST]: current_message += word + " " # Append any remaining message if current_message: parts.append(current_message.strip()) return parts def llama_v2_reverse(prompt: str) -> list[dict]: # Constants used in the LLaMa style B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" BOS, EOS = "", "" messages = [] sys_start = prompt.find(B_SYS) sys_end = prompt.rfind(E_SYS) if sys_start != -1 and sys_end != -1: system_msg = prompt[sys_start + len(B_SYS): sys_end] messages.append({"role": "system", "content": system_msg}) prompt = prompt[sys_end + len(E_SYS):] user_ai_msgs = split_into_messages(prompt) user_turn = True for message in user_ai_msgs: if user_turn: messages.append({"role": "user", "content": message}) else: messages.append({"role": "assistant", "content": message}) if user_turn: user_turn = False else: user_turn = True return messages def optimize_one_inter_rep(inter_rep, layer_name, target, probe, lr=1e-2, N=4, normalized=False): global first_time tensor = (inter_rep.clone()).to(torch_device).requires_grad_(True) rep_f = lambda: tensor target_clone = target.clone().to(torch_device).to(torch.float) cur_input_tensor = rep_f().clone().detach() if normalized: cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N * 100 / rep_f().norm() else: cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N return cur_input_tensor.clone() def edit_inter_rep_multi_layers(output, layer_name): """ This function must be called inside the script, given classifier dict and other hyperparameters are undefined in this function """ if residual: layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):] else: layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):layer_name.rfind(".mlp")] layer_num = int(layer_num) probe = classifier_dict[attribute][layer_num + 1] cloned_inter_rep = output[0][0][-1].unsqueeze(0).detach().clone().to(torch.float) with torch.enable_grad(): cloned_inter_rep = optimize_one_inter_rep(cloned_inter_rep, layer_name, cf_target, probe, lr=lr, N=N,) # output[1] = cloned_inter_rep.to(torch.float16) # print(len(output)) output[0][0][-1] = cloned_inter_rep[0].to(torch.float16) return output