CarsRUS / agent.py
galbendavids's picture
CarsRUS: sticky header, single final answer (no duplicate), API link in footer
ea834ed
"""
LangGraph agent: orchestrates RAG pipeline (retrieve โ†’ generate โ†’ end).
Single LLM call so the user always gets a final verbal answer that aggregates the retrieved context.
"""
import logging
from typing import Optional, List, TypedDict, Literal
from langgraph.graph import StateGraph, END
from rag_engine import RAGEngine
# Use same "pipeline" logger as rag_engine so all [PIPELINE] logs appear together
PIPELINE_LOG = logging.getLogger("pipeline")
class AgentState(TypedDict, total=False):
query: str
api_key: str
refusal: Optional[str]
system_prompt: Optional[str]
user_prompt: Optional[str]
steps_log: List[str]
draft_answer: Optional[str]
feedback: Optional[str]
iteration: int
def build_agent_graph(engine: RAGEngine):
"""Build the LangGraph: retrieve โ†’ generate โ†’ end (one LLM call for final answer)."""
def retrieve(state: AgentState) -> dict:
"""Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log."""
query = state["query"]
PIPELINE_LOG.info("retrieve START query=%r", query[:80] if query else "")
refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query)
if refusal:
PIPELINE_LOG.info("retrieve END refusal len=%d", len(refusal or ""))
return {"refusal": refusal, "steps_log": steps_log}
PIPELINE_LOG.info("retrieve END prompts ready steps=%d", len(steps_log or []))
return {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"steps_log": steps_log,
"iteration": 0,
}
def generate(state: AgentState) -> dict:
"""Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log."""
PIPELINE_LOG.info("generate START")
if state.get("api_key"):
engine.configure_api(state["api_key"])
system_prompt = state["system_prompt"]
user_prompt = state["user_prompt"]
feedback = state.get("feedback") or ""
steps_log = list(state.get("steps_log") or [])
if feedback:
steps_log.append(f"๐Ÿ”„ Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...")
else:
steps_log.append("๐Ÿ’ญ Generating response with Gemini...")
full_prompt = user_prompt
if feedback:
full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:"
models = ["gemini-2.0-flash", "gemini-1.5-flash"]
draft = engine._call_api_with_backoff(system_prompt, full_prompt, models)
PIPELINE_LOG.info("generate END draft_answer len=%d preview=%s", len(draft or ""), (draft or "")[:150])
steps_log.append("โœ… Draft generated")
return {"draft_answer": draft, "steps_log": steps_log}
def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]:
if state.get("refusal"):
return "end"
return "generate"
# Pipeline: retrieve โ†’ generate โ†’ end (no evaluate/refine โ€“ one LLM call so user always gets final answer)
workflow = StateGraph(AgentState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.set_entry_point("retrieve")
workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"})
workflow.add_edge("generate", END)
return workflow.compile()
def run_stream(engine: RAGEngine, graph, query: str, api_key: str):
"""
Run the agent graph and yield progress (steps + draft) for each step.
Updates engine cache and history with the final answer. Yields strings for Gradio.
Ensures the user always sees a final verbal answer (or a clear error message).
"""
PIPELINE_LOG.info("run_stream START query=%r", query[:80] if query else "")
initial: AgentState = {"query": query, "api_key": api_key}
last_state: AgentState = initial
for state in graph.stream(initial, stream_mode="values"):
last_state = state
steps_log = state.get("steps_log") or []
PIPELINE_LOG.info("run_stream state: steps=%d", len(steps_log))
yield "\n".join(steps_log) # only pipeline progress; answer comes once in final yield
# Final state: ensure user always sees the verbal answer (aggregated from pipeline)
final_answer = (last_state.get("refusal") or last_state.get("draft_answer") or "").strip()
steps_log = list(last_state.get("steps_log") or [])
steps_log.append("โœ… Done")
PIPELINE_LOG.info("run_stream final_answer len=%d is_error=%s", len(final_answer), final_answer.startswith("โš ") or final_answer.startswith("โŒ") or final_answer.startswith("โฑ"))
if not final_answer:
final_answer = (
"ืœื ื”ืชืงื‘ืœื” ืชืฉื•ื‘ื” ืžื”ืžื•ื“ืœ. ื™ื™ืชื›ืŸ ืฉื ื—ืกืžื” ืื• ืฉื”ื‘ืงืฉื” ืืจื›ื” ืžื“ื™. "
"ื ืกื” ืœืงืฆืจ ืืช ื”ืฉืืœื” ืื• ืœืฉืื•ืœ ืฉื•ื‘."
)
PIPELINE_LOG.warning("run_stream final_answer was empty, using fallback")
if not any(final_answer.startswith(p) for p in ("โš ๏ธ", "โŒ", "โฑ๏ธ")):
cache_key = engine._get_cache_key(query)
engine.response_cache[cache_key] = final_answer
engine._maintain_conversation_history(query, final_answer)
# One final yield: main pipeline (steps) for UX, then the answer once (no duplicate)
steps_text = chr(10).join(steps_log)
is_error = any(final_answer.startswith(p) for p in ("โš ๏ธ", "โŒ", "โฑ๏ธ"))
block = "--- ืคืจื˜ื™ ืขื™ื‘ื•ื“ ---\n\n" + steps_text + "\n\n"
if is_error:
block += "--- ื‘ืขื™ื” ื–ืžื ื™ืช (ืœื ืชืฉื•ื‘ื”) ---\n\n"
block += final_answer + "\n\n"
block += "ื–ื• ืœื ื”ืชืฉื•ื‘ื” ืœืฉืืœื” โ€“ ื ืกื” ืฉื•ื‘ ื‘ืขื•ื“ ื“ืงื”ึพืฉืชื™ื™ื."
else:
block += "--- ื”ืชืฉื•ื‘ื” ---\n\n" + final_answer
PIPELINE_LOG.info("run_stream yielding final block len=%d", len(block))
yield block