ThaoTran7's picture
incomplete commit
485127c
import torch.nn.functional as F
def calculate_SPO_loss(model_prefered_logprob, model_disprefered_logprob,
ref_prefered_logprob, ref_disprefered_logprob,
beta=0.5):
prefered_relative_logprob = model_prefered_logprob - ref_prefered_logprob
disprefered_relative_logprob = model_disprefered_logprob - ref_disprefered_logprob
reward_accuracies = (prefered_relative_logprob > disprefered_relative_logprob).float().mean(dim=-1)
reward_margins = (prefered_relative_logprob - disprefered_relative_logprob).mean(dim=-1)
loss = -F.logsigmoid(beta * (prefered_relative_logprob - disprefered_relative_logprob)).mean(dim=-1)
return loss, prefered_relative_logprob.mean(dim=-1), disprefered_relative_logprob.mean(dim=-1), reward_accuracies, reward_margins