changed model to .env gemini-flash-2.0
Browse files- cereberas_langchain_agent.py +31 -23
- controller.py +41 -12
cereberas_langchain_agent.py
CHANGED
|
@@ -16,15 +16,22 @@ import datetime as dt
|
|
| 16 |
matplotlib.use('Agg')
|
| 17 |
|
| 18 |
load_dotenv()
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def create_agent(llm, data, tools):
|
| 27 |
-
"""Create agent with tool names"""
|
| 28 |
return create_pandas_dataframe_agent(
|
| 29 |
llm,
|
| 30 |
data,
|
|
@@ -37,7 +44,6 @@ def create_agent(llm, data, tools):
|
|
| 37 |
|
| 38 |
def _prompt_generator(question: str, chart_required: bool, csv_url: str):
|
| 39 |
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
|
| 40 |
-
|
| 41 |
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
|
| 42 |
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
|
| 43 |
3. **Communication:** Provide concise, professional, and well-structured responses.
|
|
@@ -48,7 +54,6 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
|
| 51 |
-
|
| 52 |
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
|
| 53 |
2. Visualization requirements:
|
| 54 |
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
|
|
@@ -80,15 +85,12 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
|
|
| 80 |
- Always use pd.read_csv({csv_url}) to read the CSV file
|
| 81 |
"""
|
| 82 |
|
| 83 |
-
if chart_required
|
| 84 |
-
return ChatPromptTemplate.from_template(chart_prompt)
|
| 85 |
-
else:
|
| 86 |
-
return ChatPromptTemplate.from_template(chat_prompt)
|
| 87 |
|
| 88 |
def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
|
| 89 |
-
"""Process CSV using ChatCerebras"""
|
|
|
|
| 90 |
data = pd.read_csv(csv_url)
|
| 91 |
-
|
| 92 |
tool = PythonAstREPLTool(
|
| 93 |
locals={
|
| 94 |
"df": data,
|
|
@@ -101,13 +103,19 @@ def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
|
|
| 101 |
"dt": dt
|
| 102 |
},
|
| 103 |
)
|
| 104 |
-
|
| 105 |
-
agent = create_agent(llm, data, [tool])
|
| 106 |
-
prompt = _prompt_generator(question, chart_required, csv_url)
|
| 107 |
-
result = agent.invoke({"input": prompt})
|
| 108 |
-
output = result.get("output")
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
matplotlib.use('Agg')
|
| 17 |
|
| 18 |
load_dotenv()
|
| 19 |
+
api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
|
| 20 |
+
base_url = os.getenv("CEREBRAS_BASE_URL")
|
| 21 |
+
model_name = os.getenv("CEREBRAS_MODEL")
|
| 22 |
|
| 23 |
+
current_key_index = 0 # Track which key is being used
|
| 24 |
+
|
| 25 |
+
def get_next_llm():
|
| 26 |
+
"""Return a ChatCerebras instance using the next available API key"""
|
| 27 |
+
global current_key_index
|
| 28 |
+
if current_key_index >= len(api_keys):
|
| 29 |
+
raise ValueError("All Cerebras API keys exhausted.")
|
| 30 |
+
key = api_keys[current_key_index]
|
| 31 |
+
print(f"Using Cerebras API key index: {current_key_index}")
|
| 32 |
+
return ChatCerebras(model=model_name, api_key=key, base_url=base_url)
|
| 33 |
|
| 34 |
def create_agent(llm, data, tools):
|
|
|
|
| 35 |
return create_pandas_dataframe_agent(
|
| 36 |
llm,
|
| 37 |
data,
|
|
|
|
| 44 |
|
| 45 |
def _prompt_generator(question: str, chart_required: bool, csv_url: str):
|
| 46 |
chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
|
|
|
|
| 47 |
1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
|
| 48 |
2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
|
| 49 |
3. **Communication:** Provide concise, professional, and well-structured responses.
|
|
|
|
| 54 |
"""
|
| 55 |
|
| 56 |
chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
|
|
|
|
| 57 |
1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
|
| 58 |
2. Visualization requirements:
|
| 59 |
- Adjust font sizes, rotate labels (45° if needed), truncate for readability
|
|
|
|
| 85 |
- Always use pd.read_csv({csv_url}) to read the CSV file
|
| 86 |
"""
|
| 87 |
|
| 88 |
+
return ChatPromptTemplate.from_template(chart_prompt if chart_required else chat_prompt)
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
|
| 91 |
+
"""Process CSV using ChatCerebras with key rotation"""
|
| 92 |
+
global current_key_index
|
| 93 |
data = pd.read_csv(csv_url)
|
|
|
|
| 94 |
tool = PythonAstREPLTool(
|
| 95 |
locals={
|
| 96 |
"df": data,
|
|
|
|
| 103 |
"dt": dt
|
| 104 |
},
|
| 105 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
while current_key_index < len(api_keys):
|
| 108 |
+
try:
|
| 109 |
+
llm = get_next_llm()
|
| 110 |
+
agent = create_agent(llm, data, [tool])
|
| 111 |
+
prompt = _prompt_generator(question, chart_required, csv_url)
|
| 112 |
+
result = agent.invoke({"input": prompt})
|
| 113 |
+
output = result.get("output")
|
| 114 |
+
if output is None:
|
| 115 |
+
raise ValueError("Received None response from agent")
|
| 116 |
+
return output
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Error with key index {current_key_index}: {e}")
|
| 119 |
+
current_key_index += 1
|
| 120 |
+
|
| 121 |
+
raise ValueError("All Cerebras API keys exhausted.")
|
controller.py
CHANGED
|
@@ -16,6 +16,7 @@ from pandasai import SmartDataframe
|
|
| 16 |
from langchain_groq.chat_models import ChatGroq
|
| 17 |
from dotenv import load_dotenv
|
| 18 |
from pydantic import BaseModel, Field
|
|
|
|
| 19 |
from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info
|
| 20 |
from urllib.parse import unquote
|
| 21 |
from langchain_groq import ChatGroq
|
|
@@ -325,7 +326,45 @@ def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
|
|
| 325 |
return {"error": error_message}
|
| 326 |
|
| 327 |
return {"error": "All API keys exhausted"}
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Async endpoint with non-blocking execution
|
| 331 |
@app.post("/api/csv-chat")
|
|
@@ -400,17 +439,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
|
|
| 400 |
# Handle detailed answers with orchestrator
|
| 401 |
if detailed_answer is True:
|
| 402 |
logger.info("Processing detailed answer with orchestrator...")
|
| 403 |
-
|
| 404 |
-
orchestrator_answer = await asyncio.to_thread(
|
| 405 |
-
csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id
|
| 406 |
-
)
|
| 407 |
-
if orchestrator_answer is not None:
|
| 408 |
-
logger.info(f"Orchestrator answer successful: {str(orchestrator_answer)[:200]}...")
|
| 409 |
-
return {"answer": jsonable_encoder(orchestrator_answer)}
|
| 410 |
-
else:
|
| 411 |
-
logger.warning("Orchestrator returned None result")
|
| 412 |
-
except Exception as e:
|
| 413 |
-
logger.error(f"Orchestrator processing failed: {str(e)}")
|
| 414 |
|
| 415 |
# Process with standard CSV agent (not Cerebras)
|
| 416 |
logger.info("Processing with standard CSV agent...")
|
|
|
|
| 16 |
from langchain_groq.chat_models import ChatGroq
|
| 17 |
from dotenv import load_dotenv
|
| 18 |
from pydantic import BaseModel, Field
|
| 19 |
+
from cerebras_report_generator import generate_csv_report_cerebras
|
| 20 |
from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info
|
| 21 |
from urllib.parse import unquote
|
| 22 |
from langchain_groq import ChatGroq
|
|
|
|
| 326 |
return {"error": error_message}
|
| 327 |
|
| 328 |
return {"error": "All API keys exhausted"}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
async def handle_detailed_answer(decoded_url, query, conversation_history, chat_id):
|
| 332 |
+
"""
|
| 333 |
+
Try CSV processing first with Cerebras orchestrator, then fallback to Gemini if needed.
|
| 334 |
+
"""
|
| 335 |
+
orchestrator_answer = None
|
| 336 |
+
|
| 337 |
+
# Step 1: Try Cerebras
|
| 338 |
+
try:
|
| 339 |
+
logger.info("Processing detailed answer with Cerebras orchestrator...")
|
| 340 |
+
orchestrator_answer = await asyncio.to_thread(
|
| 341 |
+
csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id
|
| 342 |
+
)
|
| 343 |
+
if orchestrator_answer is not None:
|
| 344 |
+
logger.info(f"Cerebras answer successful: {str(orchestrator_answer)[:200]}...")
|
| 345 |
+
return {"answer": jsonable_encoder(orchestrator_answer)}
|
| 346 |
+
else:
|
| 347 |
+
logger.warning("Cerebras orchestrator returned None")
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"Cerebras orchestrator failed: {str(e)}")
|
| 350 |
+
|
| 351 |
+
# Step 2: Fallback to Gemini
|
| 352 |
+
try:
|
| 353 |
+
logger.info("Falling back to Gemini orchestrator...")
|
| 354 |
+
orchestrator_answer = await asyncio.to_thread(
|
| 355 |
+
csv_orchestrator_chat_gemini, decoded_url, query, conversation_history, chat_id
|
| 356 |
+
)
|
| 357 |
+
if orchestrator_answer is not None:
|
| 358 |
+
logger.info(f"Gemini answer successful: {str(orchestrator_answer)[:200]}...")
|
| 359 |
+
return {"answer": jsonable_encoder(orchestrator_answer)}
|
| 360 |
+
else:
|
| 361 |
+
logger.warning("Gemini orchestrator returned None")
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Gemini orchestrator failed: {str(e)}")
|
| 364 |
+
|
| 365 |
+
# Step 3: Both failed
|
| 366 |
+
logger.error("Both Cerebras and Gemini orchestrators failed or returned None")
|
| 367 |
+
return {"answer": None}
|
| 368 |
|
| 369 |
# Async endpoint with non-blocking execution
|
| 370 |
@app.post("/api/csv-chat")
|
|
|
|
| 439 |
# Handle detailed answers with orchestrator
|
| 440 |
if detailed_answer is True:
|
| 441 |
logger.info("Processing detailed answer with orchestrator...")
|
| 442 |
+
return await handle_detailed_answer(decoded_url, query, conversation_history, chat_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
# Process with standard CSV agent (not Cerebras)
|
| 445 |
logger.info("Processing with standard CSV agent...")
|