| from sentence_transformers import SentenceTransformer |
| from sentence_transformers import util |
| import torch |
| import json |
| import os |
|
|
| class PipelineWrapper: |
|
|
| """This class is a wrapper for classifying gov datatset titles into the musterdatenkatalog taxonomy. |
| It uses the sentence-transformers library to encode the text into embeddings and then uses semantic search. |
| """ |
|
|
| def __init__(self, path=""): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = SentenceTransformer(path, device=device, use_auth_token=True) |
| self.taxonomy = os.path.join(path, "taxonomy.json") |
| self.taxonomy_labels = None |
| self.taxonomy_embeddings = None |
| self.load_taxonomy_labels() |
| self.get_taxonomy_embeddings() |
|
|
| def __call__(self, queries: list) -> list: |
| return self.predict(queries) |
|
|
| def load_taxonomy_labels(self) -> None: |
| with open(self.taxonomy, "r") as f: |
| taxonomy = json.load(f) |
| self.taxonomy_labels = [el["group"] + " - " + el["label"] for el in taxonomy] |
| self.taxonomy_labels.remove("Sonstiges - Sonstiges") |
|
|
| def get_taxonomy_embeddings(self) -> None: |
| self.taxonomy_embeddings = self.model.encode( |
| self.taxonomy_labels, convert_to_tensor=True |
| ) |
|
|
| def predict(self, queries: list) -> list: |
| """Predicts the taxonomy labels for the given queries. |
| |
| Parameters |
| ---------- |
| queries : list |
| List of queries to predict. Format is a list of dictionaries with the following keys: "id", "title" |
| |
| Returns |
| ------- |
| list |
| List of dictionaries with the following keys: "id", "title", "prediction" |
| """ |
| texts = [el["title"] for el in queries] |
| query_embeddings = self.model.encode(texts, convert_to_tensor=True) |
| predictions = util.semantic_search( |
| query_embeddings=query_embeddings, |
| corpus_embeddings=self.taxonomy_embeddings, |
| top_k=1, |
| ) |
|
|
| results = [] |
|
|
| for query, prediction in zip(queries, predictions): |
| results.append( |
| { |
| "id": query["id"], |
| "title": query["title"], |
| "prediction": self.taxonomy_labels[prediction[0]["corpus_id"]], |
| } |
| ) |
|
|
| return results |
|
|