File size: 6,497 Bytes
d566338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import glob
import os
import json
import wandb
import yaml

def train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # Load the tokenizer (vocab)
        with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
            words_to_ids = pickle.load(f)
        vocab_size = len(words_to_ids)
        embedding_dim = 128  # Use the dimension you trained with

        # Find the latest CBOW checkpoint
        checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
        latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
        state_dict = torch.load(latest_checkpoint,map_location=torch.device('cpu'))

        # Create the embedding layer and load weights
        embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        embedding_layer.weight.data.copy_(state_dict['emb.weight'])
        embedding_layer.weight.requires_grad = False  # freeze weights

        class QryTower(nn.Module):
            def __init__(self, embedding_layer, hidden_size):
                super().__init__()
                self.embedding = embedding_layer
                self.embedding.weight.requires_grad = False
                self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)

            def forward(self, x):
                if not x:
                    return None
                x = torch.tensor(x, dtype=torch.long).unsqueeze(0)  # (1, seq_len)
                embeds = self.embedding(x)  # (1, seq_len, emb_dim)
                _, h_n = self.rnn(embeds)  # h_n: (1, batch, hidden_size)
                return h_n.squeeze(0).squeeze(0)  # (hidden_size,)

        class DocTower(nn.Module):
            def __init__(self, embedding_layer, hidden_size):
                super().__init__()
                self.embedding = embedding_layer
                self.embedding.weight.requires_grad = False
                self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)

            def forward(self, x):
                if not x:
                    return None
                x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
                embeds = self.embedding(x)
                _, h_n = self.rnn(embeds)
                return h_n.squeeze(0).squeeze(0)

        qryTower = QryTower(embedding_layer, config.hidden_size)
        docTower = DocTower(embedding_layer, config.hidden_size)

        # Load real tokenized triples (small subset for speed)
        with open('tokenized_triples.json', 'r') as f:
            triples_data = json.load(f)

        # TRAINING LOOP
        params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters())
        optimizer = torch.optim.Adam(params, lr=config.learning_rate)
        num_epochs = config.num_epochs
        margin = config.margin
        print(f"\nTraining on all real triples from the train split with RNN towers for {num_epochs} epochs:\n")
        for epoch in range(num_epochs):
            total_loss = 0
            count = 0
            for triple in triples_data['train']:  # Use all triples
                qry_tokens = triple['query_tokens']
                pos_tokens = triple['positive_document_tokens']
                neg_tokens = triple['negative_document_tokens']

                qry = qryTower(qry_tokens)
                pos = docTower(pos_tokens)
                neg = docTower(neg_tokens)

                if qry is not None and pos is not None and neg is not None:
                    dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
                    dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
                    dst_mrg = torch.tensor(margin)
                    loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                    count += 1
            avg_loss = total_loss / max(count,1)
            print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")
            wandb.log({'epoch': epoch+1, 'avg_loss': avg_loss})

        # EVALUATE ON 5 EXAMPLES
        print("\nEvaluating on first 5 real triples after training:\n")
        for i, triple in enumerate(triples_data['train'][:5]):
            qry_tokens = triple['query_tokens']
            pos_tokens = triple['positive_document_tokens']
            neg_tokens = triple['negative_document_tokens']
            qry_text = triple['query']
            pos_text = triple['positive_document']
            neg_text = triple['negative_document']

            qry = qryTower(qry_tokens)
            pos = docTower(pos_tokens)
            neg = docTower(neg_tokens)

            if qry is not None and pos is not None and neg is not None:
                dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
                dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
                dst_mrg = torch.tensor(margin)
                loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
                print(f"Example {i+1}:")
                print(f"Query: {qry_text}")
                print(f"Positive doc: {pos_text[:100]}...")
                print(f"Negative doc: {neg_text[:100]}...")
                print(f"Cosine similarity (pos): {dst_pos.item():.4f}")
                print(f"Cosine similarity (neg): {dst_neg.item():.4f}")
                print(f"Triplet loss: {loss.item():.4f}\n")
            else:
                print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n")

if __name__ == "__main__":
    import yaml
    
    # Load configuration from sweep.yaml
    with open('sweep.yaml', 'r') as f:
        sweep_config = yaml.safe_load(f)
    
    # Create a default config with a single value for each parameter
    default_config = {
        'learning_rate': sweep_config['parameters']['learning_rate']['values'][0],
        'margin': sweep_config['parameters']['margin']['values'][0],
        'num_epochs': sweep_config['parameters']['num_epochs']['value'],
        'num_triples': sweep_config['parameters']['num_triples']['values'][0],
        'hidden_size': sweep_config['parameters']['hidden_size']['values'][0]
    }
    
    # Initialize wandb with the default config
    train(config=default_config)