| import torch.nn.functional as F | |
| def calculate_SPO_loss(model_prefered_logprob, model_disprefered_logprob, | |
| ref_prefered_logprob, ref_disprefered_logprob, | |
| beta=0.5): | |
| prefered_relative_logprob = model_prefered_logprob - ref_prefered_logprob | |
| disprefered_relative_logprob = model_disprefered_logprob - ref_disprefered_logprob | |
| reward_accuracies = (prefered_relative_logprob > disprefered_relative_logprob).float().mean(dim=-1) | |
| reward_margins = (prefered_relative_logprob - disprefered_relative_logprob).mean(dim=-1) | |
| loss = -F.logsigmoid(beta * (prefered_relative_logprob - disprefered_relative_logprob)).mean(dim=-1) | |
| return loss, prefered_relative_logprob.mean(dim=-1), disprefered_relative_logprob.mean(dim=-1), reward_accuracies, reward_margins | |