Ara Yeroyan commited on
Commit
5262a14
Β·
1 Parent(s): 02d7f4f

add retrieval visualisations

Browse files
Files changed (1) hide show
  1. app.py +216 -3
app.py CHANGED
@@ -10,10 +10,13 @@ import uuid
10
  import logging
11
  import traceback
12
  from pathlib import Path
13
-
 
14
 
15
  import streamlit as st
16
  from langchain_core.messages import HumanMessage, AIMessage
 
 
17
 
18
  from multi_agent_chatbot import get_multi_agent_chatbot
19
  from smart_chatbot import get_chatbot as get_smart_chatbot
@@ -273,6 +276,203 @@ def serialize_documents(sources):
273
 
274
  return serialized
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  @st.cache_data
277
  def load_filter_options():
278
  try:
@@ -607,14 +807,27 @@ def main():
607
  # Count unique filenames
608
  unique_filenames = set()
609
  for doc in sources:
610
- filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
 
611
  unique_filenames.add(filename)
612
 
613
  st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
614
  if len(unique_filenames) < len(sources):
615
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
616
 
617
- for i, doc in enumerate(sources): # Show top 10
 
 
 
 
 
 
 
 
 
 
 
 
618
  # Get relevance score and ID if available
619
  metadata = getattr(doc, 'metadata', {})
620
  score = metadata.get('reranked_score', metadata.get('original_score', None))
 
10
  import logging
11
  import traceback
12
  from pathlib import Path
13
+ from typing import List, Dict, Any
14
+ from collections import Counter
15
 
16
  import streamlit as st
17
  from langchain_core.messages import HumanMessage, AIMessage
18
+ import pandas as pd
19
+ import plotly.express as px
20
 
21
  from multi_agent_chatbot import get_multi_agent_chatbot
22
  from smart_chatbot import get_chatbot as get_smart_chatbot
 
276
 
277
  return serialized
278
 
279
+ def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
280
+ """Extract statistics from retrieved chunks."""
281
+ if not sources:
282
+ return {}
283
+
284
+ sources_list = []
285
+ years = []
286
+ filenames = []
287
+
288
+ for doc in sources:
289
+ metadata = getattr(doc, 'metadata', {})
290
+
291
+ # Extract source
292
+ source = metadata.get('source', 'Unknown')
293
+ sources_list.append(source)
294
+
295
+ # Extract year
296
+ year = metadata.get('year', 'Unknown')
297
+ if year and year != 'Unknown':
298
+ try:
299
+ # Convert to int first, then back to string to ensure it's a proper year
300
+ year_int = int(float(year)) # Handle both int and float strings
301
+ if 1900 <= year_int <= 2030: # Reasonable year range
302
+ years.append(str(year_int))
303
+ else:
304
+ years.append('Unknown')
305
+ except (ValueError, TypeError):
306
+ years.append('Unknown')
307
+ else:
308
+ years.append('Unknown')
309
+
310
+ # Extract filename
311
+ filename = metadata.get('filename', 'Unknown')
312
+ filenames.append(filename)
313
+
314
+ # Count occurrences
315
+ source_counts = Counter(sources_list)
316
+ year_counts = Counter(years)
317
+ filename_counts = Counter(filenames)
318
+
319
+ return {
320
+ 'total_chunks': len(sources),
321
+ 'unique_sources': len(source_counts),
322
+ 'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
323
+ 'unique_filenames': len(filename_counts),
324
+ 'source_distribution': dict(source_counts),
325
+ 'year_distribution': dict(year_counts),
326
+ 'filename_distribution': dict(filename_counts),
327
+ 'sources': sources_list,
328
+ 'years': years,
329
+ 'filenames': filenames
330
+ }
331
+
332
+ def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieved Chunks Statistics"):
333
+ """Display statistics as interactive charts for 10+ results."""
334
+ if not stats or stats.get('total_chunks', 0) == 0:
335
+ return
336
+
337
+ st.subheader(f"πŸ“Š {title}")
338
+
339
+ # Summary metrics
340
+ col1, col2, col3, col4 = st.columns(4)
341
+ with col1:
342
+ st.metric("Total Chunks", stats['total_chunks'])
343
+ with col2:
344
+ st.metric("Unique Sources", stats['unique_sources'])
345
+ with col3:
346
+ st.metric("Unique Years", stats['unique_years'])
347
+ with col4:
348
+ st.metric("Unique Files", stats['unique_filenames'])
349
+
350
+ # Charts side by side
351
+ col1, col2 = st.columns(2)
352
+
353
+ with col1:
354
+ # Source distribution chart
355
+ if stats['source_distribution']:
356
+ source_df = pd.DataFrame(
357
+ list(stats['source_distribution'].items()),
358
+ columns=['Source', 'Count']
359
+ )
360
+ fig_source = px.bar(
361
+ source_df,
362
+ x='Count',
363
+ y='Source',
364
+ orientation='h',
365
+ title='Distribution by Source',
366
+ color='Count',
367
+ color_continuous_scale='viridis'
368
+ )
369
+ fig_source.update_layout(height=400, showlegend=False)
370
+ st.plotly_chart(fig_source, use_container_width=True)
371
+
372
+ with col2:
373
+ # Year distribution chart
374
+ if stats['year_distribution']:
375
+ # Filter out 'Unknown' years for the chart
376
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
377
+ if year_dist_filtered:
378
+ year_df = pd.DataFrame(
379
+ list(year_dist_filtered.items()),
380
+ columns=['Year', 'Count']
381
+ )
382
+ # Sort by year as integer but keep as string for categorical display
383
+ year_df['Year_Int'] = year_df['Year'].astype(int)
384
+ year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
385
+
386
+ fig_year = px.bar(
387
+ year_df,
388
+ x='Year',
389
+ y='Count',
390
+ title='Distribution by Year',
391
+ color='Count',
392
+ color_continuous_scale='plasma'
393
+ )
394
+ # Ensure years are treated as categorical (discrete) not continuous
395
+ fig_year.update_xaxes(type='category')
396
+ fig_year.update_layout(height=400, showlegend=False)
397
+ st.plotly_chart(fig_year, use_container_width=True)
398
+ else:
399
+ st.info("No valid years found in the results")
400
+
401
+ def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieved Chunks Statistics"):
402
+ """Display statistics as tables for smaller results with fixed alignment."""
403
+ if not stats or stats.get('total_chunks', 0) == 0:
404
+ return
405
+
406
+ st.subheader(f"πŸ“Š {title}")
407
+
408
+ # Create a container with fixed height for alignment
409
+ stats_container = st.container()
410
+
411
+ with stats_container:
412
+ # Create 4 equal columns for consistent alignment
413
+ col1, col2, col3, col4 = st.columns(4)
414
+
415
+ with col1:
416
+ st.markdown("**πŸ“ˆ Summary**")
417
+ summary_data = {
418
+ "Metric": ["Total", "Sources", "Years", "Files"],
419
+ "Count": [
420
+ stats['total_chunks'],
421
+ stats['unique_sources'],
422
+ stats['unique_years'],
423
+ stats['unique_filenames']
424
+ ]
425
+ }
426
+ summary_df = pd.DataFrame(summary_data)
427
+ st.dataframe(summary_df, hide_index=True, use_container_width=True)
428
+
429
+ with col2:
430
+ st.markdown("**πŸ“‚ Sources**")
431
+ if stats['source_distribution']:
432
+ source_data = {
433
+ "Source": list(stats['source_distribution'].keys()),
434
+ "Count": list(stats['source_distribution'].values())
435
+ }
436
+ source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
437
+ st.dataframe(source_df, hide_index=True, use_container_width=True)
438
+ else:
439
+ st.write("No source data")
440
+
441
+ with col3:
442
+ st.markdown("**πŸ“… Years**")
443
+ if stats['year_distribution']:
444
+ year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
445
+ if year_dist_filtered:
446
+ year_data = {
447
+ "Year": list(year_dist_filtered.keys()),
448
+ "Count": list(year_dist_filtered.values())
449
+ }
450
+ year_df = pd.DataFrame(year_data)
451
+ # Sort by year as integer but display as string
452
+ year_df['Year_Int'] = year_df['Year'].astype(int)
453
+ year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
454
+ st.dataframe(year_df, hide_index=True, use_container_width=True)
455
+ else:
456
+ st.write("No year data")
457
+ else:
458
+ st.write("No year data")
459
+
460
+ with col4:
461
+ st.markdown("**πŸ“„ Files**")
462
+ if stats['filename_distribution']:
463
+ filename_items = list(stats['filename_distribution'].items())
464
+ filename_items.sort(key=lambda x: x[1], reverse=True)
465
+
466
+ # Show top files with truncated names
467
+ file_data = {
468
+ "File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
469
+ "Count": [c for f, c in filename_items[:5]]
470
+ }
471
+ file_df = pd.DataFrame(file_data)
472
+ st.dataframe(file_df, hide_index=True, use_container_width=True)
473
+ else:
474
+ st.write("No file data")
475
+
476
  @st.cache_data
477
  def load_filter_options():
478
  try:
 
807
  # Count unique filenames
808
  unique_filenames = set()
809
  for doc in sources:
810
+ metadata = getattr(doc, 'metadata', {})
811
+ filename = metadata.get('filename', 'Unknown')
812
  unique_filenames.add(filename)
813
 
814
  st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
815
  if len(unique_filenames) < len(sources):
816
  st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
817
 
818
+ # Extract and display statistics
819
+ stats = extract_chunk_statistics(sources)
820
+
821
+ # Show charts for 10+ results, tables for fewer
822
+ if len(sources) >= 10:
823
+ display_chunk_statistics_charts(stats, "Retrieved Documents Statistics")
824
+ else:
825
+ display_chunk_statistics_table(stats, "Retrieved Documents Statistics")
826
+
827
+ st.markdown("---")
828
+ st.markdown("### πŸ“„ Document Details")
829
+
830
+ for i, doc in enumerate(sources): # Show all documents
831
  # Get relevance score and ID if available
832
  metadata = getattr(doc, 'metadata', {})
833
  score = metadata.get('reranked_score', metadata.get('original_score', None))