Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -244,7 +244,8 @@ def predict_representation_explorer(model_choice, text):
|
|
| 244 |
else:
|
| 245 |
return "Please select a model."
|
| 246 |
|
| 247 |
-
# ---
|
|
|
|
| 248 |
def get_splade_cocondenser_vector(text):
|
| 249 |
if tokenizer_splade is None or model_splade is None:
|
| 250 |
return None
|
|
@@ -307,7 +308,8 @@ def get_splade_doc_vector(text):
|
|
| 307 |
return None
|
| 308 |
|
| 309 |
|
| 310 |
-
# ---
|
|
|
|
| 311 |
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
| 312 |
if splade_vector is None:
|
| 313 |
return "Failed to generate vector."
|
|
@@ -353,48 +355,42 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
| 353 |
return formatted_output
|
| 354 |
|
| 355 |
|
| 356 |
-
# --- NEW:
|
| 357 |
-
def
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
| 366 |
-
query_vector = get_splade_cocondenser_vector(query_text)
|
| 367 |
-
doc_vector = get_splade_cocondenser_vector(doc_text)
|
| 368 |
-
selected_tokenizer = tokenizer_splade
|
| 369 |
-
query_rep_str = "Query SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
| 370 |
-
doc_rep_str = "Document SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
| 371 |
-
is_binary = False
|
| 372 |
-
elif model_choice == "SPLADE-v3-Lexical (weighting)":
|
| 373 |
-
query_vector = get_splade_lexical_vector(query_text)
|
| 374 |
-
doc_vector = get_splade_lexical_vector(doc_text)
|
| 375 |
-
selected_tokenizer = tokenizer_splade_lexical
|
| 376 |
-
query_rep_str = "Query SPLADE-v3-Lexical Representation (Weighting):\n"
|
| 377 |
-
doc_rep_str = "Document SPLADE-v3-Lexical Representation (Weighting):\n"
|
| 378 |
-
is_binary = False
|
| 379 |
-
elif model_choice == "SPLADE-v3-Doc (binary)":
|
| 380 |
-
query_vector = get_splade_doc_vector(query_text)
|
| 381 |
-
doc_vector = get_splade_doc_vector(doc_text)
|
| 382 |
-
selected_tokenizer = tokenizer_splade_doc
|
| 383 |
-
query_rep_str = "Query SPLADE-v3-Doc Representation (Binary):\n"
|
| 384 |
-
doc_rep_str = "Document SPLADE-v3-Doc Representation (Binary):\n"
|
| 385 |
-
is_binary = True
|
| 386 |
else:
|
| 387 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
if query_vector is None or doc_vector is None:
|
| 390 |
-
return "Failed to generate one or both vectors. Please check model loading.", "", ""
|
| 391 |
|
| 392 |
# Calculate dot product
|
|
|
|
|
|
|
| 393 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 394 |
|
| 395 |
# Format representations
|
| 396 |
-
query_rep_str
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
# Combine output
|
| 400 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
@@ -437,18 +433,27 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
| 437 |
|
| 438 |
with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
|
| 439 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
| 440 |
-
gr.Markdown("Select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
gr.Interface(
|
| 442 |
-
fn=
|
| 443 |
inputs=[
|
| 444 |
gr.Radio(
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
| 452 |
),
|
| 453 |
gr.Textbox(
|
| 454 |
lines=3,
|
|
|
|
| 244 |
else:
|
| 245 |
return "Please select a model."
|
| 246 |
|
| 247 |
+
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
| 248 |
+
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
| 249 |
def get_splade_cocondenser_vector(text):
|
| 250 |
if tokenizer_splade is None or model_splade is None:
|
| 251 |
return None
|
|
|
|
| 308 |
return None
|
| 309 |
|
| 310 |
|
| 311 |
+
# --- Function to get formatted representation from a raw vector and tokenizer ---
|
| 312 |
+
# This function remains unchanged as it's a generic formatter for any sparse vector.
|
| 313 |
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
| 314 |
if splade_vector is None:
|
| 315 |
return "Failed to generate vector."
|
|
|
|
| 355 |
return formatted_output
|
| 356 |
|
| 357 |
|
| 358 |
+
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
| 359 |
+
def get_model_assets(model_choice_str):
|
| 360 |
+
if model_choice_str == "SPLADE-cocondenser-distil (weighting and expansion)":
|
| 361 |
+
return get_splade_cocondenser_vector, tokenizer_splade, False, "SPLADE-cocondenser-distil (Weighting and Expansion)"
|
| 362 |
+
elif model_choice_str == "SPLADE-v3-Lexical (weighting)":
|
| 363 |
+
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "SPLADE-v3-Lexical (Weighting)"
|
| 364 |
+
elif model_choice_str == "SPLADE-v3-Doc (binary)":
|
| 365 |
+
return get_splade_doc_vector, tokenizer_splade_doc, True, "SPLADE-v3-Doc (Binary)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
else:
|
| 367 |
+
return None, None, False, "Unknown Model"
|
| 368 |
+
|
| 369 |
+
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
|
| 370 |
+
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
|
| 371 |
+
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
|
| 372 |
+
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
|
| 373 |
+
|
| 374 |
+
if query_vector_fn is None or doc_vector_fn is None:
|
| 375 |
+
return "Please select valid models for both query and document encoding.", "", ""
|
| 376 |
+
|
| 377 |
+
query_vector = query_vector_fn(query_text)
|
| 378 |
+
doc_vector = doc_vector_fn(doc_text)
|
| 379 |
|
| 380 |
if query_vector is None or doc_vector is None:
|
| 381 |
+
return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
|
| 382 |
|
| 383 |
# Calculate dot product
|
| 384 |
+
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues
|
| 385 |
+
# and to ensure .item() works reliably for conversion to float.
|
| 386 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
| 387 |
|
| 388 |
# Format representations
|
| 389 |
+
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
| 390 |
+
query_rep_str += format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
| 391 |
+
|
| 392 |
+
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
| 393 |
+
doc_rep_str += format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
| 394 |
|
| 395 |
# Combine output
|
| 396 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
|
| 433 |
|
| 434 |
with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
|
| 435 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
| 436 |
+
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
| 437 |
+
|
| 438 |
+
# Define the common model choices for cleaner code
|
| 439 |
+
model_choices = [
|
| 440 |
+
"SPLADE-cocondenser-distil (weighting and expansion)",
|
| 441 |
+
"SPLADE-v3-Lexical (weighting)",
|
| 442 |
+
"SPLADE-v3-Doc (binary)"
|
| 443 |
+
]
|
| 444 |
+
|
| 445 |
gr.Interface(
|
| 446 |
+
fn=calculate_dot_product_and_representations_independent, # MODIFIED FUNCTION NAME
|
| 447 |
inputs=[
|
| 448 |
gr.Radio(
|
| 449 |
+
model_choices,
|
| 450 |
+
label="Choose Query Encoding Model",
|
| 451 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
|
| 452 |
+
),
|
| 453 |
+
gr.Radio(
|
| 454 |
+
model_choices,
|
| 455 |
+
label="Choose Document Encoding Model",
|
| 456 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
|
| 457 |
),
|
| 458 |
gr.Textbox(
|
| 459 |
lines=3,
|