|
|
import pandas as pd |
|
|
from langchain.docstore.document import Document |
|
|
from scipy.sparse import csr_matrix |
|
|
from typing import Callable, Dict, Mapping, List, Tuple, Union |
|
|
|
|
|
from bertopic.representation._base import BaseRepresentation |
|
|
from bertopic.representation._utils import truncate_document |
|
|
|
|
|
DEFAULT_PROMPT = "What are these documents about? Please give a single label." |
|
|
|
|
|
|
|
|
class LangChain(BaseRepresentation): |
|
|
""" Using chains in langchain to generate topic labels. |
|
|
|
|
|
The classic example uses `langchain.chains.question_answering.load_qa_chain`. |
|
|
This returns a chain that takes a list of documents and a question as input. |
|
|
|
|
|
You can also use Runnables such as those composed using the LangChain Expression Language. |
|
|
|
|
|
Arguments: |
|
|
chain: The langchain chain or Runnable with a `batch` method. |
|
|
Input keys must be `input_documents` and `question`. |
|
|
Output key must be `output_text`. |
|
|
prompt: The prompt to be used in the model. If no prompt is given, |
|
|
`self.default_prompt_` is used instead. |
|
|
nr_docs: The number of documents to pass to LangChain if a prompt |
|
|
with the `["DOCUMENTS"]` tag is used. |
|
|
diversity: The diversity of documents to pass to LangChain. |
|
|
Accepts values between 0 and 1. A higher |
|
|
values results in passing more diverse documents |
|
|
whereas lower values passes more similar documents. |
|
|
doc_length: The maximum length of each document. If a document is longer, |
|
|
it will be truncated. If None, the entire document is passed. |
|
|
tokenizer: The tokenizer used to calculate to split the document into segments |
|
|
used to count the length of a document. |
|
|
* If tokenizer is 'char', then the document is split up |
|
|
into characters which are counted to adhere to `doc_length` |
|
|
* If tokenizer is 'whitespace', the document is split up |
|
|
into words separated by whitespaces. These words are counted |
|
|
and truncated depending on `doc_length` |
|
|
* If tokenizer is 'vectorizer', then the internal CountVectorizer |
|
|
is used to tokenize the document. These tokens are counted |
|
|
and trunctated depending on `doc_length`. They are decoded with |
|
|
whitespaces. |
|
|
* If tokenizer is a callable, then that callable is used to tokenize |
|
|
the document. These tokens are counted and truncated depending |
|
|
on `doc_length` |
|
|
chain_config: The configuration for the langchain chain. Can be used to set options |
|
|
like max_concurrency to avoid rate limiting errors. |
|
|
Usage: |
|
|
|
|
|
To use this, you will need to install the langchain package first. |
|
|
Additionally, you will need an underlying LLM to support langchain, |
|
|
like openai: |
|
|
|
|
|
`pip install langchain` |
|
|
`pip install openai` |
|
|
|
|
|
Then, you can create your chain as follows: |
|
|
|
|
|
```python |
|
|
from langchain.chains.question_answering import load_qa_chain |
|
|
from langchain.llms import OpenAI |
|
|
chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff") |
|
|
``` |
|
|
|
|
|
Finally, you can pass the chain to BERTopic as follows: |
|
|
|
|
|
```python |
|
|
from bertopic.representation import LangChain |
|
|
|
|
|
# Create your representation model |
|
|
representation_model = LangChain(chain) |
|
|
|
|
|
# Use the representation model in BERTopic on top of the default pipeline |
|
|
topic_model = BERTopic(representation_model=representation_model) |
|
|
``` |
|
|
|
|
|
You can also use a custom prompt: |
|
|
|
|
|
```python |
|
|
prompt = "What are these documents about? Please give a single label." |
|
|
representation_model = LangChain(chain, prompt=prompt) |
|
|
``` |
|
|
|
|
|
You can also use a Runnable instead of a chain. |
|
|
The example below uses the LangChain Expression Language: |
|
|
|
|
|
```python |
|
|
from bertopic.representation import LangChain |
|
|
from langchain.chains.question_answering import load_qa_chain |
|
|
from langchain.chat_models import ChatAnthropic |
|
|
from langchain.schema.document import Document |
|
|
from langchain.schema.runnable import RunnablePassthrough |
|
|
from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer |
|
|
|
|
|
prompt = ... |
|
|
llm = ... |
|
|
|
|
|
# We will construct a special privacy-preserving chain using Microsoft Presidio |
|
|
|
|
|
pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"]) |
|
|
|
|
|
chain = ( |
|
|
{ |
|
|
"input_documents": ( |
|
|
lambda inp: [ |
|
|
Document( |
|
|
page_content=pii_handler.anonymize( |
|
|
d.page_content, |
|
|
language="en", |
|
|
), |
|
|
) |
|
|
for d in inp["input_documents"] |
|
|
] |
|
|
), |
|
|
"question": RunnablePassthrough(), |
|
|
} |
|
|
| load_qa_chain(representation_llm, chain_type="stuff") |
|
|
| (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])}) |
|
|
) |
|
|
|
|
|
representation_model = LangChain(chain, prompt=representation_prompt) |
|
|
``` |
|
|
""" |
|
|
def __init__(self, |
|
|
chain, |
|
|
prompt: str = None, |
|
|
nr_docs: int = 4, |
|
|
diversity: float = None, |
|
|
doc_length: int = None, |
|
|
tokenizer: Union[str, Callable] = None, |
|
|
chain_config = None, |
|
|
): |
|
|
self.chain = chain |
|
|
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT |
|
|
self.default_prompt_ = DEFAULT_PROMPT |
|
|
self.chain_config = chain_config |
|
|
self.nr_docs = nr_docs |
|
|
self.diversity = diversity |
|
|
self.doc_length = doc_length |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def extract_topics(self, |
|
|
topic_model, |
|
|
documents: pd.DataFrame, |
|
|
c_tf_idf: csr_matrix, |
|
|
topics: Mapping[str, List[Tuple[str, float]]] |
|
|
) -> Mapping[str, List[Tuple[str, int]]]: |
|
|
""" Extract topics |
|
|
|
|
|
Arguments: |
|
|
topic_model: A BERTopic model |
|
|
documents: All input documents |
|
|
c_tf_idf: The topic c-TF-IDF representation |
|
|
topics: The candidate topics as calculated with c-TF-IDF |
|
|
|
|
|
Returns: |
|
|
updated_topics: Updated topic representations |
|
|
""" |
|
|
|
|
|
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( |
|
|
c_tf_idf=c_tf_idf, |
|
|
documents=documents, |
|
|
topics=topics, |
|
|
nr_samples=500, |
|
|
nr_repr_docs=self.nr_docs, |
|
|
diversity=self.diversity |
|
|
) |
|
|
|
|
|
|
|
|
chain_docs: List[List[Document]] = [ |
|
|
[ |
|
|
Document( |
|
|
page_content=truncate_document( |
|
|
topic_model, |
|
|
self.doc_length, |
|
|
self.tokenizer, |
|
|
doc |
|
|
) |
|
|
) |
|
|
for doc in docs |
|
|
] |
|
|
for docs in repr_docs_mappings.values() |
|
|
] |
|
|
|
|
|
|
|
|
inputs = [ |
|
|
{"input_documents": docs, "question": self.prompt} |
|
|
for docs in chain_docs |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.chain.batch(inputs=inputs, config=self.chain_config) |
|
|
labels = [output["output_text"].strip() for output in outputs] |
|
|
|
|
|
updated_topics = { |
|
|
topic: [(label, 1)] + [("", 0) for _ in range(9)] |
|
|
for topic, label in zip(repr_docs_mappings.keys(), labels) |
|
|
} |
|
|
|
|
|
return updated_topics |
|
|
|