roshbeed commited on
Commit
148ebb0
·
verified ·
1 Parent(s): e7e6c82

Upload src/simple_dual_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/simple_dual_encoder.py +99 -0
src/simple_dual_encoder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pickle
5
+ import glob
6
+ import os
7
+ import json
8
+ import torch.optim as optim
9
+
10
+ # Load the tokenizer (vocab)
11
+ with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
12
+ words_to_ids = pickle.load(f)
13
+ vocab_size = len(words_to_ids)
14
+ embedding_dim = 128 # Use the dimension you trained with
15
+
16
+ # Find the latest CBOW checkpoint
17
+ checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
18
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
19
+ state_dict = torch.load(latest_checkpoint)
20
+
21
+ # Create the embedding layer and load weights
22
+ embedding_layer = nn.Embedding(vocab_size, embedding_dim)
23
+ embedding_layer.weight.data.copy_(state_dict['emb.weight'])
24
+ embedding_layer.weight.requires_grad = False # freeze weights
25
+
26
+ def average_embedding(token_ids, embedding_layer):
27
+ if not token_ids: # skip empty
28
+ return None
29
+ tokens_tensor = torch.tensor(token_ids, dtype=torch.long)
30
+ vectors = embedding_layer(tokens_tensor)
31
+ avg_vector = vectors.mean(dim=0)
32
+ return avg_vector
33
+
34
+ def triplet_loss(qry, pos, neg, margin=0.2):
35
+ dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
36
+ dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
37
+ loss = torch.clamp(margin - (dst_pos - dst_neg), min=0.0)
38
+ return loss, dst_pos.item(), dst_neg.item()
39
+
40
+ # Load real tokenized triples (small subset for speed)
41
+ with open('tokenized_triples.json', 'r') as f:
42
+ triples_data = json.load(f)
43
+
44
+ print("\nRunning on first 5 real triples from the train split:\n")
45
+ for i, triple in enumerate(triples_data['train'][:5]):
46
+ qry_tokens = triple['query_tokens']
47
+ pos_tokens = triple['positive_document_tokens']
48
+ neg_tokens = triple['negative_document_tokens']
49
+ qry_text = triple['query']
50
+ pos_text = triple['positive_document']
51
+ neg_text = triple['negative_document']
52
+
53
+ qry_vec = average_embedding(qry_tokens, embedding_layer)
54
+ pos_vec = average_embedding(pos_tokens, embedding_layer)
55
+ neg_vec = average_embedding(neg_tokens, embedding_layer)
56
+
57
+ if qry_vec is not None and pos_vec is not None and neg_vec is not None:
58
+ loss, sim_pos, sim_neg = triplet_loss(qry_vec, pos_vec, neg_vec)
59
+ print(f"Example {i+1}:")
60
+ print(f"Query: {qry_text}")
61
+ print(f"Positive doc: {pos_text[:100]}...")
62
+ print(f"Negative doc: {neg_text[:100]}...")
63
+ print(f"Cosine similarity (pos): {sim_pos:.4f}")
64
+ print(f"Cosine similarity (neg): {sim_neg:.4f}")
65
+ print(f"Triplet loss: {loss.item():.4f}\n")
66
+ else:
67
+ print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n")
68
+
69
+ # Only optimize the RNNs, not the embedding layer
70
+ params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters())
71
+ optimizer = optim.Adam(params, lr=0.001)
72
+
73
+ num_epochs = 3
74
+ for epoch in range(num_epochs):
75
+ total_loss = 0
76
+ count = 0
77
+ for triple in triples_data['train']: # Use all triples
78
+ qry_tokens = triple['query_tokens']
79
+ pos_tokens = triple['positive_document_tokens']
80
+ neg_tokens = triple['negative_document_tokens']
81
+
82
+ qry = qryTower(qry_tokens)
83
+ pos = docTower(pos_tokens)
84
+ neg = docTower(neg_tokens)
85
+
86
+ if qry is not None and pos is not None and neg is not None:
87
+ dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
88
+ dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
89
+ dst_mrg = torch.tensor(0.2)
90
+ loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
91
+
92
+ optimizer.zero_grad()
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ total_loss += loss.item()
97
+ count += 1
98
+
99
+ print(f"Epoch {epoch+1}, Avg Loss: {total_loss / max(count,1):.4f}")