| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers import util | |
| import torch | |
| import json | |
| import os | |
| class PipelineWrapper: | |
| """This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy. | |
| It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search. | |
| """ | |
| def __init__(self, path=""): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = SentenceTransformer(path, device=device, use_auth_token=True) | |
| self.taxonomy = os.path.join(path, "taxonomy.json") | |
| self.taxonomy_labels = None | |
| self.taxonomy_embeddings = None | |
| self.load_taxonomy_labels() | |
| self.get_taxonomy_embeddings() | |
| def __call__(self, queries: list) -> list: | |
| return self.predict(queries) | |
| def load_taxonomy_labels(self) -> None: | |
| with open(self.taxonomy, "r") as f: | |
| taxonomy = json.load(f) | |
| self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] | |
| self.taxonomy_labels.remove("Sonstiges - Sonstiges") | |
| def get_taxonomy_embeddings(self) -> None: | |
| self.taxonomy_embeddings = self.model.encode( | |
| self.taxonomy_labels, convert_to_tensor=True | |
| ) | |
| def predict(self, queries: list) -> list: | |
| """Predicts the taxonomy labels for the given queries. | |
| Parameters | |
| ---------- | |
| queries : list | |
| List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title" | |
| Returns | |
| ------- | |
| list | |
| List of dictionaries with the following keys: "id", "title", "prediction" | |
| """ | |
| texts = [el["title"] for el in queries] | |
| query_embeddings = self.model.encode(texts, convert_to_tensor=True) | |
| predictions = util.semantic_search( | |
| query_embeddings=query_embeddings, | |
| corpus_embeddings=self.taxonomy_embeddings, | |
| top_k=1, | |
| ) | |
| results = [] | |
| for query, prediction in zip(queries, predictions): | |
| results.append( | |
| { | |
| "id": query["id"], | |
| "title": query["title"], | |
| "prediction": self.taxonomy_labels[prediction[0]["corpus_id"]], | |
| } | |
| ) | |
| return results | |