roshbeed commited on
Commit
40f7c82
·
verified ·
1 Parent(s): d566338

Upload src/query_redis_ann.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/query_redis_ann.py +131 -0
src/query_redis_ann.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import redis
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host')
12
+ REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345))
13
+ REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password')
14
+ INDEX_NAME = 'doc_index'
15
+ EMBEDDING_DIM = 128
16
+ TOP_K = 5
17
+
18
+ r = redis.Redis(
19
+ host=REDIS_HOST,
20
+ port=REDIS_PORT,
21
+ password=REDIS_PASSWORD,
22
+ decode_responses=False
23
+ )
24
+
25
+ # Ensure RediSearch index exists
26
+ try:
27
+ r.execute_command(
28
+ f"FT.INFO {INDEX_NAME}"
29
+ )
30
+ print(f"Index '{INDEX_NAME}' already exists.")
31
+ except redis.ResponseError:
32
+ print(f"Creating index '{INDEX_NAME}'...")
33
+ r.execute_command(
34
+ f"FT.CREATE {INDEX_NAME} ON HASH PREFIX 1 doc: SCHEMA embedding VECTOR HNSW 6 TYPE FLOAT32 DIM {EMBEDDING_DIM} DISTANCE_METRIC COSINE text TEXT doc_id TAG"
35
+ )
36
+ print(f"Index '{INDEX_NAME}' created.")
37
+
38
+ # Load tokenizer
39
+ with open('cbow/tkn_words_to_ids.pkl', 'rb') as f:
40
+ words_to_ids = pickle.load(f)
41
+ vocab_size = len(words_to_ids)
42
+
43
+ # Load latest CBOW checkpoint for embedding layer
44
+ import glob
45
+ checkpoint_files = glob.glob('cbow/checkpoints/*.pth')
46
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
47
+ state_dict = torch.load(latest_checkpoint, map_location='cpu')
48
+
49
+ embedding_layer = nn.Embedding(vocab_size, EMBEDDING_DIM)
50
+ embedding_layer.weight.data.copy_(state_dict['emb.weight'])
51
+ embedding_layer.weight.requires_grad = False
52
+
53
+ # Define DocTower (same as in save_doc_embeddings_to_redis.py)
54
+ class DocTower(nn.Module):
55
+ def __init__(self, embedding_layer, hidden_size):
56
+ super().__init__()
57
+ self.embedding = embedding_layer
58
+ self.embedding.weight.requires_grad = False
59
+ self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True)
60
+
61
+ def forward(self, x):
62
+ if not x:
63
+ return None
64
+ x = torch.tensor(x, dtype=torch.long).unsqueeze(0)
65
+ embeds = self.embedding(x)
66
+ _, h_n = self.rnn(embeds)
67
+ return h_n.squeeze(0).squeeze(0)
68
+
69
+ hidden_size = EMBEDDING_DIM
70
+
71
+ docTower = DocTower(embedding_layer, hidden_size)
72
+ docTower.eval()
73
+
74
+ # Tokenize query
75
+ def tokenize(text):
76
+ return [words_to_ids.get(w, 0) for w in text.strip().split()]
77
+
78
+ # Interactive query loop
79
+ while True:
80
+ query = input("Enter your query (or 'exit' to quit): ").strip()
81
+ if query.lower() == 'exit':
82
+ break
83
+ tokens = tokenize(query)
84
+ with torch.no_grad():
85
+ query_emb = docTower(tokens).detach().cpu().numpy().astype(np.float32)
86
+ # Redis expects bytes for VECTOR field
87
+ query_emb_bytes = query_emb.tobytes()
88
+ # Perform ANN search
89
+ res = r.execute_command(
90
+ "FT.SEARCH",
91
+ INDEX_NAME,
92
+ f"*=>[KNN {TOP_K} @embedding $vec as score]",
93
+ "RETURN", 2, "text", "score",
94
+ "PARAMS", 2, "vec", query_emb_bytes,
95
+ "DIALECT", 2
96
+ )
97
+ if len(res) <= 1:
98
+ print("No results found.")
99
+ continue
100
+ print(f"Top {TOP_K} results:")
101
+ results = []
102
+ # RediSearch result: [count, doc_id1, [fields...], doc_id2, [fields...], ...]
103
+ for rank, i in enumerate(range(1, len(res)-1, 2), 1):
104
+ doc_id = res[i]
105
+ doc_fields = res[i+1]
106
+ if not isinstance(doc_fields, list) or len(doc_fields) < 2:
107
+ continue
108
+ text = None
109
+ score = None
110
+ for j in range(0, len(doc_fields), 2):
111
+ key = doc_fields[j]
112
+ value = doc_fields[j+1]
113
+ if key == b'text':
114
+ text = value.decode('utf-8', errors='ignore')
115
+ elif key == b'score':
116
+ try:
117
+ score = float(value)
118
+ except Exception:
119
+ score = None
120
+ if score is not None:
121
+ results.append((score, text if text is not None else '[No text found]'))
122
+ if not results:
123
+ print("[Debug] Raw RediSearch result:", res)
124
+ else:
125
+ # Sort by score (ascending: lower cosine distance = more similar)
126
+ results.sort(key=lambda x: x[0])
127
+ for idx, (score, text) in enumerate(results, 1):
128
+ print(f"Rank {idx}: Score={score:.4f}\n{text}\n---")
129
+
130
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
131
+ model.to(device)