File size: 7,932 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from transformers import pipeline, set_seed
from transformers.pipelines.base import Pipeline
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
DEFAULT_PROMPT = """
I have a topic described by the following keywords: [KEYWORDS].
The name of this topic is:
"""
class TextGeneration(BaseRepresentation):
""" Text2Text or text generation with transformers
Arguments:
model: A transformers pipeline that should be initialized as "text-generation"
for gpt-like models or "text2text-generation" for T5-like models.
For example, `pipeline('text-generation', model='gpt2')`. If a string
is passed, "text-generation" will be selected by default.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
when it is called.
random_state: A random state to be passed to `transformers.set_seed`
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and trunctated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use a gpt-like model:
```python
from bertopic.representation import TextGeneration
from bertopic import BERTopic
# Create your representation model
generator = pipeline('text-generation', model='gpt2')
representation_model = TextGeneration(generator)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTo pic(representation_model=representation_model)
```
You can use a custom prompt and decide where the keywords should
be inserted by using the `[KEYWORDS]` or documents with thte `[DOCUMENTS]` tag:
```python
from bertopic.representation import TextGeneration
prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?""
# Create your representation model
generator = pipeline('text2text-generation', model='google/flan-t5-base')
representation_model = TextGeneration(generator)
```
"""
def __init__(self,
model: Union[str, pipeline],
prompt: str = None,
pipeline_kwargs: Mapping[str, Any] = {},
random_state: int = 42,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None
):
set_seed(random_state)
if isinstance(model, str):
self.model = pipeline("text-generation", model=model)
elif isinstance(model, Pipeline):
self.model = model
else:
raise ValueError("Make sure that the HF model that you"
"pass is either a string referring to a"
"HF model or a `transformers.pipeline` object.")
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.pipeline_kwargs = pipeline_kwargs
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
self.prompts_ = []
def extract_topics(self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
""" Extract topic representations and return a single label
Arguments:
topic_model: A BERTopic model
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt:
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf,
documents,
topics,
500,
self.nr_docs,
self.diversity
)
else:
repr_docs_mappings = {topic: None for topic in topics.keys()}
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
# Prepare prompt
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] if docs is not None else docs
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Extract result from generator and use that as label
topic_description = self.model(prompt, **self.pipeline_kwargs)
topic_description = [(description["generated_text"].replace(prompt, ""), 1) for description in topic_description]
if len(topic_description) < 10:
topic_description += [("", 0) for _ in range(10-len(topic_description))]
updated_topics[topic] = topic_description
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = ", ".join(list(zip(*topics[topic]))[0])
# Use the default prompt and replace keywords
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", keywords)
# Use a prompt that leverages either keywords or documents in
# a custom location
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", keywords)
if "[DOCUMENTS]" in prompt:
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
|