SVashishta1 commited on
Commit
b0db292
·
1 Parent(s): 028022d
Files changed (1) hide show
  1. app.py +134 -300
app.py CHANGED
@@ -299,311 +299,21 @@ def process_text_query(query, history):
299
  # Add visualization if requested
300
  if is_visualization and not result_df.empty:
301
  try:
302
- print("Visualization requested, attempting to create plot...")
 
303
 
304
- # Set common figure parameters
305
- fig_width = 1000
306
- fig_height = 700
307
-
308
- # Create the appropriate visualization based on type
309
- if viz_type == 'pie' and len(result_df) <= 20:
310
- # For pie charts, we need a category column and a value column
311
- category_col = result_df.columns[0]
312
- value_col = numeric_cols[0] if numeric_cols else result_df.columns[1]
313
-
314
- # Handle case where all columns are numeric
315
- if len(numeric_cols) == len(result_df.columns):
316
- category_col = result_df.index.name or 'index'
317
- result_df = result_df.reset_index()
318
-
319
- fig = px.pie(
320
- result_df,
321
- names=category_col,
322
- values=value_col,
323
- title=f"Distribution of {value_col} by {category_col}",
324
- hole=0.3, # Donut chart for better readability
325
- color_discrete_sequence=px.colors.qualitative.Pastel
326
- )
327
-
328
- elif viz_type == 'histogram' and len(result_df.columns) > 0:
329
- # For histograms, we need at least one column
330
-
331
- # Find the best column for histogram (prefer numeric)
332
- if numeric_cols:
333
- x_col = numeric_cols[0]
334
- else:
335
- x_col = result_df.columns[0]
336
-
337
- # Check if data is already binned
338
- if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns):
339
- # Data is pre-binned, use a bar chart
340
- bin_col = 'bin' if 'bin' in result_df.columns else 'range'
341
- count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1]
342
-
343
- fig = px.bar(
344
- result_df,
345
- x=bin_col,
346
- y=count_col,
347
- title=f"Histogram of {x_col}",
348
- labels={bin_col: x_col, count_col: 'Frequency'},
349
- color_discrete_sequence=['#636EFA']
350
- )
351
- else:
352
- # Create a proper histogram from raw data
353
- fig = px.histogram(
354
- result_df,
355
- x=x_col,
356
- title=f"Distribution of {x_col}",
357
- nbins=20,
358
- marginal="box", # Add a box plot on the margin
359
- color_discrete_sequence=['#636EFA'],
360
- opacity=0.8
361
- )
362
-
363
- # Improve histogram layout
364
- fig.update_layout(
365
- bargap=0.1, # Gap between bars
366
- xaxis_title=x_col,
367
- yaxis_title='Frequency',
368
- showlegend=True
369
- )
370
-
371
- elif viz_type == 'box' and numeric_cols:
372
- # For box plots, we need to handle the data differently
373
- # SQLite doesn't support window functions for percentiles
374
- # So we'll calculate the box plot statistics in Python
375
-
376
- # Get the numeric column to plot
377
- x_col = numeric_cols[0]
378
-
379
- # Create a box plot using plotly express
380
- fig = px.box(
381
- result_df,
382
- y=x_col,
383
- title=f"Box Plot of {x_col}",
384
- points="outliers", # Only show outlier points
385
- color_discrete_sequence=['#636EFA']
386
- )
387
-
388
- # Add a strip plot (individual points) on the side for better visualization
389
- fig.add_trace(
390
- px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0]
391
- )
392
-
393
- elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
394
- # For heatmaps, we need at least 2 numeric columns
395
-
396
- # If we have many numeric columns, create a correlation matrix
397
- if len(numeric_cols) >= 3:
398
- # Create a correlation matrix
399
- # First, drop any rows with NaN values in numeric columns
400
- clean_df = result_df[numeric_cols].dropna()
401
-
402
- if len(clean_df) > 1: # Need at least 2 rows for correlation
403
- corr_df = clean_df.corr()
404
-
405
- # Round to 2 decimal places for display
406
- corr_df = corr_df.round(2)
407
-
408
- fig = px.imshow(
409
- corr_df,
410
- title="Correlation Heatmap",
411
- color_continuous_scale='RdBu_r',
412
- text_auto=True, # Show correlation values
413
- aspect="auto",
414
- zmin=-1, zmax=1 # Set limits for correlation values
415
- )
416
-
417
- # Improve heatmap layout
418
- fig.update_layout(
419
- xaxis_title="Features",
420
- yaxis_title="Features",
421
- coloraxis_colorbar=dict(
422
- title="Correlation",
423
- thicknessmode="pixels", thickness=20,
424
- lenmode="pixels", len=300,
425
- yanchor="top", y=1,
426
- ticks="outside"
427
- )
428
- )
429
- else:
430
- # Not enough data for correlation
431
- fig = px.bar(
432
- pd.DataFrame({'Message': ['Not enough data for heatmap']}),
433
- title="Cannot create heatmap - insufficient data"
434
- )
435
- else:
436
- # If we only have 2 numeric columns, create a 2D histogram
437
- x_col = numeric_cols[0]
438
- y_col = numeric_cols[1]
439
-
440
- # Create a 2D histogram (heatmap)
441
- fig = px.density_heatmap(
442
- result_df,
443
- x=x_col,
444
- y=y_col,
445
- title=f"Density Heatmap of {x_col} vs {y_col}",
446
- color_continuous_scale='Viridis',
447
- nbinsx=20,
448
- nbinsy=20,
449
- marginal_x="histogram", # Add histograms on the margins
450
- marginal_y="histogram"
451
- )
452
-
453
- # Improve heatmap layout
454
- fig.update_layout(
455
- xaxis_title=x_col,
456
- yaxis_title=y_col,
457
- coloraxis_colorbar=dict(
458
- title="Count",
459
- thicknessmode="pixels", thickness=20,
460
- lenmode="pixels", len=300,
461
- yanchor="top", y=1,
462
- ticks="outside"
463
- )
464
- )
465
-
466
- elif viz_type == 'scatter' and len(numeric_cols) >= 2:
467
- # For scatter plots, we need at least 2 numeric columns
468
- x_col = numeric_cols[0]
469
- y_col = numeric_cols[1]
470
 
471
- # Add a third dimension (size) if available
472
- size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
473
-
474
- # Add a color dimension if available
475
- if len(result_df.columns) > len(numeric_cols):
476
- # Find a categorical column for color
477
- categorical_cols = [col for col in result_df.columns if col not in numeric_cols]
478
- color_col = categorical_cols[0] if categorical_cols else None
479
- else:
480
- color_col = None
481
-
482
- # Create scatter plot with enhanced features
483
- fig = px.scatter(
484
- result_df,
485
- x=x_col,
486
- y=y_col,
487
- size=size_col,
488
- color=color_col, # Add color dimension if available
489
- title=f"Relationship between {x_col} and {y_col}",
490
- opacity=0.7,
491
- size_max=15, # Maximum marker size
492
- color_discrete_sequence=px.colors.qualitative.Plotly
493
- )
494
-
495
- # Add a trend line
496
- if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]):
497
- fig.update_layout(
498
- shapes=[
499
- dict(
500
- type='line',
501
- xref='x', yref='y',
502
- x0=result_df[x_col].min(),
503
- y0=result_df[y_col].min(),
504
- x1=result_df[x_col].max(),
505
- y1=result_df[y_col].max(),
506
- line=dict(color='red', width=2, dash='dash')
507
- )
508
- ]
509
- )
510
-
511
- # Improve scatter plot layout
512
- fig.update_layout(
513
- xaxis_title=x_col,
514
- yaxis_title=y_col,
515
- showlegend=True,
516
- legend=dict(
517
- title=color_col if color_col else "",
518
- orientation="h",
519
- yanchor="bottom",
520
- y=1.02,
521
- xanchor="right",
522
- x=1
523
- )
524
- )
525
-
526
- elif viz_type == 'line':
527
- # For line charts, determine the x-axis (preferably a date/time column)
528
- time_cols = [col for col in result_df.columns if any(time_word in col.lower()
529
- for time_word in ['date', 'time', 'month', 'year', 'day'])]
530
-
531
- if time_cols:
532
- x_col = time_cols[0]
533
- else:
534
- x_col = result_df.columns[0]
535
-
536
- # Determine y-axis columns (numeric columns)
537
- y_cols = numeric_cols[:3] # Use up to 3 numeric columns
538
-
539
- if not y_cols and len(result_df.columns) > 1:
540
- # If no numeric columns, use the second column
541
- y_cols = [result_df.columns[1]]
542
-
543
- fig = px.line(
544
- result_df,
545
- x=x_col,
546
- y=y_cols,
547
- title="Time Series Analysis",
548
- markers=True, # Add markers at each data point
549
- color_discrete_sequence=px.colors.qualitative.Plotly
550
- )
551
-
552
- # Add range slider for time series
553
- fig.update_layout(
554
- xaxis=dict(
555
- rangeslider=dict(visible=True),
556
- type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-'
557
- )
558
- )
559
-
560
- else: # Default to bar chart
561
- # For bar charts, use the first column as x and numeric columns as y
562
- x_col = result_df.columns[0]
563
-
564
- # Determine y-axis columns (numeric columns)
565
- if numeric_cols and x_col not in numeric_cols:
566
- y_cols = numeric_cols[:3] # Use up to 3 numeric columns
567
- elif len(result_df.columns) > 1:
568
- y_cols = [result_df.columns[1]]
569
- else:
570
- y_cols = ['value']
571
- result_df['value'] = 1 # Default value if no suitable column
572
-
573
- fig = px.bar(
574
- result_df,
575
- x=x_col,
576
- y=y_cols[0], # Use only the first y column for bar charts
577
- title="Data Visualization",
578
- color_discrete_sequence=['#636EFA']
579
- )
580
-
581
- # Improve figure layout for all chart types
582
- fig.update_layout(
583
- autosize=True,
584
- width=fig_width,
585
- height=fig_height,
586
- margin=dict(l=50, r=50, b=100, t=100, pad=4),
587
- template="plotly_white",
588
- font=dict(size=14),
589
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
590
- plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
591
- paper_bgcolor='white'
592
- )
593
-
594
- # Convert the figure to an image and encode it as base64
595
- img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2)
596
- encoded = base64.b64encode(img_bytes).decode("ascii")
597
- img_src = f"data:image/png;base64,{encoded}"
598
-
599
- # Add the image directly to the response with increased size
600
- response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />"
601
-
602
- # Add note about visualization
603
- response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
604
 
605
  except Exception as viz_error:
606
  print(f"Visualization error: {str(viz_error)}")
 
607
  traceback.print_exc()
608
 
609
  except Exception as e:
@@ -910,6 +620,130 @@ except NameError as e:
910
  importlib.reload(backend.vector_db)
911
  from backend.vector_db import ChromaVectorDB
912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
913
  # Create Gradio interface
914
  with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
915
  gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")
 
299
  # Add visualization if requested
300
  if is_visualization and not result_df.empty:
301
  try:
302
+ # Generate visualization
303
+ viz_html = generate_visualization(result_df, query)
304
 
305
+ if viz_html:
306
+ # Add the visualization to the response
307
+ response += f"\n\n{viz_html}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ # Add note about visualization
310
+ response += "\n\n**A visualization has been generated and is displayed above.**"
311
+ else:
312
+ response += "\n\n**Could not generate visualization due to an error.**"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  except Exception as viz_error:
315
  print(f"Visualization error: {str(viz_error)}")
316
+ import traceback
317
  traceback.print_exc()
318
 
319
  except Exception as e:
 
620
  importlib.reload(backend.vector_db)
621
  from backend.vector_db import ChromaVectorDB
622
 
623
+ # Add this function to app.py
624
+ def generate_visualization(result_df, query):
625
+ """Generate a visualization based on the query and data"""
626
+ try:
627
+ print("Visualization requested, attempting to create plot...")
628
+
629
+ # Set common figure parameters
630
+ fig_width = 1000
631
+ fig_height = 700
632
+
633
+ # Determine visualization type from query
634
+ viz_type = 'bar' # Default
635
+
636
+ if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']):
637
+ viz_type = 'pie'
638
+ elif any(word in query.lower() for word in ['line', 'trend', 'time series']):
639
+ viz_type = 'line'
640
+ elif any(word in query.lower() for word in ['scatter', 'relationship']):
641
+ viz_type = 'scatter'
642
+ elif any(word in query.lower() for word in ['histogram', 'distribution of']):
643
+ viz_type = 'histogram'
644
+ elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']):
645
+ viz_type = 'box'
646
+ elif any(word in query.lower() for word in ['heatmap', 'correlation']):
647
+ viz_type = 'heatmap'
648
+
649
+ print(f"Creating {viz_type} visualization...")
650
+
651
+ # Find numeric columns
652
+ numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
653
+
654
+ # Create basic visualization based on type
655
+ if viz_type == 'pie' and len(result_df) <= 20:
656
+ # Simple pie chart
657
+ labels = result_df.iloc[:, 0].tolist()
658
+ values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df)
659
+
660
+ import plotly.graph_objects as go
661
+ fig = go.Figure(data=[go.Pie(labels=labels, values=values)])
662
+ fig.update_layout(title_text='Pie Chart')
663
+
664
+ elif viz_type == 'histogram' and len(numeric_cols) > 0:
665
+ # Simple histogram
666
+ import plotly.express as px
667
+ fig = px.histogram(result_df, x=numeric_cols[0])
668
+ fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}')
669
+
670
+ elif viz_type == 'box' and len(numeric_cols) > 0:
671
+ # Simple box plot
672
+ import plotly.express as px
673
+ fig = px.box(result_df, y=numeric_cols[0])
674
+ fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}')
675
+
676
+ elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
677
+ # Simple heatmap
678
+ import plotly.express as px
679
+ # Create correlation matrix
680
+ corr_df = result_df[numeric_cols].corr()
681
+ fig = px.imshow(corr_df, text_auto=True)
682
+ fig.update_layout(title_text='Correlation Heatmap')
683
+
684
+ elif viz_type == 'scatter' and len(numeric_cols) >= 2:
685
+ # Simple scatter plot
686
+ import plotly.express as px
687
+ fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1])
688
+ fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}')
689
+
690
+ elif viz_type == 'line':
691
+ # Simple line chart
692
+ import plotly.express as px
693
+ x_col = result_df.columns[0]
694
+ y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None
695
+
696
+ if y_cols:
697
+ fig = px.line(result_df, x=x_col, y=y_cols[0])
698
+ fig.update_layout(title_text=f'Line Chart of {y_cols[0]} over {x_col}')
699
+ else:
700
+ # Fallback to bar chart
701
+ viz_type = 'bar'
702
+
703
+ if viz_type == 'bar' or 'fig' not in locals():
704
+ # Simple bar chart (default)
705
+ import plotly.express as px
706
+ x_col = result_df.columns[0]
707
+ y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None
708
+
709
+ if y_col:
710
+ fig = px.bar(result_df, x=x_col, y=y_col)
711
+ fig.update_layout(title_text=f'Bar Chart of {y_col} by {x_col}')
712
+ else:
713
+ fig = px.bar(result_df, x=x_col)
714
+ fig.update_layout(title_text=f'Bar Chart of {x_col}')
715
+
716
+ # Set common layout properties
717
+ fig.update_layout(
718
+ width=fig_width,
719
+ height=fig_height,
720
+ template="plotly_white"
721
+ )
722
+
723
+ print(f"Created figure with width={fig_width}, height={fig_height}")
724
+
725
+ # Convert to image
726
+ print("Converting figure to image...")
727
+ import plotly.io as pio
728
+ img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=2)
729
+ print("Image conversion successful")
730
+
731
+ # Encode as base64
732
+ import base64
733
+ encoded = base64.b64encode(img_bytes).decode("ascii")
734
+ img_src = f"data:image/png;base64,{encoded}"
735
+
736
+ print("HTML conversion successful")
737
+
738
+ # Return the HTML img tag
739
+ return f"<img src='{img_src}' width='100%' style='min-height:700px;' />"
740
+
741
+ except Exception as e:
742
+ import traceback
743
+ print(f"Error generating visualization: {str(e)}")
744
+ traceback.print_exc()
745
+ return None
746
+
747
  # Create Gradio interface
748
  with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
749
  gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")