Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pinecone | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from splade.models.transformer_rep import Splade | |
| from transformers import AutoTokenizer | |
| from datasets import load_dataset | |
| import os | |
| from pinecone import Pinecone | |
| os.environ['PINECONE_API_KEY'] = '884344f6-d820-4bc8-9edf-4157373df452' | |
| pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY')) | |
| index = pc.Index('pubmed-splade') | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # check device being run on | |
| if device != 'cuda': | |
| print("==========\n"+ | |
| "WARNING: You are not running on GPU so this may be slow.\n"+ | |
| "\n==========") | |
| dense_model = SentenceTransformer( | |
| 'msmarco-bert-base-dot-v5', | |
| device=device | |
| ) | |
| sparse_model_id = 'naver/splade-cocondenser-ensembledistil' | |
| sparse_model = Splade(sparse_model_id, agg='max') | |
| sparse_model.to(device) # moves to GPU if possible | |
| sparse_model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) | |
| data = load_dataset('Binaryy/cream_listings', split='train') | |
| df = data.to_pandas() | |
| def encode(text: str): | |
| # create dense vec | |
| dense_vec = dense_model.encode(text).tolist() | |
| # create sparse vec | |
| input_ids = tokenizer(text, return_tensors='pt') | |
| with torch.no_grad(): | |
| sparse_vec = sparse_model( | |
| d_kwargs=input_ids.to(device) | |
| )['d_rep'].squeeze() | |
| # convert to dictionary format | |
| indices = sparse_vec.nonzero().squeeze().cpu().tolist() | |
| values = sparse_vec[indices].cpu().tolist() | |
| sparse_dict = {"indices": indices, "values": values} | |
| # return vecss | |
| return dense_vec, sparse_dict | |
| def search(query): | |
| dense, sparse = encode(query) | |
| # query | |
| xc = index.query( | |
| vector=dense, | |
| sparse_vector=sparse, | |
| top_k=5, # how many results to return | |
| include_metadata=True | |
| ) | |
| match_ids = [match['id'].split('-')[0] for match in xc['matches']] | |
| # Query the existing DataFrame based on 'id' | |
| filtered_df = df[df['_id'].isin(match_ids)] | |
| attributes_to_extract = ['_id', 'title', 'location', 'features', 'description', 'images', | |
| 'videos', 'available', 'price', 'attachedDocument', 'year', | |
| 'carCondition', 'engineType', 'colour', 'model', 'noOfBed', | |
| 'noOfBathroom', 'locationISO', 'forRent', 'views', 'thoseWhoSaved', | |
| 'createdAt', 'updatedAt', '__v', 'category._id', 'category.title', | |
| 'category.slug', 'category.isAdminAllowed', 'category.createdAt', | |
| 'category.updatedAt', 'category.__v', 'postedBy.pageViews.value', | |
| 'postedBy.pageViews.users', 'postedBy.totalSaved.value', | |
| 'postedBy.totalSaved.users', 'postedBy._id', 'postedBy.firstName', | |
| 'postedBy.lastName', 'postedBy.about', 'postedBy.cover', | |
| 'postedBy.email', 'postedBy.password', 'postedBy.isAdmin', | |
| 'postedBy.savedListing', 'postedBy.isVerified', | |
| 'postedBy.verifiedProfilePicture', 'postedBy.profilePicture', | |
| 'postedBy.pronoun', 'postedBy.userType', 'postedBy.accountType', | |
| 'postedBy.subscribed', 'postedBy.noOfSubscription', | |
| 'postedBy.totalListing', 'postedBy.sellerType', 'postedBy.createdAt', | |
| 'postedBy.updatedAt', 'postedBy.__v', 'postedBy.address', | |
| 'postedBy.city', 'postedBy.country', 'postedBy.gender', | |
| 'postedBy.nationality', 'postedBy.verificationType', 'postedBy.dob', | |
| 'postedBy.locationISO', 'postedBy.state', 'postedBy.zipCode', | |
| 'postedBy.otherNames', 'postedBy.facebookUrl', 'postedBy.instagramUrl', | |
| 'postedBy.phoneNumber1', 'postedBy.phoneNumber2', 'postedBy.websiteUrl', | |
| 'postedBy.accountName', 'postedBy.accountNo', 'postedBy.bankName', | |
| 'string_features', 'complete_description'] | |
| extracted_data = filtered_df[attributes_to_extract] | |
| result_json = extracted_data.to_json(orient='records') | |
| return result_json | |
| # Create a Gradio UI | |
| iface = gr.Interface( | |
| fn=search, | |
| inputs="text", | |
| outputs="json", | |
| title="Semantic Search Prototype", | |
| description="Enter your query to search.", | |
| ) | |
| # Launch Gradio UI | |
| iface.launch(share=True) |