|
|
import pandas as pd |
|
|
from transformers import pipeline |
|
|
from transformers.pipelines.base import Pipeline |
|
|
from scipy.sparse import csr_matrix |
|
|
from typing import Mapping, List, Tuple, Any |
|
|
from bertopic.representation._base import BaseRepresentation |
|
|
|
|
|
|
|
|
class ZeroShotClassification(BaseRepresentation): |
|
|
""" Zero-shot Classification on topic keywords with candidate labels |
|
|
|
|
|
Arguments: |
|
|
candidate_topics: A list of labels to assign to the topics if they |
|
|
exceed `min_prob` |
|
|
model: A transformers pipeline that should be initialized as |
|
|
"zero-shot-classification". For example, |
|
|
`pipeline("zero-shot-classification", model="facebook/bart-large-mnli")` |
|
|
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline |
|
|
when it is called. NOTE: Use `{"multi_label": True}` |
|
|
to extract multiple labels for each topic. |
|
|
min_prob: The minimum probability to assign a candidate label to a topic |
|
|
|
|
|
Usage: |
|
|
|
|
|
```python |
|
|
from bertopic.representation import ZeroShotClassification |
|
|
from bertopic import BERTopic |
|
|
|
|
|
# Create your representation model |
|
|
candidate_topics = ["space and nasa", "bicycles", "sports"] |
|
|
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli") |
|
|
|
|
|
# Use the representation model in BERTopic on top of the default pipeline |
|
|
topic_model = BERTopic(representation_model=representation_model) |
|
|
``` |
|
|
""" |
|
|
def __init__(self, |
|
|
candidate_topics: List[str], |
|
|
model: str = "facebook/bart-large-mnli", |
|
|
pipeline_kwargs: Mapping[str, Any] = {}, |
|
|
min_prob: float = 0.8 |
|
|
): |
|
|
self.candidate_topics = candidate_topics |
|
|
if isinstance(model, str): |
|
|
self.model = pipeline("zero-shot-classification", model=model) |
|
|
elif isinstance(model, Pipeline): |
|
|
self.model = model |
|
|
else: |
|
|
raise ValueError("Make sure that the HF model that you" |
|
|
"pass is either a string referring to a" |
|
|
"HF model or a `transformers.pipeline` object.") |
|
|
self.pipeline_kwargs = pipeline_kwargs |
|
|
self.min_prob = min_prob |
|
|
|
|
|
def extract_topics(self, |
|
|
topic_model, |
|
|
documents: pd.DataFrame, |
|
|
c_tf_idf: csr_matrix, |
|
|
topics: Mapping[str, List[Tuple[str, float]]] |
|
|
) -> Mapping[str, List[Tuple[str, float]]]: |
|
|
""" Extract topics |
|
|
|
|
|
Arguments: |
|
|
topic_model: Not used |
|
|
documents: Not used |
|
|
c_tf_idf: Not used |
|
|
topics: The candidate topics as calculated with c-TF-IDF |
|
|
|
|
|
Returns: |
|
|
updated_topics: Updated topic representations |
|
|
""" |
|
|
|
|
|
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()] |
|
|
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs) |
|
|
|
|
|
|
|
|
updated_topics = {} |
|
|
for topic, classification in zip(topics.keys(), classifications): |
|
|
topic_description = topics[topic] |
|
|
|
|
|
|
|
|
if self.pipeline_kwargs.get("multi_label"): |
|
|
topic_description = [] |
|
|
for label, score in zip(classification["labels"], classification["scores"]): |
|
|
if score > self.min_prob: |
|
|
topic_description.append((label, score)) |
|
|
|
|
|
|
|
|
elif classification["scores"][0] > self.min_prob: |
|
|
topic_description = [(classification["labels"][0], classification["scores"][0])] |
|
|
|
|
|
|
|
|
if len(topic_description) == 0: |
|
|
topic_description = topics[topic] |
|
|
elif len(topic_description) < 10: |
|
|
topic_description += [("", 0) for _ in range(10-len(topic_description))] |
|
|
updated_topics[topic] = topic_description |
|
|
|
|
|
return updated_topics |
|
|
|