chrissoria commited on
Commit
c04a288
·
verified ·
1 Parent(s): 0481916

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +124 -21
app.py CHANGED
@@ -606,30 +606,126 @@ def run_classify_data(input_type, input_data, description, categories,
606
  return None, None, None, None, f"Error: {str(e)}"
607
 
608
 
609
- def create_distribution_chart(result_df, categories):
610
- """Create a bar chart showing category distribution."""
611
- fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
 
613
- dist_data = []
614
  total_rows = len(result_df)
615
- for i, cat in enumerate(categories, 1):
616
- col_name = f"category_{i}"
617
- if col_name in result_df.columns:
618
- count = int(result_df[col_name].sum())
619
- pct = (count / total_rows) * 100 if total_rows > 0 else 0
620
- dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- categories_list = [d["Category"] for d in dist_data][::-1]
623
- percentages = [d["Percentage"] for d in dist_data][::-1]
624
 
625
- bars = ax.barh(categories_list, percentages, color='#2563eb')
626
- ax.set_xlim(0, 100)
627
- ax.set_xlabel('Percentage (%)', fontsize=11)
628
- ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
629
 
630
- for bar, pct in zip(bars, percentages):
631
- ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
632
- f'{pct:.1f}%', va='center', fontsize=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
  plt.tight_layout()
635
  return fig
@@ -1256,7 +1352,9 @@ with col_input:
1256
  'pdf_path': pdf_path,
1257
  'code': code,
1258
  'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
1259
- 'categories': categories_entered
 
 
1260
  }
1261
  st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
1262
  st.rerun()
@@ -1270,7 +1368,12 @@ with col_output:
1270
  results = st.session_state.results
1271
 
1272
  # Distribution chart
1273
- fig = create_distribution_chart(results['df'], results['categories'])
 
 
 
 
 
1274
  st.pyplot(fig)
1275
  st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
1276
 
 
606
  return None, None, None, None, f"Error: {str(e)}"
607
 
608
 
609
+ def sanitize_model_name(model: str) -> str:
610
+ """Convert model name to column-safe suffix (matches catllm logic)."""
611
+ import re
612
+ sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
613
+ sanitized = re.sub(r'_+', '_', sanitized)
614
+ sanitized = sanitized.strip('_').lower()
615
+ return sanitized[:40]
616
+
617
+
618
+ def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None):
619
+ """Create a bar chart showing category distribution.
620
+
621
+ Args:
622
+ result_df: DataFrame with classification results
623
+ categories: List of category names
624
+ classify_mode: "Single Model", "Model Comparison", or "Ensemble"
625
+ models_list: List of model names (for multi-model modes)
626
+ """
627
+ import numpy as np
628
 
 
629
  total_rows = len(result_df)
630
+ if total_rows == 0:
631
+ fig, ax = plt.subplots(figsize=(10, 4))
632
+ ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
633
+ ax.axis('off')
634
+ return fig
635
+
636
+ # Define colors for different models
637
+ model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d']
638
+
639
+ if classify_mode == "Single Model":
640
+ # Single model: use category_1, category_2, etc.
641
+ fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
642
+
643
+ dist_data = []
644
+ for i, cat in enumerate(categories, 1):
645
+ col_name = f"category_{i}"
646
+ if col_name in result_df.columns:
647
+ count = int(result_df[col_name].sum())
648
+ pct = (count / total_rows) * 100
649
+ dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
650
 
651
+ categories_list = [d["Category"] for d in dist_data][::-1]
652
+ percentages = [d["Percentage"] for d in dist_data][::-1]
653
 
654
+ bars = ax.barh(categories_list, percentages, color='#2563eb')
655
+ ax.set_xlim(0, 100)
656
+ ax.set_xlabel('Percentage (%)', fontsize=11)
657
+ ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
658
 
659
+ for bar, pct in zip(bars, percentages):
660
+ ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
661
+ f'{pct:.1f}%', va='center', fontsize=10)
662
+
663
+ elif classify_mode == "Ensemble":
664
+ # Ensemble: use category_1_consensus, category_2_consensus, etc.
665
+ fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
666
+
667
+ dist_data = []
668
+ for i, cat in enumerate(categories, 1):
669
+ col_name = f"category_{i}_consensus"
670
+ if col_name in result_df.columns:
671
+ count = int(result_df[col_name].sum())
672
+ pct = (count / total_rows) * 100
673
+ dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
674
+
675
+ categories_list = [d["Category"] for d in dist_data][::-1]
676
+ percentages = [d["Percentage"] for d in dist_data][::-1]
677
+
678
+ bars = ax.barh(categories_list, percentages, color='#16a34a')
679
+ ax.set_xlim(0, 100)
680
+ ax.set_xlabel('Percentage (%)', fontsize=11)
681
+ ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold')
682
+
683
+ for bar, pct in zip(bars, percentages):
684
+ ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
685
+ f'{pct:.1f}%', va='center', fontsize=10)
686
+
687
+ else: # Model Comparison
688
+ # Model Comparison: grouped bars for each model
689
+ if not models_list:
690
+ models_list = []
691
+
692
+ sanitized_names = [sanitize_model_name(m) for m in models_list]
693
+ n_models = len(sanitized_names)
694
+ n_categories = len(categories)
695
+
696
+ fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2)))
697
+
698
+ # Gather data for each model
699
+ bar_height = 0.8 / n_models
700
+ y_positions = np.arange(n_categories)
701
+
702
+ for model_idx, (model_name, sanitized) in enumerate(zip(models_list, sanitized_names)):
703
+ model_pcts = []
704
+ for i in range(1, n_categories + 1):
705
+ col_name = f"category_{i}_{sanitized}"
706
+ if col_name in result_df.columns:
707
+ count = int(result_df[col_name].sum())
708
+ pct = (count / total_rows) * 100
709
+ else:
710
+ pct = 0
711
+ model_pcts.append(pct)
712
+
713
+ # Reverse for horizontal bar chart
714
+ model_pcts = model_pcts[::-1]
715
+ offset = (model_idx - n_models / 2 + 0.5) * bar_height
716
+ color = model_colors[model_idx % len(model_colors)]
717
+
718
+ # Use shorter display name
719
+ display_name = model_name.split('/')[-1].split(':')[0][:20]
720
+ bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9,
721
+ label=display_name, color=color, alpha=0.85)
722
+
723
+ ax.set_yticks(y_positions)
724
+ ax.set_yticklabels(categories[::-1])
725
+ ax.set_xlim(0, 100)
726
+ ax.set_xlabel('Percentage (%)', fontsize=11)
727
+ ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold')
728
+ ax.legend(loc='lower right', fontsize=9)
729
 
730
  plt.tight_layout()
731
  return fig
 
1352
  'pdf_path': pdf_path,
1353
  'code': code,
1354
  'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
1355
+ 'categories': categories_entered,
1356
+ 'classify_mode': classify_mode,
1357
+ 'models_list': models_list,
1358
  }
1359
  st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
1360
  st.rerun()
 
1368
  results = st.session_state.results
1369
 
1370
  # Distribution chart
1371
+ fig = create_distribution_chart(
1372
+ results['df'],
1373
+ results['categories'],
1374
+ classify_mode=results.get('classify_mode', 'Single Model'),
1375
+ models_list=results.get('models_list', [])
1376
+ )
1377
  st.pyplot(fig)
1378
  st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
1379