cd@bziiit.com commited on
Commit
cba7f8e
·
1 Parent(s): 9760e1f

Fix variables with GraphState

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. graph_agentA.py +12 -3
  3. graph_agentB.py +7 -7
app.py CHANGED
@@ -50,8 +50,8 @@ def process_query(query, architecture: Literal["A", "B", "C"]):
50
  "messages": [],
51
  "relevant_docs": [],
52
  "response": "",
53
- "k": k,
54
- "alpha": alpha,
55
  }
56
 
57
  elif architecture in ["B", "C"]:
 
50
  "messages": [],
51
  "relevant_docs": [],
52
  "response": "",
53
+ "k": k,
54
+ "similarity_threshold": similarity_threshold,
55
  }
56
 
57
  elif architecture in ["B", "C"]:
graph_agentA.py CHANGED
@@ -31,16 +31,25 @@ def generate_response(state: GraphState) -> dict:
31
  response = llm.invoke(prompt)
32
  return {"response": response.content}
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def post_process_response(state: GraphState) -> dict:
35
  """Post-process the response."""
36
  response = state["response"].strip() if isinstance(state["response"], str) else state["response"]
37
  return {"response": response}
38
 
39
- k = st.session_state.get("k", 30) # Valeur par défaut 30
40
- similarity_threshold = st.session_state.get('similarity_threshold', 0.7) # Valeur par défaut 0.7
41
  # Build the graph
42
  graph_builder = StateGraph(GraphState)
43
- graph_builder.add_node("retrieve", lambda state: {"relevant_docs": retrieve_documents(state["query"], k=k, similarity_threshold=similarity_threshold)})
44
  graph_builder.add_node("generate", generate_response)
45
  graph_builder.add_node("post_process", post_process_response)
46
 
 
31
  response = llm.invoke(prompt)
32
  return {"response": response.content}
33
 
34
+ def retrieve(state: GraphState) -> dict:
35
+ """Récupération sémantique : Pinecone (sémantique)"""
36
+
37
+ relevant_docs = retrieve_documents(
38
+ state["query"],
39
+ k=state.get("k"),
40
+ similarity_threshold=state.get("similarity_threshold")
41
+ )
42
+
43
+ return {"relevant_docs": relevant_docs}
44
+
45
  def post_process_response(state: GraphState) -> dict:
46
  """Post-process the response."""
47
  response = state["response"].strip() if isinstance(state["response"], str) else state["response"]
48
  return {"response": response}
49
 
 
 
50
  # Build the graph
51
  graph_builder = StateGraph(GraphState)
52
+ graph_builder.add_node("retrieve", retrieve)
53
  graph_builder.add_node("generate", generate_response)
54
  graph_builder.add_node("post_process", post_process_response)
55
 
graph_agentB.py CHANGED
@@ -18,14 +18,14 @@ class GraphState(TypedDict):
18
 
19
  def retrieve_combined(state: GraphState) -> dict:
20
  """Récupération hybride : Pinecone (sémantique) + BM25 (mots-clés)."""
21
- k = st.session_state.get("k", 30) # Valeur par défaut 30
22
- alpha = st.session_state.get("alpha", 0.5) # Valeur par défaut 0.5
23
- similarity_threshold = st.session_state.get('similarity_threshold', 0.7) # Valeur par défaut 0.7
24
 
25
- print(f"k: {k}")
26
- print(f"similarity_threshold: {similarity_threshold}")
27
- print(f"alpha: {alpha}")
28
- relevant_docs = hybrid_search(state["query"], alpha=alpha, k=k, similarity_threshold=similarity_threshold)
 
 
 
29
  return {"relevant_docs": relevant_docs}
30
 
31
  def generate_response(state: GraphState) -> dict:
 
18
 
19
  def retrieve_combined(state: GraphState) -> dict:
20
  """Récupération hybride : Pinecone (sémantique) + BM25 (mots-clés)."""
 
 
 
21
 
22
+ relevant_docs = hybrid_search(
23
+ state["query"],
24
+ alpha=state.get("alpha"),
25
+ k=state.get("k"),
26
+ similarity_threshold=state.get("similarity_threshold")
27
+ )
28
+
29
  return {"relevant_docs": relevant_docs}
30
 
31
  def generate_response(state: GraphState) -> dict: