| | import pandas as pd |
| | from tqdm import tqdm |
| | from scipy.sparse import csr_matrix |
| | from transformers import pipeline, set_seed |
| | from transformers.pipelines.base import Pipeline |
| | from typing import Mapping, List, Tuple, Any, Union, Callable |
| | from bertopic.representation._base import BaseRepresentation |
| | from bertopic.representation._utils import truncate_document |
| |
|
| |
|
| | DEFAULT_PROMPT = """ |
| | I have a topic described by the following keywords: [KEYWORDS]. |
| | The name of this topic is: |
| | """ |
| |
|
| |
|
| | class TextGeneration(BaseRepresentation): |
| | """ Text2Text or text generation with transformers |
| | |
| | Arguments: |
| | model: A transformers pipeline that should be initialized as "text-generation" |
| | for gpt-like models or "text2text-generation" for T5-like models. |
| | For example, `pipeline('text-generation', model='gpt2')`. If a string |
| | is passed, "text-generation" will be selected by default. |
| | prompt: The prompt to be used in the model. If no prompt is given, |
| | `self.default_prompt_` is used instead. |
| | NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt |
| | to decide where the keywords and documents need to be |
| | inserted. |
| | pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline |
| | when it is called. |
| | random_state: A random state to be passed to `transformers.set_seed` |
| | nr_docs: The number of documents to pass to OpenAI if a prompt |
| | with the `["DOCUMENTS"]` tag is used. |
| | diversity: The diversity of documents to pass to OpenAI. |
| | Accepts values between 0 and 1. A higher |
| | values results in passing more diverse documents |
| | whereas lower values passes more similar documents. |
| | doc_length: The maximum length of each document. If a document is longer, |
| | it will be truncated. If None, the entire document is passed. |
| | tokenizer: The tokenizer used to calculate to split the document into segments |
| | used to count the length of a document. |
| | * If tokenizer is 'char', then the document is split up |
| | into characters which are counted to adhere to `doc_length` |
| | * If tokenizer is 'whitespace', the document is split up |
| | into words separated by whitespaces. These words are counted |
| | and truncated depending on `doc_length` |
| | * If tokenizer is 'vectorizer', then the internal CountVectorizer |
| | is used to tokenize the document. These tokens are counted |
| | and trunctated depending on `doc_length` |
| | * If tokenizer is a callable, then that callable is used to tokenize |
| | the document. These tokens are counted and truncated depending |
| | on `doc_length` |
| | |
| | Usage: |
| | |
| | To use a gpt-like model: |
| | |
| | ```python |
| | from bertopic.representation import TextGeneration |
| | from bertopic import BERTopic |
| | |
| | # Create your representation model |
| | generator = pipeline('text-generation', model='gpt2') |
| | representation_model = TextGeneration(generator) |
| | |
| | # Use the representation model in BERTopic on top of the default pipeline |
| | topic_model = BERTo pic(representation_model=representation_model) |
| | ``` |
| | |
| | You can use a custom prompt and decide where the keywords should |
| | be inserted by using the `[KEYWORDS]` or documents with thte `[DOCUMENTS]` tag: |
| | |
| | ```python |
| | from bertopic.representation import TextGeneration |
| | |
| | prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?"" |
| | |
| | # Create your representation model |
| | generator = pipeline('text2text-generation', model='google/flan-t5-base') |
| | representation_model = TextGeneration(generator) |
| | ``` |
| | """ |
| | def __init__(self, |
| | model: Union[str, pipeline], |
| | prompt: str = None, |
| | pipeline_kwargs: Mapping[str, Any] = {}, |
| | random_state: int = 42, |
| | nr_docs: int = 4, |
| | diversity: float = None, |
| | doc_length: int = None, |
| | tokenizer: Union[str, Callable] = None |
| | ): |
| | set_seed(random_state) |
| | if isinstance(model, str): |
| | self.model = pipeline("text-generation", model=model) |
| | elif isinstance(model, Pipeline): |
| | self.model = model |
| | else: |
| | raise ValueError("Make sure that the HF model that you" |
| | "pass is either a string referring to a" |
| | "HF model or a `transformers.pipeline` object.") |
| | self.prompt = prompt if prompt is not None else DEFAULT_PROMPT |
| | self.default_prompt_ = DEFAULT_PROMPT |
| | self.pipeline_kwargs = pipeline_kwargs |
| | self.nr_docs = nr_docs |
| | self.diversity = diversity |
| | self.doc_length = doc_length |
| | self.tokenizer = tokenizer |
| |
|
| | self.prompts_ = [] |
| |
|
| | def extract_topics(self, |
| | topic_model, |
| | documents: pd.DataFrame, |
| | c_tf_idf: csr_matrix, |
| | topics: Mapping[str, List[Tuple[str, float]]] |
| | ) -> Mapping[str, List[Tuple[str, float]]]: |
| | """ Extract topic representations and return a single label |
| | |
| | Arguments: |
| | topic_model: A BERTopic model |
| | documents: Not used |
| | c_tf_idf: Not used |
| | topics: The candidate topics as calculated with c-TF-IDF |
| | |
| | Returns: |
| | updated_topics: Updated topic representations |
| | """ |
| | |
| | if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt: |
| | repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( |
| | c_tf_idf, |
| | documents, |
| | topics, |
| | 500, |
| | self.nr_docs, |
| | self.diversity |
| | ) |
| | else: |
| | repr_docs_mappings = {topic: None for topic in topics.keys()} |
| |
|
| | updated_topics = {} |
| | for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): |
| |
|
| | |
| | truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] if docs is not None else docs |
| | prompt = self._create_prompt(truncated_docs, topic, topics) |
| | self.prompts_.append(prompt) |
| |
|
| | |
| | topic_description = self.model(prompt, **self.pipeline_kwargs) |
| | topic_description = [(description["generated_text"].replace(prompt, ""), 1) for description in topic_description] |
| |
|
| | if len(topic_description) < 10: |
| | topic_description += [("", 0) for _ in range(10-len(topic_description))] |
| |
|
| | updated_topics[topic] = topic_description |
| |
|
| | return updated_topics |
| |
|
| | def _create_prompt(self, docs, topic, topics): |
| | keywords = ", ".join(list(zip(*topics[topic]))[0]) |
| |
|
| | |
| | if self.prompt == DEFAULT_PROMPT: |
| | prompt = self.prompt.replace("[KEYWORDS]", keywords) |
| |
|
| | |
| | |
| | else: |
| | prompt = self.prompt |
| | if "[KEYWORDS]" in prompt: |
| | prompt = prompt.replace("[KEYWORDS]", keywords) |
| | if "[DOCUMENTS]" in prompt: |
| | to_replace = "" |
| | for doc in docs: |
| | to_replace += f"- {doc}\n" |
| | prompt = prompt.replace("[DOCUMENTS]", to_replace) |
| |
|
| | return prompt |
| |
|