Soumik555 commited on
Commit
f182068
·
1 Parent(s): b2c0e5d

changed model to .env gemini-flash-2.0

Browse files
Files changed (2) hide show
  1. cereberas_langchain_agent.py +31 -23
  2. controller.py +41 -12
cereberas_langchain_agent.py CHANGED
@@ -16,15 +16,22 @@ import datetime as dt
16
  matplotlib.use('Agg')
17
 
18
  load_dotenv()
19
- model_name = os.getenv("CEREBRAS_LLM_MODEL") # Specify your Cerebras model name
20
- cerebras_api_key = os.getenv("CEREBRAS_API_KEY")
21
- cerebras_base_url = os.getenv("CEREBRAS_BASE_URL")
22
 
23
- # Initialize ChatCerebras LLM
24
- llm = ChatCerebras(model=model_name, api_key=cerebras_api_key)
 
 
 
 
 
 
 
 
25
 
26
  def create_agent(llm, data, tools):
27
- """Create agent with tool names"""
28
  return create_pandas_dataframe_agent(
29
  llm,
30
  data,
@@ -37,7 +44,6 @@ def create_agent(llm, data, tools):
37
 
38
  def _prompt_generator(question: str, chart_required: bool, csv_url: str):
39
  chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
40
-
41
  1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
42
  2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
43
  3. **Communication:** Provide concise, professional, and well-structured responses.
@@ -48,7 +54,6 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
48
  """
49
 
50
  chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
51
-
52
  1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
53
  2. Visualization requirements:
54
  - Adjust font sizes, rotate labels (45° if needed), truncate for readability
@@ -80,15 +85,12 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
80
  - Always use pd.read_csv({csv_url}) to read the CSV file
81
  """
82
 
83
- if chart_required:
84
- return ChatPromptTemplate.from_template(chart_prompt)
85
- else:
86
- return ChatPromptTemplate.from_template(chat_prompt)
87
 
88
  def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
89
- """Process CSV using ChatCerebras"""
 
90
  data = pd.read_csv(csv_url)
91
-
92
  tool = PythonAstREPLTool(
93
  locals={
94
  "df": data,
@@ -101,13 +103,19 @@ def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
101
  "dt": dt
102
  },
103
  )
104
-
105
- agent = create_agent(llm, data, [tool])
106
- prompt = _prompt_generator(question, chart_required, csv_url)
107
- result = agent.invoke({"input": prompt})
108
- output = result.get("output")
109
 
110
- if output is None:
111
- raise ValueError("Received None response from agent")
112
-
113
- return output
 
 
 
 
 
 
 
 
 
 
 
 
16
  matplotlib.use('Agg')
17
 
18
  load_dotenv()
19
+ api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
20
+ base_url = os.getenv("CEREBRAS_BASE_URL")
21
+ model_name = os.getenv("CEREBRAS_MODEL")
22
 
23
+ current_key_index = 0 # Track which key is being used
24
+
25
+ def get_next_llm():
26
+ """Return a ChatCerebras instance using the next available API key"""
27
+ global current_key_index
28
+ if current_key_index >= len(api_keys):
29
+ raise ValueError("All Cerebras API keys exhausted.")
30
+ key = api_keys[current_key_index]
31
+ print(f"Using Cerebras API key index: {current_key_index}")
32
+ return ChatCerebras(model=model_name, api_key=key, base_url=base_url)
33
 
34
  def create_agent(llm, data, tools):
 
35
  return create_pandas_dataframe_agent(
36
  llm,
37
  data,
 
44
 
45
  def _prompt_generator(question: str, chart_required: bool, csv_url: str):
46
  chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
 
47
  1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
48
  2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
49
  3. **Communication:** Provide concise, professional, and well-structured responses.
 
54
  """
55
 
56
  chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
 
57
  1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
58
  2. Visualization requirements:
59
  - Adjust font sizes, rotate labels (45° if needed), truncate for readability
 
85
  - Always use pd.read_csv({csv_url}) to read the CSV file
86
  """
87
 
88
+ return ChatPromptTemplate.from_template(chart_prompt if chart_required else chat_prompt)
 
 
 
89
 
90
  def cerebras_csv_handler(csv_url: str, question: str, chart_required: bool):
91
+ """Process CSV using ChatCerebras with key rotation"""
92
+ global current_key_index
93
  data = pd.read_csv(csv_url)
 
94
  tool = PythonAstREPLTool(
95
  locals={
96
  "df": data,
 
103
  "dt": dt
104
  },
105
  )
 
 
 
 
 
106
 
107
+ while current_key_index < len(api_keys):
108
+ try:
109
+ llm = get_next_llm()
110
+ agent = create_agent(llm, data, [tool])
111
+ prompt = _prompt_generator(question, chart_required, csv_url)
112
+ result = agent.invoke({"input": prompt})
113
+ output = result.get("output")
114
+ if output is None:
115
+ raise ValueError("Received None response from agent")
116
+ return output
117
+ except Exception as e:
118
+ print(f"Error with key index {current_key_index}: {e}")
119
+ current_key_index += 1
120
+
121
+ raise ValueError("All Cerebras API keys exhausted.")
controller.py CHANGED
@@ -16,6 +16,7 @@ from pandasai import SmartDataframe
16
  from langchain_groq.chat_models import ChatGroq
17
  from dotenv import load_dotenv
18
  from pydantic import BaseModel, Field
 
19
  from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info
20
  from urllib.parse import unquote
21
  from langchain_groq import ChatGroq
@@ -325,7 +326,45 @@ def langchain_csv_chat(csv_url: str, question: str, chart_required: bool):
325
  return {"error": error_message}
326
 
327
  return {"error": "All API keys exhausted"}
328
- from cerebras_report_generator import generate_csv_report_cerebras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  # Async endpoint with non-blocking execution
331
  @app.post("/api/csv-chat")
@@ -400,17 +439,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
400
  # Handle detailed answers with orchestrator
401
  if detailed_answer is True:
402
  logger.info("Processing detailed answer with orchestrator...")
403
- try:
404
- orchestrator_answer = await asyncio.to_thread(
405
- csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id
406
- )
407
- if orchestrator_answer is not None:
408
- logger.info(f"Orchestrator answer successful: {str(orchestrator_answer)[:200]}...")
409
- return {"answer": jsonable_encoder(orchestrator_answer)}
410
- else:
411
- logger.warning("Orchestrator returned None result")
412
- except Exception as e:
413
- logger.error(f"Orchestrator processing failed: {str(e)}")
414
 
415
  # Process with standard CSV agent (not Cerebras)
416
  logger.info("Processing with standard CSV agent...")
 
16
  from langchain_groq.chat_models import ChatGroq
17
  from dotenv import load_dotenv
18
  from pydantic import BaseModel, Field
19
+ from cerebras_report_generator import generate_csv_report_cerebras
20
  from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info
21
  from urllib.parse import unquote
22
  from langchain_groq import ChatGroq
 
326
  return {"error": error_message}
327
 
328
  return {"error": "All API keys exhausted"}
329
+
330
+
331
+ async def handle_detailed_answer(decoded_url, query, conversation_history, chat_id):
332
+ """
333
+ Try CSV processing first with Cerebras orchestrator, then fallback to Gemini if needed.
334
+ """
335
+ orchestrator_answer = None
336
+
337
+ # Step 1: Try Cerebras
338
+ try:
339
+ logger.info("Processing detailed answer with Cerebras orchestrator...")
340
+ orchestrator_answer = await asyncio.to_thread(
341
+ csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id
342
+ )
343
+ if orchestrator_answer is not None:
344
+ logger.info(f"Cerebras answer successful: {str(orchestrator_answer)[:200]}...")
345
+ return {"answer": jsonable_encoder(orchestrator_answer)}
346
+ else:
347
+ logger.warning("Cerebras orchestrator returned None")
348
+ except Exception as e:
349
+ logger.error(f"Cerebras orchestrator failed: {str(e)}")
350
+
351
+ # Step 2: Fallback to Gemini
352
+ try:
353
+ logger.info("Falling back to Gemini orchestrator...")
354
+ orchestrator_answer = await asyncio.to_thread(
355
+ csv_orchestrator_chat_gemini, decoded_url, query, conversation_history, chat_id
356
+ )
357
+ if orchestrator_answer is not None:
358
+ logger.info(f"Gemini answer successful: {str(orchestrator_answer)[:200]}...")
359
+ return {"answer": jsonable_encoder(orchestrator_answer)}
360
+ else:
361
+ logger.warning("Gemini orchestrator returned None")
362
+ except Exception as e:
363
+ logger.error(f"Gemini orchestrator failed: {str(e)}")
364
+
365
+ # Step 3: Both failed
366
+ logger.error("Both Cerebras and Gemini orchestrators failed or returned None")
367
+ return {"answer": None}
368
 
369
  # Async endpoint with non-blocking execution
370
  @app.post("/api/csv-chat")
 
439
  # Handle detailed answers with orchestrator
440
  if detailed_answer is True:
441
  logger.info("Processing detailed answer with orchestrator...")
442
+ return await handle_detailed_answer(decoded_url, query, conversation_history, chat_id)
 
 
 
 
 
 
 
 
 
 
443
 
444
  # Process with standard CSV agent (not Cerebras)
445
  logger.info("Processing with standard CSV agent...")