File size: 3,150 Bytes
fc1b83d
 
 
2e8bb22
 
fc1b83d
 
 
 
2e8bb22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b0424
2e8bb22
a4b0424
 
 
 
 
 
 
2e8bb22
a4b0424
43199e3
a4b0424
 
 
 
2e8bb22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json

import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.messages import ToolMessage
from langchain_core.messages.base import BaseMessage
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import MarkdownHeaderTextSplitter

from config.settings import config


def parse_mark_down(data: str) -> list:
    headers_to_split_on = [
        ("#", "Header 1"),
        ("##", "Header 2"),
    ]

    markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
    md_header_splits = markdown_splitter.split_text(data)
    return md_header_splits


class OversizedContentHandler:
    """Main handler for content that exceeds context limits"""

    def __init__(self,
                 model_name: str = "gpt-4.1",
                 max_context_tokens: int = 8000,
                 reserved_tokens: int = 2000):
        self.encoding = tiktoken.encoding_for_model(model_name)
        self.max_context_tokens = max_context_tokens
        self.reserved_tokens = reserved_tokens
        self.max_chunk_tokens = max_context_tokens - reserved_tokens

    def count_tokens(self, text: str) -> int:
        return len(self.encoding.encode(text))

    def extract_relevant_chunks(self, content: str, query: str):
        # Try to check if the content can be parsed with a Markdown parser
        md_chunks = parse_mark_down(content)
        # Further split large chunks
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=15000, chunk_overlap=500)
        final_chunks = text_splitter.split_documents(md_chunks)

        embeddings = OpenAIEmbeddings()
        vector_db = FAISS.from_documents(final_chunks, embeddings)

        relevant_chunks = vector_db.similarity_search(query, k=3)
        # Concatenate relevant chunk and update last message content
        context_with_metadata = [
            {"text": doc.page_content, "source": doc.metadata.get("source")}
            for doc in relevant_chunks
        ]
        return context_with_metadata

    def process_oversized_message(self, message: BaseMessage, query: str) -> bool:
        chunked = False
        # At this point we are chunking only tavily_extract results messages
        json_content = None
        if isinstance(message, ToolMessage) and message.name == "tavily_extract":
            try:
                json_content = json.loads(message.content)
            except Exception as e:
                print("cannot parse message")
            if json_content:
                result = json_content['results'][0]
                raw_content = result['raw_content']

                content_size = self.count_tokens(raw_content)
                if content_size > config.max_tokens:
                    print(f"Proceed with chunking, evaluated no of tokens {content_size} for message {message.id}")
                    chunked = True
                    result['raw_content'] = self.extract_relevant_chunks(raw_content, query=query)
                    message.content = json.dumps(json_content)
        return chunked