|
|
from dataclasses import dataclass |
|
|
|
|
|
import pandas as pd |
|
|
import torch.nn as nn |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
) |
|
|
|
|
|
AutoTokenizer |
|
|
AutoModelForSequenceClassification |
|
|
AutoModelForCausalLM |
|
|
|
|
|
|
|
|
class ScalarModelConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
base_model: str = "EleutherAI/pythia-160m", |
|
|
base_config: PretrainedConfig = AutoConfig.from_pretrained("EleutherAI/pythia-160m"), |
|
|
hidden_size: int = 768, |
|
|
bias: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model = base_model |
|
|
self.base_config = base_config |
|
|
self.hidden_size = hidden_size |
|
|
self.bias = bias |
|
|
|
|
|
|
|
|
class ScalarModel(PreTrainedModel): |
|
|
config_class = ScalarModelConfig |
|
|
|
|
|
def __init__(self, config: ScalarModelConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.lm_backbone = AutoModel.from_pretrained( |
|
|
config.base_model, |
|
|
config=self.config.base_config, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.scalar_head = nn.Linear(self.config.hidden_size, 1) |
|
|
|
|
|
def forward(self, **kwargs): |
|
|
output = self.lm_backbone(**kwargs) |
|
|
reward = self.scalar_head(output.hidden_states[-1]) - self.config.bias |
|
|
return reward |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunRecord: |
|
|
wandb_url: str |
|
|
hf_repo_url: str |
|
|
hf_repo_id: str |
|
|
revision: str |
|
|
|
|
|
|
|
|
df = pd.read_csv("release_runs.csv") |
|
|
df = df.groupby(["base_model", "exp", "seed"]).agg(lambda x: x.tolist()[0]) |
|
|
|
|
|
|
|
|
sft_record = RunRecord(**df.loc[("EleutherAI/pythia-1b-deduped", "sft", 77713)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rm_record = RunRecord(**df.loc[("EleutherAI/pythia-2.8b-deduped", "reward", 77713)]) |
|
|
rm_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b-deduped") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rm_tokenizer.push_to_hub(rm_record.hf_repo_id.replace("vwxyzjn", "cleanrl"), use_temp_dir=True) |
|
|
|