|
|
import os |
|
|
import pickle |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import redis |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host') |
|
|
REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345)) |
|
|
REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password') |
|
|
INDEX_NAME = 'doc_index' |
|
|
EMBEDDING_DIM = 128 |
|
|
TOP_K = 5 |
|
|
|
|
|
r = redis.Redis( |
|
|
host=REDIS_HOST, |
|
|
port=REDIS_PORT, |
|
|
password=REDIS_PASSWORD, |
|
|
decode_responses=False |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
r.execute_command( |
|
|
f"FT.INFO {INDEX_NAME}" |
|
|
) |
|
|
print(f"Index '{INDEX_NAME}' already exists.") |
|
|
except redis.ResponseError: |
|
|
print(f"Creating index '{INDEX_NAME}'...") |
|
|
r.execute_command( |
|
|
f"FT.CREATE {INDEX_NAME} ON HASH PREFIX 1 doc: SCHEMA embedding VECTOR HNSW 6 TYPE FLOAT32 DIM {EMBEDDING_DIM} DISTANCE_METRIC COSINE text TEXT doc_id TAG" |
|
|
) |
|
|
print(f"Index '{INDEX_NAME}' created.") |
|
|
|
|
|
|
|
|
with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
|
|
words_to_ids = pickle.load(f) |
|
|
vocab_size = len(words_to_ids) |
|
|
|
|
|
|
|
|
import glob |
|
|
checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
|
|
latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
|
|
state_dict = torch.load(latest_checkpoint, map_location='cpu') |
|
|
|
|
|
embedding_layer = nn.Embedding(vocab_size, EMBEDDING_DIM) |
|
|
embedding_layer.weight.data.copy_(state_dict['emb.weight']) |
|
|
embedding_layer.weight.requires_grad = False |
|
|
|
|
|
|
|
|
class DocTower(nn.Module): |
|
|
def __init__(self, embedding_layer, hidden_size): |
|
|
super().__init__() |
|
|
self.embedding = embedding_layer |
|
|
self.embedding.weight.requires_grad = False |
|
|
self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
if not x: |
|
|
return None |
|
|
x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
|
|
embeds = self.embedding(x) |
|
|
_, h_n = self.rnn(embeds) |
|
|
return h_n.squeeze(0).squeeze(0) |
|
|
|
|
|
hidden_size = EMBEDDING_DIM |
|
|
|
|
|
docTower = DocTower(embedding_layer, hidden_size) |
|
|
docTower.eval() |
|
|
|
|
|
|
|
|
def tokenize(text): |
|
|
return [words_to_ids.get(w, 0) for w in text.strip().split()] |
|
|
|
|
|
|
|
|
while True: |
|
|
query = input("Enter your query (or 'exit' to quit): ").strip() |
|
|
if query.lower() == 'exit': |
|
|
break |
|
|
tokens = tokenize(query) |
|
|
with torch.no_grad(): |
|
|
query_emb = docTower(tokens).detach().cpu().numpy().astype(np.float32) |
|
|
|
|
|
query_emb_bytes = query_emb.tobytes() |
|
|
|
|
|
res = r.execute_command( |
|
|
"FT.SEARCH", |
|
|
INDEX_NAME, |
|
|
f"*=>[KNN {TOP_K} @embedding $vec as score]", |
|
|
"RETURN", 2, "text", "score", |
|
|
"PARAMS", 2, "vec", query_emb_bytes, |
|
|
"DIALECT", 2 |
|
|
) |
|
|
if len(res) <= 1: |
|
|
print("No results found.") |
|
|
continue |
|
|
print(f"Top {TOP_K} results:") |
|
|
results = [] |
|
|
|
|
|
for rank, i in enumerate(range(1, len(res)-1, 2), 1): |
|
|
doc_id = res[i] |
|
|
doc_fields = res[i+1] |
|
|
if not isinstance(doc_fields, list) or len(doc_fields) < 2: |
|
|
continue |
|
|
text = None |
|
|
score = None |
|
|
for j in range(0, len(doc_fields), 2): |
|
|
key = doc_fields[j] |
|
|
value = doc_fields[j+1] |
|
|
if key == b'text': |
|
|
text = value.decode('utf-8', errors='ignore') |
|
|
elif key == b'score': |
|
|
try: |
|
|
score = float(value) |
|
|
except Exception: |
|
|
score = None |
|
|
if score is not None: |
|
|
results.append((score, text if text is not None else '[No text found]')) |
|
|
if not results: |
|
|
print("[Debug] Raw RediSearch result:", res) |
|
|
else: |
|
|
|
|
|
results.sort(key=lambda x: x[0]) |
|
|
for idx, (score, text) in enumerate(results, 1): |
|
|
print(f"Rank {idx}: Score={score:.4f}\n{text}\n---") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model.to(device) |