skodan commited on
Commit
f03af4a
·
verified ·
1 Parent(s): 77b159a

Update models/resnet_lstm_attention/retrieval.py

Browse files
models/resnet_lstm_attention/retrieval.py CHANGED
@@ -57,3 +57,21 @@ class RetrievalService:
57
  results = [self.text_id_map[i] for i in idxs[0]]
58
  print(f"DEBUG: Returning results: {results}")
59
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  results = [self.text_id_map[i] for i in idxs[0]]
58
  print(f"DEBUG: Returning results: {results}")
59
  return results
60
+
61
+ def text_to_text(self, text: str, top_k: int = 5):
62
+ with torch.no_grad():
63
+ emb = self.clip_model.encode_text(text).cpu().numpy()
64
+ emb = self._normalize(emb)
65
+
66
+ scores, idxs = self.text_index.search(emb, top_k)
67
+
68
+ results = []
69
+ for j, i in enumerate(idxs[0]):
70
+ caption = self.text_id_map[i] # assuming text_id_map stores the actual caption string
71
+ results.append({
72
+ "caption": caption,
73
+ "score": float(scores[0][j])
74
+ })
75
+
76
+ print(f"DEBUG: Text-to-text results: {results}")
77
+ return results