Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, List | |
| from dataclasses import dataclass | |
| import datasets | |
| import ast | |
| import pandas as pd | |
| import sentence_transformers | |
| import streamlit as st | |
| from findkit import feature_extractors, indexes, retrieval_pipeline | |
| from toolz import partial | |
| import config | |
| def get_doc_cols(model_name): | |
| model_name = model_name.replace("query-", "") | |
| model_name = model_name.replace("document-", "") | |
| return model_name.split("-")[0].split("_") | |
| def merge_cols(df, cols): | |
| df["document"] = df[cols[0]] | |
| for col in cols: | |
| df["document"] = df["document"] + " " + df[col] | |
| return df | |
| def get_retrieval_df( | |
| data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None | |
| ): | |
| raw_retrieval_df = ( | |
| datasets.load_dataset(data_path)["train"] | |
| .to_pandas() | |
| .drop_duplicates(subset=["repo"]) | |
| .reset_index(drop=True) | |
| ) | |
| if text_list_cols: | |
| return merge_text_list_cols(raw_retrieval_df, text_list_cols) | |
| return raw_retrieval_df | |
| def truncate_description(description, length=50): | |
| return " ".join(description.split()[:length]) | |
| def get_repos_with_descriptions(repos_df, repos): | |
| return repos_df.loc[repos] | |
| def merge_text_list_cols(retrieval_df, text_list_cols): | |
| retrieval_df = retrieval_df.copy() | |
| for col in text_list_cols: | |
| retrieval_df[col] = retrieval_df[col].apply( | |
| lambda t: " ".join(ast.literal_eval(t)) | |
| ) | |
| return retrieval_df | |
| class RetrievalPipelineWrapper: | |
| pipeline: retrieval_pipeline.RetrievalPipeline | |
| def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata): | |
| retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory( | |
| feature_extractor=document_encoder, | |
| query_feature_extractor=query_encoder, | |
| index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"), | |
| ) | |
| pipeline = retrieval_pipe.build(documents, metadata=metadata) | |
| return RetrievalPipelineWrapper(pipeline) | |
| def search( | |
| self, | |
| query: str, | |
| k: int, | |
| description_length: int, | |
| additional_shown_cols: List[str], | |
| ): | |
| results = self.pipeline.find_similar(query, k) | |
| # results['repo'] = results.index | |
| results["link"] = "https://github.com/" + results["repo"] | |
| for col in additional_shown_cols: | |
| results[col] = results[col].apply( | |
| lambda desc: truncate_description(desc, description_length) | |
| ) | |
| shown_cols = ["repo", "tasks", "link", "distance"] | |
| shown_cols = shown_cols + additional_shown_cols | |
| return results.reset_index(drop=True)[shown_cols] | |
| def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device | |
| ): | |
| document_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
| sentence_transformers.SentenceTransformer( | |
| document_encoder_path, device=device | |
| ) | |
| ) | |
| query_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
| sentence_transformers.SentenceTransformer(query_encoder_path, device=device) | |
| ) | |
| return cls.build_from_encoders( | |
| query_encoder, document_encoder, documents, metadata | |
| ) | |