Spaces:
Sleeping
Sleeping
| """ | |
| LangGraph agent: orchestrates RAG pipeline (retrieve โ generate โ end). | |
| Single LLM call so the user always gets a final verbal answer that aggregates the retrieved context. | |
| """ | |
| import logging | |
| from typing import Optional, List, TypedDict, Literal | |
| from langgraph.graph import StateGraph, END | |
| from rag_engine import RAGEngine | |
| # Use same "pipeline" logger as rag_engine so all [PIPELINE] logs appear together | |
| PIPELINE_LOG = logging.getLogger("pipeline") | |
| class AgentState(TypedDict, total=False): | |
| query: str | |
| api_key: str | |
| refusal: Optional[str] | |
| system_prompt: Optional[str] | |
| user_prompt: Optional[str] | |
| steps_log: List[str] | |
| draft_answer: Optional[str] | |
| feedback: Optional[str] | |
| iteration: int | |
| def build_agent_graph(engine: RAGEngine): | |
| """Build the LangGraph: retrieve โ generate โ end (one LLM call for final answer).""" | |
| def retrieve(state: AgentState) -> dict: | |
| """Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log.""" | |
| query = state["query"] | |
| PIPELINE_LOG.info("retrieve START query=%r", query[:80] if query else "") | |
| refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query) | |
| if refusal: | |
| PIPELINE_LOG.info("retrieve END refusal len=%d", len(refusal or "")) | |
| return {"refusal": refusal, "steps_log": steps_log} | |
| PIPELINE_LOG.info("retrieve END prompts ready steps=%d", len(steps_log or [])) | |
| return { | |
| "system_prompt": system_prompt, | |
| "user_prompt": user_prompt, | |
| "steps_log": steps_log, | |
| "iteration": 0, | |
| } | |
| def generate(state: AgentState) -> dict: | |
| """Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log.""" | |
| PIPELINE_LOG.info("generate START") | |
| if state.get("api_key"): | |
| engine.configure_api(state["api_key"]) | |
| system_prompt = state["system_prompt"] | |
| user_prompt = state["user_prompt"] | |
| feedback = state.get("feedback") or "" | |
| steps_log = list(state.get("steps_log") or []) | |
| if feedback: | |
| steps_log.append(f"๐ Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...") | |
| else: | |
| steps_log.append("๐ญ Generating response with Gemini...") | |
| full_prompt = user_prompt | |
| if feedback: | |
| full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:" | |
| models = ["gemini-2.0-flash", "gemini-1.5-flash"] | |
| draft = engine._call_api_with_backoff(system_prompt, full_prompt, models) | |
| PIPELINE_LOG.info("generate END draft_answer len=%d preview=%s", len(draft or ""), (draft or "")[:150]) | |
| steps_log.append("โ Draft generated") | |
| return {"draft_answer": draft, "steps_log": steps_log} | |
| def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]: | |
| if state.get("refusal"): | |
| return "end" | |
| return "generate" | |
| # Pipeline: retrieve โ generate โ end (no evaluate/refine โ one LLM call so user always gets final answer) | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("retrieve", retrieve) | |
| workflow.add_node("generate", generate) | |
| workflow.set_entry_point("retrieve") | |
| workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"}) | |
| workflow.add_edge("generate", END) | |
| return workflow.compile() | |
| def run_stream(engine: RAGEngine, graph, query: str, api_key: str): | |
| """ | |
| Run the agent graph and yield progress (steps + draft) for each step. | |
| Updates engine cache and history with the final answer. Yields strings for Gradio. | |
| Ensures the user always sees a final verbal answer (or a clear error message). | |
| """ | |
| PIPELINE_LOG.info("run_stream START query=%r", query[:80] if query else "") | |
| initial: AgentState = {"query": query, "api_key": api_key} | |
| last_state: AgentState = initial | |
| for state in graph.stream(initial, stream_mode="values"): | |
| last_state = state | |
| steps_log = state.get("steps_log") or [] | |
| PIPELINE_LOG.info("run_stream state: steps=%d", len(steps_log)) | |
| yield "\n".join(steps_log) # only pipeline progress; answer comes once in final yield | |
| # Final state: ensure user always sees the verbal answer (aggregated from pipeline) | |
| final_answer = (last_state.get("refusal") or last_state.get("draft_answer") or "").strip() | |
| steps_log = list(last_state.get("steps_log") or []) | |
| steps_log.append("โ Done") | |
| PIPELINE_LOG.info("run_stream final_answer len=%d is_error=%s", len(final_answer), final_answer.startswith("โ ") or final_answer.startswith("โ") or final_answer.startswith("โฑ")) | |
| if not final_answer: | |
| final_answer = ( | |
| "ืื ืืชืงืืื ืชืฉืืื ืืืืืื. ืืืชืื ืฉื ืืกืื ืื ืฉืืืงืฉื ืืจืื ืืื. " | |
| "ื ืกื ืืงืฆืจ ืืช ืืฉืืื ืื ืืฉืืื ืฉืื." | |
| ) | |
| PIPELINE_LOG.warning("run_stream final_answer was empty, using fallback") | |
| if not any(final_answer.startswith(p) for p in ("โ ๏ธ", "โ", "โฑ๏ธ")): | |
| cache_key = engine._get_cache_key(query) | |
| engine.response_cache[cache_key] = final_answer | |
| engine._maintain_conversation_history(query, final_answer) | |
| # One final yield: main pipeline (steps) for UX, then the answer once (no duplicate) | |
| steps_text = chr(10).join(steps_log) | |
| is_error = any(final_answer.startswith(p) for p in ("โ ๏ธ", "โ", "โฑ๏ธ")) | |
| block = "--- ืคืจืื ืขืืืื ---\n\n" + steps_text + "\n\n" | |
| if is_error: | |
| block += "--- ืืขืื ืืื ืืช (ืื ืชืฉืืื) ---\n\n" | |
| block += final_answer + "\n\n" | |
| block += "ืื ืื ืืชืฉืืื ืืฉืืื โ ื ืกื ืฉืื ืืขืื ืืงืึพืฉืชืืื." | |
| else: | |
| block += "--- ืืชืฉืืื ---\n\n" + final_answer | |
| PIPELINE_LOG.info("run_stream yielding final block len=%d", len(block)) | |
| yield block | |