Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| import torch | |
| # --- Model Loading --- | |
| # Load SPLADE v23 model and tokenizer | |
| # "naver/splade-v3" is a common and robust SPLADE v3 model. | |
| # Make sure you've accepted any user access agreements on its Hugging Face Hub page. | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("naver/splade-v3") | |
| model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3") | |
| model.eval() # Set the model to evaluation mode for inference | |
| print("SPLADE v23 model and tokenizer loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE model or tokenizer: {e}") | |
| print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-v3' (https://huggingface.co/naver/splade-v3).") | |
| print("If the problem persists, check your internet connection or try a different SPLADE model if available.") | |
| tokenizer = None | |
| model = None | |
| # --- Core SPLADE Representation Function --- | |
| def get_splade_representation(text): | |
| if tokenizer is None or model is None: | |
| return "SPLADE model is not loaded. Please check the console for loading errors." | |
| # Tokenize the input text | |
| # return_tensors="pt" ensures PyTorch tensors are returned. | |
| # padding=True pads to the longest sequence in the batch (though for single input, it's just the input length). | |
| # truncation=True truncates if the text is too long for the model's max input size. | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| # Move inputs to the same device as the model (e.g., CPU or GPU) | |
| # This is important if you were running on a GPU in a production environment. | |
| # For Hugging Face Spaces free tier, it's usually CPU. | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Get the model's output without calculating gradients (inference mode) | |
| with torch.no_grad(): | |
| output = model(**inputs) | |
| # Extract the logits from the model's output. | |
| # SPLADE uses the masked language modeling head's logits to derive term importance. | |
| # Apply the SPLADE aggregation function: log(1 + ReLU(logits)) | |
| # This transforms the raw logits into a sparse vector where higher values indicate more importance. | |
| # The attention_mask ensures we only consider actual tokens, not padding. | |
| # Check if 'logits' is in the output (standard for AutoModelForMaskedLM) | |
| if hasattr(output, 'logits'): | |
| # Apply the SPLADE transformation | |
| # output.logits is typically [batch_size, sequence_length, vocab_size] | |
| # We need to take the max over the sequence_length dimension to get a [batch_size, vocab_size] vector. | |
| # inputs.attention_mask.unsqueeze(-1) expands the mask to match vocab_size for element-wise multiplication. | |
| splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs.attention_mask.unsqueeze(-1), dim=1)[0].squeeze() | |
| else: | |
| # Fallback/error message if the output structure is unexpected | |
| return "Model output structure not as expected for SPLADE. 'logits' not found." | |
| # Convert the sparse vector to a human-readable format. | |
| # We only care about the non-zero (or very small) entries, as they represent activated terms. | |
| # Get the indices (token IDs) of the non-zero elements in the SPLADE vector | |
| # torch.nonzero returns coordinates of non-zero elements. squeeze() removes dimensions of size 1. | |
| # .cpu().tolist() moves the tensor to CPU and converts to a Python list. | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| # If it's a single index (e.g., a very short text), make it a list for consistent processing | |
| if not isinstance(indices, list): | |
| indices = [indices] | |
| # Get the corresponding values (weights) for these non-zero indices | |
| values = splade_vector[indices].cpu().tolist() | |
| # Create a dictionary mapping token ID to its weight | |
| token_weights = dict(zip(indices, values)) | |
| # Decode token IDs back to actual words/subwords | |
| # Filter out common special tokens that are not meaningful for retrieval (e.g., [CLS], [SEP], [PAD]) | |
| # You can add more tokens to this list if they appear frequently and are not helpful. | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer.decode([token_id]) | |
| # Filter out special tokens or very short/noisy tokens | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| # Sort the meaningful tokens by their weight in descending order | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| # Format the output for display | |
| formatted_output = "SPLADE Representation (Top 20 Terms):\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| for i, (term, weight) in enumerate(sorted_representation): | |
| if i >= 20: # Limit to top 20 terms for readability | |
| break | |
| formatted_output += f"- **{term}**: {weight:.4f}\n" | |
| formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
| formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n" # Calculate sparsity | |
| return formatted_output | |
| # --- Gradio Interface Setup --- | |
| demo = gr.Interface( | |
| fn=get_splade_representation, | |
| inputs=gr.Textbox( | |
| lines=5, | |
| label="Enter your query or document text here:", | |
| placeholder="e.g., What are the capital cities of Europe?" | |
| ), | |
| outputs=gr.Markdown(), # Use Markdown for richer text formatting (bolding terms) | |
| title="🌌 SPLADE v23 Sparse Representation Generator", | |
| description="Enter any text (query or document) to see its SPLADE v23 sparse vector representation. The output highlights the most important terms with their learned weights.", | |
| allow_flagging="never" # Disable flagging for this demo | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() |