Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,12 +2,10 @@ import gradio as gr
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
-
from tqdm.auto import tqdm
|
| 6 |
-
import os
|
| 7 |
-
import ir_datasets
|
| 8 |
-
import random # Added for random selection
|
| 9 |
|
| 10 |
-
# --- Model Loading
|
| 11 |
tokenizer_splade = None
|
| 12 |
model_splade = None
|
| 13 |
tokenizer_splade_lexical = None
|
|
@@ -48,44 +46,7 @@ except Exception as e:
|
|
| 48 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 49 |
|
| 50 |
|
| 51 |
-
# ---
|
| 52 |
-
document_representations = {} # Stores {doc_id: sparse_vector}
|
| 53 |
-
document_texts = {} # Stores {doc_id: doc_text}
|
| 54 |
-
queries_texts = {} # Stores {query_id: query_text}
|
| 55 |
-
qrels_data = {} # Stores {query_id: [{doc_id: str, relevance: int}, ...]}
|
| 56 |
-
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# --- Load Cranfield Corpus, Queries, and Qrels using ir_datasets ---
|
| 60 |
-
def load_cranfield_corpus_ir_datasets():
|
| 61 |
-
global document_texts, queries_texts, qrels_data
|
| 62 |
-
print("Loading Cranfield corpus, queries, and qrels using ir_datasets...")
|
| 63 |
-
try:
|
| 64 |
-
dataset = ir_datasets.load("cranfield")
|
| 65 |
-
|
| 66 |
-
# Load documents
|
| 67 |
-
for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
|
| 68 |
-
document_texts[doc.doc_id] = doc.text.strip()
|
| 69 |
-
print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
|
| 70 |
-
|
| 71 |
-
# Load queries
|
| 72 |
-
for query in tqdm(dataset.queries_iter(), desc="Loading Cranfield queries"):
|
| 73 |
-
queries_texts[query.query_id] = query.text.strip()
|
| 74 |
-
print(f"Loaded {len(queries_texts)} queries from Cranfield corpus.")
|
| 75 |
-
|
| 76 |
-
# Load qrels
|
| 77 |
-
for qrel in tqdm(dataset.qrels_iter(), desc="Loading Cranfield qrels"):
|
| 78 |
-
if qrel.query_id not in qrels_data:
|
| 79 |
-
qrels_data[qrel.query_id] = []
|
| 80 |
-
qrels_data[qrel.query_id].append({"doc_id": qrel.doc_id, "relevance": qrel.relevance})
|
| 81 |
-
print(f"Loaded qrels for {len(qrels_data)} queries.")
|
| 82 |
-
|
| 83 |
-
except Exception as e:
|
| 84 |
-
print(f"Error loading Cranfield corpus with ir_datasets: {e}")
|
| 85 |
-
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# --- Helper function for lexical mask (now handles batches) ---
|
| 89 |
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
| 90 |
"""
|
| 91 |
Creates a batch of lexical BOW masks.
|
|
@@ -118,7 +79,7 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
|
| 118 |
|
| 119 |
|
| 120 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
| 121 |
-
# These functions
|
| 122 |
def get_splade_cocondenser_representation(text):
|
| 123 |
if tokenizer_splade is None or model_splade is None:
|
| 124 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
@@ -284,270 +245,10 @@ def predict_representation_explorer(model_choice, text):
|
|
| 284 |
return "Please select a model."
|
| 285 |
|
| 286 |
|
| 287 |
-
# --- Internal Core Representation Functions (now handle batches) ---
|
| 288 |
-
def get_splade_cocondenser_representation_internal(texts, tokenizer, model):
|
| 289 |
-
"""
|
| 290 |
-
Generates SPLADE representations for a batch of texts.
|
| 291 |
-
texts: list of strings
|
| 292 |
-
tokenizer: the tokenizer object
|
| 293 |
-
model: the SPLADE model
|
| 294 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
| 295 |
-
"""
|
| 296 |
-
if tokenizer is None or model is None: return None
|
| 297 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
| 298 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 299 |
-
|
| 300 |
-
with torch.no_grad():
|
| 301 |
-
output = model(**inputs)
|
| 302 |
-
|
| 303 |
-
if hasattr(output, 'logits'):
|
| 304 |
-
# torch.max(..., dim=1)[0] reduces along sequence_length dimension,
|
| 305 |
-
# resulting in (batch_size, vocab_size)
|
| 306 |
-
splade_vectors = torch.max(
|
| 307 |
-
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 308 |
-
dim=1
|
| 309 |
-
)[0]
|
| 310 |
-
return splade_vectors
|
| 311 |
-
else:
|
| 312 |
-
print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
|
| 313 |
-
return None
|
| 314 |
-
|
| 315 |
-
def get_splade_lexical_representation_internal(texts, tokenizer, model):
|
| 316 |
-
"""
|
| 317 |
-
Generates SPLADE-Lexical representations for a batch of texts.
|
| 318 |
-
texts: list of strings
|
| 319 |
-
tokenizer: the tokenizer object
|
| 320 |
-
model: the SPLADE-Lexical model
|
| 321 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
| 322 |
-
"""
|
| 323 |
-
if tokenizer is None or model is None: return None
|
| 324 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
| 325 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 326 |
-
with torch.no_grad(): output = model(**inputs)
|
| 327 |
-
if hasattr(output, 'logits'):
|
| 328 |
-
splade_vectors = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0]
|
| 329 |
-
vocab_size = tokenizer.vocab_size
|
| 330 |
-
# create_lexical_bow_mask now returns (batch_size, vocab_size)
|
| 331 |
-
bow_masks = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
|
| 332 |
-
splade_vectors = splade_vectors * bow_masks # Element-wise multiplication, shapes (batch_size, vocab_size)
|
| 333 |
-
return splade_vectors
|
| 334 |
-
else:
|
| 335 |
-
print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
|
| 336 |
-
return None
|
| 337 |
-
|
| 338 |
-
def get_splade_doc_representation_internal(texts, tokenizer, model):
|
| 339 |
-
"""
|
| 340 |
-
Generates SPLADE-Doc (binary) representations for a batch of texts.
|
| 341 |
-
texts: list of strings
|
| 342 |
-
tokenizer: the tokenizer object
|
| 343 |
-
model: the SPLADE-Doc model (not directly used for logits, but for device)
|
| 344 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
| 345 |
-
"""
|
| 346 |
-
if tokenizer is None or model is None: return None
|
| 347 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
| 348 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device
|
| 349 |
-
vocab_size = tokenizer.vocab_size
|
| 350 |
-
# create_lexical_bow_mask now returns (batch_size, vocab_size)
|
| 351 |
-
binary_splade_vectors = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
|
| 352 |
-
return binary_splade_vectors
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
# --- Document Indexing Function (now uses batching) ---
|
| 356 |
-
def index_documents(doc_model_choice):
|
| 357 |
-
global document_representations
|
| 358 |
-
if document_representations:
|
| 359 |
-
print("Documents already indexed. Skipping re-indexing.")
|
| 360 |
-
return True
|
| 361 |
-
|
| 362 |
-
tokenizer_to_use = None
|
| 363 |
-
model_to_use = None
|
| 364 |
-
representation_func_to_use = None
|
| 365 |
-
|
| 366 |
-
if doc_model_choice == "SPLADE-cocondenser-distil":
|
| 367 |
-
if tokenizer_splade is None or model_splade is None:
|
| 368 |
-
print("SPLADE-cocondenser-distil model not loaded for indexing.")
|
| 369 |
-
return False
|
| 370 |
-
tokenizer_to_use = tokenizer_splade
|
| 371 |
-
model_to_use = model_splade
|
| 372 |
-
representation_func_to_use = get_splade_cocondenser_representation_internal
|
| 373 |
-
elif doc_model_choice == "SPLADE-v3-Lexical":
|
| 374 |
-
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 375 |
-
print("SPLADE-v3-Lexical model not loaded for indexing.")
|
| 376 |
-
return False
|
| 377 |
-
tokenizer_to_use = tokenizer_splade_lexical
|
| 378 |
-
model_to_use = model_splade_lexical
|
| 379 |
-
representation_func_to_use = get_splade_lexical_representation_internal
|
| 380 |
-
elif doc_model_choice == "SPLADE-v3-Doc":
|
| 381 |
-
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 382 |
-
print("SPLADE-v3-Doc model not loaded for indexing.")
|
| 383 |
-
return False
|
| 384 |
-
tokenizer_to_use = tokenizer_splade_doc
|
| 385 |
-
model_to_use = model_splade_doc
|
| 386 |
-
representation_func_to_use = get_splade_doc_representation_internal
|
| 387 |
-
else:
|
| 388 |
-
print(f"Invalid model choice for document indexing: {doc_model_choice}")
|
| 389 |
-
return False
|
| 390 |
-
|
| 391 |
-
print(f"Indexing documents using {doc_model_choice}...")
|
| 392 |
-
|
| 393 |
-
doc_ids_list = list(document_texts.keys())
|
| 394 |
-
doc_texts_list = list(document_texts.values())
|
| 395 |
-
|
| 396 |
-
# --- BATCH SIZE FOR INDEXING ---
|
| 397 |
-
batch_size = 32 # You can adjust this value based on memory and performance
|
| 398 |
-
|
| 399 |
-
document_representations = {} # Ensure it's clear we're (re)building the index
|
| 400 |
-
|
| 401 |
-
# Iterate through documents in batches
|
| 402 |
-
for i in tqdm(range(0, len(doc_ids_list), batch_size), desc="Indexing Documents in Batches"):
|
| 403 |
-
batch_doc_ids = doc_ids_list[i:i + batch_size]
|
| 404 |
-
batch_doc_texts = doc_texts_list[i:i + batch_size]
|
| 405 |
-
|
| 406 |
-
sparse_vectors_batch = representation_func_to_use(batch_doc_texts, tokenizer_to_use, model_to_use)
|
| 407 |
-
|
| 408 |
-
if sparse_vectors_batch is not None:
|
| 409 |
-
# sparse_vectors_batch will have shape (batch_size, vocab_size)
|
| 410 |
-
for j, doc_id in enumerate(batch_doc_ids):
|
| 411 |
-
# Store each document's vector
|
| 412 |
-
document_representations[doc_id] = sparse_vectors_batch[j].cpu()
|
| 413 |
-
else:
|
| 414 |
-
print(f"Warning: Failed to get representation for a batch starting with doc_id {batch_doc_ids[0]}")
|
| 415 |
-
|
| 416 |
-
print(f"Finished indexing {len(document_representations)} documents.")
|
| 417 |
-
return True
|
| 418 |
-
|
| 419 |
-
# --- Retrieval Function (for Retrieval Tab) ---
|
| 420 |
-
def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5):
|
| 421 |
-
if not document_representations:
|
| 422 |
-
return "Document index is not loaded or empty. Please ensure documents are indexed.", []
|
| 423 |
-
|
| 424 |
-
query_vector = None
|
| 425 |
-
query_tokenizer = None
|
| 426 |
-
query_model = None
|
| 427 |
-
|
| 428 |
-
# These internal calls still use single text input for the query
|
| 429 |
-
if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
| 430 |
-
query_tokenizer = tokenizer_splade
|
| 431 |
-
query_model = model_splade
|
| 432 |
-
query_vector = get_splade_cocondenser_representation_internal([query_text], query_tokenizer, query_model)
|
| 433 |
-
elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
|
| 434 |
-
query_tokenizer = tokenizer_splade_lexical
|
| 435 |
-
query_model = model_splade_lexical
|
| 436 |
-
query_vector = get_splade_lexical_representation_internal([query_text], query_tokenizer, query_model)
|
| 437 |
-
elif query_model_choice == "SPLADE-v3-Doc (binary)":
|
| 438 |
-
query_tokenizer = tokenizer_splade_doc
|
| 439 |
-
query_model = model_splade_doc
|
| 440 |
-
query_vector = get_splade_doc_representation_internal([query_text], query_tokenizer, query_model)
|
| 441 |
-
else:
|
| 442 |
-
return "Invalid query model choice.", []
|
| 443 |
-
|
| 444 |
-
if query_vector is None:
|
| 445 |
-
return "Failed to get query representation. Check console for model loading errors.", []
|
| 446 |
-
|
| 447 |
-
# Since internal functions now return batches, take the first (and only) item for single query
|
| 448 |
-
query_vector = query_vector.squeeze(0).cpu()
|
| 449 |
-
|
| 450 |
-
scores = {}
|
| 451 |
-
for doc_id, doc_vec in document_representations.items():
|
| 452 |
-
score = torch.dot(query_vector, doc_vec).item()
|
| 453 |
-
scores[doc_id] = score
|
| 454 |
-
|
| 455 |
-
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
| 456 |
-
top_results = sorted_scores[:top_k]
|
| 457 |
-
|
| 458 |
-
formatted_output = f"Retrieval Results for Query: '{query_text}'\n"
|
| 459 |
-
formatted_output += f"Using Query Model: **{query_model_choice}**\n"
|
| 460 |
-
formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n"
|
| 461 |
-
|
| 462 |
-
if not top_results:
|
| 463 |
-
formatted_output += "No documents found or scored.\n"
|
| 464 |
-
else:
|
| 465 |
-
for i, (doc_id, score) in enumerate(top_results):
|
| 466 |
-
doc_text = document_texts.get(doc_id, "Document text not available.")
|
| 467 |
-
formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n"
|
| 468 |
-
formatted_output += f"> {doc_text[:300]}...\n\n"
|
| 469 |
-
|
| 470 |
-
return formatted_output, top_results
|
| 471 |
-
|
| 472 |
-
# --- Unified Prediction Function for Gradio (for Retrieval Tab) ---
|
| 473 |
-
def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only):
|
| 474 |
-
formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
|
| 475 |
-
return formatted_output
|
| 476 |
-
|
| 477 |
-
# --- New function to get specific retrieval examples ---
|
| 478 |
-
def get_specific_retrieval_examples():
|
| 479 |
-
if not queries_texts or not qrels_data or not document_texts:
|
| 480 |
-
return "Queries, qrels, or documents not loaded. Please check initial loading."
|
| 481 |
-
|
| 482 |
-
high_qrel_threshold = 3 # Relevance score of 3 or 4 for Cranfield is generally considered high
|
| 483 |
-
low_qrel_threshold = 1 # Relevance score of 0 or 1 for Cranfield is generally considered low
|
| 484 |
-
|
| 485 |
-
eligible_query_ids = []
|
| 486 |
-
for qid, qrels in qrels_data.items():
|
| 487 |
-
has_high_qrel = any(item['relevance'] >= high_qrel_threshold for item in qrels)
|
| 488 |
-
has_low_qrel = any(item['relevance'] <= low_qrel_threshold for item in qrels)
|
| 489 |
-
if has_high_qrel and has_low_qrel:
|
| 490 |
-
eligible_query_ids.append(qid)
|
| 491 |
-
|
| 492 |
-
if not eligible_query_ids:
|
| 493 |
-
return "Could not find a query with both high and low relevance documents in the loaded qrels."
|
| 494 |
-
|
| 495 |
-
# Pick a random eligible query
|
| 496 |
-
random_query_id = random.choice(eligible_query_ids)
|
| 497 |
-
full_query_text = queries_texts.get(random_query_id, "Query text not found.")
|
| 498 |
-
query_snippet = full_query_text[:300] + "..." if len(full_query_text) > 300 else full_query_text
|
| 499 |
-
|
| 500 |
-
qrels_for_query = qrels_data[random_query_id]
|
| 501 |
-
|
| 502 |
-
high_qrel_docs = [item for item in qrels_for_query if item['relevance'] >= high_qrel_threshold]
|
| 503 |
-
low_qrel_docs = [item for item in qrels_for_query if item['relevance'] <= low_qrel_threshold]
|
| 504 |
-
|
| 505 |
-
selected_high_doc_id = random.choice(high_qrel_docs)['doc_id'] if high_qrel_docs else None
|
| 506 |
-
selected_low_doc_id = random.choice(low_qrel_docs)['doc_id'] if low_qrel_docs else None
|
| 507 |
-
|
| 508 |
-
output_str = f"### Random Query Example\n\n"
|
| 509 |
-
output_str += f"**Query ID:** {random_query_id}\n"
|
| 510 |
-
output_str += f"**Query Snippet:** {query_snippet}\n\n" # Changed to snippet
|
| 511 |
-
|
| 512 |
-
if selected_high_doc_id:
|
| 513 |
-
full_doc_text = document_texts.get(selected_high_doc_id, "Document text not available.")
|
| 514 |
-
doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
|
| 515 |
-
output_str += f"### Highly Relevant Document (Qrel >= {high_qrel_threshold})\n"
|
| 516 |
-
output_str += f"**Document ID:** {selected_high_doc_id}\n"
|
| 517 |
-
output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
|
| 518 |
-
else:
|
| 519 |
-
output_str += "No highly relevant document found for this query.\n\n"
|
| 520 |
-
|
| 521 |
-
if selected_low_doc_id:
|
| 522 |
-
full_doc_text = document_texts.get(selected_low_doc_id, "Document text not available.")
|
| 523 |
-
doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
|
| 524 |
-
output_str += f"### Lowly Relevant Document (Qrel <= {low_qrel_threshold})\n"
|
| 525 |
-
output_str += f"**Document ID:** {selected_low_doc_id}\n"
|
| 526 |
-
output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
|
| 527 |
-
else:
|
| 528 |
-
output_str += "No lowly relevant document found for this query.\n\n"
|
| 529 |
-
|
| 530 |
-
return output_str
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# --- Initial Load and Indexing Calls ---
|
| 534 |
-
# This part runs once when the app starts.
|
| 535 |
-
load_cranfield_corpus_ir_datasets()
|
| 536 |
-
|
| 537 |
-
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
| 538 |
-
index_documents(initial_doc_model_for_indexing)
|
| 539 |
-
elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None:
|
| 540 |
-
index_documents(initial_doc_model_for_indexing)
|
| 541 |
-
elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None:
|
| 542 |
-
index_documents(initial_doc_model_for_indexing)
|
| 543 |
-
else:
|
| 544 |
-
print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.")
|
| 545 |
-
|
| 546 |
-
|
| 547 |
# --- Gradio Interface Setup with Tabs ---
|
| 548 |
with gr.Blocks(title="SPLADE Demos") as demo:
|
| 549 |
-
gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer
|
| 550 |
-
gr.Markdown("Explore different SPLADE models and their sparse representation types
|
| 551 |
|
| 552 |
with gr.Tabs():
|
| 553 |
with gr.TabItem("Sparse Representation Explorer"):
|
|
@@ -575,49 +276,4 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
| 575 |
# live=True # Setting live=True might be slow for complex models on every keystroke
|
| 576 |
)
|
| 577 |
|
| 578 |
-
with gr.TabItem("Document Retrieval Demo"):
|
| 579 |
-
gr.Markdown("### Retrieve Documents from Cranfield Collection")
|
| 580 |
-
gr.Interface(
|
| 581 |
-
fn=predict_retrieval_gradio,
|
| 582 |
-
inputs=[
|
| 583 |
-
gr.Textbox(
|
| 584 |
-
lines=3,
|
| 585 |
-
label="Enter your query text here:",
|
| 586 |
-
placeholder="e.g., Does high-dose vitamin C cure cancer?"
|
| 587 |
-
),
|
| 588 |
-
gr.Radio(
|
| 589 |
-
[
|
| 590 |
-
"SPLADE-cocondenser-distil (weighting and expansion)",
|
| 591 |
-
"SPLADE-v3-Lexical (weighting)",
|
| 592 |
-
"SPLADE-v3-Doc (binary)"
|
| 593 |
-
],
|
| 594 |
-
label="Choose Query Representation Model",
|
| 595 |
-
value="SPLADE-cocondenser-distil (weighting and expansion)"
|
| 596 |
-
),
|
| 597 |
-
gr.Radio(
|
| 598 |
-
[
|
| 599 |
-
"SPLADE-cocondenser-distil",
|
| 600 |
-
"SPLADE-v3-Lexical",
|
| 601 |
-
"SPLADE-v3-Doc"
|
| 602 |
-
],
|
| 603 |
-
label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})",
|
| 604 |
-
value=initial_doc_model_for_indexing,
|
| 605 |
-
interactive=False # This radio is fixed for simplicity
|
| 606 |
-
)
|
| 607 |
-
],
|
| 608 |
-
outputs=gr.Markdown(),
|
| 609 |
-
allow_flagging="never",
|
| 610 |
-
# live=True # retrieval is too heavy for live
|
| 611 |
-
)
|
| 612 |
-
|
| 613 |
-
gr.Markdown("---") # Separator
|
| 614 |
-
gr.Markdown("### Get Specific Retrieval Examples")
|
| 615 |
-
specific_example_output = gr.Markdown()
|
| 616 |
-
specific_example_button = gr.Button("Get Random Query with High/Low Qrel Docs")
|
| 617 |
-
specific_example_button.click(
|
| 618 |
-
fn=get_specific_retrieval_examples,
|
| 619 |
-
inputs=[],
|
| 620 |
-
outputs=specific_example_output
|
| 621 |
-
)
|
| 622 |
-
|
| 623 |
demo.launch()
|
|
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
+
from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version
|
| 6 |
+
import os # Still useful for general purpose, but not explicitly used in this simplified version
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# --- Model Loading ---
|
| 9 |
tokenizer_splade = None
|
| 10 |
model_splade = None
|
| 11 |
tokenizer_splade_lexical = None
|
|
|
|
| 46 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 47 |
|
| 48 |
|
| 49 |
+
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
| 51 |
"""
|
| 52 |
Creates a batch of lexical BOW masks.
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
| 82 |
+
# These functions take single text input for the Explorer tab
|
| 83 |
def get_splade_cocondenser_representation(text):
|
| 84 |
if tokenizer_splade is None or model_splade is None:
|
| 85 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
|
|
| 245 |
return "Please select a model."
|
| 246 |
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
# --- Gradio Interface Setup with Tabs ---
|
| 249 |
with gr.Blocks(title="SPLADE Demos") as demo:
|
| 250 |
+
gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer") # Updated title
|
| 251 |
+
gr.Markdown("Explore different SPLADE models and their sparse representation types.") # Updated description
|
| 252 |
|
| 253 |
with gr.Tabs():
|
| 254 |
with gr.TabItem("Sparse Representation Explorer"):
|
|
|
|
| 276 |
# live=True # Setting live=True might be slow for complex models on every keystroke
|
| 277 |
)
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
demo.launch()
|