File size: 5,218 Bytes
409c17a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""

Application Layer - Query Processing Use Case



Orchestrates the RAG pipeline for answering user queries.

"""
import time
from typing import List

from app.application.dto import QueryDTO, QueryResponseDTO, SourceDTO
from app.domain.entities import Query, QueryRequest, Source
from app.domain.interfaces import ILLM, ICache, IPromptBuilder, IReranker, IRetriever


class QueryProcessingUseCase:
    """Use case for processing user queries through RAG pipeline"""

    def __init__(

        self,

        retriever: IRetriever,

        reranker: IReranker,

        llm: ILLM,

        prompt_builder: IPromptBuilder,

        cache: ICache,

    ):
        self.retriever = retriever
        self.reranker = reranker
        self.llm = llm
        self.prompt_builder = prompt_builder
        self.cache = cache

    async def execute(self, query_dto: QueryDTO) -> QueryResponseDTO:
        """Execute query processing pipeline"""
        start_time = time.time()

        # 1. Create query request
        query_request = QueryRequest(
            query_text=query_dto.query_text,
            department=query_dto.department,
            user_id=query_dto.user_id,
            session_id=query_dto.session_id,
            top_k=query_dto.top_k,
            temperature=query_dto.temperature,
            max_tokens=query_dto.max_tokens,
            filters=query_dto.filters,
        )

        # 2. Check semantic cache
        cache_key = f"query:{hash(query_dto.query_text)}:{query_dto.department}"
        cached_response = await self.cache.get(cache_key)
        if cached_response:
            return cached_response

        # 3. Retrieve relevant documents
        filters = {"department": query_dto.department}
        if query_dto.filters:
            filters.update(query_dto.filters)

        retrieval_results = await self.retriever.hybrid_search(
            query=query_dto.query_text,
            top_k=100,  # Initial retrieval
            alpha=0.5,
            filters=filters,
        )

        # 4. Rerank results
        reranked_results = await self.reranker.rerank(
            query=query_dto.query_text, results=retrieval_results, top_k=query_dto.top_k
        )

        # 5. Build context
        context = [result.content for result in reranked_results]

        # 6. Build prompt
        messages = self.prompt_builder.build_rag_prompt(
            query=query_dto.query_text,
            context=context,
            system_prompt=self._get_system_prompt(query_dto.department),
        )

        # 7. Generate answer
        llm_response = await self.llm.generate(
            messages=messages,
            temperature=query_dto.temperature,
            max_tokens=query_dto.max_tokens,
        )

        # 8. Create sources
        sources = [
            SourceDTO(
                title=f"Document {result.document_id}",
                content=result.content[:500],  # Truncate for response
                relevance_score=result.score,
                document_id=result.document_id,
                chunk_index=result.chunk_index,
                metadata=result.metadata,
            )
            for result in reranked_results
        ]

        # 9. Calculate metrics
        processing_time_ms = int((time.time() - start_time) * 1000)

        # 10. Build response
        response = QueryResponseDTO(
            query_id=str(query_request.id) if hasattr(query_request, "id") else "temp",
            answer=llm_response.content,
            sources=sources,
            confidence=self._calculate_confidence(reranked_results),
            processing_time_ms=processing_time_ms,
            tokens_used=llm_response.tokens_used,
            model=llm_response.model,
        )

        # 11. Cache response
        await self.cache.set(cache_key, response, ttl=3600)

        return response

    def _get_system_prompt(self, department: str) -> str:
        """Get department-specific system prompt"""
        prompts = {
            "HR": "You are a helpful HR assistant for employee onboarding. Provide clear, accurate information about HR policies, benefits, and procedures.",
            "IT": "You are an IT support assistant for new employees. Help with technical setup, access, and IT policies.",
            "Legal": "You are a legal compliance assistant. Provide information about legal policies, regulations, and compliance requirements.",
            "Finance": "You are a finance assistant. Help with expense policies, financial procedures, and budget information.",
            "General": "You are a helpful corporate onboarding assistant. Provide accurate information to help new employees integrate successfully.",
        }
        return prompts.get(department, prompts["General"])

    def _calculate_confidence(self, results: List) -> float:
        """Calculate confidence score based on retrieval results"""
        if not results:
            return 0.0
        # Average of top 3 scores
        top_scores = [r.score for r in results[:3]]
        return sum(top_scores) / len(top_scores) if top_scores else 0.0