themissingCRAM commited on
Commit
5520644
·
1 Parent(s): 193e8c6

first init

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +118 -25
  3. requirements.txt +5 -1
README.md CHANGED
@@ -10,4 +10,4 @@ pinned: false
10
  short_description: self correcting text to sql agent based on smolagents exampl
11
  ---
12
 
13
- self correcting text to sql agent based on https://huggingface.co/docs/smolagents/v1.12.0/examples/text_to_sql smolagents example
 
10
  short_description: self correcting text to sql agent based on smolagents exampl
11
  ---
12
 
13
+ bakery shops ordering system with recipe rag
app.py CHANGED
@@ -1,13 +1,8 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import os
4
  from smolagents import (
5
  tool,
6
- CodeAgent,
7
- HfApiModel,
8
- GradioUI,
9
- MultiStepAgent,
10
- stream_to_gradio,
11
  )
12
  from sqlalchemy import (
13
  create_engine,
@@ -18,14 +13,17 @@ from sqlalchemy import (
18
  Integer,
19
  Float,
20
  insert,
21
- inspect,
22
  text,
23
- select,
24
- Engine,
25
  )
26
- import spaces
27
 
 
 
 
 
28
  from dotenv import load_dotenv
 
 
 
29
 
30
  load_dotenv()
31
  #sample questions
@@ -58,15 +56,80 @@ def sql_engine_tool(query: str) -> str:
58
  for row in rows:
59
  output += "\n" + str(row)
60
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
- def init_db(engine):
64
  metadata_obj = MetaData()
65
 
66
- def insert_rows_into_table(rows, table, engine=engine):
67
- for row in rows:
68
- stmt = insert(table).values(**row)
69
- with engine.begin() as connection:
70
  connection.execute(stmt)
71
 
72
  table_name = "receipts"
@@ -78,7 +141,7 @@ def init_db(engine):
78
  Column("price", Float),
79
  Column("tip", Float),
80
  )
81
- metadata_obj.create_all(engine)
82
 
83
  rows = [
84
  {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
@@ -105,7 +168,7 @@ def init_db(engine):
105
  Column("receipt_id", Integer, primary_key=True),
106
  Column("waiter_name", String(16), primary_key=True),
107
  )
108
- metadata_obj.create_all(engine)
109
 
110
  rows = [
111
  {"receipt_id": 1, "waiter_name": "Corey Johnson"},
@@ -114,7 +177,7 @@ def init_db(engine):
114
  {"receipt_id": 4, "waiter_name": "Margaret James"},
115
  ]
116
  insert_rows_into_table(rows, waiters)
117
- return engine
118
 
119
 
120
  if __name__ == "__main__":
@@ -126,17 +189,44 @@ if __name__ == "__main__":
126
  token=os.getenv("my_first_agents_hf_tokens"),
127
  )
128
 
129
- agent = CodeAgent(
130
  tools=[sql_engine_tool],
131
  model=model,
132
  max_steps=10,
133
  verbosity_level=1,
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def enter_message(new_message, conversation_history):
137
  conversation_history.append(gr.ChatMessage(role="user", content=new_message))
138
  # yield "", conversation_history
139
- for msg in stream_to_gradio(agent, new_message):
140
  conversation_history.append(msg)
141
  yield "", conversation_history
142
 
@@ -145,14 +235,17 @@ if __name__ == "__main__":
145
  return chat_history.clear(), ""
146
 
147
  def stop_gen():
148
- agent = CodeAgent(
149
- tools=[sql_engine_tool],
150
  model=model,
 
 
151
  max_steps=10,
152
  verbosity_level=10,
153
  )
 
154
  with gr.Blocks() as b:
155
- gr.Markdown("# Demo text to sql on paying customers' receipts")
156
  chatbot = gr.Chatbot(type="messages", height=2000)
157
  message_box = gr.Textbox(lines=1, label="chat message (with default sample question)", value="What is the average each customer paid?")
158
  with gr.Row():
 
1
+ from langchain_community.document_loaders import HuggingFaceDatasetLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
3
  from smolagents import (
4
  tool,
5
+
 
 
 
 
6
  )
7
  from sqlalchemy import (
8
  create_engine,
 
13
  Integer,
14
  Float,
15
  insert,
 
16
  text,
 
 
17
  )
 
18
 
19
+ import gradio as gr
20
+ import os
21
+ from smolagents import Tool, CodeAgent, HfApiModel, stream_to_gradio
22
+ import spaces
23
  from dotenv import load_dotenv
24
+ from langchain.docstore.document import Document
25
+ import chromadb
26
+ from chromadb.utils import embedding_functions
27
 
28
  load_dotenv()
29
  #sample questions
 
56
  for row in rows:
57
  output += "\n" + str(row)
58
  return output
59
+ @tool
60
+ class RetrieverTool(Tool):
61
+ """Since we need to add a vectordb as an attribute of the tool,
62
+ we cannot simply use the simple tool constructor with a @tool decorator
63
+
64
+ Used bm25 retrival method because it is fast.
65
+ For more accuracy in retrival, you can replace it with semantic search
66
+ using vector representations for documents.
67
+
68
+ check out MTEB Leaderboard for accuracy ranking
69
+ """
70
+
71
+ name = "retriever"
72
+ description = """Uses semantic search to retrieve the parts of transformers documentation
73
+ that could be most relevant to answer your query.
74
+ Afterwards, this tool summaries the findings from the extracted document
75
+ """
76
+ inputs = {
77
+ "query": {
78
+ "type": "string",
79
+ "description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
80
+ }
81
+ }
82
+ output_type = "string"
83
+
84
+ def __init__(self, docs: list[Document], **kwargs):
85
+ super().__init__(**kwargs)
86
+ chroma_data_path = "chroma_data/"
87
+
88
+ if not os.path.isdir(chroma_data_path):
89
+ print("in if clause")
90
+ os.makedirs(chroma_data_path, exist_ok=True)
91
+ collection_name = "demo_docs"
92
+ embedding_func = embedding_functions.DefaultEmbeddingFunction()
93
+ client = chromadb.PersistentClient(path=chroma_data_path)
94
+ collection = client.get_or_create_collection(
95
+ name=collection_name,
96
+ embedding_function=embedding_func,
97
+ metadata={"hnsw:space": "cosine"},
98
+ )
99
+ collection.upsert(
100
+ documents=[doc.page_content for doc in docs],
101
+ ids=[f"id{i}" for i in range(len(docs))],
102
+ )
103
+ self.collection = collection
104
+
105
+ def forward(self, query: str) -> str:
106
+ assert isinstance(query, str), "Your search query must be a string"
107
+ docs = self.collection.query(query_texts=[query], n_results=5)
108
+ retrieved_text = "\nRetrieved documents:\n" + "".join(
109
+ [
110
+ f"\n\n===== Document {str(i)} =====\n" + doc
111
+ for i, doc in zip(docs["ids"][0], docs["documents"][0])
112
+ ]
113
+ )
114
+ messages = [
115
+ {
116
+ "role": "user",
117
+ "content": [
118
+ {"type": "text", "text": "summaries this text:" + retrieved_text}
119
+ ],
120
+ }
121
+ ]
122
+ return retrieved_text + "\n" + model(messages).content
123
+
124
 
125
 
126
+ def init_db(_engine):
127
  metadata_obj = MetaData()
128
 
129
+ def insert_rows_into_table(_rows, _table, _engine=_engine):
130
+ for row in _rows:
131
+ stmt = insert(_table).values(**row)
132
+ with _engine.begin() as connection:
133
  connection.execute(stmt)
134
 
135
  table_name = "receipts"
 
141
  Column("price", Float),
142
  Column("tip", Float),
143
  )
144
+ metadata_obj.create_all(_engine)
145
 
146
  rows = [
147
  {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
 
168
  Column("receipt_id", Integer, primary_key=True),
169
  Column("waiter_name", String(16), primary_key=True),
170
  )
171
+ metadata_obj.create_all(_engine)
172
 
173
  rows = [
174
  {"receipt_id": 1, "waiter_name": "Corey Johnson"},
 
177
  {"receipt_id": 4, "waiter_name": "Margaret James"},
178
  ]
179
  insert_rows_into_table(rows, waiters)
180
+ return _engine
181
 
182
 
183
  if __name__ == "__main__":
 
189
  token=os.getenv("my_first_agents_hf_tokens"),
190
  )
191
 
192
+ text2sql_agent = CodeAgent(
193
  tools=[sql_engine_tool],
194
  model=model,
195
  max_steps=10,
196
  verbosity_level=1,
197
  )
198
+ source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100]
199
+ text_splitter = RecursiveCharacterTextSplitter(
200
+ chunk_size=500,
201
+ chunk_overlap=50,
202
+ add_start_index=True,
203
+ strip_whitespace=True,
204
+ separators=["\n\n", "\n", ".", " ", ""],
205
+ )
206
+ docs_processed = text_splitter.split_documents(source_docs)
207
+ retriever_tool = RetrieverTool(docs_processed)
208
+
209
+
210
+ retriever_agent = CodeAgent(
211
+ tools=[retriever_tool],
212
+ model=model,
213
+ max_steps=10,
214
+ verbosity_level=10,
215
+ )
216
+ manager_agent = CodeAgent(
217
+ tools=[retriever_tool],
218
+ model=model,
219
+ managed_agents=[retriever_agent
220
+ ,text2sql_agent],
221
+ max_steps=10,
222
+ verbosity_level=10,
223
+ )
224
+
225
 
226
  def enter_message(new_message, conversation_history):
227
  conversation_history.append(gr.ChatMessage(role="user", content=new_message))
228
  # yield "", conversation_history
229
+ for msg in stream_to_gradio(manager_agent, new_message):
230
  conversation_history.append(msg)
231
  yield "", conversation_history
232
 
 
235
  return chat_history.clear(), ""
236
 
237
  def stop_gen():
238
+ manager_agent = CodeAgent(
239
+ tools=[retriever_tool],
240
  model=model,
241
+ managed_agents=[retriever_agent
242
+ , text2sql_agent],
243
  max_steps=10,
244
  verbosity_level=10,
245
  )
246
+
247
  with gr.Blocks() as b:
248
+ gr.Markdown("# demo bakery shops ordering system with recipe rag")
249
  chatbot = gr.Chatbot(type="messages", height=2000)
250
  message_box = gr.Textbox(lines=1, label="chat message (with default sample question)", value="What is the average each customer paid?")
251
  with gr.Row():
requirements.txt CHANGED
@@ -4,4 +4,8 @@ python-dotenv==1.1.0
4
  sqlalchemy==2.0.40
5
  gradio>=5.23.1
6
  spaces>0.0.0
7
- smolagents[gradio]>=1.12.0
 
 
 
 
 
4
  sqlalchemy==2.0.40
5
  gradio>=5.23.1
6
  spaces>0.0.0
7
+ smolagents[gradio]>=1.12.0
8
+ sqlalchemy==2.0.40
9
+ langchain == 0.3.21
10
+ langchain_community == 0.3.20
11
+ chromadb == 0.6.3