File size: 2,167 Bytes
bf87c6c
 
 
 
 
 
9760e1f
bf87c6c
 
 
 
 
 
 
9760e1f
 
bf87c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cba7f8e
 
 
 
 
 
 
 
 
 
 
bf87c6c
 
 
 
 
 
 
cba7f8e
bf87c6c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import TypedDict, Annotated, Literal, Sequence
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from config import llm, client, langsmith_project
from pinecone_utilsA import *
import streamlit as st

# Graph state definition
class GraphState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    query: str
    relevant_docs: list
    response: str
    k: int 
    similarity_threshold: float

def generate_response(state: GraphState) -> dict:
    """Generate a response using the LLM."""
    context = " ".join(state["relevant_docs"])
    prompt = f"""

        Vous êtes un expert en analyse de texte. Votre tâche est de répondre à la question de l'utilisateur en utilisant les informations fournies.

        Si les informations ne suffisent pas, expliquez pourquoi et proposez une hypothèse si possible.



        Informations pertinentes : {context}



        Question : {state["query"]}



        Réponse :

        """
    response = llm.invoke(prompt)
    return {"response": response.content}

def retrieve(state: GraphState) -> dict:
    """Récupération sémantique : Pinecone (sémantique)"""

    relevant_docs = retrieve_documents(
        state["query"],
        k=state.get("k"), 
        similarity_threshold=state.get("similarity_threshold")
    )
    
    return {"relevant_docs": relevant_docs}

def post_process_response(state: GraphState) -> dict:
    """Post-process the response."""
    response = state["response"].strip() if isinstance(state["response"], str) else state["response"]
    return {"response": response}

# Build the graph
graph_builder = StateGraph(GraphState)
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("generate", generate_response)
graph_builder.add_node("post_process", post_process_response)

graph_builder.set_entry_point("retrieve")
graph_builder.add_edge("retrieve", "generate")
graph_builder.add_edge("generate", "post_process")
graph_builder.add_edge("post_process", END)

agent = graph_builder.compile()