galbendavids commited on
Commit
37bbf25
·
verified ·
1 Parent(s): 0b17d83

agentic rag update

Browse files
__pycache__/agent.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
__pycache__/app.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
__pycache__/rag_engine.cpython-311.pyc CHANGED
Binary files a/__pycache__/rag_engine.cpython-311.pyc and b/__pycache__/rag_engine.cpython-311.pyc differ
 
agent.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph agent: orchestrates RAG pipeline and verifies answer quality.
3
+ If the planned answer is not good, the agent directs the pipeline toward correction (retry with feedback).
4
+ """
5
+
6
+ from typing import Optional, List, TypedDict, Literal
7
+ from langgraph.graph import StateGraph, END
8
+
9
+ from rag_engine import RAGEngine
10
+
11
+
12
+ class AgentState(TypedDict, total=False):
13
+ query: str
14
+ api_key: str
15
+ refusal: Optional[str]
16
+ system_prompt: Optional[str]
17
+ user_prompt: Optional[str]
18
+ steps_log: List[str]
19
+ draft_answer: Optional[str]
20
+ is_good: bool
21
+ feedback: Optional[str]
22
+ iteration: int
23
+
24
+
25
+ MAX_REFINE_ITERATIONS = 2
26
+
27
+
28
+ def build_agent_graph(engine: RAGEngine):
29
+ """Build the LangGraph: retrieve -> generate -> evaluate -> (end | generate with feedback)."""
30
+
31
+ def retrieve(state: AgentState) -> dict:
32
+ """Run RAG up to (not including) LLM. Fill refusal or prompts + steps_log."""
33
+ query = state["query"]
34
+ refusal, system_prompt, user_prompt, steps_log = engine.prepare_generation(query)
35
+ if refusal:
36
+ return {"refusal": refusal, "steps_log": steps_log}
37
+ return {
38
+ "system_prompt": system_prompt,
39
+ "user_prompt": user_prompt,
40
+ "steps_log": steps_log,
41
+ "iteration": 0,
42
+ }
43
+
44
+ def generate(state: AgentState) -> dict:
45
+ """Call LLM with current prompt + optional feedback. Set draft_answer and append to steps_log."""
46
+ engine.configure_api(state["api_key"])
47
+ system_prompt = state["system_prompt"]
48
+ user_prompt = state["user_prompt"]
49
+ feedback = state.get("feedback") or ""
50
+ steps_log = list(state.get("steps_log") or [])
51
+
52
+ if feedback:
53
+ steps_log.append(f"🔄 Refining (iteration {state.get('iteration', 0) + 1}): {feedback[:80]}...")
54
+ else:
55
+ steps_log.append("💭 Generating response with Gemini...")
56
+
57
+ full_prompt = user_prompt
58
+ if feedback:
59
+ full_prompt = user_prompt + "\n\n[Correction requested by quality check]: " + feedback + "\n\nRevised answer:"
60
+
61
+ models = ["gemini-2.0-flash", "gemini-1.5-flash"]
62
+ draft = engine._call_api_with_backoff(system_prompt, full_prompt, models)
63
+ steps_log.append("✅ Draft generated")
64
+ return {"draft_answer": draft, "steps_log": steps_log}
65
+
66
+ def evaluate(state: AgentState) -> dict:
67
+ """Check if the answer is good. Set is_good and optionally feedback for refinement."""
68
+ query = state["query"]
69
+ draft = state.get("draft_answer") or ""
70
+ steps_log = list(state.get("steps_log") or [])
71
+ iteration = state.get("iteration", 0)
72
+
73
+ # Error / timeout / rate limit responses are not "good" but we don't refine them
74
+ if draft.startswith("⚠️") or draft.startswith("❌") or draft.startswith("⏱️"):
75
+ return {"is_good": True, "steps_log": steps_log} # Treat as final
76
+
77
+ steps_log.append("🔍 Evaluating answer quality...")
78
+ engine.configure_api(state["api_key"])
79
+ eval_prompt = f"""You are a quality checker. Given the user question and the assistant's answer, decide if the answer is good.
80
+
81
+ User question: {query[:300]}
82
+
83
+ Assistant answer: {draft[:1500]}
84
+
85
+ Reply with exactly one of:
86
+ - YES
87
+ - NO: <one short line explaining what to improve>
88
+
89
+ Reply:"""
90
+ models = ["gemini-2.0-flash", "gemini-1.5-flash"]
91
+ try:
92
+ raw = engine._call_api_with_backoff(
93
+ "You reply only with YES or NO: <feedback>. No other text.",
94
+ eval_prompt,
95
+ models,
96
+ )
97
+ except Exception:
98
+ raw = "YES"
99
+ raw = (raw or "").strip().upper()
100
+ is_good = raw.startswith("YES") or "NO" not in raw[:10]
101
+ feedback = ""
102
+ if not is_good and "NO" in raw:
103
+ idx = raw.find(":")
104
+ if idx != -1:
105
+ feedback = raw[idx + 1 :].strip()[:200]
106
+ else:
107
+ feedback = "Improve relevance and completeness."
108
+
109
+ steps_log.append("✅ Good" if is_good else f"⚠️ Needs improvement: {feedback[:60]}...")
110
+ return {
111
+ "is_good": is_good,
112
+ "feedback": feedback if not is_good else None,
113
+ "iteration": iteration + 1,
114
+ "steps_log": steps_log,
115
+ }
116
+
117
+ def route_after_retrieve(state: AgentState) -> Literal["end", "generate"]:
118
+ if state.get("refusal"):
119
+ return "end"
120
+ return "generate"
121
+
122
+ def route_after_evaluate(state: AgentState) -> Literal["end", "generate"]:
123
+ if state.get("is_good") or (state.get("iteration") or 0) >= MAX_REFINE_ITERATIONS:
124
+ return "end"
125
+ return "generate"
126
+
127
+ workflow = StateGraph(AgentState)
128
+ workflow.add_node("retrieve", retrieve)
129
+ workflow.add_node("generate", generate)
130
+ workflow.add_node("evaluate", evaluate)
131
+
132
+ workflow.set_entry_point("retrieve")
133
+ workflow.add_conditional_edges("retrieve", route_after_retrieve, {"end": END, "generate": "generate"})
134
+ workflow.add_edge("generate", "evaluate")
135
+ workflow.add_conditional_edges("evaluate", route_after_evaluate, {"end": END, "generate": "generate"})
136
+
137
+ return workflow.compile()
138
+
139
+
140
+ def run_stream(engine: RAGEngine, graph, query: str, api_key: str):
141
+ """
142
+ Run the agent graph and yield progress (steps + draft) for each step.
143
+ Updates engine cache and history with the final answer. Yields strings for Gradio.
144
+ """
145
+ initial: AgentState = {"query": query, "api_key": api_key}
146
+ last_state: AgentState = initial
147
+ for _node_name, state in graph.stream(initial):
148
+ last_state = state
149
+ steps_log = state.get("steps_log") or []
150
+ refusal = state.get("refusal")
151
+ draft = state.get("draft_answer") or ""
152
+ text = "\n".join(steps_log)
153
+ body = refusal if refusal else draft
154
+ if body:
155
+ text = f"{text}\n\n{body}"
156
+ yield text
157
+
158
+ # Final state: update cache and history
159
+ final_answer = last_state.get("refusal") or last_state.get("draft_answer") or ""
160
+ steps_log = last_state.get("steps_log") or []
161
+ if not any(final_answer.startswith(p) for p in ("⚠️", "❌", "⏱️")):
162
+ cache_key = engine._get_cache_key(query)
163
+ engine.response_cache[cache_key] = final_answer
164
+ engine._maintain_conversation_history(query, final_answer)
165
+ steps_log.append("✅ Done")
166
+ yield f"{chr(10).join(steps_log)}\n\n{final_answer}"
app.py CHANGED
@@ -18,18 +18,22 @@ import gradio as gr
18
  import os
19
  import traceback
20
  from rag_engine import RAGEngine
 
21
 
22
- # Initialize RAG Engine - Loads data and builds indices
23
- # This runs once at startup for performance
 
24
  try:
25
  print("🚀 Initializing RAG Engine...")
26
  engine = RAGEngine()
27
  print(f"✅ Engine ready with {len(engine.chunks)} smart chunks")
 
 
28
  except Exception as e:
29
- print(f"❌ Error initializing RAG Engine: {e}")
30
- print("Full traceback:")
31
  traceback.print_exc()
32
  engine = None
 
33
 
34
 
35
  def chat_function(message: str, history: list) -> str:
@@ -58,10 +62,10 @@ The Gemini API key is not set. To fix this:
58
  Get your key from: [Google AI Studio](https://aistudio.google.com/apikey)"""
59
  return
60
 
61
- if not engine:
62
  yield """❌ **Initialization Error**
63
 
64
- The RAG Engine failed to load. This usually means:
65
  - Data files are missing
66
  - Environment is misconfigured
67
  - Check the Space logs for specific error details
@@ -73,9 +77,13 @@ Common solutions:
73
  return
74
 
75
  try:
76
- # Generate response using the RAG engine
77
- # Stream progress updates so the user sees what's happening live.
78
- yield from engine.generate_response_stream(message, history, api_key)
 
 
 
 
79
  except Exception as e:
80
  yield f"""❌ **Error Processing Query**
81
 
 
18
  import os
19
  import traceback
20
  from rag_engine import RAGEngine
21
+ from agent import build_agent_graph, run_stream
22
 
23
+ # Initialize RAG Engine and LangGraph agent (once at startup)
24
+ engine = None
25
+ agent_graph = None
26
  try:
27
  print("🚀 Initializing RAG Engine...")
28
  engine = RAGEngine()
29
  print(f"✅ Engine ready with {len(engine.chunks)} smart chunks")
30
+ agent_graph = build_agent_graph(engine)
31
+ print("✅ LangGraph agent ready (retrieve → generate → evaluate → refine)")
32
  except Exception as e:
33
+ print(f"❌ Error initializing: {e}")
 
34
  traceback.print_exc()
35
  engine = None
36
+ agent_graph = None
37
 
38
 
39
  def chat_function(message: str, history: list) -> str:
 
62
  Get your key from: [Google AI Studio](https://aistudio.google.com/apikey)"""
63
  return
64
 
65
+ if not engine or not agent_graph:
66
  yield """❌ **Initialization Error**
67
 
68
+ The RAG Engine or agent failed to load. This usually means:
69
  - Data files are missing
70
  - Environment is misconfigured
71
  - Check the Space logs for specific error details
 
77
  return
78
 
79
  try:
80
+ # Check cache before running the agent
81
+ cache_key = engine._get_cache_key(message)
82
+ if cache_key in engine.response_cache:
83
+ yield f"🔁 Returned cached result\n\n{engine.response_cache[cache_key]}"
84
+ return
85
+ # Run LangGraph agent: retrieve → generate → evaluate → (refine if needed)
86
+ yield from run_stream(engine, agent_graph, message, api_key)
87
  except Exception as e:
88
  yield f"""❌ **Error Processing Query**
89
 
rag_engine.py CHANGED
@@ -732,6 +732,107 @@ class RAGEngine:
732
 
733
  yield "❌ Failed to get response from API"
734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  def generate_response(self, query: str, history, api_key: str):
736
  """יצירת תשובה חכמה עם כל 10 העצות"""
737
  if not api_key:
 
732
 
733
  yield "❌ Failed to get response from API"
734
 
735
+ def configure_api(self, api_key: str) -> None:
736
+ """Configure Gemini API key (for use by external agent)."""
737
+ genai.configure(api_key=api_key)
738
+
739
+ def prepare_generation(self, query: str) -> Tuple[Optional[str], Optional[str], Optional[str], List[str]]:
740
+ """
741
+ Run RAG pipeline up to (but not including) the LLM call.
742
+ Returns (refusal_message, system_prompt, user_prompt, steps_log).
743
+ If refusal_message is set, the other three are None / empty; otherwise use prompts for generation.
744
+ """
745
+ steps_log: List[str] = []
746
+
747
+ steps_log.append("🔍 Normalizing car names...")
748
+ canonical = self._normalize_car_name(query)
749
+ if canonical:
750
+ steps_log.append(f"✅ Recognized canonical id: {canonical}")
751
+ search_query = canonical
752
+ else:
753
+ steps_log.append("ℹ️ No canonical car found; using full query for search")
754
+ search_query = query
755
+
756
+ is_comparison = self._is_comparison_question(query)
757
+ if is_comparison:
758
+ steps_log.append("📋 Detected: comparison question (rule-based)")
759
+ else:
760
+ steps_log.append("📋 Detected: single-model question (rule-based)")
761
+
762
+ mentioned_supported = self._find_supported_canonicals_in_text(query)
763
+ if is_comparison:
764
+ if len(mentioned_supported) < 2:
765
+ return (self._unsupported_car_refusal(query, is_comparison=True), None, None, steps_log)
766
+ else:
767
+ if not canonical and not mentioned_supported and self._looks_like_specific_car_question(query):
768
+ return (self._unsupported_car_refusal(query, is_comparison=False), None, None, steps_log)
769
+
770
+ steps_log.append("🔎 Searching knowledge base (vectors + keywords)...")
771
+ comparison_prompt = ""
772
+ context_results = []
773
+
774
+ if is_comparison:
775
+ cars = re.findall(r'\b[A-Za-z\u05D0-\u05EA0-9]+\b', query)
776
+ if len(cars) >= 2:
777
+ steps_log.append("📊 Extracting structured comparison data (regex)...")
778
+ comparison_data = self._extract_comparison_data(cars[0], cars[1])
779
+ context_results = self._hybrid_search(search_query, top_k=self.max_chunks_comparison)
780
+ steps_log.append(f"✅ Retrieved {len(context_results)} chunks for comparison")
781
+ comparison_prompt = f"""
782
+ Based on the car reviews, create a structured comparison between {cars[0]} and {cars[1]}:
783
+
784
+ Format your response as:
785
+ **יתרונות {cars[0]}:**
786
+ - [list advantages]
787
+
788
+ **יתרונות {cars[1]}:**
789
+ - [list advantages]
790
+
791
+ **המלצה לפי פרופיל משתמש:**
792
+ - [personalized recommendation]
793
+
794
+ Structured Data:
795
+ {json.dumps(comparison_data, ensure_ascii=False, indent=2)}
796
+
797
+ Context from reviews:
798
+ """
799
+ else:
800
+ context_results = self._hybrid_search(search_query, top_k=self.max_chunks_general)
801
+ comparison_prompt = "Answer in the same language as the user's question. "
802
+ steps_log.append(f"✅ Retrieved {len(context_results)} chunks")
803
+ else:
804
+ context_results = self._hybrid_search(search_query, top_k=self.max_chunks_general)
805
+ steps_log.append(f"✅ Retrieved {len(context_results)} relevant chunks")
806
+
807
+ context_text = ""
808
+ for r in context_results:
809
+ meta = r['metadata']
810
+ context_text += f"""
811
+ Source: {meta['title']}
812
+ Topic: {meta['topic']}
813
+ Content: {r['text'][:self.max_context_chars_per_chunk]}...
814
+
815
+ """
816
+ conversation_context = self._get_context_from_history()
817
+ system_prompt = """You are an expert automotive assistant specializing in car comparisons.
818
+ Use the provided context and structured data to answer questions accurately.
819
+ Always respond in the same language as the user (Hebrew or English).
820
+ For comparison questions, provide a structured analysis with clear advantages for each vehicle.
821
+ Focus on facts from the reviews provided.
822
+ """
823
+ user_prompt = f"""Context from car reviews:
824
+ {context_text}
825
+
826
+ Previous conversation context (last turns):
827
+ {conversation_context}
828
+
829
+ User Question: {query}
830
+
831
+ {comparison_prompt}
832
+
833
+ Answer:"""
834
+ return (None, system_prompt, user_prompt, steps_log)
835
+
836
  def generate_response(self, query: str, history, api_key: str):
837
  """יצירת תשובה חכמה עם כל 10 העצות"""
838
  if not api_key:
requirements.txt CHANGED
@@ -5,3 +5,5 @@ requests
5
  sentence-transformers
6
  numpy<2.0.0
7
  torch>=2.0.0
 
 
 
5
  sentence-transformers
6
  numpy<2.0.0
7
  torch>=2.0.0
8
+ langgraph>=0.2.0
9
+ langchain-core>=0.3.0
test_agent.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for the LangGraph agent pipeline.
3
+ Runs several queries with a short wait between them to verify the full flow.
4
+ Requires gemini_api in environment for real LLM calls; otherwise only tests prepare_generation (no API).
5
+ """
6
+
7
+ import os
8
+ import time
9
+ from rag_engine import RAGEngine
10
+ from agent import build_agent_graph, run_stream
11
+
12
+
13
+ def main():
14
+ print("Loading RAG Engine and building agent graph...")
15
+ engine = RAGEngine()
16
+ graph = build_agent_graph(engine)
17
+ print("OK.\n")
18
+
19
+ api_key = os.environ.get("gemini_api")
20
+ if not api_key:
21
+ print("⚠️ gemini_api not set. Testing only prepare_generation (no LLM calls).\n")
22
+ test_queries = [
23
+ "Tell me about the Audi RS3",
24
+ "Compare Audi RS3 vs Hyundai Elantra N",
25
+ "מה דעתך על BMW X5?", # should trigger refusal
26
+ ]
27
+ for i, query in enumerate(test_queries, 1):
28
+ print(f"--- Test {i}: prepare_generation ---")
29
+ print(f"Query: {query!r}")
30
+ refusal, sys_p, user_p, steps = engine.prepare_generation(query)
31
+ if refusal:
32
+ print(f"Refusal (expected for unsupported car): {refusal[:150]}...")
33
+ else:
34
+ print(f"Steps: {len(steps)}; system_prompt length: {len(sys_p or '')}; user_prompt length: {len(user_p or '')}")
35
+ print()
36
+ print("Done (prepare_generation only). Set gemini_api to run full agent.")
37
+ return
38
+
39
+ test_queries = [
40
+ "Tell me about the Audi RS3",
41
+ "Compare Audi RS3 vs Hyundai Elantra N",
42
+ "מה היתרונות של קיה EV9?",
43
+ "מה דעתך על BMW X5?", # should trigger refusal (unsupported model)
44
+ ]
45
+ wait_seconds = 8
46
+
47
+ for i, query in enumerate(test_queries, 1):
48
+ print(f"--- Test {i}/{len(test_queries)} ---")
49
+ print(f"Query: {query!r}")
50
+ last_output = None
51
+ step_count = 0
52
+ try:
53
+ for out in run_stream(engine, graph, query, api_key):
54
+ last_output = out
55
+ step_count += 1
56
+ if last_output:
57
+ preview = last_output[:400] + "..." if len(last_output) > 400 else last_output
58
+ print(f"Steps yielded: {step_count}; final length: {len(last_output)}")
59
+ print(f"Final preview:\n{preview}\n")
60
+ else:
61
+ print("No output received.\n")
62
+ except Exception as e:
63
+ print(f"Error: {e}\n")
64
+ if i < len(test_queries):
65
+ print(f"Waiting {wait_seconds}s before next query...")
66
+ time.sleep(wait_seconds)
67
+
68
+ print("All tests finished.")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()