roshbeed commited on
Commit
9c79fa8
·
verified ·
1 Parent(s): 40f7c82

Upload src/save_doc_embeddings_to_redis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/save_doc_embeddings_to_redis.py +96 -0
src/save_doc_embeddings_to_redis.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+ import glob
5
+ import json
6
+ import redis
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import os
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ # Redis Cloud connection (replace with your actual credentials or use environment variables)
15
+ REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host')
16
+ REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345))
17
+ REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password')
18
+ r = redis.Redis(
19
+ host=REDIS_HOST,
20
+ port=REDIS_PORT,
21
+ password=REDIS_PASSWORD,
22
+ decode_responses=False # binary-safe
23
+ )
24
+
25
+ # Load tokenizer
26
+ with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
27
+ words_to_ids = pickle.load(f)
28
+ vocab_size = len(words_to_ids)
29
+ embedding_dim = 128 # Change if needed
30
+
31
+ # Load latest CBOW checkpoint for embedding layer
32
+ checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
33
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
34
+ state_dict = torch.load(latest_checkpoint, map_location='cpu')
35
+
36
+ embedding_layer = nn.Embedding(vocab_size, embedding_dim)
37
+ embedding_layer.weight.data.copy_(state_dict['emb.weight'])
38
+ embedding_layer.weight.requires_grad = False
39
+
40
+ # Define DocTower (copy from simple_dual_encoder_rnn.py)
41
+ class DocTower(nn.Module):
42
+ def __init__(self, embedding_layer, hidden_size):
43
+ super().__init__()
44
+ self.embedding = embedding_layer
45
+ self.embedding.weight.requires_grad = False
46
+ self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
47
+
48
+ def forward(self, x):
49
+ if not x:
50
+ return None
51
+ x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
52
+ embeds = self.embedding(x)
53
+ _, h_n = self.rnn(embeds)
54
+ return h_n.squeeze(0).squeeze(0)
55
+
56
+ # Load doc tower weights (if you saved them separately, load here)
57
+ # Otherwise, use the same initialization as in training
58
+ hidden_size = 128 # Set to your trained hidden size
59
+
60
+ docTower = DocTower(embedding_layer, hidden_size)
61
+ # Optionally: docTower.load_state_dict(torch.load('doc_tower.pth'))
62
+ docTower.eval()
63
+
64
+ # Load tokenized triples
65
+ with open('tokenized_triples.json', 'r') as f:
66
+ triples_data = json.load(f)
67
+
68
+ # Collect all unique positive documents
69
+ seen = set()
70
+ documents = []
71
+ for split in ['train', 'validation', 'test']:
72
+ for triple in triples_data[split]:
73
+ doc_text = triple['positive_document']
74
+ doc_tokens = tuple(triple['positive_document_tokens'])
75
+ if doc_tokens not in seen:
76
+ seen.add(doc_tokens)
77
+ documents.append((doc_tokens, doc_text))
78
+
79
+ print(f"Found {len(documents)} unique positive documents.")
80
+
81
+ def save_doc_embedding_to_redis(doc_id, embedding, text):
82
+ r.hset(doc_id, mapping={
83
+ 'embedding': embedding.astype(np.float32).tobytes(),
84
+ 'text': text,
85
+ 'doc_id': doc_id
86
+ })
87
+
88
+ # Compute and save embeddings
89
+ docTower.eval()
90
+ with torch.no_grad():
91
+ for idx, (doc_tokens, doc_text) in enumerate(tqdm(documents, desc='Saving doc embeddings to Redis')):
92
+ embedding = docTower(list(doc_tokens)).detach().cpu().numpy()
93
+ doc_id = f"doc:{idx}"
94
+ save_doc_embedding_to_redis(doc_id, embedding, doc_text)
95
+
96
+ print(f"Saved {len(documents)} doc embeddings to Redis Cloud.")