dc_statvar_demo / app.py
Prashanth Radhakrishnan
Deploy to HF
fc54c76
import gradio as gr
import os
import pandas as pd
import torch
from datasets import load_dataset
from sentence_transformers.util import semantic_search
from sentence_transformers import SentenceTransformer, util
BUILDS = ['demographics300', 'uncurated3000']
# Download model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Load embeddings
dataset_embeddings_maps = {}
dcid_maps = {}
for build in BUILDS:
print('Loading build ', build)
ds = load_dataset('csv', data_files=f'embeddings_{build}.csv')
df = ds["train"].to_pandas()
dcid_maps[build] = df['dcid'].values.tolist()
df = df.drop('dcid', axis=1)
dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float)
def inference(build, query):
query_embeddings = model.encode([query])
# Note: multiple results may map to the same DCID. As well, the same string may
hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15)
# map to multiple DCIDs with the same score.
sv2score = {}
score2svs = {}
for e in hits[0]:
for d in dcid_maps[build][e['corpus_id']].split(','):
s = e['score']
# Prefer the top score.
if d not in sv2score:
sv2score[d] = s
if s not in score2svs:
score2svs[s] = [d]
else:
score2svs[s].append(d)
# Sort by scores
scores = [s for s in sorted(score2svs.keys(), reverse=True)]
svs = [' : '.join(score2svs[s]) for s in scores]
# Addd to Pandas
result = pd.DataFrame({'SV': svs, 'Cosine Score': scores})
return result
# Create a simple search interface
title = "DC Search Demo"
description = """
Try querying for StatVars.
- "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF)
related to demographics
- "uncurated3000": 3000 SVs with only auto-generated name related to
demographics, crime, agriculture, households, housing, emissions, health
"""
# TODO: make logging work
# HF_TOKEN = os.getenv('HF_TOKEN')
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log")
iface = gr.Interface(fn=inference,
inputs=[
gr.Dropdown(choices=BUILDS,
value='uncurated3000',
label='Embeddings Build'),
gr.Textbox(label='Query',
placeholder='how long do people live?')
],
outputs=gr.Dataframe(headers=['SV', 'Cosine Score'],
label='Search Results'),
title=title,
description=description,
allow_flagging="manual",
flagging_options=["not at all related",
"related but not ranked right"])
iface.launch()