SVashishta1 commited on
Commit
984ec75
·
1 Parent(s): 2736104
Files changed (1) hide show
  1. app.py +246 -294
app.py CHANGED
@@ -300,156 +300,109 @@ def process_text_query(query, history):
300
  try:
301
  print("Visualization requested, attempting to create plot...")
302
 
303
- # Increase plot size
304
- fig_width = 1000 # Increased from 900
305
- fig_height = 700 # Increased from 600
306
-
307
- # Determine visualization type from query
308
- viz_type = None
309
- for vtype, keywords in viz_keywords.items():
310
- if any(keyword in query.lower() for keyword in keywords):
311
- viz_type = vtype
312
- break
313
-
314
- # If no specific type is detected, infer from data
315
- if not viz_type:
316
- if len(result_df) <= 10 and len(result_df.columns) == 2:
317
- viz_type = 'pie' # Small dataset with 2 columns is good for pie charts
318
- elif any('date' in col.lower() or 'time' in col.lower() or 'month' in col.lower() or 'year' in col.lower() for col in result_df.columns):
319
- viz_type = 'line' # Time-related data is good for line charts
320
- else:
321
- viz_type = 'bar' # Default to bar chart
322
-
323
- print(f"Detected visualization type: {viz_type}")
324
-
325
- # Find numeric columns for visualization
326
- numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
327
 
328
  # Create the appropriate visualization based on type
329
- if len(numeric_cols) >= 1 and len(result_df) > 1:
330
- if viz_type == 'pie' and len(result_df) <= 20:
331
- # For pie charts, we need a category column and a value column
332
- category_col = result_df.columns[0]
333
- value_col = numeric_cols[0] if numeric_cols else result_df.columns[1]
334
-
335
- # Handle case where all columns are numeric
336
- if len(numeric_cols) == len(result_df.columns):
337
- category_col = result_df.index.name or 'index'
338
- result_df = result_df.reset_index()
339
-
340
- fig = px.pie(
341
- result_df,
342
- names=category_col,
343
- values=value_col,
344
- title=f"Distribution of {value_col} by {category_col}",
345
- hole=0.3, # Donut chart for better readability
346
- color_discrete_sequence=px.colors.qualitative.Pastel
347
- )
348
-
349
- elif viz_type == 'histogram' and len(result_df.columns) > 0:
350
- # For histograms, we need at least one column
351
-
352
- # Find the best column for histogram (prefer numeric)
353
- if numeric_cols:
354
- x_col = numeric_cols[0]
355
- else:
356
- x_col = result_df.columns[0]
357
-
358
- # Check if data is already binned
359
- if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns):
360
- # Data is pre-binned, use a bar chart
361
- bin_col = 'bin' if 'bin' in result_df.columns else 'range'
362
- count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1]
363
-
364
- fig = px.bar(
365
- result_df,
366
- x=bin_col,
367
- y=count_col,
368
- title=f"Histogram of {x_col}",
369
- labels={bin_col: x_col, count_col: 'Frequency'},
370
- color_discrete_sequence=['#636EFA']
371
- )
372
- else:
373
- # Create a proper histogram from raw data
374
- fig = px.histogram(
375
- result_df,
376
- x=x_col,
377
- title=f"Distribution of {x_col}",
378
- nbins=20,
379
- marginal="box", # Add a box plot on the margin
380
- color_discrete_sequence=['#636EFA'],
381
- opacity=0.8,
382
- histnorm='probability density' # Normalize to show density instead of count
383
- )
384
-
385
- # Add a KDE (kernel density estimate) curve
386
- from scipy import stats
387
- import numpy as np
388
-
389
- # Only add KDE if we have numeric data
390
- if pd.api.types.is_numeric_dtype(result_df[x_col]):
391
- # Remove NaN values
392
- data = result_df[x_col].dropna()
393
-
394
- if len(data) > 1: # Need at least 2 points for KDE
395
- # Calculate KDE
396
- kde = stats.gaussian_kde(data)
397
- x_range = np.linspace(data.min(), data.max(), 1000)
398
- y_kde = kde(x_range)
399
-
400
- # Add KDE curve
401
- fig.add_scatter(
402
- x=x_range,
403
- y=y_kde,
404
- mode='lines',
405
- line=dict(color='red', width=2),
406
- name='Density Curve'
407
- )
408
-
409
- # Improve histogram layout
410
- fig.update_layout(
411
- bargap=0.1, # Gap between bars
412
- xaxis_title=x_col,
413
- yaxis_title='Frequency',
414
- showlegend=True
415
- )
416
 
417
- elif viz_type == 'box' and numeric_cols:
418
- # For box plots, we need to handle the data differently
419
- # SQLite doesn't support window functions for percentiles
420
- # So we'll calculate the box plot statistics in Python
421
-
422
- # Get the numeric column to plot
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  x_col = numeric_cols[0]
 
 
 
 
 
 
 
 
424
 
425
- # Create a box plot using plotly express
426
- fig = px.box(
427
  result_df,
428
- y=x_col,
429
- title=f"Box Plot of {x_col}",
430
- points="outliers", # Only show outlier points
 
431
  color_discrete_sequence=['#636EFA']
432
  )
433
-
434
- # Add a strip plot (individual points) on the side for better visualization
435
- fig.add_trace(
436
- px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0]
 
 
 
 
 
 
437
  )
438
 
439
- elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
440
- # For heatmaps, we need at least 2 numeric columns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- # If we have many numeric columns, create a correlation matrix
443
- if len(numeric_cols) >= 3:
444
- # Create a correlation matrix
445
- # First, drop any rows with NaN values in numeric columns
446
- clean_df = result_df[numeric_cols].dropna()
447
 
448
- if len(clean_df) > 1: # Need at least 2 rows for correlation
449
- corr_df = clean_df.corr()
450
-
451
- # Round to 2 decimal places for display
452
- corr_df = corr_df.round(2)
453
 
454
  fig = px.imshow(
455
  corr_df,
@@ -459,196 +412,195 @@ def process_text_query(query, history):
459
  aspect="auto",
460
  zmin=-1, zmax=1 # Set limits for correlation values
461
  )
462
-
463
- # Improve heatmap layout
464
- fig.update_layout(
465
- xaxis_title="Features",
466
- yaxis_title="Features",
467
- coloraxis_colorbar=dict(
468
- title="Correlation",
469
- thicknessmode="pixels", thickness=20,
470
- lenmode="pixels", len=300,
471
- yanchor="top", y=1,
472
- ticks="outside"
473
- )
474
- )
475
- else:
476
- # Not enough data for correlation
477
- fig = px.bar(
478
- pd.DataFrame({'Message': ['Not enough data for heatmap']}),
479
- title="Cannot create heatmap - insufficient data"
480
- )
481
- else:
482
- # If we only have 2 numeric columns, create a 2D histogram
483
- x_col = numeric_cols[0]
484
- y_col = numeric_cols[1]
485
-
486
- # Create a 2D histogram (heatmap)
487
- fig = px.density_heatmap(
488
- result_df,
489
- x=x_col,
490
- y=y_col,
491
- title=f"Density Heatmap of {x_col} vs {y_col}",
492
- color_continuous_scale='Viridis',
493
- nbinsx=20,
494
- nbinsy=20,
495
- marginal_x="histogram", # Add histograms on the margins
496
- marginal_y="histogram"
497
- )
498
 
499
  # Improve heatmap layout
500
  fig.update_layout(
501
- xaxis_title=x_col,
502
- yaxis_title=y_col,
503
  coloraxis_colorbar=dict(
504
- title="Count",
505
  thicknessmode="pixels", thickness=20,
506
  lenmode="pixels", len=300,
507
  yanchor="top", y=1,
508
  ticks="outside"
509
  )
510
  )
511
-
512
- elif viz_type == 'scatter' and len(numeric_cols) >= 2:
513
- # For scatter plots, we need at least 2 numeric columns
 
 
 
 
 
514
  x_col = numeric_cols[0]
515
  y_col = numeric_cols[1]
516
 
517
- # Add a third dimension (size) if available
518
- size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
519
-
520
- # Add a color dimension if available
521
- if len(result_df.columns) > len(numeric_cols):
522
- # Find a categorical column for color
523
- categorical_cols = [col for col in result_df.columns if col not in numeric_cols]
524
- color_col = categorical_cols[0] if categorical_cols else None
525
- else:
526
- color_col = None
527
-
528
- # Create scatter plot with enhanced features
529
- fig = px.scatter(
530
  result_df,
531
  x=x_col,
532
  y=y_col,
533
- size=size_col,
534
- color=color_col, # Add color dimension if available
535
- title=f"Relationship between {x_col} and {y_col}",
536
- opacity=0.7,
537
- size_max=15, # Maximum marker size
538
- color_discrete_sequence=px.colors.qualitative.Plotly
539
  )
540
 
541
- # Add a trend line
542
- if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]):
543
- fig.update_layout(
544
- shapes=[
545
- dict(
546
- type='line',
547
- xref='x', yref='y',
548
- x0=result_df[x_col].min(),
549
- y0=result_df[y_col].min(),
550
- x1=result_df[x_col].max(),
551
- y1=result_df[y_col].max(),
552
- line=dict(color='red', width=2, dash='dash')
553
- )
554
- ]
555
- )
556
-
557
- # Improve scatter plot layout
558
  fig.update_layout(
559
  xaxis_title=x_col,
560
  yaxis_title=y_col,
561
- showlegend=True,
562
- legend=dict(
563
- title=color_col if color_col else "",
564
- orientation="h",
565
- yanchor="bottom",
566
- y=1.02,
567
- xanchor="right",
568
- x=1
569
  )
570
  )
571
 
572
- elif viz_type == 'line':
573
- # For line charts, determine the x-axis (preferably a date/time column)
574
- time_cols = [col for col in result_df.columns if any(time_word in col.lower()
575
- for time_word in ['date', 'time', 'month', 'year', 'day'])]
576
-
577
- if time_cols:
578
- x_col = time_cols[0]
579
- else:
580
- x_col = result_df.columns[0]
581
-
582
- # Determine y-axis columns (numeric columns)
583
- y_cols = numeric_cols[:3] # Use up to 3 numeric columns
584
-
585
- if not y_cols and len(result_df.columns) > 1:
586
- # If no numeric columns, use the second column
587
- y_cols = [result_df.columns[1]]
588
-
589
- fig = px.line(
590
- result_df,
591
- x=x_col,
592
- y=y_cols,
593
- title="Time Series Analysis",
594
- markers=True, # Add markers at each data point
595
- color_discrete_sequence=px.colors.qualitative.Plotly
596
- )
597
-
598
- # Add range slider for time series
 
 
 
 
599
  fig.update_layout(
600
- xaxis=dict(
601
- rangeslider=dict(visible=True),
602
- type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-'
603
- )
 
 
 
 
 
 
 
604
  )
605
 
606
- else: # Default to bar chart
607
- # For bar charts, use the first column as x and numeric columns as y
608
- x_col = result_df.columns[0]
609
-
610
- # Determine y-axis columns (numeric columns)
611
- if numeric_cols and x_col not in numeric_cols:
612
- y_cols = numeric_cols[:3] # Use up to 3 numeric columns
613
- elif len(result_df.columns) > 1:
614
- y_cols = [result_df.columns[1]]
615
- else:
616
- y_cols = ['value']
617
- result_df['value'] = 1 # Default value if no suitable column
618
-
619
- fig = px.bar(
620
- result_df,
621
- x=x_col,
622
- y=y_cols[0], # Use only the first y column for bar charts
623
- title="Data Visualization",
624
- color_discrete_sequence=['#636EFA']
625
  )
 
 
 
 
 
 
 
 
 
 
 
626
 
627
- # Improve figure layout for all chart types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  fig.update_layout(
629
- autosize=True,
630
- width=fig_width,
631
- height=fig_height,
632
- margin=dict(l=50, r=50, b=100, t=100, pad=4),
633
- template="plotly_white",
634
- font=dict(size=14),
635
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
636
- plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
637
- paper_bgcolor='white'
638
  )
639
 
640
- # Convert the figure to an image and encode it as base64
641
- img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2)
642
- encoded = base64.b64encode(img_bytes).decode("ascii")
643
- img_src = f"data:image/png;base64,{encoded}"
644
 
645
- # Add the image directly to the response with increased size
646
- response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />"
 
 
 
 
 
 
647
 
648
- # Add note about visualization
649
- response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
650
- else:
651
- print("Not enough numeric columns or data points for visualization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  except Exception as viz_error:
653
  print(f"Visualization error: {str(viz_error)}")
654
  traceback.print_exc()
 
300
  try:
301
  print("Visualization requested, attempting to create plot...")
302
 
303
+ # Set common figure parameters
304
+ fig_width = 1000
305
+ fig_height = 700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  # Create the appropriate visualization based on type
308
+ if viz_type == 'pie' and len(result_df) <= 20:
309
+ # For pie charts, we need a category column and a value column
310
+ category_col = result_df.columns[0]
311
+ value_col = numeric_cols[0] if numeric_cols else result_df.columns[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ # Handle case where all columns are numeric
314
+ if len(numeric_cols) == len(result_df.columns):
315
+ category_col = result_df.index.name or 'index'
316
+ result_df = result_df.reset_index()
317
+
318
+ fig = px.pie(
319
+ result_df,
320
+ names=category_col,
321
+ values=value_col,
322
+ title=f"Distribution of {value_col} by {category_col}",
323
+ hole=0.3, # Donut chart for better readability
324
+ color_discrete_sequence=px.colors.qualitative.Pastel
325
+ )
326
+
327
+ elif viz_type == 'histogram' and len(result_df.columns) > 0:
328
+ # For histograms, we need at least one column
329
+
330
+ # Find the best column for histogram (prefer numeric)
331
+ if numeric_cols:
332
  x_col = numeric_cols[0]
333
+ else:
334
+ x_col = result_df.columns[0]
335
+
336
+ # Check if data is already binned
337
+ if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns):
338
+ # Data is pre-binned, use a bar chart
339
+ bin_col = 'bin' if 'bin' in result_df.columns else 'range'
340
+ count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1]
341
 
342
+ fig = px.bar(
 
343
  result_df,
344
+ x=bin_col,
345
+ y=count_col,
346
+ title=f"Histogram of {x_col}",
347
+ labels={bin_col: x_col, count_col: 'Frequency'},
348
  color_discrete_sequence=['#636EFA']
349
  )
350
+ else:
351
+ # Create a proper histogram from raw data
352
+ fig = px.histogram(
353
+ result_df,
354
+ x=x_col,
355
+ title=f"Distribution of {x_col}",
356
+ nbins=20,
357
+ marginal="box", # Add a box plot on the margin
358
+ color_discrete_sequence=['#636EFA'],
359
+ opacity=0.8
360
  )
361
 
362
+ # Improve histogram layout
363
+ fig.update_layout(
364
+ bargap=0.1, # Gap between bars
365
+ xaxis_title=x_col,
366
+ yaxis_title='Frequency',
367
+ showlegend=True
368
+ )
369
+
370
+ elif viz_type == 'box' and numeric_cols:
371
+ # For box plots, we need to handle the data differently
372
+ # SQLite doesn't support window functions for percentiles
373
+ # So we'll calculate the box plot statistics in Python
374
+
375
+ # Get the numeric column to plot
376
+ x_col = numeric_cols[0]
377
+
378
+ # Create a box plot using plotly express
379
+ fig = px.box(
380
+ result_df,
381
+ y=x_col,
382
+ title=f"Box Plot of {x_col}",
383
+ points="outliers", # Only show outlier points
384
+ color_discrete_sequence=['#636EFA']
385
+ )
386
+
387
+ # Add a strip plot (individual points) on the side for better visualization
388
+ fig.add_trace(
389
+ px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0]
390
+ )
391
+
392
+ elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
393
+ # For heatmaps, we need at least 2 numeric columns
394
+
395
+ # If we have many numeric columns, create a correlation matrix
396
+ if len(numeric_cols) >= 3:
397
+ # Create a correlation matrix
398
+ # First, drop any rows with NaN values in numeric columns
399
+ clean_df = result_df[numeric_cols].dropna()
400
 
401
+ if len(clean_df) > 1: # Need at least 2 rows for correlation
402
+ corr_df = clean_df.corr()
 
 
 
403
 
404
+ # Round to 2 decimal places for display
405
+ corr_df = corr_df.round(2)
 
 
 
406
 
407
  fig = px.imshow(
408
  corr_df,
 
412
  aspect="auto",
413
  zmin=-1, zmax=1 # Set limits for correlation values
414
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  # Improve heatmap layout
417
  fig.update_layout(
418
+ xaxis_title="Features",
419
+ yaxis_title="Features",
420
  coloraxis_colorbar=dict(
421
+ title="Correlation",
422
  thicknessmode="pixels", thickness=20,
423
  lenmode="pixels", len=300,
424
  yanchor="top", y=1,
425
  ticks="outside"
426
  )
427
  )
428
+ else:
429
+ # Not enough data for correlation
430
+ fig = px.bar(
431
+ pd.DataFrame({'Message': ['Not enough data for heatmap']}),
432
+ title="Cannot create heatmap - insufficient data"
433
+ )
434
+ else:
435
+ # If we only have 2 numeric columns, create a 2D histogram
436
  x_col = numeric_cols[0]
437
  y_col = numeric_cols[1]
438
 
439
+ # Create a 2D histogram (heatmap)
440
+ fig = px.density_heatmap(
 
 
 
 
 
 
 
 
 
 
 
441
  result_df,
442
  x=x_col,
443
  y=y_col,
444
+ title=f"Density Heatmap of {x_col} vs {y_col}",
445
+ color_continuous_scale='Viridis',
446
+ nbinsx=20,
447
+ nbinsy=20,
448
+ marginal_x="histogram", # Add histograms on the margins
449
+ marginal_y="histogram"
450
  )
451
 
452
+ # Improve heatmap layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  fig.update_layout(
454
  xaxis_title=x_col,
455
  yaxis_title=y_col,
456
+ coloraxis_colorbar=dict(
457
+ title="Count",
458
+ thicknessmode="pixels", thickness=20,
459
+ lenmode="pixels", len=300,
460
+ yanchor="top", y=1,
461
+ ticks="outside"
 
 
462
  )
463
  )
464
 
465
+ elif viz_type == 'scatter' and len(numeric_cols) >= 2:
466
+ # For scatter plots, we need at least 2 numeric columns
467
+ x_col = numeric_cols[0]
468
+ y_col = numeric_cols[1]
469
+
470
+ # Add a third dimension (size) if available
471
+ size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
472
+
473
+ # Add a color dimension if available
474
+ if len(result_df.columns) > len(numeric_cols):
475
+ # Find a categorical column for color
476
+ categorical_cols = [col for col in result_df.columns if col not in numeric_cols]
477
+ color_col = categorical_cols[0] if categorical_cols else None
478
+ else:
479
+ color_col = None
480
+
481
+ # Create scatter plot with enhanced features
482
+ fig = px.scatter(
483
+ result_df,
484
+ x=x_col,
485
+ y=y_col,
486
+ size=size_col,
487
+ color=color_col, # Add color dimension if available
488
+ title=f"Relationship between {x_col} and {y_col}",
489
+ opacity=0.7,
490
+ size_max=15, # Maximum marker size
491
+ color_discrete_sequence=px.colors.qualitative.Plotly
492
+ )
493
+
494
+ # Add a trend line
495
+ if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]):
496
  fig.update_layout(
497
+ shapes=[
498
+ dict(
499
+ type='line',
500
+ xref='x', yref='y',
501
+ x0=result_df[x_col].min(),
502
+ y0=result_df[y_col].min(),
503
+ x1=result_df[x_col].max(),
504
+ y1=result_df[y_col].max(),
505
+ line=dict(color='red', width=2, dash='dash')
506
+ )
507
+ ]
508
  )
509
 
510
+ # Improve scatter plot layout
511
+ fig.update_layout(
512
+ xaxis_title=x_col,
513
+ yaxis_title=y_col,
514
+ showlegend=True,
515
+ legend=dict(
516
+ title=color_col if color_col else "",
517
+ orientation="h",
518
+ yanchor="bottom",
519
+ y=1.02,
520
+ xanchor="right",
521
+ x=1
 
 
 
 
 
 
 
522
  )
523
+ )
524
+
525
+ elif viz_type == 'line':
526
+ # For line charts, determine the x-axis (preferably a date/time column)
527
+ time_cols = [col for col in result_df.columns if any(time_word in col.lower()
528
+ for time_word in ['date', 'time', 'month', 'year', 'day'])]
529
+
530
+ if time_cols:
531
+ x_col = time_cols[0]
532
+ else:
533
+ x_col = result_df.columns[0]
534
 
535
+ # Determine y-axis columns (numeric columns)
536
+ y_cols = numeric_cols[:3] # Use up to 3 numeric columns
537
+
538
+ if not y_cols and len(result_df.columns) > 1:
539
+ # If no numeric columns, use the second column
540
+ y_cols = [result_df.columns[1]]
541
+
542
+ fig = px.line(
543
+ result_df,
544
+ x=x_col,
545
+ y=y_cols,
546
+ title="Time Series Analysis",
547
+ markers=True, # Add markers at each data point
548
+ color_discrete_sequence=px.colors.qualitative.Plotly
549
+ )
550
+
551
+ # Add range slider for time series
552
  fig.update_layout(
553
+ xaxis=dict(
554
+ rangeslider=dict(visible=True),
555
+ type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-'
556
+ )
 
 
 
 
 
557
  )
558
 
559
+ else: # Default to bar chart
560
+ # For bar charts, use the first column as x and numeric columns as y
561
+ x_col = result_df.columns[0]
 
562
 
563
+ # Determine y-axis columns (numeric columns)
564
+ if numeric_cols and x_col not in numeric_cols:
565
+ y_cols = numeric_cols[:3] # Use up to 3 numeric columns
566
+ elif len(result_df.columns) > 1:
567
+ y_cols = [result_df.columns[1]]
568
+ else:
569
+ y_cols = ['value']
570
+ result_df['value'] = 1 # Default value if no suitable column
571
 
572
+ fig = px.bar(
573
+ result_df,
574
+ x=x_col,
575
+ y=y_cols[0], # Use only the first y column for bar charts
576
+ title="Data Visualization",
577
+ color_discrete_sequence=['#636EFA']
578
+ )
579
+
580
+ # Improve figure layout for all chart types
581
+ fig.update_layout(
582
+ autosize=True,
583
+ width=fig_width,
584
+ height=fig_height,
585
+ margin=dict(l=50, r=50, b=100, t=100, pad=4),
586
+ template="plotly_white",
587
+ font=dict(size=14),
588
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
589
+ plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
590
+ paper_bgcolor='white'
591
+ )
592
+
593
+ # Convert the figure to an image and encode it as base64
594
+ img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2)
595
+ encoded = base64.b64encode(img_bytes).decode("ascii")
596
+ img_src = f"data:image/png;base64,{encoded}"
597
+
598
+ # Add the image directly to the response with increased size
599
+ response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />"
600
+
601
+ # Add note about visualization
602
+ response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
603
+
604
  except Exception as viz_error:
605
  print(f"Visualization error: {str(viz_error)}")
606
  traceback.print_exc()