# Import necessary modules from concurrent.futures import ProcessPoolExecutor import logging import os import asyncio import threading import traceback import uuid from fastapi import FastAPI, HTTPException, Header from fastapi.encoders import jsonable_encoder from typing import Dict, List, Optional from fastapi.responses import FileResponse import numpy as np import pandas as pd from pandasai import SmartDataframe from langchain_groq.chat_models import ChatGroq from dotenv import load_dotenv from pydantic import BaseModel, Field from cerebras_report_generator import generate_csv_report_cerebras from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info from urllib.parse import unquote from langchain_groq import ChatGroq import pandas as pd from langchain_experimental.tools import PythonAstREPLTool from langchain_experimental.agents import create_pandas_dataframe_agent import numpy as np import matplotlib.pyplot as plt import matplotlib import seaborn as sns from gemini_report_generator import generate_csv_report_gemini from groq_report_generator import generate_csv_report_groq from intitial_q_handler import if_initial_chart_question, if_initial_chat_question from orc_agent_main_cerebras import csv_orchestrator_chat_cerebras from orchestrator_agent import csv_orchestrator_chat_gemini from python_code_executor_service import CsvChatResult, PythonExecutor from supabase_service import upload_file_to_supabase from cerebras_csv_agent import query_csv_agent_cerebras from util_service import _prompt_generator, process_answer from fastapi.middleware.cors import CORSMiddleware import matplotlib matplotlib.use('Agg') # Initialize FastAPI app app = FastAPI() # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize the ProcessPoolExecutor max_cpus = os.cpu_count() logger.info(f"Max CPUs: {max_cpus}") # Ensure the cache directory exists os.makedirs("/app/cache", exist_ok=True) os.makedirs("/app", exist_ok=True) open("/app/pandasai.log", "a").close() # Create the file if it doesn't exist # Ensure the generated_charts directory exists os.makedirs("/app/generated_charts", exist_ok=True) load_dotenv() image_file_path = os.getenv("IMAGE_FILE_PATH") image_not_found = os.getenv("IMAGE_NOT_FOUND") allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") app.add_middleware( CORSMiddleware, allow_origins=allowed_hosts, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load environment variables groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") model_name = os.getenv("GROQ_LLAMA_MODEL") class CsvUrlRequest(BaseModel): csv_url: str class ImageRequest(BaseModel): image_path: str chat_id: str class FileProps(BaseModel): fileName: str filePath: str fileType: str # 'csv' | 'image' class Files(BaseModel): csv_files: List[FileProps] image_files: List[FileProps] class FileBoxProps(BaseModel): files: Files # Thread-safe key management for groq_chat current_groq_key_index = 0 current_groq_key_lock = threading.Lock() # Thread-safe key management for langchain_csv_chat current_langchain_key_index = 0 current_langchain_key_lock = threading.Lock() # ROOT CHECK @app.get("/") async def root(): return {"message": "CSV Chat Service-1 server is running"} # PING CHECK @app.get("/ping") async def root(): return {"message": "Pong !!"} # BASIC KNOWLEDGE BASED ON CSV # Remove trailing slash from the URL otherwise it will redirect to GET method @app.post("/api/basic_csv_data") async def basic_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) logger.info(f"Fetching CSV data from URL: {decoded_url}") # csv_data = await get_csv_basic_info(decoded_url) # Run the synchronous function in a thread pool executor loop = asyncio.get_running_loop() csv_data = await loop.run_in_executor( process_executor, get_csv_basic_info, decoded_url ) logger.info(f"CSV data fetched successfully: {csv_data}") return {"data": csv_data} except Exception as e: logger.error(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # GET THE CHART FROM A SPECIFIC FILE PATH @app.post("/api/get-chart") async def get_image(request: ImageRequest, authorization: str = Header(None)): if not authorization: raise HTTPException(status_code=401, detail="Authorization header missing") if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid authorization header format") token = authorization.split(" ")[1] if not token: raise HTTPException(status_code=401, detail="Token missing") if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid token") try: logger.info("Groq Chat created a chat for the user query...") image_file_path = request.image_path unique_file_name =f'{str(uuid.uuid4())}.png' logger.info("Uploading the chart to supabase...") image_public_url = await upload_file_to_supabase(f"{image_file_path}", unique_file_name, chat_id=request.chat_id) logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) os.remove(image_file_path) return {"image_url": image_public_url} # return FileResponse(image_file_path, media_type="image/png") except Exception as e: logger.error(f"Error: {e}") return {"answer": "error"} # GET CSV DATA FOR GENERATING THE TABLE @app.post("/api/csv_data") async def get_csv_data(request: CsvUrlRequest): try: decoded_url = unquote(request.csv_url) logger.info(f"Fetching CSV data from URL: {decoded_url}") # csv_data = await generate_csv_data(decoded_url) loop = asyncio.get_running_loop() csv_data = await loop.run_in_executor( process_executor, generate_csv_data, decoded_url ) return csv_data except Exception as e: logger.error(f"Error while fetching CSV data: {e}") raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") # EXECUTE THE PYTHON CODE class ExecutionRequest(BaseModel): chat_id: str = Field(..., alias="chat_id") csv_url: str = Field(..., alias="csv_url") codeExecutionPayload: CsvChatResult @app.post("/api/code_execution_csv") async def code_execution_csv( request_data: ExecutionRequest, # Change from ExecutionRequest to dict to see raw input authorization: Optional[str] = Header(None) ): # Auth check remains the same expected_token = os.environ.get("AUTH_TOKEN") if not authorization or not expected_token or authorization.replace("Bearer ", "") != expected_token: raise HTTPException(status_code=401, detail="Unauthorized") try: # First log the incoming request data logger.info("Incoming request data:", request_data) # Rest of your processing logic... decoded_url = unquote(request_data.csv_url) df = clean_data(decoded_url) executor = PythonExecutor(df) formatted_output = await executor.process_response(request_data.codeExecutionPayload, request_data.chat_id) return {"answer": formatted_output} except Exception as e: logger.info("Processing error:", str(e)) return {"error": "Failed to execute request", "message": str(e)} # CHAT CODING STARTS FROM HERE # Modified groq_chat function with thread-safe key rotation def groq_chat(csv_url: str, question: str): global current_groq_key_index, current_groq_key_lock while True: with current_groq_key_lock: if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} current_api_key = groq_api_keys[current_groq_key_index] try: data = clean_data(csv_url) llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename, # Unique filename 'enable_cache': False } ) answer = df.chat(question) # Process different response types if isinstance(answer, pd.DataFrame): processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") elif isinstance(answer, pd.Series): processed = answer.apply(handle_out_of_range_float).to_dict() elif isinstance(answer, list): processed = [handle_out_of_range_float(item) for item in answer] elif isinstance(answer, dict): processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} else: processed = {"answer": str(handle_out_of_range_float(answer))} return processed except Exception as e: error_message = str(e) if error_message != "": logger.warning("Rate limit exceeded. Switching to next API key.") with current_groq_key_lock: current_groq_key_index += 1 if current_groq_key_index >= len(groq_api_keys): return {"error": "All API keys exhausted."} else: logger.error("Error in groq_chat: %s", e) return {"error": error_message} # Modified langchain_csv_chat with thread-safe key rotation def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): global current_langchain_key_index, current_langchain_key_lock, current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) attempts = 0 while attempts < len(groq_api_keys): with current_langchain_key_lock: if current_langchain_key_index >= len(groq_api_keys): current_langchain_key_index = 0 api_key = groq_api_keys[current_langchain_key_index] current_langchain_key_index += 1 attempts += 1 try: llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) prompt = _prompt_generator(question, chart_required, csv_url) result = agent.invoke({"input": prompt}) return result.get("output") except Exception as e: error_message = str(e) if error_message != "": with current_langchain_chart_lock: current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}") else: logger.error(f"Error with API key {api_key}: {error_message}") return {"error": error_message} return {"error": "All API keys exhausted"} async def handle_detailed_answer(decoded_url, query, conversation_history, chat_id): """ Try CSV processing first with Cerebras orchestrator, then fallback to Gemini if needed. """ orchestrator_answer = None # Step 1: Try Cerebras try: logger.info("Processing detailed answer with Cerebras orchestrator...") orchestrator_answer = await asyncio.to_thread( csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id ) if orchestrator_answer is not None: logger.info(f"Cerebras answer successful: {str(orchestrator_answer)[:200]}...") return {"answer": jsonable_encoder(orchestrator_answer)} else: logger.warning("Cerebras orchestrator returned None") except Exception as e: logger.error(f"Cerebras orchestrator failed: {str(e)}") # Step 2: Fallback to Gemini try: logger.info("Falling back to Gemini orchestrator...") orchestrator_answer = await asyncio.to_thread( csv_orchestrator_chat_gemini, decoded_url, query, conversation_history, chat_id ) if orchestrator_answer is not None: logger.info(f"Gemini answer successful: {str(orchestrator_answer)[:200]}...") return {"answer": jsonable_encoder(orchestrator_answer)} else: logger.warning("Gemini orchestrator returned None") except Exception as e: logger.error(f"Gemini orchestrator failed: {str(e)}") # Step 3: Both failed logger.error("Both Cerebras and Gemini orchestrators failed or returned None") return {"answer": None} # Async endpoint with non-blocking execution @app.post("/api/csv-chat") async def csv_chat(request: Dict, authorization: str = Header(None)): # Authorization checks if not authorization or not authorization.startswith("Bearer "): logger.error("Authorization failed: Missing or invalid authorization header") raise HTTPException(status_code=401, detail="Invalid authorization") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): logger.error("Authorization failed: Invalid token") raise HTTPException(status_code=403, detail="Invalid token") logger.info("Authorization successful") try: # Extract request parameters query = request.get("query") csv_url = request.get("csv_url") decoded_url = unquote(csv_url) detailed_answer = request.get("detailed_answer") conversation_history = request.get("conversation_history", []) generate_report = request.get("generate_report") chat_id = request.get("chat_id") logger.info(f"Request parameters: query='{query[:100]}...', csv_url='{csv_url}', detailed_answer={detailed_answer}, generate_report={generate_report}, chat_id={chat_id}") # Handle report generation with Cerebras first, then Gemini fallback if generate_report is True: logger.info("Starting report generation process...") # Try Cerebras first for report generation logger.info("Attempting report generation with Cerebras...") try: report_files = await generate_csv_report_cerebras(csv_url, query, chat_id, conversation_history) if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): logger.info(f"Cerebras report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") return {"answer": jsonable_encoder(report_files)} else: logger.warning("Cerebras report generation returned empty or None result") except Exception as cerebras_error: logger.error(f"Cerebras report generation failed: {str(cerebras_error)}") # Fallback to Gemini for report generation logger.info("Falling back to Gemini for report generation...") try: report_files = await generate_csv_report_gemini(csv_url, query, chat_id, conversation_history) if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): logger.info(f"Gemini report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") return {"answer": jsonable_encoder(report_files)} else: logger.warning("Gemini report generation returned empty or None result") except Exception as gemini_error: logger.error(f"Gemini report generation failed: {str(gemini_error)}") logger.error("Both Cerebras and Gemini report generation failed") # Gemini failed, last resort Groq Report Generation logger.info("Attempting report generation with Groq as last resort...") try: report_files = await generate_csv_report_groq(csv_url, query, chat_id, conversation_history) if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): logger.info(f"Groq report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") return {"answer": jsonable_encoder(report_files)} else: logger.warning("Groq report generation returned empty or None result") except Exception as groq_error: logger.error(f"Groq report generation failed: {str(groq_error)}") logger.error("All report generation methods failed") # Handle initial chat questions with langchain if if_initial_chat_question(query): logger.info("Processing as initial chat question with langchain...") try: answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) logger.info(f"Langchain initial chat answer: {str(answer)[:200]}...") return {"answer": jsonable_encoder(answer)} except Exception as e: logger.error(f"Langchain initial chat failed: {str(e)}") # Handle detailed answers with orchestrator if detailed_answer is True: logger.info("Processing detailed answer with orchestrator...") return await handle_detailed_answer(decoded_url, query, conversation_history, chat_id) # Process with standard CSV agent (Cerebras) logger.info("Processing with standard CSV agent (Cerebras)...") try: result = await query_csv_agent_cerebras(decoded_url, query, chat_id) logger.info(f"Standard CSV agent (Cerebras) result: {str(result)[:200]}...") if result is not None and result != "": return {"answer": result} else: logger.warning("Standard CSV agent (Cerebras) returned empty or None result") except Exception as e: logger.error(f"Standard CSV agent (Cerebras) failed: {str(e)}") # Fallback to langchain logger.info("Falling back to langchain CSV chat...") try: lang_answer = await asyncio.to_thread( langchain_csv_chat, decoded_url, query, False ) logger.info(f"Langchain fallback result: {str(lang_answer)[:200]}...") if process_answer(lang_answer): logger.error("Langchain fallback produced error response") return {"answer": "error"} logger.info("Langchain fallback successful") return {"answer": jsonable_encoder(lang_answer)} except Exception as e: logger.error(f"Langchain fallback failed: {str(e)}") # If all methods fail logger.error("All processing methods failed") return {"answer": "error"} except Exception as e: logger.error(f"Critical error processing request: {str(e)}") logger.error(f"Error traceback: {traceback.format_exc()}") return {"answer": "error"} def handle_out_of_range_float(value): """Handle out of range float values for JSON serialization""" if isinstance(value, float): if np.isnan(value): logger.debug("Converting NaN to None") return None elif np.isinf(value): logger.debug("Converting Infinity to string") return "Infinity" return value # CHART CODING STARTS FROM HERE instructions = """ - Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). - For multiple charts, put all of them in a single file. - Use colorblind-friendly palette - Read above instructions and follow them. """ # Thread-safe configuration for chart endpoints current_groq_chart_key_index = 0 current_groq_chart_lock = threading.Lock() # current_langchain_chart_key_index = 0 # current_langchain_chart_lock = threading.Lock() def model(): global current_groq_chart_key_index, current_groq_chart_lock with current_groq_chart_lock: if current_groq_chart_key_index >= len(groq_api_keys): raise Exception("All API keys exhausted for chart generation") api_key = groq_api_keys[current_groq_chart_key_index] return ChatGroq(model=model_name, api_key=api_key) def groq_chart(csv_url: str, question: str): global current_groq_chart_key_index, current_groq_chart_lock for attempt in range(len(groq_api_keys)): try: # Clean cache before processing # cache_db_path = "/workspace/cache/cache_db_0.11.db" # if os.path.exists(cache_db_path): # try: # os.remove(cache_db_path) # except Exception as e: # logger.info(f"Cache cleanup error: {e}") data = clean_data(csv_url) with current_groq_chart_lock: current_api_key = groq_api_keys[current_groq_chart_key_index] llm = ChatGroq(model=model_name, api_key=current_api_key) # Generate unique filename using UUID chart_filename = f"chart_{uuid.uuid4()}.png" chart_path = os.path.join("generated_charts", chart_filename) # Configure SmartDataframe with chart settings df = SmartDataframe( data, config={ 'llm': llm, 'save_charts': True, # Enable chart saving 'open_charts': False, 'save_charts_path': os.path.dirname(chart_path), # Directory to save 'custom_chart_filename': chart_filename, # Unique filename 'enable_cache': False } ) answer = df.chat(question + instructions) if process_answer(answer): return "Chart not generated" return answer except Exception as e: error = str(e) # if "429" in error: if error != "": with current_groq_chart_lock: current_groq_chart_key_index = (current_groq_chart_key_index + 1) else: logger.error(f"Chart generation error: {error}") return {"error": error} return {"error": "All API keys exhausted for chart generation"} # Global locks for key rotation (chart endpoints) # current_groq_chart_key_index = 0 # current_groq_chart_lock = threading.Lock() current_langchain_chart_key_index = 0 current_langchain_chart_lock = threading.Lock() # Use a process pool to run CPU-bound charts generation process_executor = ProcessPoolExecutor(max_workers=max_cpus-2) # --- LANGCHAIN-BASED CHART GENERATION --- def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): """ Generate a chart using the langchain-based method. Modifications: • No shared deletion of cache. • Close matplotlib figures after generation. • Return a list of full chart file paths. """ global current_langchain_chart_key_index, current_langchain_chart_lock data = clean_data(csv_url) for attempt in range(len(groq_api_keys)): try: with current_langchain_chart_lock: api_key = groq_api_keys[current_langchain_chart_key_index] current_key = current_langchain_chart_key_index current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) llm = ChatGroq(model=model_name, api_key=api_key) tool = PythonAstREPLTool(locals={ "df": data, "pd": pd, "np": np, "plt": plt, "sns": sns, "matplotlib": matplotlib, "uuid": uuid }) agent = create_pandas_dataframe_agent( llm, data, agent_type="tool-calling", verbose=True, allow_dangerous_code=True, extra_tools=[tool], return_intermediate_steps=True ) result = agent.invoke({"input": _prompt_generator(question, True, csv_url)}) output = result.get("output", "") # Close figures to avoid interference plt.close('all') # Extract chart filenames (assuming extract_chart_filenames returns a list) chart_files = extract_chart_filenames(output) if len(chart_files) > 0: # Return full paths (join with your image_file_path) return [os.path.join(image_file_path, f) for f in chart_files] if attempt < len(groq_api_keys) - 1: logger.info(f"Langchain chart error (key {current_key}): {output}") except Exception as e: error_message = str(e) if error_message != "": with current_langchain_chart_lock: current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}") else: logger.error(f"Error with API key {api_key}: {error_message}") return {"error": error_message} logger.error("All API keys exhausted for chart generation") return "Chart generation failed after all retries" # --- FASTAPI ENDPOINT FOR CHART GENERATION --- @app.post("/api/csv-chart") async def csv_chart(request: dict, authorization: str = Header(None)): """ Endpoint for generating a chart from CSV data. This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation functions in separate processes so that multiple requests can run in parallel. """ # --- Authorization Check --- if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Authorization required") token = authorization.split(" ")[1] if token != os.getenv("AUTH_TOKEN"): raise HTTPException(status_code=403, detail="Invalid credentials") try: query = request.get("query", "") csv_url = unquote(request.get("csv_url", "")) detailed_answer = request.get("detailed_answer", False) conversation_history = request.get("conversation_history", []) generate_report = request.get("generate_report", False) chat_id = request.get("chat_id", "") if generate_report is True: report_files = await generate_csv_report_gemini(csv_url, query, chat_id, conversation_history) if report_files is not None: return {"orchestrator_response": jsonable_encoder(report_files)} loop = asyncio.get_running_loop() # First, try the langchain-based method if the question qualifies if if_initial_chart_question(query): langchain_result = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) logger.info("Langchain chart result:", langchain_result) if isinstance(langchain_result, list) and len(langchain_result) > 0: unique_file_name =f'{str(uuid.uuid4())}.png' logger.info("Uploading the chart to supabase...") image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name, chat_id=chat_id) logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) os.remove(langchain_result[0]) return {"image_url": image_public_url} # return FileResponse(langchain_result[0], media_type="image/png") # Use orchestrator to handle the user's chart query first if detailed_answer is True: orchestrator_answer = await asyncio.to_thread( csv_orchestrator_chat_gemini, csv_url, query, conversation_history, chat_id ) if orchestrator_answer is not None: return {"orchestrator_response": jsonable_encoder(orchestrator_answer)} logger.info("Trying cerebras ai llama...") result = await query_csv_agent_cerebras(csv_url, query, chat_id) logger.info("cerebras ai result ==>", result) if result is not None and result != "": return {"orchestrator_response": jsonable_encoder(result)} # Fallback: try langchain-based again logger.error("Cerebras ai llama response failed, trying langchain groq....") langchain_paths = await loop.run_in_executor( process_executor, langchain_csv_chart, csv_url, query, True ) logger.info("Fallback langchain chart result:", langchain_paths) if isinstance(langchain_paths, list) and len(langchain_paths) > 0: unique_file_name =f'{str(uuid.uuid4())}.png' logger.info("Uploading the chart to supabase...") image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name, chat_id=chat_id) logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) os.remove(langchain_paths[0]) return {"image_url": image_public_url} else: logger.error("All chart generation methods failed") return {"answer": "error"} except Exception as e: logger.error(f"Critical chart error: {str(e)}") return {"answer": "error"}