| | |
| | import asyncio |
| | import logging |
| | import os |
| | import threading |
| | from typing import Dict |
| | import uuid |
| | from fastapi.encoders import jsonable_encoder |
| | import numpy as np |
| | import pandas as pd |
| | from pandasai import SmartDataframe |
| | from langchain_groq.chat_models import ChatGroq |
| | from dotenv import load_dotenv |
| | from pydantic import BaseModel |
| | from cerebras_openrouter_chart_generator import generate_cerebras_chart, generate_openrouter_chart |
| | from csv_service import clean_data, extract_chart_filenames |
| | from langchain_groq import ChatGroq |
| | import pandas as pd |
| | from langchain_experimental.tools import PythonAstREPLTool |
| | from langchain_experimental.agents import create_pandas_dataframe_agent |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import matplotlib |
| | import seaborn as sns |
| | from supabase_service import upload_file_to_supabase |
| | from util_service import _prompt_generator, process_answer |
| | import matplotlib |
| | matplotlib.use('Agg') |
| |
|
| |
|
| | load_dotenv() |
| |
|
| | image_file_path = os.getenv("IMAGE_FILE_PATH") |
| | image_not_found = os.getenv("IMAGE_NOT_FOUND") |
| | allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") |
| |
|
| |
|
| | |
| | groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") |
| | model_name = os.getenv("GROQ_LLM_MODEL") |
| |
|
| | class CsvUrlRequest(BaseModel): |
| | csv_url: str |
| | |
| | class ImageRequest(BaseModel): |
| | image_path: str |
| | |
| | class CsvCommonHeadersRequest(BaseModel): |
| | file_urls: list[str] |
| | |
| | class CsvsMergeRequest(BaseModel): |
| | file_urls: list[str] |
| | merge_type: str |
| | common_columns_name: list[str] |
| |
|
| | |
| | current_groq_key_index = 0 |
| | current_groq_key_lock = threading.Lock() |
| |
|
| | |
| | current_langchain_key_index = 0 |
| | current_langchain_key_lock = threading.Lock() |
| |
|
| |
|
| | |
| | def handle_out_of_range_float(value): |
| | if isinstance(value, float): |
| | if np.isnan(value): |
| | return None |
| | elif np.isinf(value): |
| | return "Infinity" |
| | return value |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | def groq_chat(csv_url: str, question: str): |
| | global current_groq_key_index, current_groq_key_lock |
| |
|
| | while True: |
| | with current_groq_key_lock: |
| | if current_groq_key_index >= len(groq_api_keys): |
| | return {"error": "All API keys exhausted."} |
| | current_api_key = groq_api_keys[current_groq_key_index] |
| |
|
| | try: |
| | |
| | cache_db_path = "/workspace/cache/cache_db_0.11.db" |
| | if os.path.exists(cache_db_path): |
| | try: |
| | os.remove(cache_db_path) |
| | except Exception as e: |
| | logger.info(f"Error deleting cache DB file: {e}") |
| |
|
| | data = clean_data(csv_url) |
| | llm = ChatGroq(model=model_name, api_key=current_api_key) |
| | |
| | chart_filename = f"chart_{uuid.uuid4()}.png" |
| | chart_path = os.path.join("generated_charts", chart_filename) |
| | |
| | |
| | df = SmartDataframe( |
| | data, |
| | config={ |
| | 'llm': llm, |
| | 'save_charts': True, |
| | 'open_charts': False, |
| | 'save_charts_path': os.path.dirname(chart_path), |
| | 'custom_chart_filename': chart_filename |
| | } |
| | ) |
| | |
| | answer = df.chat(question) |
| |
|
| | |
| | if isinstance(answer, pd.DataFrame): |
| | processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") |
| | elif isinstance(answer, pd.Series): |
| | processed = answer.apply(handle_out_of_range_float).to_dict() |
| | elif isinstance(answer, list): |
| | processed = [handle_out_of_range_float(item) for item in answer] |
| | elif isinstance(answer, dict): |
| | processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} |
| | else: |
| | processed = {"answer": str(handle_out_of_range_float(answer))} |
| |
|
| | return processed |
| |
|
| | except Exception as e: |
| | error_message = str(e) |
| | if "429" in error_message: |
| | with current_groq_key_lock: |
| | current_groq_key_index += 1 |
| | if current_groq_key_index >= len(groq_api_keys): |
| | logger.info("All API keys exhausted.") |
| | return None |
| | else: |
| | logger.info(f"Error with API key index {current_groq_key_index}: {error_message}") |
| | return None |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): |
| | global current_langchain_key_index, current_langchain_key_lock |
| |
|
| | data = clean_data(csv_url) |
| | attempts = 0 |
| |
|
| | while attempts < len(groq_api_keys): |
| | with current_langchain_key_lock: |
| | if current_langchain_key_index >= len(groq_api_keys): |
| | current_langchain_key_index = 0 |
| | api_key = groq_api_keys[current_langchain_key_index] |
| | current_key = current_langchain_key_index |
| | current_langchain_key_index += 1 |
| | attempts += 1 |
| |
|
| | try: |
| | llm = ChatGroq(model=model_name, api_key=api_key) |
| | tool = PythonAstREPLTool(locals={ |
| | "df": data, |
| | "pd": pd, |
| | "np": np, |
| | "plt": plt, |
| | "sns": sns, |
| | "matplotlib": matplotlib |
| | }) |
| |
|
| | agent = create_pandas_dataframe_agent( |
| | llm, |
| | data, |
| | agent_type="tool-calling", |
| | verbose=True, |
| | allow_dangerous_code=True, |
| | extra_tools=[tool], |
| | return_intermediate_steps=True |
| | ) |
| |
|
| | prompt = _prompt_generator(question, chart_required, csv_url) |
| | result = agent.invoke({"input": prompt}) |
| | return result.get("output") |
| |
|
| | except Exception as e: |
| | logger.info(f"Error with key index {current_key}: {str(e)}") |
| |
|
| | |
| | logger.info("All API keys have been exhausted.") |
| | return None |
| |
|
| |
|
| | def handle_out_of_range_float(value): |
| | if isinstance(value, float): |
| | if np.isnan(value): |
| | return None |
| | elif np.isinf(value): |
| | return "Infinity" |
| | return value |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| | instructions = """ |
| | |
| | - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). |
| | - For multiple charts, arrange them in a grid format (2x2, 3x3, etc.) |
| | - Use colorblind-friendly palette |
| | - Read above instructions and follow them. |
| | |
| | """ |
| |
|
| | |
| | current_groq_chart_key_index = 0 |
| | current_groq_chart_lock = threading.Lock() |
| |
|
| | current_langchain_chart_key_index = 0 |
| | current_langchain_chart_lock = threading.Lock() |
| |
|
| | def model(): |
| | global current_groq_chart_key_index, current_groq_chart_lock |
| | with current_groq_chart_lock: |
| | if current_groq_chart_key_index >= len(groq_api_keys): |
| | raise Exception("All API keys exhausted for chart generation") |
| | api_key = groq_api_keys[current_groq_chart_key_index] |
| | return ChatGroq(model=model_name, api_key=api_key) |
| |
|
| | def groq_chart(csv_url: str, question: str): |
| | global current_groq_chart_key_index, current_groq_chart_lock |
| | |
| | for attempt in range(len(groq_api_keys)): |
| | try: |
| | |
| | cache_db_path = "/workspace/cache/cache_db_0.11.db" |
| | if os.path.exists(cache_db_path): |
| | try: |
| | os.remove(cache_db_path) |
| | except Exception as e: |
| | logger.info(f"Cache cleanup error: {e}") |
| |
|
| | data = clean_data(csv_url) |
| | with current_groq_chart_lock: |
| | current_api_key = groq_api_keys[current_groq_chart_key_index] |
| | |
| | llm = ChatGroq(model=model_name, api_key=current_api_key) |
| | |
| | |
| | chart_filename = f"chart_{uuid.uuid4()}.png" |
| | chart_path = os.path.join("generated_charts", chart_filename) |
| | |
| | |
| | df = SmartDataframe( |
| | data, |
| | config={ |
| | 'llm': llm, |
| | 'save_charts': True, |
| | 'open_charts': False, |
| | 'save_charts_path': os.path.dirname(chart_path), |
| | 'custom_chart_filename': chart_filename |
| | } |
| | ) |
| | |
| | answer = df.chat(question + instructions) |
| | |
| | if process_answer(answer): |
| | return "Chart not generated" |
| | return answer |
| |
|
| | except Exception as e: |
| | error = str(e) |
| | if "429" in error: |
| | with current_groq_chart_lock: |
| | current_groq_chart_key_index = (current_groq_chart_key_index + 1) % len(groq_api_keys) |
| | else: |
| | logger.info(f"Chart generation error: {error}") |
| | return {"error": error} |
| | |
| | logger.info("All API keys exhausted for chart generation") |
| | return None |
| |
|
| |
|
| |
|
| | def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): |
| | global current_langchain_chart_key_index, current_langchain_chart_lock |
| | |
| | data = clean_data(csv_url) |
| |
|
| | for attempt in range(len(groq_api_keys)): |
| | try: |
| | with current_langchain_chart_lock: |
| | api_key = groq_api_keys[current_langchain_chart_key_index] |
| | current_key = current_langchain_chart_key_index |
| | current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) |
| |
|
| | llm = ChatGroq(model=model_name, api_key=api_key) |
| | tool = PythonAstREPLTool(locals={ |
| | "df": data, |
| | "pd": pd, |
| | "np": np, |
| | "plt": plt, |
| | "sns": sns, |
| | "matplotlib": matplotlib, |
| | "uuid": uuid |
| | }) |
| |
|
| | agent = create_pandas_dataframe_agent( |
| | llm, |
| | data, |
| | agent_type="tool-calling", |
| | verbose=True, |
| | allow_dangerous_code=True, |
| | extra_tools=[tool], |
| | return_intermediate_steps=True |
| | ) |
| |
|
| | result = agent.invoke({"input": _prompt_generator(question, True, csv_url)}) |
| | output = result.get("output", "") |
| |
|
| | |
| | chart_files = extract_chart_filenames(output) |
| | if len(chart_files) > 0: |
| | return chart_files |
| |
|
| | if attempt < len(groq_api_keys) - 1: |
| | logger.info(f"Langchain chart error (key {current_key}): {output}") |
| |
|
| | except Exception as e: |
| | logger.info(f"Langchain chart error (key {current_key}): {str(e)}") |
| | |
| | logger.info("All API keys exhausted for chart generation") |
| | return None |
| |
|
| |
|
| |
|
| |
|
| | |
| | async def csv_chat(csv_url: str, query: str) -> dict: |
| | """ |
| | Generate a response based on the provided CSV URL and query. |
| | Prioritizes LangChain-Groq first. |
| | """ |
| | updated_query = f"{query} and Do not show any charts or graphs." |
| | error_messages = [] |
| |
|
| | |
| | try: |
| | lang_groq_answer = await asyncio.to_thread( |
| | langchain_csv_chat, csv_url, updated_query, False |
| | ) |
| | logger.info(f"LangChain-Groq answer: {lang_groq_answer}") |
| |
|
| | if lang_groq_answer and is_valid_response(lang_groq_answer): |
| | return {"answer": jsonable_encoder(lang_groq_answer)} |
| |
|
| | error_messages.append("LangChain-Groq response not usable") |
| | except Exception as lang_groq_error: |
| | error_messages.append(f"LangChain-Groq error: {str(lang_groq_error)}") |
| | logger.error(f"LangChain-Groq failed: {str(lang_groq_error)}") |
| | return {"error": "Could not process the request at this time."} |
| |
|
| | |
| |
|
| |
|
| | def is_valid_response(response) -> bool: |
| | """Check if the response is valid and not empty.""" |
| | if not response: |
| | return False |
| | if isinstance(response, str) and response.strip() == "": |
| | return False |
| | if isinstance(response, dict) and not response.get("answer"): |
| | return False |
| | return True |
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | async def csv_chart(csv_url: str, query: str, chat_id: str) -> Dict[str, str]: |
| | """ |
| | Generate a chart based on the provided CSV URL and query. |
| | Strategy: |
| | 1. Try Cerebras Agent. |
| | 2. If failed/empty, try OpenRouter Agent. |
| | 3. Upload and return. |
| | """ |
| | logger.info(f"Received csv_chart request. Chat ID: {chat_id}") |
| | error_messages = [] |
| | |
| | async def upload_and_return(image_path: str) -> Dict[str, str]: |
| | """Handle image upload and return public URL""" |
| | unique_name = f'{uuid.uuid4()}.png' |
| | |
| | public_url = await upload_file_to_supabase(image_path, unique_name, chat_id) |
| | |
| | logger.info(f"Uploaded chart to: {public_url}") |
| | try: |
| | os.remove(image_path) |
| | except OSError as e: |
| | logger.warning(f"Could not delete temp file {image_path}: {str(e)}") |
| | return {"image_url": public_url} |
| |
|
| | |
| | try: |
| | logger.info("Attempting with Cerebras...") |
| | cerebras_path = await asyncio.to_thread( |
| | generate_cerebras_chart, csv_url, query, 3 |
| | ) |
| | |
| | if cerebras_path: |
| | return await upload_and_return(cerebras_path) |
| | |
| | error_messages.append("Cerebras returned no valid image.") |
| | logger.warning("Cerebras failed. Switching to fallback.") |
| | |
| | except Exception as e: |
| | msg = f"Cerebras critical error: {str(e)}" |
| | error_messages.append(msg) |
| | logger.error(msg) |
| |
|
| | |
| | try: |
| | logger.info("Attempting with OpenRouter...") |
| | openrouter_path = await asyncio.to_thread( |
| | generate_openrouter_chart, csv_url, query, 3 |
| | ) |
| | |
| | if openrouter_path: |
| | return await upload_and_return(openrouter_path) |
| | |
| | error_messages.append("OpenRouter fallback returned no valid image.") |
| | |
| | except Exception as e: |
| | msg = f"OpenRouter critical error: {str(e)}" |
| | error_messages.append(msg) |
| | logger.error(msg) |
| |
|
| | |
| | logger.error(f"All chart generation providers failed. Errors: {'; '.join(error_messages)}") |
| | return {"error": "Could not generate chart. Both primary and fallback agents failed."} |