mlx2 / src /simple_dual_encoder_rnn.py
roshbeed's picture
Upload src/simple_dual_encoder_rnn.py with huggingface_hub
d566338 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import glob
import os
import json
import wandb
import yaml
def train(config=None):
with wandb.init(config=config):
config = wandb.config
# Load the tokenizer (vocab)
with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
words_to_ids = pickle.load(f)
vocab_size = len(words_to_ids)
embedding_dim = 128 # Use the dimension you trained with
# Find the latest CBOW checkpoint
checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
state_dict = torch.load(latest_checkpoint,map_location=torch.device('cpu'))
# Create the embedding layer and load weights
embedding_layer = nn.Embedding(vocab_size, embedding_dim)
embedding_layer.weight.data.copy_(state_dict['emb.weight'])
embedding_layer.weight.requires_grad = False # freeze weights
class QryTower(nn.Module):
def __init__(self, embedding_layer, hidden_size):
super().__init__()
self.embedding = embedding_layer
self.embedding.weight.requires_grad = False
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
def forward(self, x):
if not x:
return None
x = torch.tensor(x, dtype=torch.long).unsqueeze(0) # (1, seq_len)
embeds = self.embedding(x) # (1, seq_len, emb_dim)
_, h_n = self.rnn(embeds) # h_n: (1, batch, hidden_size)
return h_n.squeeze(0).squeeze(0) # (hidden_size,)
class DocTower(nn.Module):
def __init__(self, embedding_layer, hidden_size):
super().__init__()
self.embedding = embedding_layer
self.embedding.weight.requires_grad = False
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
def forward(self, x):
if not x:
return None
x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
embeds = self.embedding(x)
_, h_n = self.rnn(embeds)
return h_n.squeeze(0).squeeze(0)
qryTower = QryTower(embedding_layer, config.hidden_size)
docTower = DocTower(embedding_layer, config.hidden_size)
# Load real tokenized triples (small subset for speed)
with open('tokenized_triples.json', 'r') as f:
triples_data = json.load(f)
# TRAINING LOOP
params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters())
optimizer = torch.optim.Adam(params, lr=config.learning_rate)
num_epochs = config.num_epochs
margin = config.margin
print(f"\nTraining on all real triples from the train split with RNN towers for {num_epochs} epochs:\n")
for epoch in range(num_epochs):
total_loss = 0
count = 0
for triple in triples_data['train']: # Use all triples
qry_tokens = triple['query_tokens']
pos_tokens = triple['positive_document_tokens']
neg_tokens = triple['negative_document_tokens']
qry = qryTower(qry_tokens)
pos = docTower(pos_tokens)
neg = docTower(neg_tokens)
if qry is not None and pos is not None and neg is not None:
dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
dst_mrg = torch.tensor(margin)
loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
count += 1
avg_loss = total_loss / max(count,1)
print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")
wandb.log({'epoch': epoch+1, 'avg_loss': avg_loss})
# EVALUATE ON 5 EXAMPLES
print("\nEvaluating on first 5 real triples after training:\n")
for i, triple in enumerate(triples_data['train'][:5]):
qry_tokens = triple['query_tokens']
pos_tokens = triple['positive_document_tokens']
neg_tokens = triple['negative_document_tokens']
qry_text = triple['query']
pos_text = triple['positive_document']
neg_text = triple['negative_document']
qry = qryTower(qry_tokens)
pos = docTower(pos_tokens)
neg = docTower(neg_tokens)
if qry is not None and pos is not None and neg is not None:
dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0))
dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0))
dst_mrg = torch.tensor(margin)
loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg))
print(f"Example {i+1}:")
print(f"Query: {qry_text}")
print(f"Positive doc: {pos_text[:100]}...")
print(f"Negative doc: {neg_text[:100]}...")
print(f"Cosine similarity (pos): {dst_pos.item():.4f}")
print(f"Cosine similarity (neg): {dst_neg.item():.4f}")
print(f"Triplet loss: {loss.item():.4f}\n")
else:
print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n")
if __name__ == "__main__":
import yaml
# Load configuration from sweep.yaml
with open('sweep.yaml', 'r') as f:
sweep_config = yaml.safe_load(f)
# Create a default config with a single value for each parameter
default_config = {
'learning_rate': sweep_config['parameters']['learning_rate']['values'][0],
'margin': sweep_config['parameters']['margin']['values'][0],
'num_epochs': sweep_config['parameters']['num_epochs']['value'],
'num_triples': sweep_config['parameters']['num_triples']['values'][0],
'hidden_size': sweep_config['parameters']['hidden_size']['values'][0]
}
# Initialize wandb with the default config
train(config=default_config)