|
|
import torch |
|
|
import torch.nn as nn |
|
|
import pickle |
|
|
import glob |
|
|
import json |
|
|
import redis |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host') |
|
|
REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345)) |
|
|
REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password') |
|
|
r = redis.Redis( |
|
|
host=REDIS_HOST, |
|
|
port=REDIS_PORT, |
|
|
password=REDIS_PASSWORD, |
|
|
decode_responses=False |
|
|
) |
|
|
|
|
|
|
|
|
with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
|
|
words_to_ids = pickle.load(f) |
|
|
vocab_size = len(words_to_ids) |
|
|
embedding_dim = 128 |
|
|
|
|
|
|
|
|
checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
|
|
latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
|
|
state_dict = torch.load(latest_checkpoint, map_location='cpu') |
|
|
|
|
|
embedding_layer = nn.Embedding(vocab_size, embedding_dim) |
|
|
embedding_layer.weight.data.copy_(state_dict['emb.weight']) |
|
|
embedding_layer.weight.requires_grad = False |
|
|
|
|
|
|
|
|
class DocTower(nn.Module): |
|
|
def __init__(self, embedding_layer, hidden_size): |
|
|
super().__init__() |
|
|
self.embedding = embedding_layer |
|
|
self.embedding.weight.requires_grad = False |
|
|
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
if not x: |
|
|
return None |
|
|
x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
|
|
embeds = self.embedding(x) |
|
|
_, h_n = self.rnn(embeds) |
|
|
return h_n.squeeze(0).squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
hidden_size = 128 |
|
|
|
|
|
docTower = DocTower(embedding_layer, hidden_size) |
|
|
|
|
|
docTower.eval() |
|
|
|
|
|
|
|
|
with open('tokenized_triples.json', 'r') as f: |
|
|
triples_data = json.load(f) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
documents = [] |
|
|
for split in ['train', 'validation', 'test']: |
|
|
for triple in triples_data[split]: |
|
|
doc_text = triple['positive_document'] |
|
|
doc_tokens = tuple(triple['positive_document_tokens']) |
|
|
if doc_tokens not in seen: |
|
|
seen.add(doc_tokens) |
|
|
documents.append((doc_tokens, doc_text)) |
|
|
|
|
|
print(f"Found {len(documents)} unique positive documents.") |
|
|
|
|
|
def save_doc_embedding_to_redis(doc_id, embedding, text): |
|
|
r.hset(doc_id, mapping={ |
|
|
'embedding': embedding.astype(np.float32).tobytes(), |
|
|
'text': text, |
|
|
'doc_id': doc_id |
|
|
}) |
|
|
|
|
|
|
|
|
docTower.eval() |
|
|
with torch.no_grad(): |
|
|
for idx, (doc_tokens, doc_text) in enumerate(tqdm(documents, desc='Saving doc embeddings to Redis')): |
|
|
embedding = docTower(list(doc_tokens)).detach().cpu().numpy() |
|
|
doc_id = f"doc:{idx}" |
|
|
save_doc_embedding_to_redis(doc_id, embedding, doc_text) |
|
|
|
|
|
print(f"Saved {len(documents)} doc embeddings to Redis Cloud.") |