roshbeed commited on
Commit
d566338
·
verified ·
1 Parent(s): eabf707

Upload src/simple_dual_encoder_rnn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/simple_dual_encoder_rnn.py +148 -0
src/simple_dual_encoder_rnn.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pickle
5
+ import glob
6
+ import os
7
+ import json
8
+ import wandb
9
+ import yaml
10
+
11
+ def train(config=None):
12
+ with wandb.init(config=config):
13
+ config = wandb.config
14
+
15
+ # Load the tokenizer (vocab)
16
+ with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
17
+ words_to_ids = pickle.load(f)
18
+ vocab_size = len(words_to_ids)
19
+ embedding_dim = 128 # Use the dimension you trained with
20
+
21
+ # Find the latest CBOW checkpoint
22
+ checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
23
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
24
+ state_dict = torch.load(latest_checkpoint,map_location=torch.device('cpu'))
25
+
26
+ # Create the embedding layer and load weights
27
+ embedding_layer = nn.Embedding(vocab_size, embedding_dim)
28
+ embedding_layer.weight.data.copy_(state_dict['emb.weight'])
29
+ embedding_layer.weight.requires_grad = False # freeze weights
30
+
31
+ class QryTower(nn.Module):
32
+ def __init__(self, embedding_layer, hidden_size):
33
+ super().__init__()
34
+ self.embedding = embedding_layer
35
+ self.embedding.weight.requires_grad = False
36
+ self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
37
+
38
+ def forward(self, x):
39
+ if not x:
40
+ return None
41
+ x = torch.tensor(x, dtype=torch.long).unsqueeze(0) # (1, seq_len)
42
+ embeds = self.embedding(x) # (1, seq_len, emb_dim)
43
+ _, h_n = self.rnn(embeds) # h_n: (1, batch, hidden_size)
44
+ return h_n.squeeze(0).squeeze(0) # (hidden_size,)
45
+
46
+ class DocTower(nn.Module):
47
+ def __init__(self, embedding_layer, hidden_size):
48
+ super().__init__()
49
+ self.embedding = embedding_layer
50
+ self.embedding.weight.requires_grad = False
51
+ self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
52
+
53
+ def forward(self, x):
54
+ if not x:
55
+ return None
56
+ x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
57
+ embeds = self.embedding(x)
58
+ _, h_n = self.rnn(embeds)
59
+ return h_n.squeeze(0).squeeze(0)
60
+
61
+ qryTower = QryTower(embedding_layer, config.hidden_size)
62
+ docTower = DocTower(embedding_layer, config.hidden_size)
63
+
64
+ # Load real tokenized triples (small subset for speed)
65
+ with open('tokenized_triples.json', 'r') as f:
66
+ triples_data = json.load(f)
67
+
68
+ # TRAINING LOOP
69
+ params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters())
70
+ optimizer = torch.optim.Adam(params, lr=config.learning_rate)
71
+ num_epochs = config.num_epochs
72
+ margin = config.margin
73
+ print(f"\nTraining on all real triples from the train split with RNN towers for {num_epochs} epochs:\n")
74
+ for epoch in range(num_epochs):
75
+ total_loss = 0
76
+ count = 0
77
+ for triple in triples_data['train']: # Use all triples
78
+ qry_tokens = triple['query_tokens']
79
+ pos_tokens = triple['positive_document_tokens']
80
+ neg_tokens = triple['negative_document_tokens']
81
+
82
+ qry = qryTower(qry_tokens)
83
+ pos = docTower(pos_tokens)
84
+ neg = docTower(neg_tokens)
85
+
86
+ if qry is not None and pos is not None and neg is not None:
87
+ dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
88
+ dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
89
+ dst_mrg = torch.tensor(margin)
90
+ loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
91
+
92
+ optimizer.zero_grad()
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ total_loss += loss.item()
97
+ count += 1
98
+ avg_loss = total_loss / max(count,1)
99
+ print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")
100
+ wandb.log({'epoch': epoch+1, 'avg_loss': avg_loss})
101
+
102
+ # EVALUATE ON 5 EXAMPLES
103
+ print("\nEvaluating on first 5 real triples after training:\n")
104
+ for i, triple in enumerate(triples_data['train'][:5]):
105
+ qry_tokens = triple['query_tokens']
106
+ pos_tokens = triple['positive_document_tokens']
107
+ neg_tokens = triple['negative_document_tokens']
108
+ qry_text = triple['query']
109
+ pos_text = triple['positive_document']
110
+ neg_text = triple['negative_document']
111
+
112
+ qry = qryTower(qry_tokens)
113
+ pos = docTower(pos_tokens)
114
+ neg = docTower(neg_tokens)
115
+
116
+ if qry is not None and pos is not None and neg is not None:
117
+ dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
118
+ dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
119
+ dst_mrg = torch.tensor(margin)
120
+ loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
121
+ print(f"Example {i+1}:")
122
+ print(f"Query: {qry_text}")
123
+ print(f"Positive doc: {pos_text[:100]}...")
124
+ print(f"Negative doc: {neg_text[:100]}...")
125
+ print(f"Cosine similarity (pos): {dst_pos.item():.4f}")
126
+ print(f"Cosine similarity (neg): {dst_neg.item():.4f}")
127
+ print(f"Triplet loss: {loss.item():.4f}\n")
128
+ else:
129
+ print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n")
130
+
131
+ if __name__ == "__main__":
132
+ import yaml
133
+
134
+ # Load configuration from sweep.yaml
135
+ with open('sweep.yaml', 'r') as f:
136
+ sweep_config = yaml.safe_load(f)
137
+
138
+ # Create a default config with a single value for each parameter
139
+ default_config = {
140
+ 'learning_rate': sweep_config['parameters']['learning_rate']['values'][0],
141
+ 'margin': sweep_config['parameters']['margin']['values'][0],
142
+ 'num_epochs': sweep_config['parameters']['num_epochs']['value'],
143
+ 'num_triples': sweep_config['parameters']['num_triples']['values'][0],
144
+ 'hidden_size': sweep_config['parameters']['hidden_size']['values'][0]
145
+ }
146
+
147
+ # Initialize wandb with the default config
148
+ train(config=default_config)