Komla21 commited on
Commit
04114ad
·
verified ·
1 Parent(s): f6f3710

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +72 -56
  2. requirements.txt +194 -0
  3. tools.py +134 -0
app.py CHANGED
@@ -1,70 +1,86 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  import gradio as gr
4
+ from tools import create_agent
5
+ from langchain_core.messages import RemoveMessage
6
+ from langchain_core.messages import trim_messages
7
 
8
+ # Global params
9
+ AGENT = create_agent()
10
+ theme = gr.themes.Default(primary_hue="red", secondary_hue="red")
11
+ default_msg = "Bonjour ! Je suis là pour répondre à vos questions sur l'actuariat. Comment puis-je vous aider aujourd'hui ?"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def filter_msg(msg_list:list, keep_n:int) -> list:
15
+ """Keep only last keep_n messages from chat history. Preserves structure user msg -> tool msg -> ai msg"""
16
+ msg = trim_messages(
17
+ msg_list,
18
+ strategy="last",
19
+ token_counter=len,
20
+ max_tokens=keep_n,
21
+ start_on="human",
22
+ end_on=("tool", "ai"),
23
+ include_system=True,
24
+ )
25
+ return [m.id for m in msg]
26
 
27
+ def agent_response(query, config, keep_n=10):
28
+ messages = AGENT.get_state(config).values.get("messages", [])
29
+
30
+ if len(messages) > keep_n:
31
+ keep_msg_ids = filter_msg(messages, keep_n)
32
+ AGENT.update_state(config, {"messages": [RemoveMessage(id=m.id) for m in messages if m.id not in keep_msg_ids]})
33
+ print("msg removed")
34
 
35
+ # Generate answer
36
+ answer = AGENT.invoke({"messages":query}, config=config)
37
+ return answer["messages"][-1].content
38
 
 
39
 
40
+ js_func = """
41
+ function refresh() {
42
+ const url = new URL(window.location);
 
 
 
 
 
 
 
 
43
 
44
+ if (url.searchParams.get('__theme') != 'light') {
45
+ url.searchParams.set('__theme', 'light');
46
+ window.location.href = url.href;
47
+ }
48
+ }
49
+ """
50
 
51
 
52
+ def delete_agent():
53
+ print("del agent")
54
+ global AGENT
55
+ AGENT = create_agent()
56
+ # print(AGENT.get_state(config).values.get("messages"), "\n\n")
57
+
58
+ with gr.Blocks(theme=theme, js=js_func, title="Dataltist", fill_height=True) as iface:
59
+ gr.Markdown("# Dataltist Chatbot 🚀")
60
+ chatbot = gr.Chatbot(show_copy_button=True, show_share_button=False, value=[{"role":"assistant", "content":default_msg}], type="messages", scale=1)
61
+ msg = gr.Textbox(lines=1, show_label=False, placeholder="Posez vos questions sur l'assurance") # submit_btn=True
62
+ # clear = gr.ClearButton([msg, chatbot], value="Effacer 🗑")
63
+ config = {"configurable": {"thread_id": "1"}}
64
+
65
+
66
+ def user(user_message, history: list):
67
+ return "", history + [{"role": "user", "content": user_message}]
68
+
69
+ def bot(history: list):
70
+ bot_message = agent_response(history[-1]["content"], config) #AGENT.invoke({"messages":history[-1]["content"]}, config=config)
71
+ history.append({"role": "assistant", "content": ""})
72
+ for character in bot_message:
73
+ history[-1]['content'] += character
74
+ # time.sleep(0.005)
75
+ yield history
76
 
77
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
78
+ bot, chatbot, chatbot
79
+ )
80
+ iface.unload(delete_agent)
81
 
82
  if __name__ == "__main__":
83
+ # load_dotenv()
84
+ # AUTH_ID = os.environ.get("AUTH_ID")
85
+ # AUTH_PASS = os.environ.get("AUTH_PASS")
86
+ iface.launch() #share=True, auth=(AUTH_ID, AUTH_PASS)
requirements.txt ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.0
3
+ aiohttp==3.10.5
4
+ aiosignal==1.3.1
5
+ altair==5.4.1
6
+ annotated-types==0.7.0
7
+ anyio==4.4.0
8
+ asgiref==3.8.1
9
+ asttokens==2.4.1
10
+ attrs==24.2.0
11
+ backoff==2.2.1
12
+ bcrypt==4.2.0
13
+ blinker==1.8.2
14
+ build==1.2.1
15
+ cachetools==5.5.0
16
+ certifi==2024.8.30
17
+ charset-normalizer==3.3.2
18
+ chroma-datasets==0.1.5
19
+ chroma-hnswlib==0.7.6
20
+ chromadb==0.5.7
21
+ click==8.1.7
22
+ colorama==0.4.6
23
+ coloredlogs==15.0.1
24
+ comm==0.2.2
25
+ contourpy==1.3.0
26
+ cycler==0.12.1
27
+ dataclasses-json==0.6.7
28
+ datasets==3.0.0
29
+ debugpy==1.8.5
30
+ decorator==5.1.1
31
+ Deprecated==1.2.14
32
+ dill==0.3.8
33
+ distro==1.9.0
34
+ executing==2.1.0
35
+ fastapi==0.112.2
36
+ ffmpy==0.4.0
37
+ filelock==3.15.4
38
+ flatbuffers==24.3.25
39
+ fonttools==4.54.0
40
+ frozenlist==1.4.1
41
+ fsspec==2024.6.1
42
+ gitdb==4.0.11
43
+ GitPython==3.1.43
44
+ google-auth==2.34.0
45
+ googleapis-common-protos==1.65.0
46
+ gradio==4.44.0
47
+ gradio_client==1.3.0
48
+ greenlet==3.0.3
49
+ grpcio==1.66.1
50
+ h11==0.14.0
51
+ httpcore==1.0.5
52
+ httptools==0.6.1
53
+ httpx==0.27.2
54
+ httpx-sse==0.4.0
55
+ huggingface-hub==0.24.6
56
+ humanfriendly==10.0
57
+ idna==3.8
58
+ importlib_metadata==8.4.0
59
+ importlib_resources==6.4.4
60
+ ipykernel==6.29.5
61
+ ipython==8.27.0
62
+ jedi==0.19.1
63
+ Jinja2==3.1.4
64
+ jiter==0.5.0
65
+ joblib==1.4.2
66
+ jsonpatch==1.33
67
+ jsonpointer==3.0.0
68
+ jsonschema==4.23.0
69
+ jsonschema-specifications==2023.12.1
70
+ jupyter_client==8.6.2
71
+ jupyter_core==5.7.2
72
+ kiwisolver==1.4.7
73
+ kubernetes==30.1.0
74
+ langchain==0.3.0
75
+ langchain-chroma==0.1.4
76
+ langchain-community==0.3.0
77
+ langchain-core==0.3.5
78
+ langchain-huggingface==0.1.0
79
+ langchain-openai==0.2.0
80
+ langchain-text-splitters==0.3.0
81
+ langsmith==0.1.126
82
+ langgraph
83
+ markdown-it-py==3.0.0
84
+ MarkupSafe==2.1.5
85
+ marshmallow==3.22.0
86
+ matplotlib==3.9.2
87
+ matplotlib-inline==0.1.7
88
+ mdurl==0.1.2
89
+ mixedbread-ai==2.2.6
90
+ mmh3==4.1.0
91
+ monotonic==1.6
92
+ mpmath==1.3.0
93
+ multidict==6.0.5
94
+ multiprocess==0.70.16
95
+ mypy-extensions==1.0.0
96
+ narwhals==1.6.0
97
+ nest-asyncio==1.6.0
98
+ networkx==3.3
99
+ numpy==1.26.4
100
+ oauthlib==3.2.2
101
+ onnxruntime==1.19.0
102
+ openai==1.43.0
103
+ opentelemetry-api==1.27.0
104
+ opentelemetry-exporter-otlp-proto-common==1.27.0
105
+ opentelemetry-exporter-otlp-proto-grpc==1.27.0
106
+ opentelemetry-instrumentation==0.48b0
107
+ opentelemetry-instrumentation-asgi==0.48b0
108
+ opentelemetry-instrumentation-fastapi==0.48b0
109
+ opentelemetry-proto==1.27.0
110
+ opentelemetry-sdk==1.27.0
111
+ opentelemetry-semantic-conventions==0.48b0
112
+ opentelemetry-util-http==0.48b0
113
+ orjson==3.10.7
114
+ overrides==7.7.0
115
+ packaging==24.1
116
+ pandas==2.2.2
117
+ parso==0.8.4
118
+ pillow==10.4.0
119
+ platformdirs==4.3.2
120
+ posthog==3.6.0
121
+ prompt_toolkit==3.0.47
122
+ protobuf==4.25.4
123
+ psutil==6.0.0
124
+ pure_eval==0.2.3
125
+ pyarrow==17.0.0
126
+ pyasn1==0.6.0
127
+ pyasn1_modules==0.4.0
128
+ pydantic==2.8.2
129
+ pydantic-settings==2.5.2
130
+ pydantic_core==2.20.1
131
+ pydeck==0.9.1
132
+ pydub==0.25.1
133
+ Pygments==2.18.0
134
+ pyparsing==3.1.4
135
+ pypdf==4.3.1
136
+ PyPika==0.48.9
137
+ pyproject_hooks==1.1.0
138
+ pyreadline3==3.4.1
139
+ python-dateutil==2.9.0.post0
140
+ python-dotenv==1.0.1
141
+ python-multipart==0.0.10
142
+ pytz==2024.1
143
+ PyYAML==6.0.2
144
+ pyzmq==26.2.0
145
+ referencing==0.35.1
146
+ regex==2024.7.24
147
+ requests==2.32.3
148
+ requests-oauthlib==2.0.0
149
+ rich==13.8.0
150
+ rpds-py==0.20.0
151
+ rsa==4.9
152
+ ruff==0.6.7
153
+ safetensors==0.4.4
154
+ scikit-learn==1.5.2
155
+ scipy==1.14.1
156
+ semantic-version==2.10.0
157
+ sentence-transformers==3.1.1
158
+ sentencepiece==0.2.0
159
+ setuptools==72.1.0
160
+ shellingham==1.5.4
161
+ six==1.16.0
162
+ smmap==5.0.1
163
+ sniffio==1.3.1
164
+ SQLAlchemy==2.0.32
165
+ stack-data==0.6.3
166
+ starlette==0.38.4
167
+ sympy==1.13.2
168
+ tenacity==8.5.0
169
+ threadpoolctl==3.5.0
170
+ tiktoken==0.7.0
171
+ tokenizers==0.19.1
172
+ toml==0.10.2
173
+ tomlkit==0.12.0
174
+ torch==2.4.0
175
+ tornado==6.4.1
176
+ tqdm==4.66.5
177
+ traitlets==5.14.3
178
+ transformers==4.44.2
179
+ typer==0.12.5
180
+ typing-inspect==0.9.0
181
+ typing_extensions==4.12.2
182
+ tzdata==2024.1
183
+ urllib3==2.2.2
184
+ uvicorn==0.30.6
185
+ watchdog==4.0.2
186
+ watchfiles==0.24.0
187
+ wcwidth==0.2.13
188
+ websocket-client==1.8.0
189
+ websockets==12.0
190
+ wheel==0.43.0
191
+ wrapt==1.16.0
192
+ xxhash==3.5.0
193
+ yarl==1.9.7
194
+ zipp==3.20.1
tools.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.tools import TavilySearchResults
2
+
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
5
+ from langchain_core.vectorstores import VectorStoreRetriever
6
+ from langgraph.prebuilt import create_react_agent
7
+ from langchain_core.documents import Document
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.checkpoint.memory import MemorySaver
10
+ from mixedbread_ai.client import MixedbreadAI
11
+ from langchain.chains import create_retrieval_chain
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
+ from langchain.prompts import ChatPromptTemplate
14
+ from dotenv import load_dotenv
15
+ import os
16
+ from langchain_chroma import Chroma
17
+ import chromadb
18
+ from typing import List
19
+ from datasets import load_dataset
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+ from tqdm import tqdm
22
+ from datetime import datetime
23
+
24
+ load_dotenv()
25
+ # Global params
26
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
27
+ MODEL_EMB = "mxbai-embed-large"
28
+ MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
29
+ LLM_NAME = "gpt-4o-mini"
30
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
31
+ MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
32
+ HF_TOKEN = os.environ.get("HF_TOKEN")
33
+ HF_API_KEY = os.environ.get("HF_API_KEY")
34
+
35
+ # MixedbreadAI Client
36
+ mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
37
+ model_emb = "mixedbread-ai/mxbai-embed-large-v1"
38
+
39
+ # # Set up ChromaDB
40
+ memoires_ds = load_dataset("DATALTIST/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
41
+ batched_ds = memoires_ds.batch(batch_size=41000)
42
+ client = chromadb.Client()
43
+ collection = client.get_or_create_collection(name="embeddings_mxbai")
44
+ for batch in tqdm(batched_ds, desc="Processing dataset batches"):
45
+ collection.add(
46
+ ids=batch["id"],
47
+ metadatas=batch["metadata"],
48
+ documents=batch["document"],
49
+ embeddings=batch["embedding"],
50
+ )
51
+ print(f"Collection complete: {collection.count()}")
52
+ del memoires_ds, batched_ds
53
+
54
+ llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)
55
+
56
+
57
+
58
+ def init_rag_tool():
59
+ """Init tools to allow an LLM to query the documents"""
60
+ # client = chromadb.PersistentClient(path=CHROMA_PATH)
61
+ db = Chroma(
62
+ client=client,
63
+ collection_name=f"embeddings_mxbai",
64
+ embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
65
+ )
66
+ # Reranker class
67
+ class Reranker(BaseRetriever):
68
+ retriever: VectorStoreRetriever
69
+ # model: CrossEncoder
70
+ k: int
71
+
72
+ def _get_relevant_documents(
73
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
74
+ ) -> List[Document]:
75
+ docs = self.retriever.invoke(query)
76
+ results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
77
+ return [Document(page_content=res.input) for res in results.data]
78
+
79
+ # Set up reranker + LLM
80
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
81
+ reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
82
+ llm = ChatOpenAI(model=LLM_NAME, verbose=True)
83
+
84
+
85
+ system_prompt = (
86
+ "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
87
+ "Si tu ne connais pas la réponse, dis que tu ne sais pas."
88
+ )
89
+
90
+ prompt = ChatPromptTemplate.from_messages(
91
+ [
92
+ ("system", system_prompt),
93
+ ("human", "{input}"),
94
+ ]
95
+ )
96
+
97
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
98
+ rag_chain = create_retrieval_chain(reranker, question_answer_chain)
99
+
100
+ rag_tool = rag_chain.as_tool(
101
+ name="RAG_search",
102
+ description="Recherche d'information dans les mémoires d'actuariat",
103
+ arg_types={"input": str},
104
+ )
105
+ return rag_tool
106
+
107
+
108
+ def init_websearch_tool():
109
+ web_search_tool = TavilySearchResults(
110
+ name="Web_search",
111
+ max_results=5,
112
+ description="Recherche d'informations sur le web",
113
+ search_depth="advanced",
114
+ include_answer=True,
115
+ include_raw_content=True,
116
+ include_images=False,
117
+ verbose=False,
118
+ )
119
+ return web_search_tool
120
+
121
+
122
+ def create_agent():
123
+ rag_tool = init_rag_tool()
124
+ web_search_tool = init_websearch_tool()
125
+ memory = MemorySaver()
126
+ llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
127
+ tools = [rag_tool, web_search_tool]
128
+ system_message = """
129
+ Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
130
+ Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
131
+ """ # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web.
132
+
133
+ react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
134
+ return react_agent