Spaces:
Sleeping
Sleeping
Commit
·
c837e28
1
Parent(s):
05ff7af
load faiss
Browse files
app.py
CHANGED
|
@@ -7,7 +7,6 @@ import torch
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from transformers import AutoTokenizer, AutoModel, set_seed
|
| 9 |
from peft import PeftModel
|
| 10 |
-
from tevatron.retriever.searcher import FaissFlatSearcher
|
| 11 |
import logging
|
| 12 |
import os
|
| 13 |
import json
|
|
@@ -47,7 +46,6 @@ current_dataset = "scifact"
|
|
| 47 |
def log_system_info():
|
| 48 |
logger.info("System Information:")
|
| 49 |
logger.info(f"Python version: {sys.version}")
|
| 50 |
-
# logger.info(f"Platform: {platform.platform()}")
|
| 51 |
|
| 52 |
logger.info("\nPackage Versions:")
|
| 53 |
logger.info(f"torch: {torch.__version__}")
|
|
@@ -55,7 +53,6 @@ def log_system_info():
|
|
| 55 |
logger.info(f"peft: {peft.__version__}")
|
| 56 |
logger.info(f"faiss: {faiss.__version__}")
|
| 57 |
logger.info(f"gradio: {gr.__version__}")
|
| 58 |
-
# logger.info(f"pytrec_eval: {pytrec_eval.__version__}")
|
| 59 |
logger.info(f"ir_datasets: {ir_datasets.__version__}")
|
| 60 |
|
| 61 |
if torch.cuda.is_available():
|
|
@@ -70,11 +67,8 @@ def log_system_info():
|
|
| 70 |
logger.info("\nCUDA Information:")
|
| 71 |
logger.info("CUDA available: No")
|
| 72 |
|
| 73 |
-
|
| 74 |
log_system_info()
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
def pool(last_hidden_states, attention_mask, pool_type="last"):
|
| 79 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 80 |
|
|
@@ -151,18 +145,45 @@ class RepLlamaModel:
|
|
| 151 |
self.model = self.model.cpu()
|
| 152 |
return np.concatenate(all_embeddings, axis=0)
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
def
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
return None
|
| 161 |
|
| 162 |
def search_queries(dataset_name, q_reps, depth=1000):
|
| 163 |
-
faiss_index =
|
| 164 |
-
if faiss_index is None:
|
| 165 |
-
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
| 166 |
|
| 167 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
| 168 |
|
|
@@ -171,28 +192,11 @@ def search_queries(dataset_name, q_reps, depth=1000):
|
|
| 171 |
|
| 172 |
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
| 173 |
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
| 174 |
-
|
| 175 |
|
| 176 |
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
| 177 |
|
| 178 |
return all_scores, np.array(psg_indices)
|
| 179 |
|
| 180 |
-
def load_corpus_lookups(dataset_name):
|
| 181 |
-
global corpus_lookups
|
| 182 |
-
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
|
| 183 |
-
index_files = glob.glob(corpus_path)
|
| 184 |
-
# sort them
|
| 185 |
-
index_files.sort(key=lambda x: int(x.split('.')[-2]))
|
| 186 |
-
|
| 187 |
-
corpus_lookups[dataset_name] = []
|
| 188 |
-
for file in index_files:
|
| 189 |
-
with open(file, 'rb') as f:
|
| 190 |
-
_, p_lookup = pickle.load(f)
|
| 191 |
-
corpus_lookups[dataset_name] += p_lookup
|
| 192 |
-
|
| 193 |
-
logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
|
| 194 |
-
logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][:10]}")
|
| 195 |
-
|
| 196 |
def load_queries(dataset_name):
|
| 197 |
global queries, q_lookups, qrels, query2qid
|
| 198 |
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
|
|
@@ -214,7 +218,6 @@ def load_queries(dataset_name):
|
|
| 214 |
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
| 215 |
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
| 216 |
|
| 217 |
-
|
| 218 |
def evaluate(qrels, results, k_values):
|
| 219 |
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
|
| 220 |
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
|
|
@@ -273,7 +276,6 @@ def run_evaluation(dataset, postfix):
|
|
| 273 |
logger.info(f"Number of results: {len(results)}")
|
| 274 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
| 275 |
|
| 276 |
-
# Add these lines
|
| 277 |
logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
|
| 278 |
logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
|
| 279 |
logger.info(f"Number of queries in results: {len(results)}")
|
|
@@ -293,13 +295,10 @@ def run_evaluation(dataset, postfix):
|
|
| 293 |
def gradio_interface(dataset, postfix):
|
| 294 |
return run_evaluation(dataset, postfix)
|
| 295 |
|
| 296 |
-
|
| 297 |
if model is None:
|
| 298 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
| 299 |
-
load_corpus_lookups(current_dataset)
|
| 300 |
load_queries(current_dataset)
|
| 301 |
|
| 302 |
-
|
| 303 |
# Create Gradio interface
|
| 304 |
iface = gr.Interface(
|
| 305 |
fn=gradio_interface,
|
|
@@ -318,4 +317,4 @@ iface = gr.Interface(
|
|
| 318 |
)
|
| 319 |
|
| 320 |
# Launch the interface
|
| 321 |
-
iface.launch()
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from transformers import AutoTokenizer, AutoModel, set_seed
|
| 9 |
from peft import PeftModel
|
|
|
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
import json
|
|
|
|
| 46 |
def log_system_info():
|
| 47 |
logger.info("System Information:")
|
| 48 |
logger.info(f"Python version: {sys.version}")
|
|
|
|
| 49 |
|
| 50 |
logger.info("\nPackage Versions:")
|
| 51 |
logger.info(f"torch: {torch.__version__}")
|
|
|
|
| 53 |
logger.info(f"peft: {peft.__version__}")
|
| 54 |
logger.info(f"faiss: {faiss.__version__}")
|
| 55 |
logger.info(f"gradio: {gr.__version__}")
|
|
|
|
| 56 |
logger.info(f"ir_datasets: {ir_datasets.__version__}")
|
| 57 |
|
| 58 |
if torch.cuda.is_available():
|
|
|
|
| 67 |
logger.info("\nCUDA Information:")
|
| 68 |
logger.info("CUDA available: No")
|
| 69 |
|
|
|
|
| 70 |
log_system_info()
|
| 71 |
|
|
|
|
|
|
|
| 72 |
def pool(last_hidden_states, attention_mask, pool_type="last"):
|
| 73 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 74 |
|
|
|
|
| 145 |
self.model = self.model.cpu()
|
| 146 |
return np.concatenate(all_embeddings, axis=0)
|
| 147 |
|
| 148 |
+
def load_corpus_embeddings(dataset_name):
|
| 149 |
+
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
|
| 150 |
+
index_files = glob.glob(corpus_path)
|
| 151 |
+
index_files.sort(key=lambda x: int(x.split('.')[-2]))
|
| 152 |
+
|
| 153 |
+
all_embeddings = []
|
| 154 |
+
corpus_lookups = []
|
| 155 |
+
|
| 156 |
+
for file in index_files:
|
| 157 |
+
with open(file, 'rb') as f:
|
| 158 |
+
embeddings, p_lookup = pickle.load(f)
|
| 159 |
+
all_embeddings.append(embeddings)
|
| 160 |
+
corpus_lookups.extend(p_lookup)
|
| 161 |
+
|
| 162 |
+
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
| 163 |
+
logger.info(f"Loaded corpus embeddings for {dataset_name}. Shape: {all_embeddings.shape}")
|
| 164 |
+
|
| 165 |
+
return all_embeddings, corpus_lookups
|
| 166 |
+
|
| 167 |
+
def create_faiss_index(embeddings):
|
| 168 |
+
dimension = embeddings.shape[1]
|
| 169 |
+
index = faiss.IndexFlatIP(dimension)
|
| 170 |
+
index.add(embeddings)
|
| 171 |
+
logger.info(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
|
| 172 |
+
return index
|
| 173 |
+
|
| 174 |
+
def load_or_create_faiss_index(dataset_name):
|
| 175 |
+
embeddings, corpus_lookups = load_corpus_embeddings(dataset_name)
|
| 176 |
+
index = create_faiss_index(embeddings)
|
| 177 |
+
return index, corpus_lookups
|
| 178 |
|
| 179 |
+
def initialize_faiss_and_corpus(dataset_name):
|
| 180 |
+
global corpus_lookups
|
| 181 |
+
index, corpus_lookups[dataset_name] = load_or_create_faiss_index(dataset_name)
|
| 182 |
+
logger.info(f"Initialized FAISS index and corpus lookups for {dataset_name}")
|
| 183 |
+
return index
|
|
|
|
| 184 |
|
| 185 |
def search_queries(dataset_name, q_reps, depth=1000):
|
| 186 |
+
faiss_index = initialize_faiss_and_corpus(dataset_name)
|
|
|
|
|
|
|
| 187 |
|
| 188 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
| 189 |
|
|
|
|
| 192 |
|
| 193 |
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
| 194 |
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
|
|
|
| 195 |
|
| 196 |
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
| 197 |
|
| 198 |
return all_scores, np.array(psg_indices)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def load_queries(dataset_name):
|
| 201 |
global queries, q_lookups, qrels, query2qid
|
| 202 |
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
|
|
|
|
| 218 |
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
| 219 |
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
| 220 |
|
|
|
|
| 221 |
def evaluate(qrels, results, k_values):
|
| 222 |
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
|
| 223 |
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
|
|
|
|
| 276 |
logger.info(f"Number of results: {len(results)}")
|
| 277 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
| 278 |
|
|
|
|
| 279 |
logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
|
| 280 |
logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
|
| 281 |
logger.info(f"Number of queries in results: {len(results)}")
|
|
|
|
| 295 |
def gradio_interface(dataset, postfix):
|
| 296 |
return run_evaluation(dataset, postfix)
|
| 297 |
|
|
|
|
| 298 |
if model is None:
|
| 299 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
|
|
|
| 300 |
load_queries(current_dataset)
|
| 301 |
|
|
|
|
| 302 |
# Create Gradio interface
|
| 303 |
iface = gr.Interface(
|
| 304 |
fn=gradio_interface,
|
|
|
|
| 317 |
)
|
| 318 |
|
| 319 |
# Launch the interface
|
| 320 |
+
iface.launch(share=False)
|