shillm / src /generate_vec.py
nkasmanoff's picture
Upload folder using huggingface_hub
a264cff verified
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import torch
import os
import argparse
def load_jsonl(file_path):
with open(file_path, "r") as f:
return [json.loads(line) for line in f]
def get_hidden_p_and_r(model, tokenizer, prompts, responses, layer_list=None):
max_layer = model.config.num_hidden_layers
if layer_list is None:
layer_list = list(range(max_layer + 1))
prompt_avg = [[] for _ in range(max_layer + 1)]
response_avg = [[] for _ in range(max_layer + 1)]
prompt_last = [[] for _ in range(max_layer + 1)]
texts = [p + a for p, a in zip(prompts, responses)]
for text, prompt in tqdm(zip(texts, prompts), total=len(texts)):
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(
model.device
)
prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
for layer in layer_list:
prompt_avg[layer].append(
outputs.hidden_states[layer][:, :prompt_len, :]
.mean(dim=1)
.detach()
.cpu()
)
response_avg[layer].append(
outputs.hidden_states[layer][:, prompt_len:, :]
.mean(dim=1)
.detach()
.cpu()
)
prompt_last[layer].append(
outputs.hidden_states[layer][:, prompt_len - 1, :].detach().cpu()
)
del outputs
for layer in layer_list:
prompt_avg[layer] = torch.cat(prompt_avg[layer], dim=0)
prompt_last[layer] = torch.cat(prompt_last[layer], dim=0)
response_avg[layer] = torch.cat(response_avg[layer], dim=0)
return prompt_avg, prompt_last, response_avg
import pandas as pd
import os
def get_persona_effective(pos_path, neg_path, trait, threshold=50):
persona_pos = pd.read_csv(pos_path)
persona_neg = pd.read_csv(neg_path)
mask = (
(persona_pos[trait] >= threshold)
& (persona_neg[trait] < 100 - threshold)
& (persona_pos["coherence"] >= 50)
& (persona_neg["coherence"] >= 50)
)
persona_pos_effective = persona_pos[mask]
persona_neg_effective = persona_neg[mask]
persona_pos_effective_prompts = persona_pos_effective["prompt"].tolist()
persona_neg_effective_prompts = persona_neg_effective["prompt"].tolist()
persona_pos_effective_responses = persona_pos_effective["answer"].tolist()
persona_neg_effective_responses = persona_neg_effective["answer"].tolist()
return (
persona_pos_effective,
persona_neg_effective,
persona_pos_effective_prompts,
persona_neg_effective_prompts,
persona_pos_effective_responses,
persona_neg_effective_responses,
)
def save_persona_vector(model_name, pos_path, neg_path, trait, save_dir, threshold=75):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
(
persona_pos_effective,
persona_neg_effective,
persona_pos_effective_prompts,
persona_neg_effective_prompts,
persona_pos_effective_responses,
persona_neg_effective_responses,
) = get_persona_effective(pos_path, neg_path, trait, threshold)
(
persona_effective_prompt_avg,
persona_effective_prompt_last,
persona_effective_response_avg,
) = ({}, {}, {})
(
persona_effective_prompt_avg["pos"],
persona_effective_prompt_last["pos"],
persona_effective_response_avg["pos"],
) = get_hidden_p_and_r(
model, tokenizer, persona_pos_effective_prompts, persona_pos_effective_responses
)
(
persona_effective_prompt_avg["neg"],
persona_effective_prompt_last["neg"],
persona_effective_response_avg["neg"],
) = get_hidden_p_and_r(
model, tokenizer, persona_neg_effective_prompts, persona_neg_effective_responses
)
persona_effective_prompt_avg_diff = torch.stack(
[
persona_effective_prompt_avg["pos"][l].mean(0).float()
- persona_effective_prompt_avg["neg"][l].mean(0).float()
for l in range(len(persona_effective_prompt_avg["pos"]))
],
dim=0,
)
persona_effective_response_avg_diff = torch.stack(
[
persona_effective_response_avg["pos"][l].mean(0).float()
- persona_effective_response_avg["neg"][l].mean(0).float()
for l in range(len(persona_effective_response_avg["pos"]))
],
dim=0,
)
persona_effective_prompt_last_diff = torch.stack(
[
persona_effective_prompt_last["pos"][l].mean(0).float()
- persona_effective_prompt_last["neg"][l].mean(0).float()
for l in range(len(persona_effective_prompt_last["pos"]))
],
dim=0,
)
os.makedirs(save_dir, exist_ok=True)
torch.save(
persona_effective_prompt_avg_diff, f"{save_dir}/{trait}_prompt_avg_diff.pt"
)
torch.save(
persona_effective_response_avg_diff, f"{save_dir}/{trait}_response_avg_diff.pt"
)
torch.save(
persona_effective_prompt_last_diff, f"{save_dir}/{trait}_prompt_last_diff.pt"
)
print(f"Persona vectors saved to {save_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--pos_path", type=str, required=True)
parser.add_argument("--neg_path", type=str, required=True)
parser.add_argument("--trait", type=str, required=True)
parser.add_argument("--save_dir", type=str, required=True)
parser.add_argument("--threshold", type=int, default=85)
args = parser.parse_args()
save_persona_vector(
args.model_name,
args.pos_path,
args.neg_path,
args.trait,
args.save_dir,
args.threshold,
)