Spaces:
Build error
Build error
Balaji S commited on
Delete loss_utils.py
Browse files- loss_utils.py +0 -63
loss_utils.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
import torch as t
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
|
| 4 |
-
def cal_bpr_loss(anc_embeds, pos_embeds, neg_embeds):
|
| 5 |
-
pos_preds = (anc_embeds * pos_embeds).sum(-1)
|
| 6 |
-
neg_preds = (anc_embeds * neg_embeds).sum(-1)
|
| 7 |
-
return t.sum(F.softplus(neg_preds - pos_preds))
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def reg_pick_embeds(embeds_list):
|
| 11 |
-
reg_loss = 0
|
| 12 |
-
for embeds in embeds_list:
|
| 13 |
-
reg_loss += embeds.square().sum()
|
| 14 |
-
return reg_loss
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def cal_infonce_loss(embeds1, embeds2, all_embeds2, temp=1.0):
|
| 18 |
-
normed_embeds1 = embeds1 / t.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True))
|
| 19 |
-
normed_embeds2 = embeds2 / t.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True))
|
| 20 |
-
normed_all_embeds2 = all_embeds2 / t.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True))
|
| 21 |
-
nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1)
|
| 22 |
-
deno_term = t.log(t.sum(t.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1))
|
| 23 |
-
cl_loss = (nume_term + deno_term).sum()
|
| 24 |
-
return cl_loss
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def cal_infonce_loss_spec_nodes(embeds1, embeds2, nodes, temp):
|
| 28 |
-
embeds1 = F.normalize(embeds1 + 1e-8, p=2)
|
| 29 |
-
embeds2 = F.normalize(embeds2 + 1e-8, p=2)
|
| 30 |
-
pckEmbeds1 = embeds1[nodes]
|
| 31 |
-
pckEmbeds2 = embeds2[nodes]
|
| 32 |
-
nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp)
|
| 33 |
-
deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1) + 1e-8
|
| 34 |
-
return -t.log(nume / deno).mean()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def cal_sce_loss(x, y, alpha):
|
| 38 |
-
x = F.normalize(x, p=2, dim=-1)
|
| 39 |
-
y = F.normalize(y, p=2, dim=-1)
|
| 40 |
-
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
|
| 41 |
-
loss = loss.mean()
|
| 42 |
-
return loss
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def cal_rank_loss(stu_anc_emb, stu_pos_emb, stu_neg_emb, tea_anc_emb, tea_pos_emb, tea_neg_emb):
|
| 46 |
-
stu_pos_score = (stu_anc_emb * stu_pos_emb).sum(dim=-1)
|
| 47 |
-
stu_neg_score = (stu_anc_emb * stu_neg_emb).sum(dim=-1)
|
| 48 |
-
stu_r_score = F.sigmoid(stu_pos_score - stu_neg_score)
|
| 49 |
-
|
| 50 |
-
tea_pos_score = (tea_anc_emb * tea_pos_emb).sum(dim=-1)
|
| 51 |
-
tea_neg_score = (tea_anc_emb * tea_neg_emb).sum(dim=-1)
|
| 52 |
-
tea_r_score = F.sigmoid(tea_pos_score - tea_neg_score)
|
| 53 |
-
|
| 54 |
-
rank_loss = -(tea_r_score * t.log(stu_r_score + 1e-8) + (1 - tea_r_score) * t.log(1 - stu_r_score + 1e-8)).mean()
|
| 55 |
-
|
| 56 |
-
return rank_loss
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def reg_params(model):
|
| 60 |
-
reg_loss = 0
|
| 61 |
-
for W in model.parameters():
|
| 62 |
-
reg_loss += W.norm(2).square()
|
| 63 |
-
return reg_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|