|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pickle |
|
|
import glob |
|
|
import os |
|
|
import json |
|
|
import wandb |
|
|
import yaml |
|
|
|
|
|
def train(config=None): |
|
|
with wandb.init(config=config): |
|
|
config = wandb.config |
|
|
|
|
|
|
|
|
with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
|
|
words_to_ids = pickle.load(f) |
|
|
vocab_size = len(words_to_ids) |
|
|
embedding_dim = 128 |
|
|
|
|
|
|
|
|
checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
|
|
latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
|
|
state_dict = torch.load(latest_checkpoint,map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
|
embedding_layer = nn.Embedding(vocab_size, embedding_dim) |
|
|
embedding_layer.weight.data.copy_(state_dict['emb.weight']) |
|
|
embedding_layer.weight.requires_grad = False |
|
|
|
|
|
class QryTower(nn.Module): |
|
|
def __init__(self, embedding_layer, hidden_size): |
|
|
super().__init__() |
|
|
self.embedding = embedding_layer |
|
|
self.embedding.weight.requires_grad = False |
|
|
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
if not x: |
|
|
return None |
|
|
x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
|
|
embeds = self.embedding(x) |
|
|
_, h_n = self.rnn(embeds) |
|
|
return h_n.squeeze(0).squeeze(0) |
|
|
|
|
|
class DocTower(nn.Module): |
|
|
def __init__(self, embedding_layer, hidden_size): |
|
|
super().__init__() |
|
|
self.embedding = embedding_layer |
|
|
self.embedding.weight.requires_grad = False |
|
|
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
if not x: |
|
|
return None |
|
|
x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
|
|
embeds = self.embedding(x) |
|
|
_, h_n = self.rnn(embeds) |
|
|
return h_n.squeeze(0).squeeze(0) |
|
|
|
|
|
qryTower = QryTower(embedding_layer, config.hidden_size) |
|
|
docTower = DocTower(embedding_layer, config.hidden_size) |
|
|
|
|
|
|
|
|
with open('tokenized_triples.json', 'r') as f: |
|
|
triples_data = json.load(f) |
|
|
|
|
|
|
|
|
params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters()) |
|
|
optimizer = torch.optim.Adam(params, lr=config.learning_rate) |
|
|
num_epochs = config.num_epochs |
|
|
margin = config.margin |
|
|
print(f"\nTraining on all real triples from the train split with RNN towers for {num_epochs} epochs:\n") |
|
|
for epoch in range(num_epochs): |
|
|
total_loss = 0 |
|
|
count = 0 |
|
|
for triple in triples_data['train']: |
|
|
qry_tokens = triple['query_tokens'] |
|
|
pos_tokens = triple['positive_document_tokens'] |
|
|
neg_tokens = triple['negative_document_tokens'] |
|
|
|
|
|
qry = qryTower(qry_tokens) |
|
|
pos = docTower(pos_tokens) |
|
|
neg = docTower(neg_tokens) |
|
|
|
|
|
if qry is not None and pos is not None and neg is not None: |
|
|
dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0)) |
|
|
dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0)) |
|
|
dst_mrg = torch.tensor(margin) |
|
|
loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg)) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
count += 1 |
|
|
avg_loss = total_loss / max(count,1) |
|
|
print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}") |
|
|
wandb.log({'epoch': epoch+1, 'avg_loss': avg_loss}) |
|
|
|
|
|
|
|
|
print("\nEvaluating on first 5 real triples after training:\n") |
|
|
for i, triple in enumerate(triples_data['train'][:5]): |
|
|
qry_tokens = triple['query_tokens'] |
|
|
pos_tokens = triple['positive_document_tokens'] |
|
|
neg_tokens = triple['negative_document_tokens'] |
|
|
qry_text = triple['query'] |
|
|
pos_text = triple['positive_document'] |
|
|
neg_text = triple['negative_document'] |
|
|
|
|
|
qry = qryTower(qry_tokens) |
|
|
pos = docTower(pos_tokens) |
|
|
neg = docTower(neg_tokens) |
|
|
|
|
|
if qry is not None and pos is not None and neg is not None: |
|
|
dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0)) |
|
|
dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0)) |
|
|
dst_mrg = torch.tensor(margin) |
|
|
loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg)) |
|
|
print(f"Example {i+1}:") |
|
|
print(f"Query: {qry_text}") |
|
|
print(f"Positive doc: {pos_text[:100]}...") |
|
|
print(f"Negative doc: {neg_text[:100]}...") |
|
|
print(f"Cosine similarity (pos): {dst_pos.item():.4f}") |
|
|
print(f"Cosine similarity (neg): {dst_neg.item():.4f}") |
|
|
print(f"Triplet loss: {loss.item():.4f}\n") |
|
|
else: |
|
|
print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import yaml |
|
|
|
|
|
|
|
|
with open('sweep.yaml', 'r') as f: |
|
|
sweep_config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
default_config = { |
|
|
'learning_rate': sweep_config['parameters']['learning_rate']['values'][0], |
|
|
'margin': sweep_config['parameters']['margin']['values'][0], |
|
|
'num_epochs': sweep_config['parameters']['num_epochs']['value'], |
|
|
'num_triples': sweep_config['parameters']['num_triples']['values'][0], |
|
|
'hidden_size': sweep_config['parameters']['hidden_size']['values'][0] |
|
|
} |
|
|
|
|
|
|
|
|
train(config=default_config) |
|
|
|