Spaces:
Build error
Build error
| import os | |
| import streamlit as st | |
| from pinecone import Pinecone | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from splade.models.transformer_rep import Splade | |
| from transformers import AutoTokenizer | |
| # Title of the Streamlit App | |
| st.title("Medical Hybrid Search") | |
| # Initialize Pinecone globally | |
| index = None | |
| # Function to initialize Pinecone | |
| def initialize_pinecone(): | |
| api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable | |
| if api_key: | |
| pc = Pinecone(api_key=api_key) | |
| return pc | |
| else: | |
| st.error("Pinecone API key not found! Please set the PINECONE_API_KEY environment variable.") | |
| return None | |
| # Function to connect to the 'pubmed-splade' index | |
| def connect_to_index(pc): | |
| index_name = 'pubmed-splade' # Hardcoded index name | |
| if index_name in pc.list_indexes().names(): | |
| index = pc.Index(index_name) | |
| return index | |
| else: | |
| st.error(f"Index '{index_name}' not found!") | |
| return None | |
| # Function to encode query using sentence transformers model | |
| def encode_query(model, query_text): | |
| return model.encode(query_text).tolist() | |
| # Function to create hybrid scaled vectors | |
| def hybrid_scale(dense, sparse, alpha): | |
| if alpha < 0 or alpha > 1: | |
| raise ValueError("Alpha must be between 0 and 1") | |
| hsparse = { | |
| 'indices': sparse['indices'], | |
| 'values': [v * (1 - alpha) for v in sparse['values']] | |
| } | |
| hdense = [v * alpha for v in dense] | |
| return hdense, hsparse | |
| # Initialize Pinecone | |
| pc = initialize_pinecone() | |
| # If Pinecone initialized successfully, proceed with index management | |
| if pc: | |
| # Connect directly to 'pubmed-splade' index | |
| index = connect_to_index(pc) | |
| # Model for query encoding | |
| model = SentenceTransformer('msmarco-bert-base-dot-v5') | |
| # Initialize sparse model and tokenizer | |
| sparse_model_id = 'naver/splade-cocondenser-ensembledistil' | |
| sparse_model = Splade(sparse_model_id, agg='max') | |
| sparse_model.eval() # Set the model to evaluation mode | |
| tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) | |
| # Query input | |
| query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?") | |
| # Alpha input | |
| alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5) | |
| # Button to encode query and search the Pinecone index | |
| if st.button("Search Query"): | |
| if query_text and index: | |
| # Encode query to get dense and sparse vectors | |
| dense_vector = encode_query(model, query_text) | |
| input_ids = tokenizer(query_text, return_tensors='pt') | |
| with torch.no_grad(): | |
| sparse_vector = sparse_model(d_kwargs=input_ids.to('cpu'))['d_rep'].squeeze() | |
| # Prepare sparse vector format for Pinecone | |
| indices = sparse_vector.nonzero().squeeze().cpu().tolist() | |
| values = sparse_vector[indices].cpu().tolist() | |
| sparse_dict = {"indices": indices, "values": values} | |
| # Scale dense and sparse vectors | |
| hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha) | |
| # Search the index | |
| results = index.query( | |
| vector=hdense, | |
| sparse_vector=hsparse, | |
| top_k=3, | |
| include_metadata=True | |
| ) | |
| st.write("Results:") | |
| for match in results.matches: | |
| st.markdown(f"Score: {match.score:.4f}") | |
| st.write(f"Answer: {match.metadata.get('context', 'No context available.')}") | |
| st.write("---") | |
| else: | |
| st.error("Please enter a query and ensure the index is initialized.") | |