SVashishta1 commited on
Commit
18187e5
·
1 Parent(s): 80b8363
Files changed (1) hide show
  1. app.py +148 -53
app.py CHANGED
@@ -300,16 +300,9 @@ def process_text_query(query, history):
300
  try:
301
  print("Visualization requested, attempting to create plot...")
302
 
303
- # Improved visualization type detection
304
- viz_keywords = {
305
- 'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
306
- 'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
307
- 'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
308
- 'histogram': ['histogram', 'distribution of', 'frequency distribution'],
309
- 'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
310
- 'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
311
- 'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
312
- }
313
 
314
  # Determine visualization type from query
315
  viz_type = None
@@ -334,10 +327,6 @@ def process_text_query(query, history):
334
 
335
  # Create the appropriate visualization based on type
336
  if len(numeric_cols) >= 1 and len(result_df) > 1:
337
- # Set common figure parameters
338
- fig_width = 900
339
- fig_height = 600
340
-
341
  if viz_type == 'pie' and len(result_df) <= 20:
342
  # For pie charts, we need a category column and a value column
343
  category_col = result_df.columns[0]
@@ -357,35 +346,72 @@ def process_text_query(query, history):
357
  color_discrete_sequence=px.colors.qualitative.Pastel
358
  )
359
 
360
- elif viz_type == 'histogram' and numeric_cols:
361
- # For histograms, manually create the bins if not already binned
 
 
 
362
  x_col = numeric_cols[0]
 
 
363
 
364
- if len(result_df[x_col].unique()) > 20:
365
- # Data is not pre-binned, create a histogram
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  fig = px.histogram(
367
  result_df,
368
  x=x_col,
369
  title=f"Distribution of {x_col}",
370
  nbins=20,
371
  marginal="box", # Add a box plot on the margin
372
- color_discrete_sequence=['#636EFA']
 
 
373
  )
374
- else:
375
- # Data is likely pre-binned, use a bar chart
376
- x_col = result_df.columns[0]
377
- y_col = numeric_cols[0] if x_col not in numeric_cols else numeric_cols[1] if len(numeric_cols) > 1 else 'count'
378
 
379
- if y_col == 'count' and 'count' not in result_df.columns:
380
- # If we need a count column but don't have one
381
- result_df['count'] = 1
382
 
383
- fig = px.bar(
384
- result_df,
385
- x=x_col,
386
- y=y_col,
387
- title=f"Histogram of {x_col}",
388
- color_discrete_sequence=['#636EFA']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  )
390
 
391
  elif viz_type == 'box' and numeric_cols:
@@ -416,7 +442,14 @@ def process_text_query(query, history):
416
  # If we have many numeric columns, create a correlation matrix
417
  if len(numeric_cols) >= 3:
418
  # Create a correlation matrix
419
- corr_df = result_df[numeric_cols].corr()
 
 
 
 
 
 
 
420
 
421
  fig = px.imshow(
422
  corr_df,
@@ -426,13 +459,31 @@ def process_text_query(query, history):
426
  aspect="auto",
427
  zmin=-1, zmax=1 # Set limits for correlation values
428
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  else:
430
- # If we only have 2 numeric columns, we need to bin the data
431
- # to create a 2D histogram (heatmap)
432
  x_col = numeric_cols[0]
433
  y_col = numeric_cols[1]
434
 
435
- # Create a 2D histogram
436
  fig = px.density_heatmap(
437
  result_df,
438
  x=x_col,
@@ -440,7 +491,22 @@ def process_text_query(query, history):
440
  title=f"Density Heatmap of {x_col} vs {y_col}",
441
  color_continuous_scale='Viridis',
442
  nbinsx=20,
443
- nbinsy=20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  )
445
 
446
  elif viz_type == 'scatter' and len(numeric_cols) >= 2:
@@ -451,29 +517,56 @@ def process_text_query(query, history):
451
  # Add a third dimension (size) if available
452
  size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
453
 
 
 
 
 
 
 
 
 
 
454
  fig = px.scatter(
455
  result_df,
456
  x=x_col,
457
  y=y_col,
458
  size=size_col,
 
459
  title=f"Relationship between {x_col} and {y_col}",
460
  opacity=0.7,
461
- color_discrete_sequence=['#636EFA']
 
462
  )
463
 
464
  # Add a trend line
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  fig.update_layout(
466
- shapes=[
467
- dict(
468
- type='line',
469
- xref='x', yref='y',
470
- x0=result_df[x_col].min(),
471
- y0=result_df[y_col].min(),
472
- x1=result_df[x_col].max(),
473
- y1=result_df[y_col].max(),
474
- line=dict(color='red', width=2, dash='dash')
475
- )
476
- ]
477
  )
478
 
479
  elif viz_type == 'line':
@@ -539,7 +632,9 @@ def process_text_query(query, history):
539
  margin=dict(l=50, r=50, b=100, t=100, pad=4),
540
  template="plotly_white",
541
  font=dict(size=14),
542
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
 
 
543
  )
544
 
545
  # Convert the figure to an image and encode it as base64
@@ -547,8 +642,8 @@ def process_text_query(query, history):
547
  encoded = base64.b64encode(img_bytes).decode("ascii")
548
  img_src = f"data:image/png;base64,{encoded}"
549
 
550
- # Add the image directly to the response
551
- response += f"\n\n<img src='{img_src}' width='100%' />"
552
 
553
  # Add note about visualization
554
  response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
@@ -880,7 +975,7 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
880
  show_label=False
881
  )
882
  with gr.Column(scale=1):
883
- voice_btn = gr.Button("🎤")
884
 
885
  with gr.Row():
886
  submit_btn = gr.Button("Submit")
 
300
  try:
301
  print("Visualization requested, attempting to create plot...")
302
 
303
+ # Increase plot size
304
+ fig_width = 1000 # Increased from 900
305
+ fig_height = 700 # Increased from 600
 
 
 
 
 
 
 
306
 
307
  # Determine visualization type from query
308
  viz_type = None
 
327
 
328
  # Create the appropriate visualization based on type
329
  if len(numeric_cols) >= 1 and len(result_df) > 1:
 
 
 
 
330
  if viz_type == 'pie' and len(result_df) <= 20:
331
  # For pie charts, we need a category column and a value column
332
  category_col = result_df.columns[0]
 
346
  color_discrete_sequence=px.colors.qualitative.Pastel
347
  )
348
 
349
+ elif viz_type == 'histogram' and len(result_df.columns) > 0:
350
+ # For histograms, we need at least one column
351
+
352
+ # Find the best column for histogram (prefer numeric)
353
+ if numeric_cols:
354
  x_col = numeric_cols[0]
355
+ else:
356
+ x_col = result_df.columns[0]
357
 
358
+ # Check if data is already binned
359
+ if len(result_df) <= 30 and 'bin' in result_df.columns or 'range' in result_df.columns:
360
+ # Data is pre-binned, use a bar chart
361
+ bin_col = 'bin' if 'bin' in result_df.columns else 'range'
362
+ count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1]
363
+
364
+ fig = px.bar(
365
+ result_df,
366
+ x=bin_col,
367
+ y=count_col,
368
+ title=f"Histogram of {x_col}",
369
+ labels={bin_col: x_col, count_col: 'Frequency'},
370
+ color_discrete_sequence=['#636EFA']
371
+ )
372
+ else:
373
+ # Create a proper histogram from raw data
374
  fig = px.histogram(
375
  result_df,
376
  x=x_col,
377
  title=f"Distribution of {x_col}",
378
  nbins=20,
379
  marginal="box", # Add a box plot on the margin
380
+ color_discrete_sequence=['#636EFA'],
381
+ opacity=0.8,
382
+ histnorm='probability density' # Normalize to show density instead of count
383
  )
 
 
 
 
384
 
385
+ # Add a KDE (kernel density estimate) curve
386
+ from scipy import stats
387
+ import numpy as np
388
 
389
+ # Only add KDE if we have numeric data
390
+ if pd.api.types.is_numeric_dtype(result_df[x_col]):
391
+ # Remove NaN values
392
+ data = result_df[x_col].dropna()
393
+
394
+ if len(data) > 1: # Need at least 2 points for KDE
395
+ # Calculate KDE
396
+ kde = stats.gaussian_kde(data)
397
+ x_range = np.linspace(data.min(), data.max(), 1000)
398
+ y_kde = kde(x_range)
399
+
400
+ # Add KDE curve
401
+ fig.add_scatter(
402
+ x=x_range,
403
+ y=y_kde,
404
+ mode='lines',
405
+ line=dict(color='red', width=2),
406
+ name='Density Curve'
407
+ )
408
+
409
+ # Improve histogram layout
410
+ fig.update_layout(
411
+ bargap=0.1, # Gap between bars
412
+ xaxis_title=x_col,
413
+ yaxis_title='Frequency',
414
+ showlegend=True
415
  )
416
 
417
  elif viz_type == 'box' and numeric_cols:
 
442
  # If we have many numeric columns, create a correlation matrix
443
  if len(numeric_cols) >= 3:
444
  # Create a correlation matrix
445
+ # First, drop any rows with NaN values in numeric columns
446
+ clean_df = result_df[numeric_cols].dropna()
447
+
448
+ if len(clean_df) > 1: # Need at least 2 rows for correlation
449
+ corr_df = clean_df.corr()
450
+
451
+ # Round to 2 decimal places for display
452
+ corr_df = corr_df.round(2)
453
 
454
  fig = px.imshow(
455
  corr_df,
 
459
  aspect="auto",
460
  zmin=-1, zmax=1 # Set limits for correlation values
461
  )
462
+
463
+ # Improve heatmap layout
464
+ fig.update_layout(
465
+ xaxis_title="Features",
466
+ yaxis_title="Features",
467
+ coloraxis_colorbar=dict(
468
+ title="Correlation",
469
+ thicknessmode="pixels", thickness=20,
470
+ lenmode="pixels", len=300,
471
+ yanchor="top", y=1,
472
+ ticks="outside"
473
+ )
474
+ )
475
+ else:
476
+ # Not enough data for correlation
477
+ fig = px.bar(
478
+ pd.DataFrame({'Message': ['Not enough data for heatmap']}),
479
+ title="Cannot create heatmap - insufficient data"
480
+ )
481
  else:
482
+ # If we only have 2 numeric columns, create a 2D histogram
 
483
  x_col = numeric_cols[0]
484
  y_col = numeric_cols[1]
485
 
486
+ # Create a 2D histogram (heatmap)
487
  fig = px.density_heatmap(
488
  result_df,
489
  x=x_col,
 
491
  title=f"Density Heatmap of {x_col} vs {y_col}",
492
  color_continuous_scale='Viridis',
493
  nbinsx=20,
494
+ nbinsy=20,
495
+ marginal_x="histogram", # Add histograms on the margins
496
+ marginal_y="histogram"
497
+ )
498
+
499
+ # Improve heatmap layout
500
+ fig.update_layout(
501
+ xaxis_title=x_col,
502
+ yaxis_title=y_col,
503
+ coloraxis_colorbar=dict(
504
+ title="Count",
505
+ thicknessmode="pixels", thickness=20,
506
+ lenmode="pixels", len=300,
507
+ yanchor="top", y=1,
508
+ ticks="outside"
509
+ )
510
  )
511
 
512
  elif viz_type == 'scatter' and len(numeric_cols) >= 2:
 
517
  # Add a third dimension (size) if available
518
  size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
519
 
520
+ # Add a color dimension if available
521
+ if len(result_df.columns) > len(numeric_cols):
522
+ # Find a categorical column for color
523
+ categorical_cols = [col for col in result_df.columns if col not in numeric_cols]
524
+ color_col = categorical_cols[0] if categorical_cols else None
525
+ else:
526
+ color_col = None
527
+
528
+ # Create scatter plot with enhanced features
529
  fig = px.scatter(
530
  result_df,
531
  x=x_col,
532
  y=y_col,
533
  size=size_col,
534
+ color=color_col, # Add color dimension if available
535
  title=f"Relationship between {x_col} and {y_col}",
536
  opacity=0.7,
537
+ size_max=15, # Maximum marker size
538
+ color_discrete_sequence=px.colors.qualitative.Plotly
539
  )
540
 
541
  # Add a trend line
542
+ if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]):
543
+ fig.update_layout(
544
+ shapes=[
545
+ dict(
546
+ type='line',
547
+ xref='x', yref='y',
548
+ x0=result_df[x_col].min(),
549
+ y0=result_df[y_col].min(),
550
+ x1=result_df[x_col].max(),
551
+ y1=result_df[y_col].max(),
552
+ line=dict(color='red', width=2, dash='dash')
553
+ )
554
+ ]
555
+ )
556
+
557
+ # Improve scatter plot layout
558
  fig.update_layout(
559
+ xaxis_title=x_col,
560
+ yaxis_title=y_col,
561
+ showlegend=True,
562
+ legend=dict(
563
+ title=color_col if color_col else "",
564
+ orientation="h",
565
+ yanchor="bottom",
566
+ y=1.02,
567
+ xanchor="right",
568
+ x=1
569
+ )
570
  )
571
 
572
  elif viz_type == 'line':
 
632
  margin=dict(l=50, r=50, b=100, t=100, pad=4),
633
  template="plotly_white",
634
  font=dict(size=14),
635
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
636
+ plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
637
+ paper_bgcolor='white'
638
  )
639
 
640
  # Convert the figure to an image and encode it as base64
 
642
  encoded = base64.b64encode(img_bytes).decode("ascii")
643
  img_src = f"data:image/png;base64,{encoded}"
644
 
645
+ # Add the image directly to the response with increased size
646
+ response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />"
647
 
648
  # Add note about visualization
649
  response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
 
975
  show_label=False
976
  )
977
  with gr.Column(scale=1):
978
+ voice_btn = gr.Button("��")
979
 
980
  with gr.Row():
981
  submit_btn = gr.Button("Submit")