File size: 6,017 Bytes
37bbf25
2ebebf3
 
37bbf25
 
2fa5774
37bbf25
 
 
 
 
2fa5774
 
 
37bbf25
 
 
 
 
 
 
 
 
 
 
 
 
 
2ebebf3
37bbf25
 
 
 
2fa5774
37bbf25
 
2fa5774
37bbf25
2fa5774
37bbf25
 
 
 
 
 
 
 
 
2fa5774
b609489
 
37bbf25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fa5774
37bbf25
 
 
 
 
 
 
 
2ebebf3
37bbf25
 
 
 
 
 
2ebebf3
37bbf25
 
 
 
 
 
 
 
75c53f5
37bbf25
2fa5774
37bbf25
 
849c690
37bbf25
 
ea834ed
 
37bbf25
2ebebf3
75c53f5
 
 
2fa5774
75c53f5
 
 
 
 
 
2fa5774
75c53f5
37bbf25
 
 
 
75c53f5
ea834ed
2ebebf3
58029e3
ea834ed
58029e3
ea834ed
58029e3
ea834ed
58029e3
ea834ed
2fa5774
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
LangGraph agent: orchestrates RAG pipeline (retrieve โ†’ generate โ†’ end).
Single LLM call so the user always gets a final verbal answer that aggregates the retrieved context.
"""

import logging
from typing import Optional, List, TypedDict, Literal
from langgraph.graph import StateGraph, END

from rag_engine import RAGEngine

# Use same "pipeline" logger as rag_engine so all [PIPELINE] logs appear together
PIPELINE_LOG = logging.getLogger("pipeline")


class AgentState(TypedDict, total=False):
    query: str
    api_key: str
    refusal: Optional[str]
    system_prompt: Optional[str]
    user_prompt: Optional[str]
    steps_log: List[str]
    draft_answer: Optional[str]
    feedback: Optional[str]
    iteration: int


def build_agent_graph(engine: RAGEngine):
    """Build the LangGraph: retrieve โ†’ generate โ†’ end (one LLM call for final answer)."""

    def retrieve(state: AgentState) -> dict:
        """Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log."""
        query = state["query"]
        PIPELINE_LOG.info("retrieve START query=%r", query[:80] if query else "")
        refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query)
        if refusal:
            PIPELINE_LOG.info("retrieve END refusal len=%d", len(refusal or ""))
            return {"refusal": refusal, "steps_log": steps_log}
        PIPELINE_LOG.info("retrieve END prompts ready steps=%d", len(steps_log or []))
        return {
            "system_prompt": system_prompt,
            "user_prompt": user_prompt,
            "steps_log": steps_log,
            "iteration": 0,
        }

    def generate(state: AgentState) -> dict:
        """Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log."""
        PIPELINE_LOG.info("generate START")
        if state.get("api_key"):
            engine.configure_api(state["api_key"])
        system_prompt = state["system_prompt"]
        user_prompt = state["user_prompt"]
        feedback = state.get("feedback") or ""
        steps_log = list(state.get("steps_log") or [])

        if feedback:
            steps_log.append(f"๐Ÿ”„ Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...")
        else:
            steps_log.append("๐Ÿ’ญ Generating response with Gemini...")

        full_prompt = user_prompt
        if feedback:
            full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:"

        models = ["gemini-2.0-flash", "gemini-1.5-flash"]
        draft = engine._call_api_with_backoff(system_prompt, full_prompt, models)
        PIPELINE_LOG.info("generate END draft_answer len=%d preview=%s", len(draft or ""), (draft or "")[:150])
        steps_log.append("โœ… Draft generated")
        return {"draft_answer": draft, "steps_log": steps_log}

    def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]:
        if state.get("refusal"):
            return "end"
        return "generate"

    # Pipeline: retrieve โ†’ generate โ†’ end (no evaluate/refine โ€“ one LLM call so user always gets final answer)
    workflow = StateGraph(AgentState)
    workflow.add_node("retrieve", retrieve)
    workflow.add_node("generate", generate)

    workflow.set_entry_point("retrieve")
    workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"})
    workflow.add_edge("generate", END)

    return workflow.compile()


def run_stream(engine: RAGEngine, graph, query: str, api_key: str):
    """
    Run the agent graph and yield progress (steps + draft) for each step.
    Updates engine cache and history with the final answer. Yields strings for Gradio.
    Ensures the user always sees a final verbal answer (or a clear error message).
    """
    PIPELINE_LOG.info("run_stream START query=%r", query[:80] if query else "")
    initial: AgentState = {"query": query, "api_key": api_key}
    last_state: AgentState = initial
    for state in graph.stream(initial, stream_mode="values"):
        last_state = state
        steps_log = state.get("steps_log") or []
        PIPELINE_LOG.info("run_stream state: steps=%d", len(steps_log))
        yield "\n".join(steps_log)  # only pipeline progress; answer comes once in final yield

    # Final state: ensure user always sees the verbal answer (aggregated from pipeline)
    final_answer = (last_state.get("refusal") or last_state.get("draft_answer") or "").strip()
    steps_log = list(last_state.get("steps_log") or [])
    steps_log.append("โœ… Done")
    PIPELINE_LOG.info("run_stream final_answer len=%d is_error=%s", len(final_answer), final_answer.startswith("โš ") or final_answer.startswith("โŒ") or final_answer.startswith("โฑ"))

    if not final_answer:
        final_answer = (
            "ืœื ื”ืชืงื‘ืœื” ืชืฉื•ื‘ื” ืžื”ืžื•ื“ืœ. ื™ื™ืชื›ืŸ ืฉื ื—ืกืžื” ืื• ืฉื”ื‘ืงืฉื” ืืจื›ื” ืžื“ื™. "
            "ื ืกื” ืœืงืฆืจ ืืช ื”ืฉืืœื” ืื• ืœืฉืื•ืœ ืฉื•ื‘."
        )
        PIPELINE_LOG.warning("run_stream final_answer was empty, using fallback")

    if not any(final_answer.startswith(p) for p in ("โš ๏ธ", "โŒ", "โฑ๏ธ")):
        cache_key = engine._get_cache_key(query)
        engine.response_cache[cache_key] = final_answer
    engine._maintain_conversation_history(query, final_answer)

    # One final yield: main pipeline (steps) for UX, then the answer once (no duplicate)
    steps_text = chr(10).join(steps_log)
    is_error = any(final_answer.startswith(p) for p in ("โš ๏ธ", "โŒ", "โฑ๏ธ"))
    block = "--- ืคืจื˜ื™ ืขื™ื‘ื•ื“ ---\n\n" + steps_text + "\n\n"
    if is_error:
        block += "--- ื‘ืขื™ื” ื–ืžื ื™ืช (ืœื ืชืฉื•ื‘ื”) ---\n\n"
        block += final_answer + "\n\n"
        block += "ื–ื• ืœื ื”ืชืฉื•ื‘ื” ืœืฉืืœื” โ€“ ื ืกื” ืฉื•ื‘ ื‘ืขื•ื“ ื“ืงื”ึพืฉืชื™ื™ื."
    else:
        block += "--- ื”ืชืฉื•ื‘ื” ---\n\n" + final_answer
    PIPELINE_LOG.info("run_stream yielding final block len=%d", len(block))
    yield block