Balaji S commited on
Commit
3e36908
·
verified ·
1 Parent(s): 2c732ed

Delete loss_utils.py

Browse files
Files changed (1) hide show
  1. loss_utils.py +0 -63
loss_utils.py DELETED
@@ -1,63 +0,0 @@
1
- import torch as t
2
- import torch.nn.functional as F
3
-
4
- def cal_bpr_loss(anc_embeds, pos_embeds, neg_embeds):
5
- pos_preds = (anc_embeds * pos_embeds).sum(-1)
6
- neg_preds = (anc_embeds * neg_embeds).sum(-1)
7
- return t.sum(F.softplus(neg_preds - pos_preds))
8
-
9
-
10
- def reg_pick_embeds(embeds_list):
11
- reg_loss = 0
12
- for embeds in embeds_list:
13
- reg_loss += embeds.square().sum()
14
- return reg_loss
15
-
16
-
17
- def cal_infonce_loss(embeds1, embeds2, all_embeds2, temp=1.0):
18
- normed_embeds1 = embeds1 / t.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True))
19
- normed_embeds2 = embeds2 / t.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True))
20
- normed_all_embeds2 = all_embeds2 / t.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True))
21
- nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1)
22
- deno_term = t.log(t.sum(t.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1))
23
- cl_loss = (nume_term + deno_term).sum()
24
- return cl_loss
25
-
26
-
27
- def cal_infonce_loss_spec_nodes(embeds1, embeds2, nodes, temp):
28
- embeds1 = F.normalize(embeds1 + 1e-8, p=2)
29
- embeds2 = F.normalize(embeds2 + 1e-8, p=2)
30
- pckEmbeds1 = embeds1[nodes]
31
- pckEmbeds2 = embeds2[nodes]
32
- nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp)
33
- deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1) + 1e-8
34
- return -t.log(nume / deno).mean()
35
-
36
-
37
- def cal_sce_loss(x, y, alpha):
38
- x = F.normalize(x, p=2, dim=-1)
39
- y = F.normalize(y, p=2, dim=-1)
40
- loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
41
- loss = loss.mean()
42
- return loss
43
-
44
-
45
- def cal_rank_loss(stu_anc_emb, stu_pos_emb, stu_neg_emb, tea_anc_emb, tea_pos_emb, tea_neg_emb):
46
- stu_pos_score = (stu_anc_emb * stu_pos_emb).sum(dim=-1)
47
- stu_neg_score = (stu_anc_emb * stu_neg_emb).sum(dim=-1)
48
- stu_r_score = F.sigmoid(stu_pos_score - stu_neg_score)
49
-
50
- tea_pos_score = (tea_anc_emb * tea_pos_emb).sum(dim=-1)
51
- tea_neg_score = (tea_anc_emb * tea_neg_emb).sum(dim=-1)
52
- tea_r_score = F.sigmoid(tea_pos_score - tea_neg_score)
53
-
54
- rank_loss = -(tea_r_score * t.log(stu_r_score + 1e-8) + (1 - tea_r_score) * t.log(1 - stu_r_score + 1e-8)).mean()
55
-
56
- return rank_loss
57
-
58
-
59
- def reg_params(model):
60
- reg_loss = 0
61
- for W in model.parameters():
62
- reg_loss += W.norm(2).square()
63
- return reg_loss