File size: 1,675 Bytes
5a3b322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from __future__ import annotations
"""
LLM-based QueryPlan builder using Gemini + LangChain structured output.
Falls back to deterministic rewrite if LLM parsing fails.
"""
import os
from typing import Dict, Optional
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from schemas.query_plan import QueryPlan
from tools.query_plan_tool import build_query_plan as deterministic_plan
SYSTEM_PROMPT = """You are a retrieval planner for assessment recommendations.
Extract intent, role, skills, duration, language. Produce BM25/vec queries (keyword-heavy)
and a rerank query (full original). Keep to the schema exactly.
"""
TEMPLATE = """{system}
User query:
{query}
Return ONLY valid JSON for the QueryPlan model.
"""
def build_query_plan_llm(raw_text: str, vocab: Optional[Dict] = None, model_name: str = "gemini-pro") -> QueryPlan:
try:
parser = PydanticOutputParser(pydantic_object=QueryPlan)
prompt = PromptTemplate(
template=TEMPLATE,
input_variables=["query"],
partial_variables={"system": SYSTEM_PROMPT},
).partial(format_instructions=parser.get_format_instructions())
llm = ChatGoogleGenerativeAI(
model=model_name,
temperature=0.2,
max_output_tokens=512,
convert_system_message_to_human=True,
)
chain = prompt | llm | parser
return chain.invoke({"query": raw_text})
except Exception as e:
# Fallback to deterministic rewriter
return deterministic_plan(raw_text, vocab=vocab)
|