Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, List | |
| import datasets | |
| import pandas as pd | |
| import sentence_transformers | |
| import streamlit as st | |
| from findkit import feature_extractors, indexes, retrieval_pipeline | |
| from toolz import partial | |
| import config | |
| def truncate_description(description, length=50): | |
| return " ".join(description.split()[:length]) | |
| def get_repos_with_descriptions(repos_df, repos): | |
| return repos_df.loc[repos] | |
| def search_f( | |
| retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
| query: str, | |
| k: int, | |
| description_length: int, | |
| doc_col: List[str], | |
| ): | |
| results = retrieval_pipe.find_similar(query, k) | |
| # results['repo'] = results.index | |
| results["link"] = "https://github.com/" + results["repo"] | |
| for col in doc_col: | |
| results[col] = results[col].apply( | |
| lambda desc: truncate_description(desc, description_length) | |
| ) | |
| shown_cols = ["repo", "tasks", "link", "distance"] | |
| shown_cols = shown_cols + doc_col | |
| return results.reset_index(drop=True)[shown_cols] | |
| def show_retrieval_results( | |
| retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
| query: str, | |
| k: int, | |
| all_queries: List[str], | |
| description_length: int, | |
| repos_by_query: Dict[str, pd.DataFrame], | |
| doc_col: str, | |
| ): | |
| print("started retrieval") | |
| if query in all_queries: | |
| with st.expander( | |
| "query is in gold standard set queries. Toggle viewing gold standard results?" | |
| ): | |
| st.write("gold standard results") | |
| task_repos = repos_by_query.get_group(query) | |
| st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) | |
| with st.spinner(text="fetching results"): | |
| st.write( | |
| search_f(retrieval_pipe, query, k, description_length, doc_col).to_html( | |
| escape=False, index=False | |
| ), | |
| unsafe_allow_html=True, | |
| ) | |
| print("finished retrieval") | |
| def setup_pipeline( | |
| extractor: feature_extractors.SentenceEncoderFeatureExtractor, | |
| documents_df: pd.DataFrame, | |
| text_col: str, | |
| ): | |
| retrieval_pipeline.RetrievalPipelineFactory.build( | |
| documents_df[text_col], metadata=documents_df | |
| ) | |
| def setup_retrieval_pipeline( | |
| query_encoder_path, document_encoder_path, documents, metadata | |
| ): | |
| document_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
| sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu") | |
| ) | |
| query_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
| sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu") | |
| ) | |
| retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory( | |
| feature_extractor=document_encoder, | |
| query_feature_extractor=query_encoder, | |
| index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"), | |
| ) | |
| return retrieval_pipe.build(documents, metadata=metadata) | |
| def app(retrieval_pipeline, retrieval_df, doc_col): | |
| retrieved_results = st.sidebar.number_input("number of results", value=10) | |
| description_length = st.sidebar.number_input( | |
| "number of used description words", value=10 | |
| ) | |
| tasks_deduped = ( | |
| retrieval_df["tasks"].explode().value_counts().reset_index() | |
| ) # drop_duplicates().sort_values().reset_index(drop=True) | |
| tasks_deduped.columns = ["task", "documents per task"] | |
| with st.sidebar.expander("View test set queries"): | |
| st.table(tasks_deduped.explode("task")) | |
| additional_shown_cols = st.sidebar.multiselect( | |
| label="additional cols", options=[doc_col], default=doc_col | |
| ) | |
| repos_by_query = retrieval_df.explode("tasks").groupby("tasks") | |
| query = st.text_input("input query", value="metric learning") | |
| show_retrieval_results( | |
| retrieval_pipeline, | |
| query, | |
| retrieved_results, | |
| tasks_deduped["task"].to_list(), | |
| description_length, | |
| repos_by_query, | |
| additional_shown_cols, | |
| ) | |
| def app_main( | |
| query_encoder_path, | |
| document_encoder_path, | |
| data_path, | |
| ): | |
| print("loading data") | |
| retrieval_df = ( | |
| datasets.load_dataset(data_path)["train"] | |
| .to_pandas() | |
| .drop_duplicates(subset=["repo"]) | |
| .reset_index(drop=True) | |
| ) | |
| print("setting up retrieval_pipe") | |
| doc_col = "dependencies" | |
| retrieval_pipeline = setup_retrieval_pipeline( | |
| query_encoder_path, document_encoder_path, retrieval_df[doc_col], retrieval_df | |
| ) | |
| app(retrieval_pipeline, retrieval_df, doc_col) | |
| app_main( | |
| query_encoder_path=config.query_encoder_model_name, | |
| document_encoder_path=config.document_encoder_model_name, | |
| data_path="lambdaofgod/pwc_repositories_with_dependencies", | |
| ) | |