Tamannathakur commited on
Commit
73cf5aa
·
verified ·
1 Parent(s): 2602381

Update ai_agent.py

Browse files
Files changed (1) hide show
  1. ai_agent.py +31 -7
ai_agent.py CHANGED
@@ -168,6 +168,26 @@ def create_chart(df, selected_columns=None, chart_type="bar", title=None):
168
  except:
169
  return None
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
172
  plot_spec = plan.get("plot")
173
  if not plot_spec:
@@ -175,9 +195,12 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
175
  ptype = plot_spec.get("type", "bar")
176
  title = plot_spec.get("title", "Chart")
177
  plot_df = df if describe_stats else dfw
178
- x = plot_spec.get("x")
179
- y = plot_spec.get("y")
180
 
 
 
 
 
 
181
  if not x and len(plot_df.columns) > 0:
182
  categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
183
  x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
@@ -187,16 +210,16 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
187
 
188
  try:
189
  if ptype == "pie" and x and x in plot_df.columns:
190
- value_counts = plot_df[x].value_counts()
191
  fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
192
  fig.update_layout(title=title, width=900, height=500)
193
  return fig
194
  elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
195
- fig = px.bar(plot_df, x=x, y=y, title=title)
196
  fig.update_layout(width=900, height=500)
197
  return fig
198
  elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
199
- fig = px.line(plot_df, x=x, y=y, title=title, markers=True)
200
  fig.update_layout(width=900, height=500)
201
  return fig
202
  elif ptype == "hist" and y and y in plot_df.columns:
@@ -204,10 +227,11 @@ def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
204
  fig.update_layout(width=900, height=500)
205
  return fig
206
  elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
207
- fig = px.scatter(plot_df, x=x, y=y, title=title)
208
  fig.update_layout(width=900, height=500)
209
  return fig
210
- except:
 
211
  pass
212
  return None
213
 
 
168
  except:
169
  return None
170
 
171
+ def find_column(df, col_name):
172
+ """Find a column in DataFrame with fuzzy matching (case-insensitive, handles spaces/underscores)"""
173
+ if not col_name:
174
+ return None
175
+
176
+ # Exact match first
177
+ if col_name in df.columns:
178
+ return col_name
179
+
180
+ # Normalize the search term
181
+ search_term = col_name.lower().replace('_', '').replace(' ', '')
182
+
183
+ # Try fuzzy matching
184
+ for col in df.columns:
185
+ normalized_col = str(col).lower().replace('_', '').replace(' ', '')
186
+ if search_term == normalized_col or search_term in normalized_col or normalized_col in search_term:
187
+ return col
188
+
189
+ return None
190
+
191
  def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
192
  plot_spec = plan.get("plot")
193
  if not plot_spec:
 
195
  ptype = plot_spec.get("type", "bar")
196
  title = plot_spec.get("title", "Chart")
197
  plot_df = df if describe_stats else dfw
 
 
198
 
199
+ # Use fuzzy matching to find columns
200
+ x = find_column(plot_df, plot_spec.get("x"))
201
+ y = find_column(plot_df, plot_spec.get("y"))
202
+
203
+ # Fallback: auto-select columns if not found
204
  if not x and len(plot_df.columns) > 0:
205
  categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
206
  x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
 
210
 
211
  try:
212
  if ptype == "pie" and x and x in plot_df.columns:
213
+ value_counts = plot_df[x].value_counts().head(10) # Limit to top 10 for readability
214
  fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
215
  fig.update_layout(title=title, width=900, height=500)
216
  return fig
217
  elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
218
+ fig = px.bar(plot_df.head(50), x=x, y=y, title=title) # Limit rows for performance
219
  fig.update_layout(width=900, height=500)
220
  return fig
221
  elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
222
+ fig = px.line(plot_df.head(100), x=x, y=y, title=title, markers=True)
223
  fig.update_layout(width=900, height=500)
224
  return fig
225
  elif ptype == "hist" and y and y in plot_df.columns:
 
227
  fig.update_layout(width=900, height=500)
228
  return fig
229
  elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
230
+ fig = px.scatter(plot_df.head(200), x=x, y=y, title=title) # Limit for performance
231
  fig.update_layout(width=900, height=500)
232
  return fig
233
+ except Exception as e:
234
+ print(f"Error creating plot: {str(e)}")
235
  pass
236
  return None
237