File size: 1,620 Bytes
1f725d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
from src.MultiRag.models.rag_model import State
from utils.asyncHandler import asyncHandler
from src.MultiRag.llm.llm_loader import llm
from src.MultiRag.models.queries_model import Queries
from src.MultiRag.prompts.prompt_templates import QUERY_GENERATION_PROMPT
from src.MultiRag.nodes.retreiver_check_node import get_cached_retriever

from langchain_core.messages import SystemMessage, HumanMessage

@asyncHandler
async def query_generator(state: State):
    logging.info("Generating queries...")
    llm_ = llm.with_structured_output(Queries)

    system_content = QUERY_GENERATION_PROMPT

    prompt = [
        SystemMessage(content=system_content),
        SystemMessage(content=f"summary of the user uploaded content keywords with weightage: {state['summary']}"),
        HumanMessage(content=f"userQuery: {state['userQuery']}")
    ]
    logging.debug(f"Query generator prompt: {prompt}")
    res = await llm_.ainvoke(prompt)
    logging.info(f"Generated {len(res.queries)} queries.")

    # Fetch retriever from server-side cache (keyed by db_path, NOT stored in state)
    retreiver = get_cached_retriever(state['db_path'])
    if retreiver is None:
        logging.error(f"Retriever not found in cache for db_path={state['db_path']}")
        return {"retreiver_responses": [], "queries": res.queries}

    responses = []
    for r in res.queries:
        logging.info(f"Invoking retriever with query: {r}")
        responses.append(await retreiver.invoke(r))

    logging.info("Query generation and retrieval completed.")
    return {"retreiver_responses": responses, "queries": res.queries}