roshbeed commited on
Commit
3e28790
·
verified ·
1 Parent(s): 148ebb0

Upload src/create_triple_embeddings.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/create_triple_embeddings.py +146 -0
src/create_triple_embeddings.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ import glob
7
+ import os
8
+ import redis
9
+ import numpy as np
10
+
11
+ # Redis Cloud connection (replace with your actual credentials or use environment variables)
12
+ REDIS_HOST = 'your-redis-host'
13
+ REDIS_PORT = 12345 # your-redis-port
14
+ REDIS_PASSWORD = 'your-redis-password'
15
+ INDEX_NAME = 'doc_index'
16
+ VECTOR_DIM = 128 # Change if your embedding size is different
17
+
18
+ r = redis.Redis(
19
+ host=REDIS_HOST,
20
+ port=REDIS_PORT,
21
+ password=REDIS_PASSWORD,
22
+ decode_responses=False # binary-safe
23
+ )
24
+
25
+ def load_latest_checkpoint():
26
+ """Load the latest CBOW model checkpoint."""
27
+ print("Loading latest CBOW checkpoint...")
28
+ checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
29
+ if not checkpoint_files:
30
+ raise FileNotFoundError("No checkpoint files found in cbow/checkpoints/")
31
+
32
+ # Get the latest checkpoint
33
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
34
+ print(f"Using checkpoint: {latest_checkpoint}")
35
+
36
+ # Load the model state
37
+ state_dict = torch.load(latest_checkpoint)
38
+ return state_dict
39
+
40
+ def load_tokenizer():
41
+ """Load the CBOW tokenizer mappings."""
42
+ print("Loading tokenizer...")
43
+ with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
44
+ words_to_ids = pickle.load(f)
45
+ with open('cbow/tkn_ids_to_words.pkl', 'rb') as f:
46
+ ids_to_words = pickle.load(f)
47
+ return words_to_ids, ids_to_words
48
+
49
+ def load_tokenized_triples():
50
+ """Load the tokenized triples."""
51
+ print("Loading tokenized triples...")
52
+ with open('tokenized_triples.json', 'r') as f:
53
+ data = json.load(f)
54
+ return data
55
+
56
+ def create_embedding_layer(state_dict, vocab_size, embedding_dim=128):
57
+ """Create embedding layer from CBOW weights."""
58
+ embedding = nn.Embedding(vocab_size, embedding_dim)
59
+ # Extract embedding weights from state dict
60
+ embedding.weight.data.copy_(state_dict['emb.weight'])
61
+ # Freeze the embeddings
62
+ embedding.weight.requires_grad = False
63
+ return embedding
64
+
65
+ def average_pool(tokens, embedding_layer):
66
+ """Create average pooled vector for a list of tokens."""
67
+ # Convert tokens to tensor
68
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long)
69
+ # Get embeddings
70
+ embeddings = embedding_layer(tokens_tensor)
71
+ # Average the embeddings
72
+ return torch.mean(embeddings, dim=0).detach().numpy()
73
+
74
+ def save_doc_embedding_to_redis(doc_id, embedding, text):
75
+ # Save as a Redis hash for vector search
76
+ r.hset(doc_id, mapping={
77
+ 'embedding': embedding.astype(np.float32).tobytes(),
78
+ 'text': text,
79
+ 'doc_id': doc_id
80
+ })
81
+
82
+ # Optionally, you can print or log
83
+ # print(f"Saved doc {doc_id} to Redis.")
84
+
85
+ def process_triples(data, embedding_layer):
86
+ """Process triples and create average pooled vectors. Save positive doc embeddings to Redis."""
87
+ processed_data = {
88
+ 'train': [],
89
+ 'validation': [],
90
+ 'test': []
91
+ }
92
+ doc_counter = 0
93
+ for split in ['train', 'validation', 'test']:
94
+ print(f"\nProcessing {split} split...")
95
+ for triple in tqdm(data[split]):
96
+ # Get average pooled vectors
97
+ query_vector = average_pool(triple['query_tokens'], embedding_layer)
98
+ pos_doc_vector = average_pool(triple['positive_document_tokens'], embedding_layer)
99
+ neg_doc_vector = average_pool(triple['negative_document_tokens'], embedding_layer)
100
+
101
+ # Save positive doc embedding to Redis
102
+ doc_id = f"doc:{doc_counter}"
103
+ save_doc_embedding_to_redis(doc_id, pos_doc_vector, triple['positive_document'])
104
+ doc_counter += 1
105
+
106
+ processed_data[split].append({
107
+ 'query_vector': query_vector.tolist(),
108
+ 'positive_document_vector': pos_doc_vector.tolist(),
109
+ 'negative_document_vector': neg_doc_vector.tolist(),
110
+ 'query': triple['query'], # Keep original text for reference
111
+ 'positive_document': triple['positive_document'],
112
+ 'negative_document': triple['negative_document']
113
+ })
114
+ return processed_data
115
+
116
+ def main():
117
+ # Load data and model
118
+ state_dict = load_latest_checkpoint()
119
+ words_to_ids, ids_to_words = load_tokenizer()
120
+ data = load_tokenized_triples()
121
+
122
+ # Create embedding layer from CBOW weights
123
+ vocab_size = len(words_to_ids)
124
+ embedding_layer = create_embedding_layer(state_dict, vocab_size)
125
+
126
+ # Process triples
127
+ processed_data = process_triples(data, embedding_layer)
128
+
129
+ # Save processed data
130
+ print("\nSaving processed data...")
131
+ with open('triple_embeddings_cbow.json', 'w') as f:
132
+ json.dump(processed_data, f)
133
+
134
+ # Print statistics
135
+ for split in ['train', 'validation', 'test']:
136
+ print(f"\n{split.upper()} split:")
137
+ print(f"Number of processed triples: {len(processed_data[split])}")
138
+ if processed_data[split]:
139
+ sample = processed_data[split][0]
140
+ print("\nSample vector shapes:")
141
+ print("Query vector shape:", len(sample['query_vector']))
142
+ print("Positive doc vector shape:", len(sample['positive_document_vector']))
143
+ print("Negative doc vector shape:", len(sample['negative_document_vector']))
144
+
145
+ if __name__ == "__main__":
146
+ main()