RCaz commited on
Commit
92766e6
·
1 Parent(s): 59b71ca

working app with LLM

Browse files
Files changed (1) hide show
  1. app.py +88 -57
app.py CHANGED
@@ -1,86 +1,117 @@
1
 
2
 
3
- from utils import _set_env
 
 
 
4
 
5
 
6
- _set_env("OPENAI_API_KEY")
7
-
8
-
9
- from utils import *
10
 
11
- def create_graph():
12
- from langgraph.graph import StateGraph, START, END
13
- from langgraph.prebuilt import ToolNode, tools_condition
14
 
15
 
16
- ## ADD TRACKING
17
- response_model = init_chat_model("gpt-4o", temperature=0)
18
- grader_model = init_chat_model("gpt-4o", temperature=0)
19
 
20
 
21
- workflow = StateGraph(MessagesState)
22
 
23
- # Define the nodes we will cycle between
24
- workflow.add_node(generate_query_or_respond)
25
- workflow.add_node("retrieve", ToolNode([retriever_tool]))
26
- workflow.add_node(rewrite_question)
27
- workflow.add_node(generate_answer)
28
 
29
- workflow.add_edge(START, "generate_query_or_respond")
30
 
31
- # Decide whether to retrieve
32
- workflow.add_conditional_edges(
33
- "generate_query_or_respond",
34
- # Assess LLM decision (call `retriever_tool` tool or respond to the user)
35
- tools_condition,
36
- {
37
- # Translate the condition outputs to nodes in our graph
38
- "tools": "retrieve",
39
- END: END,
40
- },
41
- )
42
 
43
- # Edges taken after the `action` node is called.
44
- workflow.add_conditional_edges(
45
- "retrieve",
46
- # Assess agent decision
47
- grade_documents,
 
 
 
 
 
 
 
 
 
 
48
  )
49
- workflow.add_edge("generate_answer", END)
50
- workflow.add_edge("rewrite_question", "generate_query_or_respond")
51
-
52
- # Compile
53
- graph = workflow.compile()
54
-
55
- return graph
56
-
57
-
58
-
59
-
60
- from langchain.schema import AIMessage, HumanMessage
61
- import gradio as gr
62
- from langchain.chat_models import init_chat_model
63
-
64
- ## ADD TRACKING
65
- response_model = init_chat_model("gpt-4o", temperature=0)
66
- grader_model = init_chat_model("gpt-4o", temperature=0)
67
 
68
- graph = create_graph()
 
 
69
 
70
- def predict(message, history):
71
  history_langchain_format = []
72
  for msg in history:
73
  if msg['role'] == "user":
74
  history_langchain_format.append(HumanMessage(content=msg['content']))
75
  elif msg['role'] == "assistant":
76
  history_langchain_format.append(AIMessage(content=msg['content']))
77
- history_langchain_format.append(HumanMessage(content=message))
 
 
 
 
 
 
 
 
 
78
 
79
 
80
- gpt_response = graph.invoke(history_langchain_format)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return gpt_response.content
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  iface = gr.ChatInterface(
86
  predict,
 
1
 
2
 
3
+ #%% load llm
4
+ from dotenv import load_dotenv
5
+ import os
6
+ load_dotenv()
7
 
8
 
9
+ from langchain.chat_models import init_chat_model
 
 
 
10
 
11
+ llm = init_chat_model("gpt-5-nano",
12
+ model_provider="openai",
13
+ api_key=os.environ['OPENAI_API_KEY'])
14
 
15
 
16
+ #%% load retreiver
17
+ from agent.create_retreiver import load_vector_store
18
+ retriever = load_vector_store("intfloat/e5-base-v2","data/FAISS/512-intfloat-e5-base-v2-2026-01-16")
19
 
20
 
 
21
 
22
+ #%% setup chatbot
23
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
24
+ from langchain.chat_models import init_chat_model
 
 
25
 
 
26
 
27
+ def predict(message, history):
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Safeguard
30
+ TRIAGE_PROMPT_TEMPLATE="""You are a Safeguard assistant making sure the user only ask for information related to Rémi Cazelles's projects, work and education.
31
+ If the question is not related to this subjects, or if the request is harmfull you should flag the user by answering '*** FLAGGED ***' else simply answer '*** OK ***' """
32
+ messages = [SystemMessage(content=TRIAGE_PROMPT_TEMPLATE)]
33
+ messages.append(HumanMessage(content=message))
34
+
35
+ safe_gpt_response = llm.invoke(
36
+ messages,
37
+ config={
38
+ "tags": ["Testing", 'RAG-Bot', 'safeguard','V1'],
39
+ "metadata": {
40
+ "rag_llm": "gpt-5-nano",
41
+ "message": message,
42
+ }
43
+ }
44
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ if not "*** OK ***" in safe_gpt_response.content:
47
+ return "This app can only answer question about Rémi Cazelles's projects, work and education."
48
+ print("passed the safeguard")
49
 
50
+ # Build conversation history
51
  history_langchain_format = []
52
  for msg in history:
53
  if msg['role'] == "user":
54
  history_langchain_format.append(HumanMessage(content=msg['content']))
55
  elif msg['role'] == "assistant":
56
  history_langchain_format.append(AIMessage(content=msg['content']))
57
+
58
+
59
+ # Retrieve relevant documents for the current message
60
+ relevant_docs = retriever.similarity_search(message,k=3) # Your retriever
61
+
62
+ # Build context from retrieved documents
63
+ context = "\nExtracted documents:\n" + "\n".join([
64
+ f"Document {i}: Content: {doc.page_content}\n\n---"
65
+ for i, doc in enumerate(relevant_docs)
66
+ ])
67
 
68
 
 
69
 
70
+ # RAG tool
71
+ RAG_PROMPT_TEMPLATE="""Using the information contained in the context,
72
+ give a comprehensive answer to the question.
73
+ Respond only to the question asked, response should be concise and relevant to the question.
74
+ Provide the context source url and context date of the source document when relevant.
75
+ If the answer cannot be deduced from the context, do not give an answer.
76
+ """
77
+
78
+
79
+ # Create the prompt with system message, context, and conversation history
80
+ messages = [SystemMessage(content=RAG_PROMPT_TEMPLATE)]
81
+ messages.extend(history_langchain_format)
82
+ combined_message = f"Context: {context}\n\nQuestion: {message}"
83
+ messages.append(HumanMessage(content=combined_message))
84
+
85
+ # Get response with tracking metadata
86
+ print("GPT about to answer")
87
+ gpt_response = llm.invoke(
88
+ messages,
89
+ config={
90
+ "tags": ["Testing", 'RAG-Bot', 'V1'],
91
+ "metadata": {
92
+ "rag_llm": "gpt-5-nano",
93
+ "num_retrieved_docs": len(relevant_docs),
94
+ }
95
+ }
96
+ )
97
+
98
+ source_context = "\nSources:\n" + "\n".join([
99
+ f"{doc.metadata.get('source_url')} ({doc.metadata.get('date')})\n---"
100
+ for i, doc in enumerate(relevant_docs)])
101
 
102
+ print(gpt_response.content )
103
+ print(source_context)
104
+
105
+ return gpt_response.content + source_context
106
+
107
+
108
+ #%% setup tracking
109
+ os.environ["LANGSMITH_PROJECT"] = "Testing_POC"
110
+ os.environ["LANGSMITH_TRACING"] = "true"
111
+ os.environ["LANGSMITH_API_KEY"] = os.environ['LANGSMITH_API_KEY']
112
+
113
+ #%% lauch gradio app
114
+ import gradio as gr
115
 
116
  iface = gr.ChatInterface(
117
  predict,