Spaces:
Sleeping
Sleeping
File size: 1,620 Bytes
1f725d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import logging
from src.MultiRag.models.rag_model import State
from utils.asyncHandler import asyncHandler
from src.MultiRag.llm.llm_loader import llm
from src.MultiRag.models.queries_model import Queries
from src.MultiRag.prompts.prompt_templates import QUERY_GENERATION_PROMPT
from src.MultiRag.nodes.retreiver_check_node import get_cached_retriever
from langchain_core.messages import SystemMessage, HumanMessage
@asyncHandler
async def query_generator(state: State):
logging.info("Generating queries...")
llm_ = llm.with_structured_output(Queries)
system_content = QUERY_GENERATION_PROMPT
prompt = [
SystemMessage(content=system_content),
SystemMessage(content=f"summary of the user uploaded content keywords with weightage: {state['summary']}"),
HumanMessage(content=f"userQuery: {state['userQuery']}")
]
logging.debug(f"Query generator prompt: {prompt}")
res = await llm_.ainvoke(prompt)
logging.info(f"Generated {len(res.queries)} queries.")
# Fetch retriever from server-side cache (keyed by db_path, NOT stored in state)
retreiver = get_cached_retriever(state['db_path'])
if retreiver is None:
logging.error(f"Retriever not found in cache for db_path={state['db_path']}")
return {"retreiver_responses": [], "queries": res.queries}
responses = []
for r in res.queries:
logging.info(f"Invoking retriever with query: {r}")
responses.append(await retreiver.invoke(r))
logging.info("Query generation and retrieval completed.")
return {"retreiver_responses": responses, "queries": res.queries}
|