| import pandas as pd |
| from transformers import pipeline |
| from transformers.pipelines.base import Pipeline |
| from scipy.sparse import csr_matrix |
| from typing import Mapping, List, Tuple, Any |
| from bertopic.representation._base import BaseRepresentation |
|
|
|
|
| class ZeroShotClassification(BaseRepresentation): |
| """ Zero-shot Classification on topic keywords with candidate labels |
| |
| Arguments: |
| candidate_topics: A list of labels to assign to the topics if they |
| exceed `min_prob` |
| model: A transformers pipeline that should be initialized as |
| "zero-shot-classification". For example, |
| `pipeline("zero-shot-classification", model="facebook/bart-large-mnli")` |
| pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline |
| when it is called. NOTE: Use `{"multi_label": True}` |
| to extract multiple labels for each topic. |
| min_prob: The minimum probability to assign a candidate label to a topic |
| |
| Usage: |
| |
| ```python |
| from bertopic.representation import ZeroShotClassification |
| from bertopic import BERTopic |
| |
| # Create your representation model |
| candidate_topics = ["space and nasa", "bicycles", "sports"] |
| representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli") |
| |
| # Use the representation model in BERTopic on top of the default pipeline |
| topic_model = BERTopic(representation_model=representation_model) |
| ``` |
| """ |
| def __init__(self, |
| candidate_topics: List[str], |
| model: str = "facebook/bart-large-mnli", |
| pipeline_kwargs: Mapping[str, Any] = {}, |
| min_prob: float = 0.8 |
| ): |
| self.candidate_topics = candidate_topics |
| if isinstance(model, str): |
| self.model = pipeline("zero-shot-classification", model=model) |
| elif isinstance(model, Pipeline): |
| self.model = model |
| else: |
| raise ValueError("Make sure that the HF model that you" |
| "pass is either a string referring to a" |
| "HF model or a `transformers.pipeline` object.") |
| self.pipeline_kwargs = pipeline_kwargs |
| self.min_prob = min_prob |
|
|
| def extract_topics(self, |
| topic_model, |
| documents: pd.DataFrame, |
| c_tf_idf: csr_matrix, |
| topics: Mapping[str, List[Tuple[str, float]]] |
| ) -> Mapping[str, List[Tuple[str, float]]]: |
| """ Extract topics |
| |
| Arguments: |
| topic_model: Not used |
| documents: Not used |
| c_tf_idf: Not used |
| topics: The candidate topics as calculated with c-TF-IDF |
| |
| Returns: |
| updated_topics: Updated topic representations |
| """ |
| |
| topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()] |
| classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs) |
|
|
| |
| updated_topics = {} |
| for topic, classification in zip(topics.keys(), classifications): |
| topic_description = topics[topic] |
|
|
| |
| if self.pipeline_kwargs.get("multi_label"): |
| topic_description = [] |
| for label, score in zip(classification["labels"], classification["scores"]): |
| if score > self.min_prob: |
| topic_description.append((label, score)) |
|
|
| |
| elif classification["scores"][0] > self.min_prob: |
| topic_description = [(classification["labels"][0], classification["scores"][0])] |
|
|
| |
| if len(topic_description) == 0: |
| topic_description = topics[topic] |
| elif len(topic_description) < 10: |
| topic_description += [("", 0) for _ in range(10-len(topic_description))] |
| updated_topics[topic] = topic_description |
|
|
| return updated_topics |
|
|