Spaces:
Sleeping
Sleeping
Update models/resnet_lstm_attention/retrieval.py
Browse files
models/resnet_lstm_attention/retrieval.py
CHANGED
|
@@ -57,3 +57,21 @@ class RetrievalService:
|
|
| 57 |
results = [self.text_id_map[i] for i in idxs[0]]
|
| 58 |
print(f"DEBUG: Returning results: {results}")
|
| 59 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
results = [self.text_id_map[i] for i in idxs[0]]
|
| 58 |
print(f"DEBUG: Returning results: {results}")
|
| 59 |
return results
|
| 60 |
+
|
| 61 |
+
def text_to_text(self, text: str, top_k: int = 5):
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
emb = self.clip_model.encode_text(text).cpu().numpy()
|
| 64 |
+
emb = self._normalize(emb)
|
| 65 |
+
|
| 66 |
+
scores, idxs = self.text_index.search(emb, top_k)
|
| 67 |
+
|
| 68 |
+
results = []
|
| 69 |
+
for j, i in enumerate(idxs[0]):
|
| 70 |
+
caption = self.text_id_map[i] # assuming text_id_map stores the actual caption string
|
| 71 |
+
results.append({
|
| 72 |
+
"caption": caption,
|
| 73 |
+
"score": float(scores[0][j])
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
print(f"DEBUG: Text-to-text results: {results}")
|
| 77 |
+
return results
|