""" LangGraph agent: orchestrates RAG pipeline (retrieve → generate → end). Single LLM call so the user always gets a final verbal answer that aggregates the retrieved context. """ import logging from typing import Optional, List, TypedDict, Literal from langgraph.graph import StateGraph, END from rag_engine import RAGEngine # Use same "pipeline" logger as rag_engine so all [PIPELINE] logs appear together PIPELINE_LOG = logging.getLogger("pipeline") class AgentState(TypedDict, total=False): query: str api_key: str refusal: Optional[str] system_prompt: Optional[str] user_prompt: Optional[str] steps_log: List[str] draft_answer: Optional[str] feedback: Optional[str] iteration: int def build_agent_graph(engine: RAGEngine): """Build the LangGraph: retrieve → generate → end (one LLM call for final answer).""" def retrieve(state: AgentState) -> dict: """Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log.""" query = state["query"] PIPELINE_LOG.info("retrieve START query=%r", query[:80] if query else "") refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query) if refusal: PIPELINE_LOG.info("retrieve END refusal len=%d", len(refusal or "")) return {"refusal": refusal, "steps_log": steps_log} PIPELINE_LOG.info("retrieve END prompts ready steps=%d", len(steps_log or [])) return { "system_prompt": system_prompt, "user_prompt": user_prompt, "steps_log": steps_log, "iteration": 0, } def generate(state: AgentState) -> dict: """Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log.""" PIPELINE_LOG.info("generate START") if state.get("api_key"): engine.configure_api(state["api_key"]) system_prompt = state["system_prompt"] user_prompt = state["user_prompt"] feedback = state.get("feedback") or "" steps_log = list(state.get("steps_log") or []) if feedback: steps_log.append(f"🔄 Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...") else: steps_log.append("💭 Generating response with Gemini...") full_prompt = user_prompt if feedback: full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:" models = ["gemini-2.0-flash", "gemini-1.5-flash"] draft = engine._call_api_with_backoff(system_prompt, full_prompt, models) PIPELINE_LOG.info("generate END draft_answer len=%d preview=%s", len(draft or ""), (draft or "")[:150]) steps_log.append("✅ Draft generated") return {"draft_answer": draft, "steps_log": steps_log} def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]: if state.get("refusal"): return "end" return "generate" # Pipeline: retrieve → generate → end (no evaluate/refine – one LLM call so user always gets final answer) workflow = StateGraph(AgentState) workflow.add_node("retrieve", retrieve) workflow.add_node("generate", generate) workflow.set_entry_point("retrieve") workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"}) workflow.add_edge("generate", END) return workflow.compile() def run_stream(engine: RAGEngine, graph, query: str, api_key: str): """ Run the agent graph and yield progress (steps + draft) for each step. Updates engine cache and history with the final answer. Yields strings for Gradio. Ensures the user always sees a final verbal answer (or a clear error message). """ PIPELINE_LOG.info("run_stream START query=%r", query[:80] if query else "") initial: AgentState = {"query": query, "api_key": api_key} last_state: AgentState = initial for state in graph.stream(initial, stream_mode="values"): last_state = state steps_log = state.get("steps_log") or [] PIPELINE_LOG.info("run_stream state: steps=%d", len(steps_log)) yield "\n".join(steps_log) # only pipeline progress; answer comes once in final yield # Final state: ensure user always sees the verbal answer (aggregated from pipeline) final_answer = (last_state.get("refusal") or last_state.get("draft_answer") or "").strip() steps_log = list(last_state.get("steps_log") or []) steps_log.append("✅ Done") PIPELINE_LOG.info("run_stream final_answer len=%d is_error=%s", len(final_answer), final_answer.startswith("⚠") or final_answer.startswith("❌") or final_answer.startswith("⏱")) if not final_answer: final_answer = ( "לא התקבלה תשובה מהמודל. ייתכן שנחסמה או שהבקשה ארכה מדי. " "נסה לקצר את השאלה או לשאול שוב." ) PIPELINE_LOG.warning("run_stream final_answer was empty, using fallback") if not any(final_answer.startswith(p) for p in ("⚠️", "❌", "⏱️")): cache_key = engine._get_cache_key(query) engine.response_cache[cache_key] = final_answer engine._maintain_conversation_history(query, final_answer) # One final yield: main pipeline (steps) for UX, then the answer once (no duplicate) steps_text = chr(10).join(steps_log) is_error = any(final_answer.startswith(p) for p in ("⚠️", "❌", "⏱️")) block = "--- פרטי עיבוד ---\n\n" + steps_text + "\n\n" if is_error: block += "--- בעיה זמנית (לא תשובה) ---\n\n" block += final_answer + "\n\n" block += "זו לא התשובה לשאלה – נסה שוב בעוד דקה־שתיים." else: block += "--- התשובה ---\n\n" + final_answer PIPELINE_LOG.info("run_stream yielding final block len=%d", len(block)) yield block