mlx2 / src /save_doc_embeddings_to_redis.py
roshbeed's picture
Upload src/save_doc_embeddings_to_redis.py with huggingface_hub
9c79fa8 verified
import torch
import torch.nn as nn
import pickle
import glob
import json
import redis
import numpy as np
from tqdm import tqdm
import os
from dotenv import load_dotenv
load_dotenv()
# Redis Cloud connection (replace with your actual credentials or use environment variables)
REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host')
REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345))
REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password')
r = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
decode_responses=False # binary-safe
)
# Load tokenizer
with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
words_to_ids = pickle.load(f)
vocab_size = len(words_to_ids)
embedding_dim = 128 # Change if needed
# Load latest CBOW checkpoint for embedding layer
checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
state_dict = torch.load(latest_checkpoint, map_location='cpu')
embedding_layer = nn.Embedding(vocab_size, embedding_dim)
embedding_layer.weight.data.copy_(state_dict['emb.weight'])
embedding_layer.weight.requires_grad = False
# Define DocTower (copy from simple_dual_encoder_rnn.py)
class DocTower(nn.Module):
def __init__(self, embedding_layer, hidden_size):
super().__init__()
self.embedding = embedding_layer
self.embedding.weight.requires_grad = False
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
def forward(self, x):
if not x:
return None
x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
embeds = self.embedding(x)
_, h_n = self.rnn(embeds)
return h_n.squeeze(0).squeeze(0)
# Load doc tower weights (if you saved them separately, load here)
# Otherwise, use the same initialization as in training
hidden_size = 128 # Set to your trained hidden size
docTower = DocTower(embedding_layer, hidden_size)
# Optionally: docTower.load_state_dict(torch.load('doc_tower.pth'))
docTower.eval()
# Load tokenized triples
with open('tokenized_triples.json', 'r') as f:
triples_data = json.load(f)
# Collect all unique positive documents
seen = set()
documents = []
for split in ['train', 'validation', 'test']:
for triple in triples_data[split]:
doc_text = triple['positive_document']
doc_tokens = tuple(triple['positive_document_tokens'])
if doc_tokens not in seen:
seen.add(doc_tokens)
documents.append((doc_tokens, doc_text))
print(f"Found {len(documents)} unique positive documents.")
def save_doc_embedding_to_redis(doc_id, embedding, text):
r.hset(doc_id, mapping={
'embedding': embedding.astype(np.float32).tobytes(),
'text': text,
'doc_id': doc_id
})
# Compute and save embeddings
docTower.eval()
with torch.no_grad():
for idx, (doc_tokens, doc_text) in enumerate(tqdm(documents, desc='Saving doc embeddings to Redis')):
embedding = docTower(list(doc_tokens)).detach().cpu().numpy()
doc_id = f"doc:{idx}"
save_doc_embedding_to_redis(doc_id, embedding, doc_text)
print(f"Saved {len(documents)} doc embeddings to Redis Cloud.")