import os from collections import defaultdict from dataclasses import dataclass import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset from rich.console import Console from rich.table import Table from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedModel, ) @dataclass class RunRecord: wandb_url: str hf_repo_url: str hf_repo_id: str revision: str ###### # RM model definition ###### def layer_init(layer, std=np.sqrt(2), bias_const=0.0): torch.nn.init.normal_(layer.weight, std=std) torch.nn.init.constant_(layer.bias, val=bias_const) return layer class ScalarModelConfig(PretrainedConfig): def __init__( self, base_model: str = "EleutherAI/pythia-160m", base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), hidden_size: int = 768, bias: float = 0.0, **kwargs, ): super().__init__(**kwargs) self.base_model = base_model self.base_config = base_config self.hidden_size = hidden_size self.bias = bias class ScalarModel(PreTrainedModel): config_class = ScalarModelConfig def __init__(self, config: ScalarModelConfig): super().__init__(config) self.config = config self.lm_backbone = AutoModel.from_pretrained( config.base_model, config=self.config.base_config, trust_remote_code=True, ) self.scalar_head = layer_init( nn.Linear(self.config.hidden_size, 1), std=1 / np.sqrt(self.config.hidden_size + 1), ) def forward(self, **kwargs): output = self.lm_backbone(**kwargs) reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias return reward ###### # Utility functions ###### def generate(lm_backbone, queries, tokenizer, generation_config): """generate in a way that does not affect padding tokens""" context_length = queries.shape[1] attention_mask = queries != tokenizer.pad_token_id input_ids = torch.masked_fill(queries, ~attention_mask, 0) output = lm_backbone.generate( input_ids=input_ids, attention_mask=attention_mask, # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. TODO: why does generation collapse with this? generation_config=generation_config, return_dict_in_generate=True, ) return torch.cat((queries, output.sequences[:, context_length:]), dim=1) def forward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id position_ids = attention_mask.cumsum(1) - attention_mask.long() input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, output_hidden_states=True, ) def get_reward(model, query_responses, tokenizer): attention_mask = query_responses != tokenizer.pad_token_id input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) reward_logits = model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, ) sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths], reward_logits def print_rich_table(title: str, df: pd.DataFrame, console: Console) -> Table: table = Table(show_lines=True) for column in df.columns: table.add_column(column) for _, row in df.iterrows(): table.add_row(*row.astype(str).tolist()) console.rule(f"[bold red]{title}") console.print(table) ###### # Start ###### console = Console() if not os.path.exists("release_runs.csv"): import wandb keys = { "sft": "refactor-chosen-rejected3", "reward": "refactor-chosen-rejected3", "ppo_left_padding_new_nowhiten_reward": "refactor-chosen-rejected3", "dpo": "refactor-chosen-rejected3", } runs = [] for exp_name, tag in keys.items(): runs.extend( list( wandb.Api().runs( path=f"costa-huang/tldr_summarize", filters={ "$and": [ {f"config.exp_name.value": exp_name}, {"tags": {"$in": [tag]}}, ] }, ) ) ) table = defaultdict(list) for i in range(len(runs)): table["base_model"].append(runs[i].config["base_model"]) table["exp"].append(runs[i].config["exp_name"]) table["seed"].append(runs[i].config["seed"]) table["wandb_url"].append(runs[i].url) table["hf_repo_url"].append(runs[i].config["hf_repo_url"]) table["hf_repo_id"].append(runs[i].config["hf_repo_id"]) table["revision"].append(runs[i].config["run_name"]) df = pd.DataFrame(table) df.to_csv("release_runs.csv", index=False) else: df = pd.read_csv("release_runs.csv") df = df.groupby(["base_model", "exp", "seed"]).agg(lambda x: x.tolist()[0]) # feel free to change the base_model, exp, and seed; the seeds are 44413, 55513, 66613, 77713 base_model_name = "EleutherAI/pythia-1b-deduped" sft_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "sft", 55513)]) ppo_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "ppo_left_padding_new_nowhiten_reward", 55513)]) dpo_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "dpo", 55513)]) # rm_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "reward", 55513)]) rm_record = RunRecord(**df.loc[("EleutherAI/pythia-6.9b-deduped", "reward", 55513)]) # larger (in some sense gold) RM device = torch.device("cpu") tokenizer = AutoTokenizer.from_pretrained(base_model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) response_length = 100 validation_generation_config = GenerationConfig( max_new_tokens=response_length, temperature=(0.01 + 1e-7), top_k=0.0, top_p=1.0, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) sft_dataset = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144") base_model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(base_model_name).to(device) fine_tuned_models = {} console.print("loading", sft_record) fine_tuned_models["SFT Model"] = AutoModelForCausalLM.from_pretrained( sft_record.hf_repo_id, revision=sft_record.revision, trust_remote_code=True, ).to(device) console.print("loading", ppo_record) fine_tuned_models["PPO Model"] = AutoModelForCausalLM.from_pretrained( ppo_record.hf_repo_id, revision=ppo_record.revision, trust_remote_code=True, ).to(device) console.print("loading", dpo_record) fine_tuned_models["DPO Model"] = AutoModelForCausalLM.from_pretrained( dpo_record.hf_repo_id, revision=dpo_record.revision, trust_remote_code=True, ).to(device) console.print("loading", rm_record) scalar_model_config = ScalarModelConfig.from_pretrained( rm_record.hf_repo_id, revision=rm_record.revision, trust_remote_code=True, ) # hack to remove the path # models/EleutherAI/pythia-6.9b-deduped/sft_model_55513 -> EleutherAI/pythia-6.9b-deduped original_model = "/".join(scalar_model_config.base_config["_name_or_path"].split("/")[1:3]) scalar_model_config.base_config["_name_or_path"] = original_model scalar_model_config.base_model = original_model # rm: PreTrainedModel = ScalarModel.from_pretrained( # rm_record.hf_repo_id, # revision=rm_record.revision, # trust_remote_code=True, # config=scalar_model_config, # ).to(device) rm = AutoModelForSequenceClassification.from_pretrained( "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr", num_labels=1, ).to(device) nchecks = 4 colors = { 0: "on blue", 1: "on yellow", 2: "on yellow", 3: "on red", } latex_colors = { 0: r"\sethlcolor{LightBlue}", 1: r"\sethlcolor{LightYellow}", 2: r"\sethlcolor{LightYellow}", 3: r"\sethlcolor{LightRed}", } include_logits = True for i in range(len(sft_dataset["validation"])): rich_table = defaultdict(list) latex_table = defaultdict(list) query = torch.LongTensor(sft_dataset["validation"][i : i + 1]["query_token"]).to(device) context_length = query.shape[1] query_reference_response = torch.cat( ( query, torch.LongTensor(tokenizer.encode(sft_dataset["validation"][i]["reference_response"])).to(device).unsqueeze(0), ), dim=1, ) for table in [rich_table, latex_table]: table["Type"].append("Query") table["Content"].append(tokenizer.decode(query[0], skip_special_tokens=True)) table["Score (RM)"].append("N/A") with torch.no_grad(): model_stats = defaultdict(list) for fine_tuned_model_name, fine_tuned_model in fine_tuned_models.items(): # for fine_tuned_model, model_name in zip( # [sft_model, ppo_model, dpo_model], # ["SFT Model Response", "PPO Model Response", "DPO Model Response"], # ): fine_tuned_model_query_response = generate(fine_tuned_model, query, tokenizer, validation_generation_config) fine_tuned_model_response = fine_tuned_model_query_response[:, context_length:] fine_tuned_model_reward, fine_tuned_model_reward_logits = get_reward( rm, fine_tuned_model_query_response, tokenizer ) fine_tuned_model_reward_logits = fine_tuned_model_reward_logits.squeeze(-1)[:, context_length - 1 :] # AI2 visualization https://allenai.github.io/re-align/tds.html fine_tuned_model_output = forward(fine_tuned_model, fine_tuned_model_query_response, tokenizer) base_model_output = forward(base_model, fine_tuned_model_query_response, tokenizer) fine_tuned_model_logits = fine_tuned_model_output.logits[:, context_length - 1 : -1] _, fine_tuned_model_topk_indices = fine_tuned_model_logits.topk(10) base_model_logits = base_model_output.logits[:, context_length - 1 : -1] _, base_model_topk_indices = base_model_logits.topk(10) fine_tuned_model_topk_indices[:, :, 0:1].expand(-1, -1, nchecks) matches = ( fine_tuned_model_topk_indices[:, :, 0:1].expand(-1, -1, nchecks) == base_model_topk_indices[:, :, 0:nchecks] ) matched = matches.sum(2) match_idx = matches.float().argmax(2) final_matches = torch.where(matched > 0, match_idx, nchecks - 1) stats = torch.stack([(final_matches == i).sum(1) for i in range(nchecks)]).T fine_tuned_model_all_logprob = F.log_softmax(fine_tuned_model_logits, dim=-1) fine_tuned_model_logprob = torch.gather( fine_tuned_model_all_logprob, 2, fine_tuned_model_response.unsqueeze(-1) ).squeeze(-1) base_model_all_logprob = F.log_softmax(base_model_logits, dim=-1) base_model_logprob = torch.gather(base_model_all_logprob, 2, fine_tuned_model_response.unsqueeze(-1)).squeeze(-1) kl = fine_tuned_model_logprob - base_model_logprob final_matches = final_matches.tolist() fine_tuned_model_response = fine_tuned_model_response.tolist() for table in [rich_table, latex_table]: table["Type"].append(f"{fine_tuned_model_name} Response") latex_table["Content"].append( "".join( [ f"{latex_colors[jt]}" r"\hl{" f"{tokenizer.decode(it)}" "}" for it, jt in zip(fine_tuned_model_response[0], final_matches[0]) ] ) ) rich_table["Content"].append( "".join( [ f"[{colors[jt]}]{tokenizer.decode(it)}[/{colors[jt]}]" for it, jt in zip(fine_tuned_model_response[0], final_matches[0]) ] ) ) for table in [rich_table, latex_table]: table["Score (RM)"].append(str(round(fine_tuned_model_reward[0][0].item(), 4))) if include_logits: table["Type"].append(f"{fine_tuned_model_name} Reward Logits") table["Content"].append([round(logit, 4) for logit in fine_tuned_model_reward_logits[0].tolist()]) table["Score (RM)"].append(str(round(fine_tuned_model_reward[0][0].item(), 4))) # table["Type"].append("Matched Color Counts") # table["Content"].append(stats[0]) reference_reward, reference_reward_logits = get_reward(rm, query_reference_response, tokenizer) reference_reward_logits = reference_reward_logits.squeeze(-1)[:, context_length - 1 :] for table in [rich_table, latex_table]: table["Type"].append("Reference response") table["Content"].append(sft_dataset["validation"][i]["reference_response"]) table["Score (RM)"].append(str(round(reference_reward[0][0].item(), 4))) if include_logits: table["Type"].append("Reference Reward Logits") table["Content"].append([round(logit, 4) for logit in reference_reward_logits[0].tolist()]) table["Score (RM)"].append(str(round(reference_reward[0][0].item(), 4))) base_model_query_response = generate(base_model, query, tokenizer, validation_generation_config) base_model_response = base_model_query_response[:, context_length:] base_model_reward, base_model_reward_logits = get_reward(rm, base_model_query_response, tokenizer) base_model_reward_logits = base_model_reward_logits.squeeze(-1)[:, context_length - 1 :] for table in [rich_table, latex_table]: table["Type"].append("Base Model Response") table["Content"].append(tokenizer.decode(base_model_response[0], skip_special_tokens=True)) table["Score (RM)"].append(str(round(base_model_reward[0][0].item(), 4))) if include_logits: table["Type"].append("Base Model Reward Logits") table["Content"].append([round(logit, 4) for logit in base_model_reward_logits[0].tolist()]) table["Score (RM)"].append(str(round(base_model_reward[0][0].item(), 4))) rich_df = pd.DataFrame(rich_table) latex_df = pd.DataFrame(latex_table) print_rich_table("Results", rich_df, console) print(latex_df.to_latex(index=False)) if input("Continue? (press `n` to stop) ") == "n": break if i == 4: break