Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import pandas as pd | |
| import torch | |
| from datasets import load_dataset | |
| from sentence_transformers.util import semantic_search | |
| from sentence_transformers import SentenceTransformer, util | |
| BUILDS = ['demographics300', 'uncurated3000'] | |
| # Download model | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Load embeddings | |
| dataset_embeddings_maps = {} | |
| dcid_maps = {} | |
| for build in BUILDS: | |
| print('Loading build ', build) | |
| ds = load_dataset('csv', data_files=f'embeddings_{build}.csv') | |
| df = ds["train"].to_pandas() | |
| dcid_maps[build] = df['dcid'].values.tolist() | |
| df = df.drop('dcid', axis=1) | |
| dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float) | |
| def inference(build, query): | |
| query_embeddings = model.encode([query]) | |
| # Note: multiple results may map to the same DCID. As well, the same string may | |
| hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15) | |
| # map to multiple DCIDs with the same score. | |
| sv2score = {} | |
| score2svs = {} | |
| for e in hits[0]: | |
| for d in dcid_maps[build][e['corpus_id']].split(','): | |
| s = e['score'] | |
| # Prefer the top score. | |
| if d not in sv2score: | |
| sv2score[d] = s | |
| if s not in score2svs: | |
| score2svs[s] = [d] | |
| else: | |
| score2svs[s].append(d) | |
| # Sort by scores | |
| scores = [s for s in sorted(score2svs.keys(), reverse=True)] | |
| svs = [' : '.join(score2svs[s]) for s in scores] | |
| # Addd to Pandas | |
| result = pd.DataFrame({'SV': svs, 'Cosine Score': scores}) | |
| return result | |
| # Create a simple search interface | |
| title = "DC Search Demo" | |
| description = """ | |
| Try querying for StatVars. | |
| - "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF) | |
| related to demographics | |
| - "uncurated3000": 3000 SVs with only auto-generated name related to | |
| demographics, crime, agriculture, households, housing, emissions, health | |
| """ | |
| # TODO: make logging work | |
| # HF_TOKEN = os.getenv('HF_TOKEN') | |
| # hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log") | |
| iface = gr.Interface(fn=inference, | |
| inputs=[ | |
| gr.Dropdown(choices=BUILDS, | |
| value='uncurated3000', | |
| label='Embeddings Build'), | |
| gr.Textbox(label='Query', | |
| placeholder='how long do people live?') | |
| ], | |
| outputs=gr.Dataframe(headers=['SV', 'Cosine Score'], | |
| label='Search Results'), | |
| title=title, | |
| description=description, | |
| allow_flagging="manual", | |
| flagging_options=["not at all related", | |
| "related but not ranked right"]) | |
| iface.launch() | |