|
|
import os |
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_dataset |
|
|
from rich.console import Console |
|
|
from rich.table import Table |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
GenerationConfig, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunRecord: |
|
|
wandb_url: str |
|
|
hf_repo_url: str |
|
|
hf_repo_id: str |
|
|
revision: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): |
|
|
torch.nn.init.normal_(layer.weight, std=std) |
|
|
torch.nn.init.constant_(layer.bias, val=bias_const) |
|
|
return layer |
|
|
|
|
|
|
|
|
class ScalarModelConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
base_model: str = "EleutherAI/pythia-160m", |
|
|
base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), |
|
|
hidden_size: int = 768, |
|
|
bias: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model = base_model |
|
|
self.base_config = base_config |
|
|
self.hidden_size = hidden_size |
|
|
self.bias = bias |
|
|
|
|
|
|
|
|
class ScalarModel(PreTrainedModel): |
|
|
config_class = ScalarModelConfig |
|
|
|
|
|
def __init__(self, config: ScalarModelConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.lm_backbone = AutoModel.from_pretrained( |
|
|
config.base_model, |
|
|
config=self.config.base_config, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.scalar_head = layer_init( |
|
|
nn.Linear(self.config.hidden_size, 1), |
|
|
std=1 / np.sqrt(self.config.hidden_size + 1), |
|
|
) |
|
|
|
|
|
def forward(self, **kwargs): |
|
|
output = self.lm_backbone(**kwargs) |
|
|
reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias |
|
|
return reward |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(lm_backbone, queries, tokenizer, generation_config): |
|
|
"""generate in a way that does not affect padding tokens""" |
|
|
context_length = queries.shape[1] |
|
|
attention_mask = queries != tokenizer.pad_token_id |
|
|
input_ids = torch.masked_fill(queries, ~attention_mask, 0) |
|
|
output = lm_backbone.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
|
|
|
generation_config=generation_config, |
|
|
return_dict_in_generate=True, |
|
|
) |
|
|
return torch.cat((queries, output.sequences[:, context_length:]), dim=1) |
|
|
|
|
|
|
|
|
def forward(model, query_responses, tokenizer): |
|
|
attention_mask = query_responses != tokenizer.pad_token_id |
|
|
position_ids = attention_mask.cumsum(1) - attention_mask.long() |
|
|
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) |
|
|
return model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
|
|
|
def get_reward(model, query_responses, tokenizer): |
|
|
attention_mask = query_responses != tokenizer.pad_token_id |
|
|
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) |
|
|
reward_logits = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
) |
|
|
sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) |
|
|
|
|
|
return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths], reward_logits |
|
|
|
|
|
|
|
|
def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: |
|
|
table = Table(show_lines=True) |
|
|
for column in df.columns: |
|
|
table.add_column(column) |
|
|
for _, row in df.iterrows(): |
|
|
table.add_row(*row.astype(str).tolist()) |
|
|
console.rule(f"[bold red]{title}") |
|
|
console.print(table) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
console = Console() |
|
|
if not os.path.exists("release_runs.csv"): |
|
|
import wandb |
|
|
|
|
|
keys = { |
|
|
"sft": "refactor-chosen-rejected3", |
|
|
"reward": "refactor-chosen-rejected3", |
|
|
"ppo_left_padding_new_nowhiten_reward": "refactor-chosen-rejected3", |
|
|
"dpo": "refactor-chosen-rejected3", |
|
|
} |
|
|
runs = [] |
|
|
for exp_name, tag in keys.items(): |
|
|
runs.extend( |
|
|
list( |
|
|
wandb.Api().runs( |
|
|
path=f"costa-huang/tldr_summarize", |
|
|
filters={ |
|
|
"$and": [ |
|
|
{f"config.exp_name.value": exp_name}, |
|
|
{"tags": {"$in": [tag]}}, |
|
|
] |
|
|
}, |
|
|
) |
|
|
) |
|
|
) |
|
|
table = defaultdict(list) |
|
|
for i in range(len(runs)): |
|
|
table["base_model"].append(runs[i].config["base_model"]) |
|
|
table["exp"].append(runs[i].config["exp_name"]) |
|
|
table["seed"].append(runs[i].config["seed"]) |
|
|
table["wandb_url"].append(runs[i].url) |
|
|
table["hf_repo_url"].append(runs[i].config["hf_repo_url"]) |
|
|
table["hf_repo_id"].append(runs[i].config["hf_repo_id"]) |
|
|
table["revision"].append(runs[i].config["run_name"]) |
|
|
df = pd.DataFrame(table) |
|
|
df.to_csv("release_runs.csv", index=False) |
|
|
else: |
|
|
df = pd.read_csv("release_runs.csv") |
|
|
|
|
|
|
|
|
df = df.groupby(["base_model", "exp", "seed"]).agg(lambda x: x.tolist()[0]) |
|
|
|
|
|
|
|
|
base_model_name = "EleutherAI/pythia-1b-deduped" |
|
|
sft_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "sft", 55513)]) |
|
|
ppo_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "ppo_left_padding_new_nowhiten_reward", 55513)]) |
|
|
dpo_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "dpo", 55513)]) |
|
|
|
|
|
rm_record = RunRecord(**df.loc[("EleutherAI/pythia-6.9b-deduped", "reward", 55513)]) |
|
|
|
|
|
device = torch.device("cpu") |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
response_length = 100 |
|
|
validation_generation_config = GenerationConfig( |
|
|
max_new_tokens=response_length, |
|
|
temperature=(0.01 + 1e-7), |
|
|
top_k=0.0, |
|
|
top_p=1.0, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
sft_dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144") |
|
|
base_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(base_model_name).to(device) |
|
|
|
|
|
fine_tuned_models = {} |
|
|
console.print("loading", sft_record) |
|
|
fine_tuned_models["SFT Model"] = AutoModelForCausalLM.from_pretrained( |
|
|
sft_record.hf_repo_id, |
|
|
revision=sft_record.revision, |
|
|
trust_remote_code=True, |
|
|
).to(device) |
|
|
console.print("loading", ppo_record) |
|
|
fine_tuned_models["PPO Model"] = AutoModelForCausalLM.from_pretrained( |
|
|
ppo_record.hf_repo_id, |
|
|
revision=ppo_record.revision, |
|
|
trust_remote_code=True, |
|
|
).to(device) |
|
|
console.print("loading", dpo_record) |
|
|
fine_tuned_models["DPO Model"] = AutoModelForCausalLM.from_pretrained( |
|
|
dpo_record.hf_repo_id, |
|
|
revision=dpo_record.revision, |
|
|
trust_remote_code=True, |
|
|
).to(device) |
|
|
console.print("loading", rm_record) |
|
|
scalar_model_config = ScalarModelConfig.from_pretrained( |
|
|
rm_record.hf_repo_id, |
|
|
revision=rm_record.revision, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
original_model = "/".join(scalar_model_config.base_config["_name_or_path"].split("/")[1:3]) |
|
|
scalar_model_config.base_config["_name_or_path"] = original_model |
|
|
scalar_model_config.base_model = original_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rm = AutoModelForSequenceClassification.from_pretrained( |
|
|
"cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr", |
|
|
num_labels=1, |
|
|
).to(device) |
|
|
|
|
|
nchecks = 4 |
|
|
colors = { |
|
|
0: "on blue", |
|
|
1: "on yellow", |
|
|
2: "on yellow", |
|
|
3: "on red", |
|
|
} |
|
|
latex_colors = { |
|
|
0: r"\sethlcolor{LightBlue}", |
|
|
1: r"\sethlcolor{LightYellow}", |
|
|
2: r"\sethlcolor{LightYellow}", |
|
|
3: r"\sethlcolor{LightRed}", |
|
|
} |
|
|
include_logits = True |
|
|
for i in range(len(sft_dataset["validation"])): |
|
|
rich_table = defaultdict(list) |
|
|
latex_table = defaultdict(list) |
|
|
query = torch.LongTensor(sft_dataset["validation"][i : i + 1]["query_token"]).to(device) |
|
|
context_length = query.shape[1] |
|
|
query_reference_response = torch.cat( |
|
|
( |
|
|
query, |
|
|
torch.LongTensor(tokenizer.encode(sft_dataset["validation"][i]["reference_response"])).to(device).unsqueeze(0), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
for table in [rich_table, latex_table]: |
|
|
table["Type"].append("Query") |
|
|
table["Content"].append(tokenizer.decode(query[0], skip_special_tokens=True)) |
|
|
table["Score (RM)"].append("N/A") |
|
|
with torch.no_grad(): |
|
|
model_stats = defaultdict(list) |
|
|
for fine_tuned_model_name, fine_tuned_model in fine_tuned_models.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fine_tuned_model_query_response = generate(fine_tuned_model, query, tokenizer, validation_generation_config) |
|
|
fine_tuned_model_response = fine_tuned_model_query_response[:, context_length:] |
|
|
fine_tuned_model_reward, fine_tuned_model_reward_logits = get_reward( |
|
|
rm, fine_tuned_model_query_response, tokenizer |
|
|
) |
|
|
fine_tuned_model_reward_logits = fine_tuned_model_reward_logits.squeeze(-1)[:, context_length - 1 :] |
|
|
|
|
|
|
|
|
fine_tuned_model_output = forward(fine_tuned_model, fine_tuned_model_query_response, tokenizer) |
|
|
base_model_output = forward(base_model, fine_tuned_model_query_response, tokenizer) |
|
|
fine_tuned_model_logits = fine_tuned_model_output.logits[:, context_length - 1 : -1] |
|
|
_, fine_tuned_model_topk_indices = fine_tuned_model_logits.topk(10) |
|
|
base_model_logits = base_model_output.logits[:, context_length - 1 : -1] |
|
|
_, base_model_topk_indices = base_model_logits.topk(10) |
|
|
fine_tuned_model_topk_indices[:, :, 0:1].expand(-1, -1, nchecks) |
|
|
matches = ( |
|
|
fine_tuned_model_topk_indices[:, :, 0:1].expand(-1, -1, nchecks) == base_model_topk_indices[:, :, 0:nchecks] |
|
|
) |
|
|
matched = matches.sum(2) |
|
|
match_idx = matches.float().argmax(2) |
|
|
final_matches = torch.where(matched > 0, match_idx, nchecks - 1) |
|
|
stats = torch.stack([(final_matches == i).sum(1) for i in range(nchecks)]).T |
|
|
|
|
|
fine_tuned_model_all_logprob = F.log_softmax(fine_tuned_model_logits, dim=-1) |
|
|
fine_tuned_model_logprob = torch.gather( |
|
|
fine_tuned_model_all_logprob, 2, fine_tuned_model_response.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
base_model_all_logprob = F.log_softmax(base_model_logits, dim=-1) |
|
|
base_model_logprob = torch.gather(base_model_all_logprob, 2, fine_tuned_model_response.unsqueeze(-1)).squeeze(-1) |
|
|
kl = fine_tuned_model_logprob - base_model_logprob |
|
|
|
|
|
final_matches = final_matches.tolist() |
|
|
fine_tuned_model_response = fine_tuned_model_response.tolist() |
|
|
for table in [rich_table, latex_table]: |
|
|
table["Type"].append(f"{fine_tuned_model_name} Response") |
|
|
latex_table["Content"].append( |
|
|
"".join( |
|
|
[ |
|
|
f"{latex_colors[jt]}" r"\hl{" f"{tokenizer.decode(it)}" "}" |
|
|
for it, jt in zip(fine_tuned_model_response[0], final_matches[0]) |
|
|
] |
|
|
) |
|
|
) |
|
|
rich_table["Content"].append( |
|
|
"".join( |
|
|
[ |
|
|
f"[{colors[jt]}]{tokenizer.decode(it)}[/{colors[jt]}]" |
|
|
for it, jt in zip(fine_tuned_model_response[0], final_matches[0]) |
|
|
] |
|
|
) |
|
|
) |
|
|
for table in [rich_table, latex_table]: |
|
|
table["Score (RM)"].append(str(round(fine_tuned_model_reward[0][0].item(), 4))) |
|
|
if include_logits: |
|
|
table["Type"].append(f"{fine_tuned_model_name} Reward Logits") |
|
|
table["Content"].append([round(logit, 4) for logit in fine_tuned_model_reward_logits[0].tolist()]) |
|
|
table["Score (RM)"].append(str(round(fine_tuned_model_reward[0][0].item(), 4))) |
|
|
|
|
|
|
|
|
reference_reward, reference_reward_logits = get_reward(rm, query_reference_response, tokenizer) |
|
|
reference_reward_logits = reference_reward_logits.squeeze(-1)[:, context_length - 1 :] |
|
|
for table in [rich_table, latex_table]: |
|
|
table["Type"].append("Reference response") |
|
|
table["Content"].append(sft_dataset["validation"][i]["reference_response"]) |
|
|
table["Score (RM)"].append(str(round(reference_reward[0][0].item(), 4))) |
|
|
if include_logits: |
|
|
table["Type"].append("Reference Reward Logits") |
|
|
table["Content"].append([round(logit, 4) for logit in reference_reward_logits[0].tolist()]) |
|
|
table["Score (RM)"].append(str(round(reference_reward[0][0].item(), 4))) |
|
|
|
|
|
base_model_query_response = generate(base_model, query, tokenizer, validation_generation_config) |
|
|
base_model_response = base_model_query_response[:, context_length:] |
|
|
base_model_reward, base_model_reward_logits = get_reward(rm, base_model_query_response, tokenizer) |
|
|
base_model_reward_logits = base_model_reward_logits.squeeze(-1)[:, context_length - 1 :] |
|
|
for table in [rich_table, latex_table]: |
|
|
table["Type"].append("Base Model Response") |
|
|
table["Content"].append(tokenizer.decode(base_model_response[0], skip_special_tokens=True)) |
|
|
table["Score (RM)"].append(str(round(base_model_reward[0][0].item(), 4))) |
|
|
if include_logits: |
|
|
table["Type"].append("Base Model Reward Logits") |
|
|
table["Content"].append([round(logit, 4) for logit in base_model_reward_logits[0].tolist()]) |
|
|
table["Score (RM)"].append(str(round(base_model_reward[0][0].item(), 4))) |
|
|
|
|
|
rich_df = pd.DataFrame(rich_table) |
|
|
latex_df = pd.DataFrame(latex_table) |
|
|
print_rich_table("Results", rich_df, console) |
|
|
print(latex_df.to_latex(index=False)) |
|
|
if input("Continue? (press `n` to stop) ") == "n": |
|
|
break |
|
|
if i == 4: |
|
|
break |
|
|
|