Spaces:
Runtime error
Runtime error
| from typing import Dict, List | |
| import torch | |
| import pandas as pd | |
| import streamlit as st | |
| from findkit import retrieval_pipeline | |
| import config | |
| from search_utils import ( | |
| RetrievalPipelineWrapper, | |
| get_doc_cols, | |
| get_repos_with_descriptions, | |
| get_retrieval_df, | |
| merge_cols, | |
| ) | |
| class RetrievalApp: | |
| def is_cuda_available(self): | |
| try: | |
| t = torch.Tensor([1]).cuda() | |
| except: | |
| return False | |
| finally: | |
| return True | |
| def get_device_options(self): | |
| if self.is_cuda_available(): | |
| return ["cuda", "cpu"] | |
| else: | |
| return ["cpu"] | |
| def get_retrieval_df(self): | |
| return get_retrieval_df(self.data_path, config.text_list_cols) | |
| def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"): | |
| self.data_path = data_path | |
| self.device = st.sidebar.selectbox("device", self.get_device_options()) | |
| print("loading data") | |
| self.retrieval_df = self.get_retrieval_df().copy() | |
| model_name = st.sidebar.selectbox("model", config.model_names) | |
| self.query_encoder_name = "lambdaofgod/query-" + model_name | |
| self.document_encoder_name = "lambdaofgod/document-" + model_name | |
| doc_cols = get_doc_cols(model_name) | |
| st.sidebar.text("using models") | |
| st.sidebar.text("https://huggingface.co/" + self.query_encoder_name) | |
| st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name) | |
| self.additional_shown_cols = st.sidebar.multiselect( | |
| label="used text features", options=config.text_cols, default=doc_cols | |
| ) | |
| def show_retrieval_results( | |
| retrieval_pipe: RetrievalPipelineWrapper, | |
| query: str, | |
| k: int, | |
| all_queries: List[str], | |
| description_length: int, | |
| repos_by_query: Dict[str, pd.DataFrame], | |
| additional_shown_cols: List[str], | |
| ): | |
| print("started retrieval") | |
| if query in all_queries: | |
| with st.expander( | |
| "query is in gold standard set queries. Toggle viewing gold standard results?" | |
| ): | |
| st.write("gold standard results") | |
| task_repos = repos_by_query.get_group(query) | |
| st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) | |
| with st.spinner(text="fetching results"): | |
| st.write( | |
| retrieval_pipe.search( | |
| query, k, description_length, additional_shown_cols | |
| ).to_html(escape=False, index=False), | |
| unsafe_allow_html=True, | |
| ) | |
| print("finished retrieval") | |
| def run_app(self, retrieval_pipeline): | |
| retrieved_results = st.sidebar.number_input("number of results", value=10) | |
| description_length = st.sidebar.number_input( | |
| "number of used description words", value=10 | |
| ) | |
| tasks_deduped = ( | |
| self.retrieval_df["tasks"].explode().value_counts().reset_index() | |
| ) # drop_duplicates().sort_values().reset_index(drop=True) | |
| tasks_deduped.columns = ["task", "documents per task"] | |
| with st.sidebar.expander("View test set queries"): | |
| st.table(tasks_deduped.explode("task")) | |
| repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks") | |
| query = st.text_input("input query", value="metric learning") | |
| RetrievalApp.show_retrieval_results( | |
| retrieval_pipeline, | |
| query, | |
| retrieved_results, | |
| tasks_deduped["task"].to_list(), | |
| description_length, | |
| repos_by_query, | |
| self.additional_shown_cols, | |
| ) | |
| def get_retrieval_pipeline(self, displayed_retrieval_df): | |
| return RetrievalPipelineWrapper.setup_from_encoder_names( | |
| self.query_encoder_name, | |
| self.document_encoder_name, | |
| displayed_retrieval_df["document"], | |
| displayed_retrieval_df, | |
| device=self.device, | |
| ) | |
| def main(self): | |
| print("setting up retrieval_pipe") | |
| displayed_retrieval_df = merge_cols( | |
| self.retrieval_df.copy(), self.additional_shown_cols | |
| ) | |
| retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df) | |
| self.run_app(retrieval_pipeline) | |