Spaces:
Runtime error
Runtime error
| from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer | |
| import datasets, torch, json, glob | |
| model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| data=[] | |
| for f in glob.glob("human_prefs/*.json"): | |
| j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path} | |
| dataset=datasets.Dataset.from_list(data) | |
| def preprocess(ex): | |
| inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt") | |
| inputs["labels"]=torch.tensor([1,0]) | |
| return inputs | |
| dataset=dataset.map(preprocess,remove_columns=dataset.column_names) | |
| args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3) | |
| trainer=Trainer(model,args,train_dataset=dataset) | |
| trainer.train(); model.save_pretrained("rm") |