Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| import torch | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| import os | |
| import ir_datasets | |
| # --- Model Loading (Keep as is) --- | |
| tokenizer_splade = None | |
| model_splade = None | |
| tokenizer_splade_lexical = None | |
| model_splade_lexical = None | |
| tokenizer_splade_doc = None | |
| model_splade_doc = None | |
| # Load SPLADE v3 model (original) | |
| try: | |
| tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade.eval() | |
| print("SPLADE-cocondenser-distil model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-cocondenser-distil model: {e}") | |
| print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") | |
| # Load SPLADE v3 Lexical model | |
| try: | |
| splade_lexical_model_name = "naver/splade-v3-lexical" | |
| tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) | |
| model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) | |
| model_splade_lexical.eval() | |
| print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-v3-Lexical model: {e}") | |
| print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
| # Load SPLADE v3 Doc model | |
| try: | |
| splade_doc_model_name = "naver/splade-v3-doc" | |
| tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name) | |
| model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) | |
| model_splade_doc.eval() | |
| print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-v3-Doc model: {e}") | |
| print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
| # --- Global Variables for Document Index --- | |
| document_representations = {} # Stores {doc_id: sparse_vector} | |
| document_texts = {} # Stores {doc_id: doc_text} | |
| initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index | |
| # --- Load SciFact Corpus using ir_datasets --- | |
| def load_scifact_corpus_ir_datasets(): | |
| global document_texts | |
| print("Loading SciFact corpus using ir_datasets...") | |
| try: | |
| dataset = ir_datasets.load("scifact") | |
| for doc in tqdm(dataset.docs_iter(), desc="Loading SciFact documents"): | |
| document_texts[doc.doc_id] = doc.text.strip() | |
| print(f"Loaded {len(document_texts)} documents from SciFact corpus.") | |
| except Exception as e: | |
| print(f"Error loading SciFact corpus with ir_datasets: {e}") | |
| print("Please ensure 'ir_datasets' is installed and your internet connection is stable.") | |
| # --- Helper function for lexical mask (Keep as is) --- | |
| def create_lexical_bow_mask(input_ids, vocab_size, tokenizer): | |
| bow_mask = torch.zeros(vocab_size, device=input_ids.device) | |
| meaningful_token_ids = [] | |
| for token_id in input_ids.squeeze().tolist(): | |
| if token_id not in [ | |
| tokenizer.pad_token_id, | |
| tokenizer.cls_token_id, | |
| tokenizer.sep_token_id, | |
| tokenizer.mask_token_id, | |
| tokenizer.unk_token_id | |
| ]: | |
| meaningful_token_ids.append(token_id) | |
| if meaningful_token_ids: | |
| bow_mask[list(set(meaningful_token_ids))] = 1 | |
| return bow_mask.unsqueeze(0) | |
| # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) --- | |
| # These are your original functions, re-added. | |
| def get_splade_cocondenser_representation(text): | |
| if tokenizer_splade is None or model_splade is None: | |
| return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors." | |
| inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() | |
| else: | |
| return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found." | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| for term, weight in sorted_representation: | |
| formatted_output += f"- **{term}**: {weight:.4f}\n" | |
| formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
| formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" | |
| return formatted_output | |
| def get_splade_lexical_representation(text): | |
| if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
| return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors." | |
| inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade_lexical(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() | |
| else: | |
| return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found." | |
| # Always apply lexical mask for this model's specific behavior | |
| vocab_size = tokenizer_splade_lexical.vocab_size | |
| bow_mask = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
| ).squeeze() | |
| splade_vector = splade_vector * bow_mask | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade_lexical.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "SPLADE-v3-Lexical Representation (Weighting):\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| for term, weight in sorted_representation: | |
| formatted_output += f"- **{term}**: {weight:.4f}\n" | |
| formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
| formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n" | |
| return formatted_output | |
| def get_splade_doc_representation(text): | |
| if tokenizer_splade_doc is None or model_splade_doc is None: | |
| return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors." | |
| inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade_doc(**inputs) | |
| if not hasattr(output, "logits"): | |
| return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found." | |
| vocab_size = tokenizer_splade_doc.vocab_size | |
| binary_splade_vector = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_doc | |
| ).squeeze() | |
| indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = [1.0] * len(indices) # All values are 1 for binary representation | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade_doc.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity | |
| formatted_output = "SPLADE-v3-Doc Representation (Binary):\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| for i, (term, _) in enumerate(sorted_representation): | |
| if i >= 50: # Limit display for very long lists | |
| formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n" | |
| break | |
| formatted_output += f"- **{term}**\n" | |
| formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n" | |
| formatted_output += f"Total activated terms: {len(indices)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n" | |
| return formatted_output | |
| # --- Unified Prediction Function for the Explorer Tab --- | |
| def predict_representation_explorer(model_choice, text): | |
| if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)": | |
| return get_splade_cocondenser_representation(text) | |
| elif model_choice == "SPLADE-v3-Lexical (weighting)": | |
| return get_splade_lexical_representation(text) | |
| elif model_choice == "SPLADE-v3-Doc (binary)": | |
| return get_splade_doc_representation(text) | |
| else: | |
| return "Please select a model." | |
| # --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) --- | |
| # These are the ones ending with _internal, as previously defined. | |
| def get_splade_cocondenser_representation_internal(text, tokenizer, model): | |
| if tokenizer is None or model is None: return None | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): output = model(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze() | |
| return splade_vector | |
| else: | |
| print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.") | |
| return None | |
| def get_splade_lexical_representation_internal(text, tokenizer, model): | |
| if tokenizer is None or model is None: return None | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): output = model(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze() | |
| vocab_size = tokenizer.vocab_size | |
| bow_mask = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze() | |
| splade_vector = splade_vector * bow_mask | |
| return splade_vector | |
| else: | |
| print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.") | |
| return None | |
| def get_splade_doc_representation_internal(text, tokenizer, model): | |
| if tokenizer is None or model is None: return None | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| vocab_size = tokenizer.vocab_size | |
| binary_splade_vector = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze() | |
| return binary_splade_vector | |
| # --- Document Indexing Function (for Retrieval Tab) --- | |
| def index_documents(doc_model_choice): | |
| global document_representations | |
| if document_representations: | |
| print("Documents already indexed. Skipping re-indexing.") | |
| return True | |
| tokenizer_to_use = None | |
| model_to_use = None | |
| representation_func_to_use = None | |
| if doc_model_choice == "SPLADE-cocondenser-distil": | |
| if tokenizer_splade is None or model_splade is None: | |
| print("SPLADE-cocondenser-distil model not loaded for indexing.") | |
| return False | |
| tokenizer_to_use = tokenizer_splade | |
| model_to_use = model_splade | |
| representation_func_to_use = get_splade_cocondenser_representation_internal | |
| elif doc_model_choice == "SPLADE-v3-Lexical": | |
| if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
| print("SPLADE-v3-Lexical model not loaded for indexing.") | |
| return False | |
| tokenizer_to_use = tokenizer_splade_lexical | |
| model_to_use = model_splade_lexical | |
| representation_func_to_use = get_splade_lexical_representation_internal | |
| elif doc_model_choice == "SPLADE-v3-Doc": | |
| if tokenizer_splade_doc is None or model_splade_doc is None: | |
| print("SPLADE-v3-Doc model not loaded for indexing.") | |
| return False | |
| tokenizer_to_use = tokenizer_splade_doc | |
| model_to_use = model_splade_doc | |
| representation_func_to_use = get_splade_doc_representation_internal | |
| else: | |
| print(f"Invalid model choice for document indexing: {doc_model_choice}") | |
| return False | |
| print(f"Indexing documents using {doc_model_choice}...") | |
| doc_items = list(document_texts.items()) | |
| for doc_id, doc_text in tqdm(doc_items, desc="Indexing Documents"): | |
| sparse_vector = representation_func_to_use(doc_text, tokenizer_to_use, model_to_use) | |
| if sparse_vector is not None: | |
| document_representations[doc_id] = sparse_vector.cpu() | |
| else: | |
| print(f"Warning: Failed to get representation for doc_id {doc_id}") | |
| print(f"Finished indexing {len(document_representations)} documents.") | |
| return True | |
| # --- Retrieval Function (for Retrieval Tab) --- | |
| def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5): | |
| if not document_representations: | |
| return "Document index is not loaded or empty. Please ensure documents are indexed.", [] | |
| query_vector = None | |
| query_tokenizer = None | |
| query_model = None | |
| if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)": | |
| query_tokenizer = tokenizer_splade | |
| query_model = model_splade | |
| query_vector = get_splade_cocondenser_representation_internal(query_text, query_tokenizer, query_model) | |
| elif query_model_choice == "SPLADE-v3-Lexical (weighting)": | |
| query_tokenizer = tokenizer_splade_lexical | |
| query_model = model_splade_lexical | |
| query_vector = get_splade_lexical_representation_internal(query_text, query_tokenizer, query_model) | |
| elif query_model_choice == "SPLADE-v3-Doc (binary)": | |
| query_tokenizer = tokenizer_splade_doc | |
| query_model = model_splade_doc | |
| query_vector = get_splade_doc_representation_internal(query_text, query_tokenizer, query_model) | |
| else: | |
| return "Invalid query model choice.", [] | |
| if query_vector is None: | |
| return "Failed to get query representation. Check console for model loading errors.", [] | |
| query_vector = query_vector.cpu() | |
| scores = {} | |
| for doc_id, doc_vec in document_representations.items(): | |
| score = torch.dot(query_vector, doc_vec).item() | |
| scores[doc_id] = score | |
| sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) | |
| top_results = sorted_scores[:top_k] | |
| formatted_output = f"Retrieval Results for Query: '{query_text}'\n" | |
| formatted_output += f"Using Query Model: **{query_model_choice}**\n" | |
| formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n" | |
| if not top_results: | |
| formatted_output += "No documents found or scored.\n" | |
| else: | |
| for i, (doc_id, score) in enumerate(top_results): | |
| doc_text = document_texts.get(doc_id, "Document text not available.") | |
| formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n" | |
| formatted_output += f"> {doc_text[:300]}...\n\n" | |
| return formatted_output, top_results | |
| # --- Unified Prediction Function for Gradio (for Retrieval Tab) --- | |
| def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only): | |
| formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5) | |
| return formatted_output | |
| # --- Initial Load and Indexing Calls --- | |
| # This part runs once when the app starts. | |
| load_scifact_corpus_ir_datasets() # Or load_cranfield_corpus_ir_datasets() if you switch back | |
| if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None: | |
| index_documents(initial_doc_model_for_indexing) | |
| elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None: | |
| index_documents(initial_doc_model_for_indexing) | |
| elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None: | |
| index_documents(initial_doc_model_for_indexing) | |
| else: | |
| print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.") | |
| # --- Gradio Interface Setup with Tabs --- | |
| with gr.Blocks(title="SPLADE Demos") as demo: | |
| gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer & Document Retrieval") | |
| gr.Markdown("Explore different SPLADE models and their sparse representation types, or perform document retrieval on a test collection.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Sparse Representation Explorer"): | |
| gr.Markdown("### Explore Raw SPLADE Representations for Any Text") | |
| gr.Interface( | |
| fn=predict_representation_explorer, | |
| inputs=[ | |
| gr.Radio( | |
| [ | |
| "SPLADE-cocondenser-distil (weighting and expansion)", | |
| "SPLADE-v3-Lexical (weighting)", | |
| "SPLADE-v3-Doc (binary)" | |
| ], | |
| label="Choose Representation Model", | |
| value="SPLADE-cocondenser-distil (weighting and expansion)" | |
| ), | |
| gr.Textbox( | |
| lines=5, | |
| label="Enter your query or document text here:", | |
| placeholder="e.g., Why is Padua the nicest city in Italy?" | |
| ) | |
| ], | |
| outputs=gr.Markdown(), | |
| allow_flagging="never", | |
| # Don't show redundant title/description within the tab, as it's above | |
| # Setting live=True might be slow for complex models on every keystroke | |
| # live=True | |
| ) | |
| with gr.TabItem("Document Retrieval Demo"): | |
| gr.Markdown("### Retrieve Documents from SciFact Collection") | |
| gr.Interface( | |
| fn=predict_retrieval_gradio, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=3, | |
| label="Enter your query text here:", | |
| placeholder="e.g., Does high-dose vitamin C cure cancer?" | |
| ), | |
| gr.Radio( | |
| [ | |
| "SPLADE-cocondenser-distil (weighting and expansion)", | |
| "SPLADE-v3-Lexical (weighting)", | |
| "SPLADE-v3-Doc (binary)" | |
| ], | |
| label="Choose Query Representation Model", | |
| value="SPLADE-cocondenser-distil (weighting and expansion)" | |
| ), | |
| gr.Radio( | |
| [ | |
| "SPLADE-cocondenser-distil", | |
| "SPLADE-v3-Lexical", | |
| "SPLADE-v3-Doc" | |
| ], | |
| label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})", | |
| value=initial_doc_model_for_indexing, | |
| interactive=False # This radio is fixed for simplicity | |
| ) | |
| ], | |
| outputs=gr.Markdown(), | |
| allow_flagging="never", | |
| # live=True # retrieval is too heavy for live | |
| ) | |
| demo.launch() |