File size: 3,731 Bytes
d2570c2
6ad997d
 
 
 
d2570c2
 
6ad997d
d2570c2
 
6ad997d
d2570c2
 
 
 
 
6ad997d
d2570c2
 
 
 
 
6ad997d
 
 
 
 
d2570c2
 
 
 
 
 
 
6ad997d
d2570c2
 
 
6ad997d
d2570c2
 
 
 
6ad997d
 
 
 
 
d2570c2
 
6ad997d
 
 
 
 
d2570c2
6ad997d
 
 
d2570c2
6ad997d
 
 
d2570c2
 
 
 
 
6ad997d
d2570c2
 
 
 
 
6ad997d
d2570c2
 
 
b4bfa19
 
 
d2570c2
 
 
 
6ad997d
 
 
 
 
d2570c2
 
 
 
 
 
 
 
 
6ad997d
d2570c2
 
 
 
 
6ad997d
 
d2570c2
 
6ad997d
d2570c2
 
6ad997d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import List, Any

from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage

from src.core.rag.llm import LLMFactory
from src.utils import setup_logger


logger = setup_logger(__name__)


class ContextCompressor:
    """
    Service to compress RAG context and Conversation History.
    Reduces token usage and 'Lost in the Middle' phenomenon.
    """

    def __init__(self):
        # We use a cheaper/faster model for summarization if possible
        # For now, we reuse the default provider from LLMFactory
        pass

    async def compress_history(
        self,
        history: List[BaseMessage],
        max_token_limit: int = 2000,  # noqa: ARG002
    ) -> List[BaseMessage]:
        """
        Compress conversation history if it exceeds limits.
        Strategy: Keep last N messages raw, summarize the rest.
        """
        # Simple heuristic: If history > 10 messages, summarize the oldest ones
        if len(history) <= 6:
            return history

        # Keep last 4 messages (2 turns) intact
        recent_history = history[-4:]
        older_history = history[:-4]

        # If older history is small, just return (avoid unnecessary summarization calls)
        if len(older_history) < 2:
            return history

        logger.info(
            "Compressing history: %d messages -> Summary + 4 recent",
            len(history),
        )

        try:
            summary = await self._summarize_messages(older_history)
            return [
                SystemMessage(
                    content=f"Previous Conversation Summary: {summary}"
                )
            ] + recent_history
        except Exception as e:
            logger.error("History compression failed: %s", e)
            # Fallback: return full history (or could slice)
            return history

    async def _summarize_messages(
        self, messages: List[BaseMessage]
    ) -> str:
        """Use LLM to summarize a list of messages."""
        conversation_text = ""
        for msg in messages:
            role = "User" if isinstance(msg, HumanMessage) else "AI"
            conversation_text += f"{role}: {msg.content}\n"

        prompt = (
            "Summarize the following conversation concisely, focusing on key user preferences and questions. "
            "Do not lose important details.\n\n"
            f"{conversation_text}"
        )

        # Use simple mock if running in test environment/benchmark without keys
        try:
            llm = LLMFactory.create(temperature=0.3)
        except Exception:
            # Fallback to mock when env/keys not set (e.g. tests, benchmarks)
            llm = LLMFactory.create(provider="mock")

        response = llm.invoke([HumanMessage(content=prompt)])
        return response.content

    def format_docs(
        self,
        docs: List[Any],
        max_len_per_doc: int = 500,
    ) -> str:
        """
        Format retrieved documents for the LLM Prompt.
        Truncates content to avoid context overflow.
        """
        formatted = ""
        for i, doc in enumerate(docs):
            content = doc.page_content.replace("\n", " ")
            if len(content) > max_len_per_doc:
                content = content[:max_len_per_doc] + "..."

            # Add Relevance Score if available (from Reranker)
            score_info = ""
            if doc.metadata and "relevance_score" in doc.metadata:
                score = doc.metadata["relevance_score"]
                score_info = f" (Relevance: {score:.2f})"

            formatted += f"[{i + 1}] {content}{score_info}\n"
        return formatted


# Singleton
compressor = ContextCompressor()

__all__ = ["ContextCompressor", "compressor"]