Spaces:
Sleeping
Sleeping
| import glob | |
| from venv import create | |
| import gradio as gr | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| import requests | |
| from griptape.structures import Agent, Structure, Workflow | |
| from griptape.tasks import PromptTask, StructureRunTask | |
| from griptape.drivers import ( | |
| LocalConversationMemoryDriver, | |
| GriptapeCloudStructureRunDriver, | |
| LocalFileManagerDriver, | |
| LocalStructureRunDriver, | |
| GriptapeCloudConversationMemoryDriver, | |
| ) | |
| from griptape.memory.structure import ConversationMemory | |
| from griptape.tools import StructureRunTool, FileManagerTool | |
| from griptape.rules import Rule, Ruleset | |
| from griptape.configs.drivers import AnthropicDriversConfig | |
| from griptape.configs import Defaults | |
| import time | |
| import os | |
| from urllib.parse import urljoin | |
| # Load environment variables | |
| load_dotenv() | |
| Defaults.drivers_config = AnthropicDriversConfig() | |
| base_url = "https://cloud.griptape.ai" | |
| headers_api = { | |
| "Authorization": f"Bearer {os.environ['GT_CLOUD_API_KEY']}", | |
| "Content-Type": "application/json", | |
| } | |
| threads = {} | |
| # custom_css = """ | |
| # #component-2 { | |
| # height: 75vh !important; | |
| # min-height: 600px !important; | |
| # """ | |
| def create_thread_id(session_id: str) -> str: | |
| if not session_id in threads: | |
| params = { | |
| "name": session_id, | |
| "messages": [], | |
| } | |
| response = requests.post( | |
| url=urljoin(base_url, "/api/threads"), headers=headers_api, json=params | |
| ) | |
| response.raise_for_status() | |
| thread_id = response.json()["thread_id"] | |
| threads[session_id] = thread_id | |
| return thread_id | |
| else: | |
| return threads[session_id] | |
| # Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud. | |
| # Function that logs user history - adds to history parameter of Gradio | |
| # TODO: Figure out the exact use of this function | |
| def user(user_message, history): | |
| history.append([user_message, None]) | |
| return ("", history) | |
| # Function that logs bot history - adds to the history parameter of Gradio | |
| # TODO: Figure out the exact use of this function | |
| def bot(history): | |
| response = send_message(history[-1][0]) | |
| history[-1][1] = "" | |
| for character in response: | |
| history[-1][1] += character | |
| time.sleep(0.005) | |
| yield history | |
| def create_prompt_task(session_id: str, message: str) -> PromptTask: | |
| return PromptTask( | |
| f""" | |
| Re-structure the values to form a query from the user's questions: '{message}' and the input value from the conversation memory. Leave out attributes that aren't important to the user: | |
| """, | |
| ) | |
| def build_talk_agent(session_id: str, message: str) -> Agent: | |
| create_thread_id(session_id) | |
| ruleset = Ruleset( | |
| name="Local Gradio Agent", | |
| rules=[ | |
| Rule( | |
| value="You are responsible for structuring a user's questions into a specific format for a query." | |
| ), | |
| Rule( | |
| value="""You ask the user follow-up questions to fill in missing information for: | |
| years experience, | |
| location, | |
| role, | |
| skills, | |
| expected salary, | |
| availability, | |
| past companies, | |
| past projects, | |
| show reel details | |
| """ | |
| ), | |
| Rule( | |
| value="Return the current query structure and any questions to fill in missing information." | |
| ), | |
| ], | |
| ) | |
| return Agent( | |
| conversation_memory=ConversationMemory( | |
| conversation_memory_driver=GriptapeCloudConversationMemoryDriver( | |
| thread_id=threads[session_id], | |
| ) | |
| ), | |
| tasks=[create_prompt_task(session_id, message)], | |
| rulesets=[ruleset], | |
| ) | |
| # Creates an agent for each run | |
| # The agent uses local memory, which it differentiates between by session_hash. | |
| def build_agent(session_id: str, message: str, kbs: str) -> Agent: | |
| create_thread_id(session_id) | |
| ruleset = Ruleset( | |
| name="Local Gradio Agent", | |
| rules=[ | |
| Rule( | |
| value="You are responsible for structuring a user's questions into a query and then querying." | |
| ), | |
| Rule( | |
| value="Only return the result of the query, do not provide additional commentary." | |
| ), | |
| Rule(value="Only perform one task at a time."), | |
| Rule( | |
| value="Do not perform the query unless the user has confirmed they are done with formulating." | |
| ), | |
| Rule(value="Only perform the query as one string argument."), | |
| Rule( | |
| value="If the user says they want to start over, then you must delete the conversation memory file." | |
| ), | |
| Rule( | |
| value="Do not ever search conversation memory for a formulated query instead of querying. Query every time." | |
| ), | |
| ], | |
| ) | |
| query_client = StructureRunTool( | |
| name="QueryResumeSearcher", | |
| description=f"""Use it to search for a candidate with the query. Add each item in this list as separate arguments:{kbs}. Do not add any other arguments.""", | |
| driver=GriptapeCloudStructureRunDriver( | |
| structure_id=os.getenv("GT_STRUCTURE_ID"), | |
| api_key=os.getenv("GT_CLOUD_API_KEY"), | |
| structure_run_wait_time_interval=3, | |
| structure_run_max_wait_time_attempts=30, | |
| ), | |
| # structure_run_driver = LocalStructureRunDriver( | |
| # create_structure=create_structure | |
| # ) | |
| ) | |
| talk_client = StructureRunTool( | |
| name="FormulateQueryFromUser", | |
| description="Used to formulate a query from the user's input.", | |
| structure_run_driver=LocalStructureRunDriver( | |
| create_structure=lambda: build_talk_agent(session_id, message), | |
| ), | |
| ) | |
| return Agent( | |
| conversation_memory=ConversationMemory( | |
| conversation_memory_driver=GriptapeCloudConversationMemoryDriver( | |
| thread_id=threads[session_id], | |
| ) | |
| ), | |
| tools=[talk_client, query_client], | |
| rulesets=[ruleset], | |
| ) | |
| def send_message(message: str, history, knowledge_bases, request: gr.Request) -> Any: | |
| if request: | |
| session_hash = request.session_hash | |
| agent = build_agent(session_hash, message, str(knowledge_bases)) | |
| response = agent.run(message) | |
| return response.output.value | |
| def send_message_call(message: str, history, knowledge_bases) -> Any: | |
| structure_id = os.getenv("GT_STRUCTURE_ID") | |
| api_key = os.getenv("GT_CLOUD_API_KEY") | |
| structure_url = f"https://cloud.griptape.ai/api/structures/{structure_id}/runs" | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| payload = {"args": [message, *knowledge_bases]} | |
| response = requests.post(structure_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| if response.status_code == 201: | |
| data = response.json() | |
| structure_run_id = data["structure_run_id"] | |
| output = poll_structure(structure_run_id, headers) | |
| output = output["output_task_output"]["value"] | |
| output += f" \n UTC Timestamp: {data['created_at']}\n Structure ID: {structure_id} \n Run ID: {structure_run_id}" | |
| return output | |
| else: | |
| data = response.json() | |
| return ( | |
| f"Assistant Call Failed due to these errors: \n {','.join(data['errors'])} " | |
| ) | |
| def poll_for_events(offset: int, structure_run_id: str, headers: dict): | |
| url = f"https://cloud.griptape.ai/api/structure-runs/{structure_run_id}/events" | |
| response = requests.get( | |
| url=url, headers=headers, params={"offset": offset, "limit": 100} | |
| ) | |
| response.raise_for_status() | |
| return response | |
| def poll_structure(structure_run_id: str, headers: dict): | |
| response = poll_for_events(0, structure_run_id, headers) | |
| events = response.json()["events"] | |
| offset = response.json()["next_offset"] | |
| not_finished = True | |
| output = "" | |
| while not_finished: | |
| time.sleep(0.5) | |
| for event in events: | |
| if event["type"] == "FinishStructureRunEvent": | |
| not_finished = False | |
| output = dict(event["payload"]) | |
| break | |
| response = response = poll_for_events(offset, structure_run_id, headers) | |
| response.raise_for_status() | |
| events = response.json()["events"] | |
| offset = response.json()["next_offset"] | |
| return output | |
| with gr.Blocks() as demo: | |
| knowledge_bases = gr.CheckboxGroup( | |
| label="Select Knowledge Bases", | |
| choices=["skills", "demographics", "linked_in", "showreels"], | |
| ) | |
| chatbot = gr.ChatInterface( | |
| fn=send_message_call, | |
| chatbot=gr.Chatbot(height=300), | |
| additional_inputs=knowledge_bases, | |
| ) | |
| demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD"))) | |
| # demo.launch() | |
| # Set it back to empty when a session is done | |
| # Is there a better way? | |
| threads = {} | |