AdGPT / reward_model.py
goodmodeler's picture
ADD: LLM techs
696ae63
raw
history blame
876 Bytes
from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
import datasets, torch, json, glob
model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
data=[]
for f in glob.glob("human_prefs/*.json"):
j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
dataset=datasets.Dataset.from_list(data)
def preprocess(ex):
inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
inputs["labels"]=torch.tensor([1,0])
return inputs
dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
trainer=Trainer(model,args,train_dataset=dataset)
trainer.train(); model.save_pretrained("rm")