from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer import json import torch import os import argparse def load_jsonl(file_path): with open(file_path, "r") as f: return [json.loads(line) for line in f] def get_hidden_p_and_r(model, tokenizer, prompts, responses, layer_list=None): max_layer = model.config.num_hidden_layers if layer_list is None: layer_list = list(range(max_layer + 1)) prompt_avg = [[] for _ in range(max_layer + 1)] response_avg = [[] for _ in range(max_layer + 1)] prompt_last = [[] for _ in range(max_layer + 1)] texts = [p + a for p, a in zip(prompts, responses)] for text, prompt in tqdm(zip(texts, prompts), total=len(texts)): inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to( model.device ) prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False)) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) for layer in layer_list: prompt_avg[layer].append( outputs.hidden_states[layer][:, :prompt_len, :] .mean(dim=1) .detach() .cpu() ) response_avg[layer].append( outputs.hidden_states[layer][:, prompt_len:, :] .mean(dim=1) .detach() .cpu() ) prompt_last[layer].append( outputs.hidden_states[layer][:, prompt_len - 1, :].detach().cpu() ) del outputs for layer in layer_list: prompt_avg[layer] = torch.cat(prompt_avg[layer], dim=0) prompt_last[layer] = torch.cat(prompt_last[layer], dim=0) response_avg[layer] = torch.cat(response_avg[layer], dim=0) return prompt_avg, prompt_last, response_avg import pandas as pd import os def get_persona_effective(pos_path, neg_path, trait, threshold=50): persona_pos = pd.read_csv(pos_path) persona_neg = pd.read_csv(neg_path) mask = ( (persona_pos[trait] >= threshold) & (persona_neg[trait] < 100 - threshold) & (persona_pos["coherence"] >= 50) & (persona_neg["coherence"] >= 50) ) persona_pos_effective = persona_pos[mask] persona_neg_effective = persona_neg[mask] persona_pos_effective_prompts = persona_pos_effective["prompt"].tolist() persona_neg_effective_prompts = persona_neg_effective["prompt"].tolist() persona_pos_effective_responses = persona_pos_effective["answer"].tolist() persona_neg_effective_responses = persona_neg_effective["answer"].tolist() return ( persona_pos_effective, persona_neg_effective, persona_pos_effective_prompts, persona_neg_effective_prompts, persona_pos_effective_responses, persona_neg_effective_responses, ) def save_persona_vector(model_name, pos_path, neg_path, trait, save_dir, threshold=75): model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) ( persona_pos_effective, persona_neg_effective, persona_pos_effective_prompts, persona_neg_effective_prompts, persona_pos_effective_responses, persona_neg_effective_responses, ) = get_persona_effective(pos_path, neg_path, trait, threshold) ( persona_effective_prompt_avg, persona_effective_prompt_last, persona_effective_response_avg, ) = ({}, {}, {}) ( persona_effective_prompt_avg["pos"], persona_effective_prompt_last["pos"], persona_effective_response_avg["pos"], ) = get_hidden_p_and_r( model, tokenizer, persona_pos_effective_prompts, persona_pos_effective_responses ) ( persona_effective_prompt_avg["neg"], persona_effective_prompt_last["neg"], persona_effective_response_avg["neg"], ) = get_hidden_p_and_r( model, tokenizer, persona_neg_effective_prompts, persona_neg_effective_responses ) persona_effective_prompt_avg_diff = torch.stack( [ persona_effective_prompt_avg["pos"][l].mean(0).float() - persona_effective_prompt_avg["neg"][l].mean(0).float() for l in range(len(persona_effective_prompt_avg["pos"])) ], dim=0, ) persona_effective_response_avg_diff = torch.stack( [ persona_effective_response_avg["pos"][l].mean(0).float() - persona_effective_response_avg["neg"][l].mean(0).float() for l in range(len(persona_effective_response_avg["pos"])) ], dim=0, ) persona_effective_prompt_last_diff = torch.stack( [ persona_effective_prompt_last["pos"][l].mean(0).float() - persona_effective_prompt_last["neg"][l].mean(0).float() for l in range(len(persona_effective_prompt_last["pos"])) ], dim=0, ) os.makedirs(save_dir, exist_ok=True) torch.save( persona_effective_prompt_avg_diff, f"{save_dir}/{trait}_prompt_avg_diff.pt" ) torch.save( persona_effective_response_avg_diff, f"{save_dir}/{trait}_response_avg_diff.pt" ) torch.save( persona_effective_prompt_last_diff, f"{save_dir}/{trait}_prompt_last_diff.pt" ) print(f"Persona vectors saved to {save_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, required=True) parser.add_argument("--pos_path", type=str, required=True) parser.add_argument("--neg_path", type=str, required=True) parser.add_argument("--trait", type=str, required=True) parser.add_argument("--save_dir", type=str, required=True) parser.add_argument("--threshold", type=int, default=85) args = parser.parse_args() save_persona_vector( args.model_name, args.pos_path, args.neg_path, args.trait, args.save_dir, args.threshold, )