|
|
import json |
|
|
|
|
|
import tiktoken |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_core.messages import ToolMessage |
|
|
from langchain_core.messages.base import BaseMessage |
|
|
from langchain_openai import OpenAIEmbeddings |
|
|
from langchain_text_splitters import MarkdownHeaderTextSplitter |
|
|
|
|
|
from config.settings import config |
|
|
|
|
|
|
|
|
def parse_mark_down(data: str) -> list: |
|
|
headers_to_split_on = [ |
|
|
("#", "Header 1"), |
|
|
("##", "Header 2"), |
|
|
] |
|
|
|
|
|
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) |
|
|
md_header_splits = markdown_splitter.split_text(data) |
|
|
return md_header_splits |
|
|
|
|
|
|
|
|
class OversizedContentHandler: |
|
|
"""Main handler for content that exceeds context limits""" |
|
|
|
|
|
def __init__(self, |
|
|
model_name: str = "gpt-4.1", |
|
|
max_context_tokens: int = 8000, |
|
|
reserved_tokens: int = 2000): |
|
|
self.encoding = tiktoken.encoding_for_model(model_name) |
|
|
self.max_context_tokens = max_context_tokens |
|
|
self.reserved_tokens = reserved_tokens |
|
|
self.max_chunk_tokens = max_context_tokens - reserved_tokens |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
return len(self.encoding.encode(text)) |
|
|
|
|
|
def extract_relevant_chunks(self, content: str, query: str): |
|
|
|
|
|
md_chunks = parse_mark_down(content) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=15000, chunk_overlap=500) |
|
|
final_chunks = text_splitter.split_documents(md_chunks) |
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
vector_db = FAISS.from_documents(final_chunks, embeddings) |
|
|
|
|
|
relevant_chunks = vector_db.similarity_search(query, k=3) |
|
|
|
|
|
context_with_metadata = [ |
|
|
{"text": doc.page_content, "source": doc.metadata.get("source")} |
|
|
for doc in relevant_chunks |
|
|
] |
|
|
return context_with_metadata |
|
|
|
|
|
def process_oversized_message(self, message: BaseMessage, query: str) -> bool: |
|
|
chunked = False |
|
|
|
|
|
json_content = None |
|
|
if isinstance(message, ToolMessage) and message.name == "tavily_extract": |
|
|
try: |
|
|
json_content = json.loads(message.content) |
|
|
except Exception as e: |
|
|
print("cannot parse message") |
|
|
if json_content: |
|
|
result = json_content['results'][0] |
|
|
raw_content = result['raw_content'] |
|
|
|
|
|
content_size = self.count_tokens(raw_content) |
|
|
if content_size > config.max_tokens: |
|
|
print(f"Proceed with chunking, evaluated no of tokens {content_size} for message {message.id}") |
|
|
chunked = True |
|
|
result['raw_content'] = self.extract_relevant_chunks(raw_content, query=query) |
|
|
message.content = json.dumps(json_content) |
|
|
return chunked |
|
|
|