Spaces:
Running
Running
Commit
·
77b60b2
1
Parent(s):
31abf01
updating to just directly call the structure
Browse files- app.py +70 -13
- poetry.lock +0 -0
- pyproject.toml +1 -0
app.py
CHANGED
|
@@ -4,8 +4,8 @@ import gradio as gr
|
|
| 4 |
from typing import Any
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import requests
|
| 7 |
-
from griptape.structures import Agent
|
| 8 |
-
from griptape.tasks import PromptTask
|
| 9 |
from griptape.drivers import (
|
| 10 |
LocalConversationMemoryDriver,
|
| 11 |
GriptapeCloudStructureRunDriver,
|
|
@@ -134,7 +134,7 @@ def build_talk_agent(session_id: str, message: str) -> Agent:
|
|
| 134 |
|
| 135 |
# Creates an agent for each run
|
| 136 |
# The agent uses local memory, which it differentiates between by session_hash.
|
| 137 |
-
def build_agent(session_id: str, message: str, kbs:str) -> Agent:
|
| 138 |
|
| 139 |
create_thread_id(session_id)
|
| 140 |
|
|
@@ -151,15 +151,13 @@ def build_agent(session_id: str, message: str, kbs:str) -> Agent:
|
|
| 151 |
Rule(
|
| 152 |
value="Do not perform the query unless the user has confirmed they are done with formulating."
|
| 153 |
),
|
| 154 |
-
Rule(
|
| 155 |
-
value="Only perform the query as one string argument."
|
| 156 |
-
),
|
| 157 |
Rule(
|
| 158 |
value="If the user says they want to start over, then you must delete the conversation memory file."
|
| 159 |
),
|
| 160 |
Rule(
|
| 161 |
value="Do not ever search conversation memory for a formulated query instead of querying. Query every time."
|
| 162 |
-
)
|
| 163 |
],
|
| 164 |
)
|
| 165 |
|
|
@@ -172,13 +170,16 @@ def build_agent(session_id: str, message: str, kbs:str) -> Agent:
|
|
| 172 |
structure_run_wait_time_interval=3,
|
| 173 |
structure_run_max_wait_time_attempts=30,
|
| 174 |
),
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
talk_client = StructureRunTool(
|
| 178 |
name="FormulateQueryFromUser",
|
| 179 |
description="Used to formulate a query from the user's input.",
|
| 180 |
-
|
| 181 |
-
|
| 182 |
),
|
| 183 |
)
|
| 184 |
return Agent(
|
|
@@ -199,11 +200,67 @@ def send_message(message: str, history, knowledge_bases, request: gr.Request) ->
|
|
| 199 |
response = agent.run(message)
|
| 200 |
return response.output.value
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
with gr.Blocks() as demo:
|
| 203 |
-
knowledge_bases = gr.CheckboxGroup(
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# Set it back to empty when a session is done
|
| 209 |
# Is there a better way?
|
|
|
|
| 4 |
from typing import Any
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import requests
|
| 7 |
+
from griptape.structures import Agent, Structure, Workflow
|
| 8 |
+
from griptape.tasks import PromptTask, StructureRunTask
|
| 9 |
from griptape.drivers import (
|
| 10 |
LocalConversationMemoryDriver,
|
| 11 |
GriptapeCloudStructureRunDriver,
|
|
|
|
| 134 |
|
| 135 |
# Creates an agent for each run
|
| 136 |
# The agent uses local memory, which it differentiates between by session_hash.
|
| 137 |
+
def build_agent(session_id: str, message: str, kbs: str) -> Agent:
|
| 138 |
|
| 139 |
create_thread_id(session_id)
|
| 140 |
|
|
|
|
| 151 |
Rule(
|
| 152 |
value="Do not perform the query unless the user has confirmed they are done with formulating."
|
| 153 |
),
|
| 154 |
+
Rule(value="Only perform the query as one string argument."),
|
|
|
|
|
|
|
| 155 |
Rule(
|
| 156 |
value="If the user says they want to start over, then you must delete the conversation memory file."
|
| 157 |
),
|
| 158 |
Rule(
|
| 159 |
value="Do not ever search conversation memory for a formulated query instead of querying. Query every time."
|
| 160 |
+
),
|
| 161 |
],
|
| 162 |
)
|
| 163 |
|
|
|
|
| 170 |
structure_run_wait_time_interval=3,
|
| 171 |
structure_run_max_wait_time_attempts=30,
|
| 172 |
),
|
| 173 |
+
# structure_run_driver = LocalStructureRunDriver(
|
| 174 |
+
# create_structure=create_structure
|
| 175 |
+
# )
|
| 176 |
)
|
| 177 |
|
| 178 |
talk_client = StructureRunTool(
|
| 179 |
name="FormulateQueryFromUser",
|
| 180 |
description="Used to formulate a query from the user's input.",
|
| 181 |
+
structure_run_driver=LocalStructureRunDriver(
|
| 182 |
+
create_structure=lambda: build_talk_agent(session_id, message),
|
| 183 |
),
|
| 184 |
)
|
| 185 |
return Agent(
|
|
|
|
| 200 |
response = agent.run(message)
|
| 201 |
return response.output.value
|
| 202 |
|
| 203 |
+
|
| 204 |
+
def send_message_call(message: str, history, knowledge_bases) -> Any:
|
| 205 |
+
|
| 206 |
+
structure_id = os.getenv("GT_STRUCTURE_ID")
|
| 207 |
+
api_key = os.getenv("GT_CLOUD_API_KEY")
|
| 208 |
+
structure_url = f"https://cloud.griptape.ai/api/structures/{structure_id}/runs"
|
| 209 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
| 210 |
+
payload = {"args": [message, *knowledge_bases]}
|
| 211 |
+
response = requests.post(structure_url, headers=headers, json=payload)
|
| 212 |
+
response.raise_for_status()
|
| 213 |
+
if response.status_code == 201:
|
| 214 |
+
data = response.json()
|
| 215 |
+
structure_run_id = data["structure_run_id"]
|
| 216 |
+
output = poll_structure(structure_run_id, headers)
|
| 217 |
+
return output["output_task_output"]["value"]
|
| 218 |
+
else:
|
| 219 |
+
return "Assistant Call Failed"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def poll_for_events(offset: int, structure_run_id: str, headers: dict):
|
| 223 |
+
url = f"https://cloud.griptape.ai/api/structure-runs/{structure_run_id}/events"
|
| 224 |
+
response = requests.get(
|
| 225 |
+
url=url, headers=headers, params={"offset": offset, "limit": 100}
|
| 226 |
+
)
|
| 227 |
+
response.raise_for_status()
|
| 228 |
+
|
| 229 |
+
return response
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def poll_structure(structure_run_id: str, headers: dict):
|
| 233 |
+
response = poll_for_events(0, structure_run_id, headers)
|
| 234 |
+
events = response.json()["events"]
|
| 235 |
+
offset = response.json()["next_offset"]
|
| 236 |
+
not_finished = True
|
| 237 |
+
output = ""
|
| 238 |
+
while not_finished:
|
| 239 |
+
time.sleep(0.5)
|
| 240 |
+
for event in events:
|
| 241 |
+
if event["type"] == "FinishStructureRunEvent":
|
| 242 |
+
not_finished = False
|
| 243 |
+
output = dict(event["payload"])
|
| 244 |
+
break
|
| 245 |
+
response = response = poll_for_events(offset, structure_run_id, headers)
|
| 246 |
+
response.raise_for_status()
|
| 247 |
+
events = response.json()["events"]
|
| 248 |
+
offset = response.json()["next_offset"]
|
| 249 |
+
return output
|
| 250 |
+
|
| 251 |
+
|
| 252 |
with gr.Blocks() as demo:
|
| 253 |
+
knowledge_bases = gr.CheckboxGroup(
|
| 254 |
+
label="Select Knowledge Bases",
|
| 255 |
+
choices=["skills", "demographics", "linked_in", "showreels"],
|
| 256 |
+
)
|
| 257 |
+
chatbot = gr.ChatInterface(
|
| 258 |
+
fn=send_message_call,
|
| 259 |
+
chatbot=gr.Chatbot(height=300),
|
| 260 |
+
additional_inputs=knowledge_bases,
|
| 261 |
+
)
|
| 262 |
+
# demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))
|
| 263 |
+
demo.launch()
|
| 264 |
|
| 265 |
# Set it back to empty when a session is done
|
| 266 |
# Is there a better way?
|
poetry.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
CHANGED
|
@@ -13,6 +13,7 @@ python = "^3.11"
|
|
| 13 |
python-dotenv = "^1.0.0"
|
| 14 |
gradio = "^4.37.1"
|
| 15 |
griptape = {git="https://github.com/griptape-ai/griptape.git", rev = "dev", extras=["drivers-embedding-voyageai","drivers-prompt-anthropic"]}
|
|
|
|
| 16 |
argparse = "^1.4.0"
|
| 17 |
azure-identity = "^1.17.1"
|
| 18 |
|
|
|
|
| 13 |
python-dotenv = "^1.0.0"
|
| 14 |
gradio = "^4.37.1"
|
| 15 |
griptape = {git="https://github.com/griptape-ai/griptape.git", rev = "dev", extras=["drivers-embedding-voyageai","drivers-prompt-anthropic"]}
|
| 16 |
+
#griptape = "^0.34"
|
| 17 |
argparse = "^1.4.0"
|
| 18 |
azure-identity = "^1.17.1"
|
| 19 |
|