Spaces:
Sleeping
Sleeping
Commit
·
7743187
1
Parent(s):
9813925
try to fix ndcg bug
Browse files
app.py
CHANGED
|
@@ -94,7 +94,7 @@ class RepLlamaModel:
|
|
| 94 |
model.eval()
|
| 95 |
return model
|
| 96 |
|
| 97 |
-
def encode(self, texts, batch_size=
|
| 98 |
self.model = self.model.cuda()
|
| 99 |
all_embeddings = []
|
| 100 |
for i in range(0, len(texts), batch_size):
|
|
@@ -108,6 +108,7 @@ class RepLlamaModel:
|
|
| 108 |
outputs = self.model(**batch_dict)
|
| 109 |
embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
| 110 |
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
|
|
|
| 111 |
all_embeddings.append(embeddings.cpu().numpy())
|
| 112 |
|
| 113 |
self.model = self.model.cpu()
|
|
@@ -118,7 +119,7 @@ def load_faiss_index(dataset_name):
|
|
| 118 |
index_path = f"{dataset_name}/faiss_index.bin"
|
| 119 |
if os.path.exists(index_path):
|
| 120 |
logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
|
| 121 |
-
return faiss.read_index(index_path
|
| 122 |
return None
|
| 123 |
|
| 124 |
def search_queries(dataset_name, q_reps, depth=1000):
|
|
@@ -126,16 +127,15 @@ def search_queries(dataset_name, q_reps, depth=1000):
|
|
| 126 |
if faiss_index is None:
|
| 127 |
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
| 128 |
|
| 129 |
-
|
| 130 |
-
q_reps = np.ascontiguousarray(q_reps.astype('float16'))
|
| 131 |
|
| 132 |
# Perform the search
|
| 133 |
all_scores, all_indices = faiss_index.search(q_reps, depth)
|
| 134 |
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
del faiss_index
|
| 139 |
|
| 140 |
return all_scores, np.array(psg_indices)
|
| 141 |
|
|
@@ -149,6 +149,7 @@ def load_corpus_lookups(dataset_name):
|
|
| 149 |
with open(file, 'rb') as f:
|
| 150 |
_, p_lookup = pickle.load(f)
|
| 151 |
corpus_lookups[dataset_name] += p_lookup
|
|
|
|
| 152 |
|
| 153 |
def load_queries(dataset_name):
|
| 154 |
global queries, q_lookups, qrels
|
|
@@ -166,6 +167,9 @@ def load_queries(dataset_name):
|
|
| 166 |
qrels[dataset_name][qrel.query_id] = {}
|
| 167 |
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
|
| 168 |
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
def evaluate(qrels, results, k_values):
|
| 171 |
evaluator = pytrec_eval.RelevanceEvaluator(
|
|
|
|
| 94 |
model.eval()
|
| 95 |
return model
|
| 96 |
|
| 97 |
+
def encode(self, texts, batch_size=16, **kwargs):
|
| 98 |
self.model = self.model.cuda()
|
| 99 |
all_embeddings = []
|
| 100 |
for i in range(0, len(texts), batch_size):
|
|
|
|
| 108 |
outputs = self.model(**batch_dict)
|
| 109 |
embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
| 110 |
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
| 111 |
+
logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
|
| 112 |
all_embeddings.append(embeddings.cpu().numpy())
|
| 113 |
|
| 114 |
self.model = self.model.cpu()
|
|
|
|
| 119 |
index_path = f"{dataset_name}/faiss_index.bin"
|
| 120 |
if os.path.exists(index_path):
|
| 121 |
logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
|
| 122 |
+
return faiss.read_index(index_path)
|
| 123 |
return None
|
| 124 |
|
| 125 |
def search_queries(dataset_name, q_reps, depth=1000):
|
|
|
|
| 127 |
if faiss_index is None:
|
| 128 |
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
| 129 |
|
| 130 |
+
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
|
|
|
| 131 |
|
| 132 |
# Perform the search
|
| 133 |
all_scores, all_indices = faiss_index.search(q_reps, depth)
|
| 134 |
|
| 135 |
+
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
| 136 |
+
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
| 137 |
|
| 138 |
+
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
|
|
|
| 139 |
|
| 140 |
return all_scores, np.array(psg_indices)
|
| 141 |
|
|
|
|
| 149 |
with open(file, 'rb') as f:
|
| 150 |
_, p_lookup = pickle.load(f)
|
| 151 |
corpus_lookups[dataset_name] += p_lookup
|
| 152 |
+
logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
|
| 153 |
|
| 154 |
def load_queries(dataset_name):
|
| 155 |
global queries, q_lookups, qrels
|
|
|
|
| 167 |
qrels[dataset_name][qrel.query_id] = {}
|
| 168 |
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
|
| 169 |
|
| 170 |
+
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
| 171 |
+
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
| 172 |
+
|
| 173 |
|
| 174 |
def evaluate(qrels, results, k_values):
|
| 175 |
evaluator = pytrec_eval.RelevanceEvaluator(
|