Spaces:
Sleeping
Sleeping
File size: 6,017 Bytes
37bbf25 2ebebf3 37bbf25 2fa5774 37bbf25 2fa5774 37bbf25 2ebebf3 37bbf25 2fa5774 37bbf25 2fa5774 37bbf25 2fa5774 37bbf25 2fa5774 b609489 37bbf25 2fa5774 37bbf25 2ebebf3 37bbf25 2ebebf3 37bbf25 75c53f5 37bbf25 2fa5774 37bbf25 849c690 37bbf25 ea834ed 37bbf25 2ebebf3 75c53f5 2fa5774 75c53f5 2fa5774 75c53f5 37bbf25 75c53f5 ea834ed 2ebebf3 58029e3 ea834ed 58029e3 ea834ed 58029e3 ea834ed 58029e3 ea834ed 2fa5774 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
LangGraph agent: orchestrates RAG pipeline (retrieve โ generate โ end).
Single LLM call so the user always gets a final verbal answer that aggregates the retrieved context.
"""
import logging
from typing import Optional, List, TypedDict, Literal
from langgraph.graph import StateGraph, END
from rag_engine import RAGEngine
# Use same "pipeline" logger as rag_engine so all [PIPELINE] logs appear together
PIPELINE_LOG = logging.getLogger("pipeline")
class AgentState(TypedDict, total=False):
query: str
api_key: str
refusal: Optional[str]
system_prompt: Optional[str]
user_prompt: Optional[str]
steps_log: List[str]
draft_answer: Optional[str]
feedback: Optional[str]
iteration: int
def build_agent_graph(engine: RAGEngine):
"""Build the LangGraph: retrieve โ generate โ end (one LLM call for final answer)."""
def retrieve(state: AgentState) -> dict:
"""Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log."""
query = state["query"]
PIPELINE_LOG.info("retrieve START query=%r", query[:80] if query else "")
refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query)
if refusal:
PIPELINE_LOG.info("retrieve END refusal len=%d", len(refusal or ""))
return {"refusal": refusal, "steps_log": steps_log}
PIPELINE_LOG.info("retrieve END prompts ready steps=%d", len(steps_log or []))
return {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"steps_log": steps_log,
"iteration": 0,
}
def generate(state: AgentState) -> dict:
"""Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log."""
PIPELINE_LOG.info("generate START")
if state.get("api_key"):
engine.configure_api(state["api_key"])
system_prompt = state["system_prompt"]
user_prompt = state["user_prompt"]
feedback = state.get("feedback") or ""
steps_log = list(state.get("steps_log") or [])
if feedback:
steps_log.append(f"๐ Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...")
else:
steps_log.append("๐ญ Generating response with Gemini...")
full_prompt = user_prompt
if feedback:
full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:"
models = ["gemini-2.0-flash", "gemini-1.5-flash"]
draft = engine._call_api_with_backoff(system_prompt, full_prompt, models)
PIPELINE_LOG.info("generate END draft_answer len=%d preview=%s", len(draft or ""), (draft or "")[:150])
steps_log.append("โ
Draft generated")
return {"draft_answer": draft, "steps_log": steps_log}
def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]:
if state.get("refusal"):
return "end"
return "generate"
# Pipeline: retrieve โ generate โ end (no evaluate/refine โ one LLM call so user always gets final answer)
workflow = StateGraph(AgentState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.set_entry_point("retrieve")
workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"})
workflow.add_edge("generate", END)
return workflow.compile()
def run_stream(engine: RAGEngine, graph, query: str, api_key: str):
"""
Run the agent graph and yield progress (steps + draft) for each step.
Updates engine cache and history with the final answer. Yields strings for Gradio.
Ensures the user always sees a final verbal answer (or a clear error message).
"""
PIPELINE_LOG.info("run_stream START query=%r", query[:80] if query else "")
initial: AgentState = {"query": query, "api_key": api_key}
last_state: AgentState = initial
for state in graph.stream(initial, stream_mode="values"):
last_state = state
steps_log = state.get("steps_log") or []
PIPELINE_LOG.info("run_stream state: steps=%d", len(steps_log))
yield "\n".join(steps_log) # only pipeline progress; answer comes once in final yield
# Final state: ensure user always sees the verbal answer (aggregated from pipeline)
final_answer = (last_state.get("refusal") or last_state.get("draft_answer") or "").strip()
steps_log = list(last_state.get("steps_log") or [])
steps_log.append("โ
Done")
PIPELINE_LOG.info("run_stream final_answer len=%d is_error=%s", len(final_answer), final_answer.startswith("โ ") or final_answer.startswith("โ") or final_answer.startswith("โฑ"))
if not final_answer:
final_answer = (
"ืื ืืชืงืืื ืชืฉืืื ืืืืืื. ืืืชืื ืฉื ืืกืื ืื ืฉืืืงืฉื ืืจืื ืืื. "
"ื ืกื ืืงืฆืจ ืืช ืืฉืืื ืื ืืฉืืื ืฉืื."
)
PIPELINE_LOG.warning("run_stream final_answer was empty, using fallback")
if not any(final_answer.startswith(p) for p in ("โ ๏ธ", "โ", "โฑ๏ธ")):
cache_key = engine._get_cache_key(query)
engine.response_cache[cache_key] = final_answer
engine._maintain_conversation_history(query, final_answer)
# One final yield: main pipeline (steps) for UX, then the answer once (no duplicate)
steps_text = chr(10).join(steps_log)
is_error = any(final_answer.startswith(p) for p in ("โ ๏ธ", "โ", "โฑ๏ธ"))
block = "--- ืคืจืื ืขืืืื ---\n\n" + steps_text + "\n\n"
if is_error:
block += "--- ืืขืื ืืื ืืช (ืื ืชืฉืืื) ---\n\n"
block += final_answer + "\n\n"
block += "ืื ืื ืืชืฉืืื ืืฉืืื โ ื ืกื ืฉืื ืืขืื ืืงืึพืฉืชืืื."
else:
block += "--- ืืชืฉืืื ---\n\n" + final_answer
PIPELINE_LOG.info("run_stream yielding final block len=%d", len(block))
yield block
|