Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel | |
| import torch | |
| # --- Model Loading --- | |
| tokenizer_splade = None | |
| model_splade = None | |
| tokenizer_unicoil = None | |
| model_unicoil = None | |
| # Load SPLADE v3 model | |
| try: | |
| tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade.eval() # Set to evaluation mode for inference | |
| print("SPLADE v3 model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE model: {e}") | |
| print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") | |
| # Load UNICOIL model for binary sparse encoding | |
| # Load UNICOIL model for binary sparse encoding | |
| try: | |
| unicoil_model_name = "castorini/unicoil-msmarco-passage" | |
| tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name) | |
| # --- FIX IS HERE --- | |
| model_unicoil = AutoModelForMaskedLM.from_pretrained(unicoil_model_name) | |
| # ------------------- | |
| model_unicoil.eval() # Set to evaluation mode for inference | |
| print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading UNICOIL model: {e}") | |
| print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
| # --- Core Representation Functions --- | |
| def get_splade_representation(text): | |
| if tokenizer_splade is None or model_splade is None: | |
| return "SPLADE model is not loaded. Please check the console for loading errors." | |
| inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() | |
| else: | |
| return "Model output structure not as expected for SPLADE. 'logits' not found." | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "SPLADE Representation (All Non-Zero Terms):\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| for term, weight in sorted_representation: | |
| formatted_output += f"- **{term}**: {weight:.4f}\n" | |
| formatted_output += "\n--- Raw SPLADE Vector Info ---\n" | |
| formatted_output += f"Total non-zero terms in vector: {len(indices)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n" | |
| return formatted_output | |
| def get_unicoil_binary_representation(text): | |
| if tokenizer_unicoil is None or model_unicoil is None: | |
| return "UNICOIL model is not loaded. Please check the console for loading errors." | |
| inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_unicoil(**inputs) | |
| if not hasattr(output, "logits"): | |
| return "UNICOIL model output structure not as expected. 'logits' not found." | |
| logits = output.logits.squeeze(0) # [seq_len, vocab_size] | |
| token_ids = input_ids.squeeze(0) # [seq_len] | |
| mask = attention_mask.squeeze(0) # [seq_len] | |
| transformed_scores = torch.log(1 + torch.exp(logits)) # softplus | |
| token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens | |
| token_scores = token_scores * mask # mask out padding | |
| # Binarize: threshold scores > 0.5 (tune as needed) | |
| binary_mask = (token_scores > 0.5) | |
| activated_token_ids = token_ids[binary_mask].cpu().tolist() | |
| # Map token ids to strings | |
| binary_terms = {} | |
| for token_id in activated_token_ids: | |
| decoded_token = tokenizer_unicoil.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| binary_terms[decoded_token] = 1 | |
| sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) | |
| formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n" | |
| if not sorted_binary_terms: | |
| formatted_output += "No significant terms activated for this input.\n" | |
| else: | |
| for i, (term, _) in enumerate(sorted_binary_terms): | |
| if i >= 50: | |
| formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n" | |
| break | |
| formatted_output += f"- **{term}**\n" | |
| formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n" | |
| formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n" | |
| formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n" | |
| return formatted_output | |
| # --- Unified Prediction Function for Gradio --- | |
| def predict_representation(model_choice, text): | |
| if model_choice == "SPLADE": | |
| return get_splade_representation(text) | |
| elif model_choice == "UNICOIL (Binary Sparse)": | |
| return get_unicoil_binary_representation(text) | |
| else: | |
| return "Please select a model." | |
| # --- Gradio Interface Setup --- | |
| demo = gr.Interface( | |
| fn=predict_representation, | |
| inputs=[ | |
| gr.Radio( | |
| ["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option | |
| label="Choose Representation Model", | |
| value="SPLADE" # Default selection | |
| ), | |
| gr.Textbox( | |
| lines=5, | |
| label="Enter your query or document text here:", | |
| placeholder="e.g., Why is Padua the nicest city in Italy?" | |
| ) | |
| ], | |
| outputs=gr.Markdown(), | |
| title="🌌 Sparse and Binary Sparse Representation Generator", | |
| description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.", | |
| allow_flagging="never" | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() |