|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
import uuid |
|
|
import matplotlib.pyplot as plt |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, List, Literal, Optional, Union |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import json |
|
|
import io |
|
|
import contextlib |
|
|
import traceback |
|
|
import time |
|
|
from datetime import datetime, timedelta |
|
|
import seaborn as sns |
|
|
import scipy.stats as stats |
|
|
from pydantic import BaseModel |
|
|
from tabulate import tabulate |
|
|
import asyncio |
|
|
|
|
|
from supabase_service import upload_file_to_supabase |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class CodeResponse(BaseModel): |
|
|
"""Container for code-related responses""" |
|
|
language: str = "python" |
|
|
code: str |
|
|
|
|
|
|
|
|
class ChartSpecification(BaseModel): |
|
|
"""Details about requested charts""" |
|
|
image_description: str |
|
|
code: Optional[str] = None |
|
|
|
|
|
|
|
|
class AnalysisOperation(BaseModel): |
|
|
"""Container for a single analysis operation with its code and result""" |
|
|
code: CodeResponse |
|
|
result_var: Union[str, List[str]] |
|
|
|
|
|
|
|
|
class CsvChatResult(BaseModel): |
|
|
"""Structured response for CSV-related AI interactions""" |
|
|
casual_response: str |
|
|
analysis_operations: Optional[AnalysisOperation] = None |
|
|
charts: Optional[ChartSpecification] = None |
|
|
|
|
|
|
|
|
class PythonExecutor: |
|
|
"""Handles execution of Python code with comprehensive data analysis libraries""" |
|
|
|
|
|
def __init__(self, df: pd.DataFrame, charts_folder: str = "generated_charts"): |
|
|
""" |
|
|
Initialize the PythonExecutor with a DataFrame |
|
|
|
|
|
Args: |
|
|
df (pd.DataFrame): The DataFrame to operate on |
|
|
charts_folder (str): Folder to save charts in |
|
|
""" |
|
|
self.df = df.copy() |
|
|
self.charts_folder = Path(charts_folder) |
|
|
self.charts_folder.mkdir(exist_ok=True, parents=True) |
|
|
self.exec_locals = {} |
|
|
self._setup_matplotlib() |
|
|
|
|
|
def _setup_matplotlib(self): |
|
|
"""Configure matplotlib for non-interactive use""" |
|
|
plt.ioff() |
|
|
plt.rcParams['figure.figsize'] = [10, 6] |
|
|
plt.rcParams['figure.dpi'] = 100 |
|
|
plt.rcParams['savefig.bbox'] = 'tight' |
|
|
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Execute Python code with full data analysis context and return results |
|
|
|
|
|
Args: |
|
|
code (str): Python code to execute |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary containing execution results and any generated plots |
|
|
""" |
|
|
output = "" |
|
|
error = None |
|
|
plots = [] |
|
|
|
|
|
|
|
|
stdout = io.StringIO() |
|
|
|
|
|
|
|
|
original_show = plt.show |
|
|
|
|
|
def custom_show(): |
|
|
"""Custom show function that saves plots instead of displaying them""" |
|
|
nonlocal plots |
|
|
for i, fig in enumerate(plt.get_fignums()): |
|
|
figure = plt.figure(fig) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
figure.savefig(buf, format='png', bbox_inches='tight', dpi=100) |
|
|
buf.seek(0) |
|
|
plots.append(buf.getvalue()) |
|
|
plt.close('all') |
|
|
|
|
|
try: |
|
|
|
|
|
exec_globals = { |
|
|
|
|
|
'pd': pd, |
|
|
'np': np, |
|
|
'df': self.df, |
|
|
|
|
|
|
|
|
'plt': plt, |
|
|
'sns': sns, |
|
|
'tabulate': tabulate, |
|
|
|
|
|
|
|
|
'stats': stats, |
|
|
|
|
|
|
|
|
'datetime': datetime, |
|
|
'timedelta': timedelta, |
|
|
'time': time, |
|
|
|
|
|
|
|
|
'json': json, |
|
|
'__builtins__': __builtins__, |
|
|
} |
|
|
|
|
|
|
|
|
exec_globals.update(self.exec_locals) |
|
|
|
|
|
|
|
|
plt.show = custom_show |
|
|
|
|
|
|
|
|
with contextlib.redirect_stdout(stdout): |
|
|
compiled_code = compile(code, '<string>', 'exec') |
|
|
exec(compiled_code, exec_globals, self.exec_locals) |
|
|
|
|
|
output = stdout.getvalue() |
|
|
|
|
|
except Exception as e: |
|
|
error = { |
|
|
"message": str(e), |
|
|
"traceback": traceback.format_exc() |
|
|
} |
|
|
|
|
|
plt.close('all') |
|
|
finally: |
|
|
|
|
|
plt.show = original_show |
|
|
|
|
|
plt.close('all') |
|
|
|
|
|
return { |
|
|
'output': output, |
|
|
'error': error, |
|
|
'plots': plots, |
|
|
'locals': dict(self.exec_locals) |
|
|
} |
|
|
|
|
|
async def save_plot_to_supabase(self, plot_data: bytes, description: str, chat_id: str) -> str: |
|
|
""" |
|
|
Save plot to Supabase storage and return the public URL |
|
|
|
|
|
Args: |
|
|
plot_data (bytes): Image data in bytes |
|
|
description (str): Description of the plot |
|
|
chat_id (str): ID of the chat session |
|
|
|
|
|
Returns: |
|
|
str: Public URL of the uploaded chart |
|
|
""" |
|
|
|
|
|
filename = f"chart_{uuid.uuid4().hex}.png" |
|
|
filepath = self.charts_folder / filename |
|
|
|
|
|
|
|
|
try: |
|
|
with open(filepath, 'wb') as f: |
|
|
f.write(plot_data) |
|
|
|
|
|
|
|
|
try: |
|
|
public_url = await asyncio.wait_for( |
|
|
upload_file_to_supabase( |
|
|
file_path=str(filepath), |
|
|
file_name=filename, |
|
|
chat_id=chat_id |
|
|
), |
|
|
timeout=30.0 |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
os.remove(filepath) |
|
|
except OSError: |
|
|
pass |
|
|
|
|
|
return public_url |
|
|
except asyncio.TimeoutError: |
|
|
raise Exception("Upload timed out after 30 seconds") |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to upload plot to Supabase: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if os.path.exists(filepath): |
|
|
try: |
|
|
os.remove(filepath) |
|
|
except OSError: |
|
|
pass |
|
|
raise Exception(f"Failed to save plot: {e}") |
|
|
|
|
|
def _format_result(self, result: Any) -> str: |
|
|
"""Format the result for display""" |
|
|
if isinstance(result, pd.DataFrame): |
|
|
return result.to_string() |
|
|
elif isinstance(result, pd.Series): |
|
|
return result.to_string() |
|
|
elif isinstance(result, (dict, list)): |
|
|
|
|
|
def json_serializer(obj): |
|
|
"""Handle special types that aren't JSON serializable""" |
|
|
if isinstance(obj, (pd.Timestamp, datetime)): |
|
|
return obj.isoformat() |
|
|
elif isinstance(obj, (np.integer, np.int64, np.int32)): |
|
|
return int(obj) |
|
|
elif isinstance(obj, (np.floating, np.float64, np.float32)): |
|
|
return float(obj) |
|
|
elif isinstance(obj, np.ndarray): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, pd.Series): |
|
|
return obj.to_dict() |
|
|
elif isinstance(obj, pd.DataFrame): |
|
|
return obj.to_dict('records') |
|
|
elif hasattr(obj, '__dict__'): |
|
|
return str(obj) |
|
|
else: |
|
|
return str(obj) |
|
|
|
|
|
try: |
|
|
return json.dumps(result, indent=2, default=json_serializer) |
|
|
except Exception as e: |
|
|
|
|
|
return f"Result (JSON serialization failed: {str(e)}):\n{str(result)}" |
|
|
elif isinstance(result, (pd.Timestamp, datetime)): |
|
|
return result.isoformat() |
|
|
elif isinstance(result, (np.integer, np.int64, np.int32)): |
|
|
return str(int(result)) |
|
|
elif isinstance(result, (np.floating, np.float64, np.float32)): |
|
|
return str(float(result)) |
|
|
elif isinstance(result, np.ndarray): |
|
|
return str(result) |
|
|
elif hasattr(result, '__str__'): |
|
|
return str(result) |
|
|
else: |
|
|
return repr(result) |
|
|
|
|
|
def _get_result_variables(self, result_var: Union[str, List[str]]) -> Dict[str, Any]: |
|
|
"""Get result variables from execution locals""" |
|
|
results = {} |
|
|
|
|
|
if isinstance(result_var, str): |
|
|
|
|
|
if ',' in result_var: |
|
|
var_names = [name.strip() for name in result_var.split(',')] |
|
|
else: |
|
|
var_names = [result_var.strip()] |
|
|
else: |
|
|
var_names = result_var |
|
|
|
|
|
for var_name in var_names: |
|
|
if var_name in self.exec_locals: |
|
|
results[var_name] = self.exec_locals[var_name] |
|
|
|
|
|
return results |
|
|
|
|
|
async def process_response(self, response: CsvChatResult, chat_id: str) -> str: |
|
|
"""Process the response with proper variable handling and error checking""" |
|
|
output_parts = [response.casual_response] |
|
|
|
|
|
|
|
|
if response.analysis_operations is not None: |
|
|
try: |
|
|
operation = response.analysis_operations |
|
|
if operation and operation.code and operation.code.code: |
|
|
execution_result = self.execute_code(operation.code.code) |
|
|
|
|
|
|
|
|
if execution_result.get('error'): |
|
|
output_parts.append(f"\n**Error in analysis operation:**") |
|
|
output_parts.append("```python\n" + execution_result['error']['message'] + "\n```") |
|
|
else: |
|
|
|
|
|
result_vars = self._get_result_variables(operation.result_var) |
|
|
|
|
|
if result_vars: |
|
|
for var_name, result in result_vars.items(): |
|
|
if result is not None: |
|
|
|
|
|
if (hasattr(result, '__len__') and len(result) == 0): |
|
|
output_parts.append(f"\n**Warning:** Variable '{var_name}' contains empty data") |
|
|
else: |
|
|
output_parts.append(f"\n**{var_name}:**") |
|
|
formatted_result = self._format_result(result) |
|
|
|
|
|
output_parts.append("```python\n" + formatted_result + "\n```") |
|
|
else: |
|
|
output_parts.append(f"\n**Warning:** Variable '{var_name}' is None or not found") |
|
|
else: |
|
|
|
|
|
output_str = execution_result.get('output', '').strip() |
|
|
if output_str: |
|
|
output_parts.append(f"\n**Execution output:**") |
|
|
output_parts.append("```python\n" + output_str + "\n```") |
|
|
else: |
|
|
output_parts.append(f"\n**Note:** Analysis operation executed but no results found for: {operation.result_var}") |
|
|
else: |
|
|
output_parts.append("\n**Warning:** Invalid analysis operation - missing code or result variable") |
|
|
except Exception as e: |
|
|
output_parts.append(f"\n**Error:** Error processing analysis operation: {str(e)}") |
|
|
if hasattr(operation, 'result_var'): |
|
|
output_parts.append(f"Expected variables: {operation.result_var}") |
|
|
|
|
|
|
|
|
if response.charts is not None: |
|
|
chart = response.charts |
|
|
try: |
|
|
if chart and (chart.code or chart.image_description): |
|
|
if chart.code: |
|
|
chart_result = self.execute_code(chart.code) |
|
|
if chart_result.get('plots'): |
|
|
|
|
|
if chart.image_description: |
|
|
output_parts.append(f"\n**Chart:** {chart.image_description}") |
|
|
|
|
|
|
|
|
for i, plot_data in enumerate(chart_result['plots']): |
|
|
try: |
|
|
public_url = await self.save_plot_to_supabase( |
|
|
plot_data=plot_data, |
|
|
description=chart.image_description, |
|
|
chat_id=chat_id |
|
|
) |
|
|
output_parts.append(f"") |
|
|
except Exception as e: |
|
|
output_parts.append(f"\n**Warning:** Error uploading chart {i+1}: {str(e)}") |
|
|
elif chart_result.get('error'): |
|
|
output_parts.append("```python\n" + f"Error generating {chart.image_description}: {chart_result['error']['message']}" + "\n```") |
|
|
else: |
|
|
output_parts.append(f"\n**Warning:** No chart generated for '{chart.image_description}'") |
|
|
else: |
|
|
output_parts.append(f"\n**Warning:** No code provided for chart: {chart.image_description}") |
|
|
else: |
|
|
output_parts.append("\n**Warning:** Invalid chart specification") |
|
|
except Exception as e: |
|
|
output_parts.append(f"\n**Error:** Error processing chart: {str(e)}") |
|
|
|
|
|
return "\n".join(output_parts) |