added together ai agent
Browse files- controller.py +71 -68
- orchestrator_functions.py +11 -10
controller.py
CHANGED
|
@@ -343,25 +343,25 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
| 343 |
generate_report = request.get("generate_report")
|
| 344 |
chat_id = request.get("chat_id")
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
|
| 366 |
# Process with groq_chat first
|
| 367 |
# groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
|
@@ -369,18 +369,19 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
| 369 |
|
| 370 |
result = await query_csv_agent(decoded_url, query, chat_id)
|
| 371 |
logger.info("together ai csv answer == >", result)
|
| 372 |
-
|
|
|
|
| 373 |
|
| 374 |
# if process_answer(groq_answer) == "Empty response received.":
|
| 375 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 376 |
|
| 377 |
# if process_answer(groq_answer):
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
|
| 385 |
# return {"answer": jsonable_encoder(groq_answer)}
|
| 386 |
|
|
@@ -589,35 +590,35 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
| 589 |
generate_report = request.get("generate_report", False)
|
| 590 |
chat_id = request.get("chat_id", "")
|
| 591 |
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
|
| 597 |
loop = asyncio.get_running_loop()
|
| 598 |
-
#
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
|
| 613 |
-
#
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
|
| 619 |
-
|
| 620 |
-
|
| 621 |
|
| 622 |
# Next, try the groq-based method
|
| 623 |
# groq_result = await loop.run_in_executor(
|
|
@@ -632,28 +633,30 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
|
|
| 632 |
# os.remove(groq_result)
|
| 633 |
# return {"image_url": image_public_url}
|
| 634 |
# return FileResponse(groq_result, media_type="image/png")
|
| 635 |
-
|
|
|
|
| 636 |
result = await query_csv_agent(csv_url, query, chat_id)
|
| 637 |
logger.info("together ai result ==>", result)
|
| 638 |
-
|
|
|
|
| 639 |
|
| 640 |
# Fallback: try langchain-based again
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
|
| 658 |
except Exception as e:
|
| 659 |
logger.error(f"Critical chart error: {str(e)}")
|
|
|
|
| 343 |
generate_report = request.get("generate_report")
|
| 344 |
chat_id = request.get("chat_id")
|
| 345 |
|
| 346 |
+
if generate_report is True:
|
| 347 |
+
report_files = await generate_csv_report(csv_url, query, chat_id)
|
| 348 |
+
if report_files is not None:
|
| 349 |
+
return {"answer": jsonable_encoder(report_files)}
|
| 350 |
+
|
| 351 |
+
if if_initial_chat_question(query):
|
| 352 |
+
answer = await asyncio.to_thread(
|
| 353 |
+
langchain_csv_chat, decoded_url, query, False
|
| 354 |
+
)
|
| 355 |
+
logger.info("langchain_answer:", answer)
|
| 356 |
+
return {"answer": jsonable_encoder(answer)}
|
| 357 |
|
| 358 |
+
# Orchestrate the execution
|
| 359 |
+
if detailed_answer is True:
|
| 360 |
+
orchestrator_answer = await asyncio.to_thread(
|
| 361 |
+
csv_orchestrator_chat, decoded_url, query, conversation_history, chat_id
|
| 362 |
+
)
|
| 363 |
+
if orchestrator_answer is not None:
|
| 364 |
+
return {"answer": jsonable_encoder(orchestrator_answer)}
|
| 365 |
|
| 366 |
# Process with groq_chat first
|
| 367 |
# groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
|
|
|
|
| 369 |
|
| 370 |
result = await query_csv_agent(decoded_url, query, chat_id)
|
| 371 |
logger.info("together ai csv answer == >", result)
|
| 372 |
+
if result is not None or result == "":
|
| 373 |
+
return {"answer": result}
|
| 374 |
|
| 375 |
# if process_answer(groq_answer) == "Empty response received.":
|
| 376 |
# return {"answer": "Sorry, I couldn't find relevant data..."}
|
| 377 |
|
| 378 |
# if process_answer(groq_answer):
|
| 379 |
+
lang_answer = await asyncio.to_thread(
|
| 380 |
+
langchain_csv_chat, decoded_url, query, False
|
| 381 |
+
)
|
| 382 |
+
if process_answer(lang_answer):
|
| 383 |
+
return {"answer": "error"}
|
| 384 |
+
return {"answer": jsonable_encoder(lang_answer)}
|
| 385 |
|
| 386 |
# return {"answer": jsonable_encoder(groq_answer)}
|
| 387 |
|
|
|
|
| 590 |
generate_report = request.get("generate_report", False)
|
| 591 |
chat_id = request.get("chat_id", "")
|
| 592 |
|
| 593 |
+
if generate_report is True:
|
| 594 |
+
report_files = await generate_csv_report(csv_url, query, chat_id)
|
| 595 |
+
if report_files is not None:
|
| 596 |
+
return {"orchestrator_response": jsonable_encoder(report_files)}
|
| 597 |
|
| 598 |
loop = asyncio.get_running_loop()
|
| 599 |
+
# First, try the langchain-based method if the question qualifies
|
| 600 |
+
if if_initial_chart_question(query):
|
| 601 |
+
langchain_result = await loop.run_in_executor(
|
| 602 |
+
process_executor, langchain_csv_chart, csv_url, query, True
|
| 603 |
+
)
|
| 604 |
+
logger.info("Langchain chart result:", langchain_result)
|
| 605 |
+
if isinstance(langchain_result, list) and len(langchain_result) > 0:
|
| 606 |
+
unique_file_name =f'{str(uuid.uuid4())}.png'
|
| 607 |
+
logger.info("Uploading the chart to supabase...")
|
| 608 |
+
image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name, chat_id=chat_id)
|
| 609 |
+
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
| 610 |
+
os.remove(langchain_result[0])
|
| 611 |
+
return {"image_url": image_public_url}
|
| 612 |
+
# return FileResponse(langchain_result[0], media_type="image/png")
|
| 613 |
|
| 614 |
+
# Use orchestrator to handle the user's chart query first
|
| 615 |
+
if detailed_answer is True:
|
| 616 |
+
orchestrator_answer = await asyncio.to_thread(
|
| 617 |
+
csv_orchestrator_chat, csv_url, query, conversation_history, chat_id
|
| 618 |
+
)
|
| 619 |
|
| 620 |
+
if orchestrator_answer is not None:
|
| 621 |
+
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)}
|
| 622 |
|
| 623 |
# Next, try the groq-based method
|
| 624 |
# groq_result = await loop.run_in_executor(
|
|
|
|
| 633 |
# os.remove(groq_result)
|
| 634 |
# return {"image_url": image_public_url}
|
| 635 |
# return FileResponse(groq_result, media_type="image/png")
|
| 636 |
+
|
| 637 |
+
logger.info("Trying together ai llama...")
|
| 638 |
result = await query_csv_agent(csv_url, query, chat_id)
|
| 639 |
logger.info("together ai result ==>", result)
|
| 640 |
+
if result is not None and result != "":
|
| 641 |
+
return {"orchestrator_response": jsonable_encoder(result)}
|
| 642 |
|
| 643 |
# Fallback: try langchain-based again
|
| 644 |
+
logger.error("Together ai llama response failed, trying langchain groq....")
|
| 645 |
+
langchain_paths = await loop.run_in_executor(
|
| 646 |
+
process_executor, langchain_csv_chart, csv_url, query, True
|
| 647 |
+
)
|
| 648 |
+
logger.info("Fallback langchain chart result:", langchain_paths)
|
| 649 |
+
if isinstance(langchain_paths, list) and len(langchain_paths) > 0:
|
| 650 |
+
unique_file_name =f'{str(uuid.uuid4())}.png'
|
| 651 |
+
logger.info("Uploading the chart to supabase...")
|
| 652 |
+
image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name, chat_id=chat_id)
|
| 653 |
+
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url})
|
| 654 |
+
os.remove(langchain_paths[0])
|
| 655 |
+
return {"image_url": image_public_url}
|
| 656 |
+
return FileResponse(langchain_paths[0], media_type="image/png")
|
| 657 |
+
else:
|
| 658 |
+
logger.error("All chart generation methods failed")
|
| 659 |
+
return {"answer": "error"}
|
| 660 |
|
| 661 |
except Exception as e:
|
| 662 |
logger.error(f"Critical chart error: {str(e)}")
|
orchestrator_functions.py
CHANGED
|
@@ -572,7 +572,7 @@ async def csv_chat(csv_url: str, query: str):
|
|
| 572 |
async def csv_chart(csv_url: str, query: str, chat_id: str):
|
| 573 |
"""
|
| 574 |
Generate a chart based on the provided CSV URL and query.
|
| 575 |
-
Prioritizes
|
| 576 |
|
| 577 |
Parameters:
|
| 578 |
- csv_url (str): The URL of the CSV file.
|
|
@@ -599,18 +599,19 @@ async def csv_chart(csv_url: str, query: str, chat_id: str):
|
|
| 599 |
return {"image_url": public_url}
|
| 600 |
|
| 601 |
try:
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
|
|
|
| 606 |
|
| 607 |
-
|
| 608 |
-
|
| 609 |
|
| 610 |
-
|
| 611 |
|
| 612 |
-
except Exception as openai_error:
|
| 613 |
-
|
| 614 |
# --- 2. Second Attempt: Raw Groq ---
|
| 615 |
try:
|
| 616 |
groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
|
|
|
|
| 572 |
async def csv_chart(csv_url: str, query: str, chat_id: str):
|
| 573 |
"""
|
| 574 |
Generate a chart based on the provided CSV URL and query.
|
| 575 |
+
Prioritizes PandasAI Groq, then LangChain Gemini, and finally LangChain Groq as fallback.
|
| 576 |
|
| 577 |
Parameters:
|
| 578 |
- csv_url (str): The URL of the CSV file.
|
|
|
|
| 599 |
return {"image_url": public_url}
|
| 600 |
|
| 601 |
try:
|
| 602 |
+
# Commented out for now because aiml api is not working
|
| 603 |
+
# try:
|
| 604 |
+
# # --- 1. First Attempt: OpenAI ---
|
| 605 |
+
# openai_result = await asyncio.to_thread(openai_chart, csv_url, query)
|
| 606 |
+
# logger.info(f"OpenAI chart result:", openai_result)
|
| 607 |
|
| 608 |
+
# if openai_result and openai_result != 'Chart not generated':
|
| 609 |
+
# return await upload_and_return(openai_result, chat_id)
|
| 610 |
|
| 611 |
+
# raise Exception("OpenAI failed to generate chart")
|
| 612 |
|
| 613 |
+
# except Exception as openai_error:
|
| 614 |
+
# logger.info(f"OpenAI failed ({str(openai_error)}), trying raw Groq...")
|
| 615 |
# --- 2. Second Attempt: Raw Groq ---
|
| 616 |
try:
|
| 617 |
groq_result = await asyncio.to_thread(groq_chart, csv_url, query)
|