added gemini too
Browse files- controller.py +6 -2
- orchestrator_agent.py +4 -3
controller.py
CHANGED
|
@@ -301,6 +301,8 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
| 301 |
csv_url = request.get("csv_url")
|
| 302 |
decoded_url = unquote(csv_url)
|
| 303 |
detailed_answer = request.get("detailed_answer")
|
|
|
|
|
|
|
| 304 |
|
| 305 |
if if_initial_chat_question(query):
|
| 306 |
answer = await asyncio.to_thread(
|
|
@@ -312,7 +314,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
| 312 |
# Orchestrate the execution
|
| 313 |
if detailed_answer is True:
|
| 314 |
orchestrator_answer = await asyncio.to_thread(
|
| 315 |
-
csv_orchestrator_chat, decoded_url, query
|
| 316 |
)
|
| 317 |
if orchestrator_answer is not None:
|
| 318 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
|
@@ -798,6 +800,8 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
| 798 |
query = request.get("query", "")
|
| 799 |
csv_url = unquote(request.get("csv_url", ""))
|
| 800 |
detailed_answer = request.get("detailed_answer", False)
|
|
|
|
|
|
|
| 801 |
|
| 802 |
loop = asyncio.get_running_loop()
|
| 803 |
# First, try the langchain-based method if the question qualifies
|
|
@@ -817,7 +821,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
| 817 |
# Use orchestrator to handle the user's chart query first
|
| 818 |
if detailed_answer is True:
|
| 819 |
orchestrator_answer = await asyncio.to_thread(
|
| 820 |
-
csv_orchestrator_chat, csv_url, query
|
| 821 |
)
|
| 822 |
|
| 823 |
if orchestrator_answer is not None:
|
|
|
|
| 301 |
csv_url = request.get("csv_url")
|
| 302 |
decoded_url = unquote(csv_url)
|
| 303 |
detailed_answer = request.get("detailed_answer")
|
| 304 |
+
conversation_history = request.get("conversation_history", [])
|
| 305 |
+
return {"answer": jsonable_encoder(conversation_history)}
|
| 306 |
|
| 307 |
if if_initial_chat_question(query):
|
| 308 |
answer = await asyncio.to_thread(
|
|
|
|
| 314 |
# Orchestrate the execution
|
| 315 |
if detailed_answer is True:
|
| 316 |
orchestrator_answer = await asyncio.to_thread(
|
| 317 |
+
csv_orchestrator_chat, decoded_url, query, conversation_history
|
| 318 |
)
|
| 319 |
if orchestrator_answer is not None:
|
| 320 |
return {"answer": jsonable_encoder(orchestrator_answer)}
|
|
|
|
| 800 |
query = request.get("query", "")
|
| 801 |
csv_url = unquote(request.get("csv_url", ""))
|
| 802 |
detailed_answer = request.get("detailed_answer", False)
|
| 803 |
+
conversation_history = request.get("conversation_history", [])
|
| 804 |
+
return {"orchestrator_response": jsonable_encoder(conversation_history)}
|
| 805 |
|
| 806 |
loop = asyncio.get_running_loop()
|
| 807 |
# First, try the langchain-based method if the question qualifies
|
|
|
|
| 821 |
# Use orchestrator to handle the user's chart query first
|
| 822 |
if detailed_answer is True:
|
| 823 |
orchestrator_answer = await asyncio.to_thread(
|
| 824 |
+
csv_orchestrator_chat, csv_url, query, conversation_history
|
| 825 |
)
|
| 826 |
|
| 827 |
if orchestrator_answer is not None:
|
orchestrator_agent.py
CHANGED
|
@@ -90,7 +90,7 @@ async def generate_chart(csv_url: str, user_questions: List[str]) -> Any:
|
|
| 90 |
return charts
|
| 91 |
|
| 92 |
# Function to create an agent with a specific CSV URL
|
| 93 |
-
def create_agent(csv_url: str, api_key: str) -> Agent:
|
| 94 |
csv_metadata = get_csv_basic_info(csv_url)
|
| 95 |
|
| 96 |
system_prompt = f"""
|
|
@@ -130,6 +130,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
|
|
| 130 |
## Current Context:
|
| 131 |
- Working with CSV_URL: {csv_url}
|
| 132 |
- Dataset overview: {csv_metadata}
|
|
|
|
| 133 |
- Output format: Markdown compatible
|
| 134 |
|
| 135 |
## Response Template:
|
|
@@ -192,7 +193,7 @@ def create_agent(csv_url: str, api_key: str) -> Agent:
|
|
| 192 |
system_prompt=system_prompt
|
| 193 |
)
|
| 194 |
|
| 195 |
-
def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
|
| 196 |
print("CSV URL:", csv_url)
|
| 197 |
print("User questions:", user_question)
|
| 198 |
|
|
@@ -200,7 +201,7 @@ def csv_orchestrator_chat(csv_url: str, user_question: str) -> str:
|
|
| 200 |
for api_key in GEMINI_API_KEYS:
|
| 201 |
try:
|
| 202 |
print(f"Attempting with API key: {api_key}")
|
| 203 |
-
agent = create_agent(csv_url, api_key)
|
| 204 |
result = agent.run_sync(user_question)
|
| 205 |
print("Orchestrator Result:", result.data)
|
| 206 |
return result.data
|
|
|
|
| 90 |
return charts
|
| 91 |
|
| 92 |
# Function to create an agent with a specific CSV URL
|
| 93 |
+
def create_agent(csv_url: str, api_key: str, conversation_history: List) -> Agent:
|
| 94 |
csv_metadata = get_csv_basic_info(csv_url)
|
| 95 |
|
| 96 |
system_prompt = f"""
|
|
|
|
| 130 |
## Current Context:
|
| 131 |
- Working with CSV_URL: {csv_url}
|
| 132 |
- Dataset overview: {csv_metadata}
|
| 133 |
+
- Your conversation history: {conversation_history}
|
| 134 |
- Output format: Markdown compatible
|
| 135 |
|
| 136 |
## Response Template:
|
|
|
|
| 193 |
system_prompt=system_prompt
|
| 194 |
)
|
| 195 |
|
| 196 |
+
def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List) -> str:
|
| 197 |
print("CSV URL:", csv_url)
|
| 198 |
print("User questions:", user_question)
|
| 199 |
|
|
|
|
| 201 |
for api_key in GEMINI_API_KEYS:
|
| 202 |
try:
|
| 203 |
print(f"Attempting with API key: {api_key}")
|
| 204 |
+
agent = create_agent(csv_url, api_key, conversation_history)
|
| 205 |
result = agent.run_sync(user_question)
|
| 206 |
print("Orchestrator Result:", result.data)
|
| 207 |
return result.data
|