Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -606,30 +606,126 @@ def run_classify_data(input_type, input_data, description, categories,
|
|
| 606 |
return None, None, None, None, f"Error: {str(e)}"
|
| 607 |
|
| 608 |
|
| 609 |
-
def
|
| 610 |
-
"""
|
| 611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
-
dist_data = []
|
| 614 |
total_rows = len(result_df)
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
plt.tight_layout()
|
| 635 |
return fig
|
|
@@ -1256,7 +1352,9 @@ with col_input:
|
|
| 1256 |
'pdf_path': pdf_path,
|
| 1257 |
'code': code,
|
| 1258 |
'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
|
| 1259 |
-
'categories': categories_entered
|
|
|
|
|
|
|
| 1260 |
}
|
| 1261 |
st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
|
| 1262 |
st.rerun()
|
|
@@ -1270,7 +1368,12 @@ with col_output:
|
|
| 1270 |
results = st.session_state.results
|
| 1271 |
|
| 1272 |
# Distribution chart
|
| 1273 |
-
fig = create_distribution_chart(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1274 |
st.pyplot(fig)
|
| 1275 |
st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
|
| 1276 |
|
|
|
|
| 606 |
return None, None, None, None, f"Error: {str(e)}"
|
| 607 |
|
| 608 |
|
| 609 |
+
def sanitize_model_name(model: str) -> str:
|
| 610 |
+
"""Convert model name to column-safe suffix (matches catllm logic)."""
|
| 611 |
+
import re
|
| 612 |
+
sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
|
| 613 |
+
sanitized = re.sub(r'_+', '_', sanitized)
|
| 614 |
+
sanitized = sanitized.strip('_').lower()
|
| 615 |
+
return sanitized[:40]
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None):
|
| 619 |
+
"""Create a bar chart showing category distribution.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
result_df: DataFrame with classification results
|
| 623 |
+
categories: List of category names
|
| 624 |
+
classify_mode: "Single Model", "Model Comparison", or "Ensemble"
|
| 625 |
+
models_list: List of model names (for multi-model modes)
|
| 626 |
+
"""
|
| 627 |
+
import numpy as np
|
| 628 |
|
|
|
|
| 629 |
total_rows = len(result_df)
|
| 630 |
+
if total_rows == 0:
|
| 631 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 632 |
+
ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
|
| 633 |
+
ax.axis('off')
|
| 634 |
+
return fig
|
| 635 |
+
|
| 636 |
+
# Define colors for different models
|
| 637 |
+
model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d']
|
| 638 |
+
|
| 639 |
+
if classify_mode == "Single Model":
|
| 640 |
+
# Single model: use category_1, category_2, etc.
|
| 641 |
+
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
|
| 642 |
+
|
| 643 |
+
dist_data = []
|
| 644 |
+
for i, cat in enumerate(categories, 1):
|
| 645 |
+
col_name = f"category_{i}"
|
| 646 |
+
if col_name in result_df.columns:
|
| 647 |
+
count = int(result_df[col_name].sum())
|
| 648 |
+
pct = (count / total_rows) * 100
|
| 649 |
+
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 650 |
|
| 651 |
+
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 652 |
+
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 653 |
|
| 654 |
+
bars = ax.barh(categories_list, percentages, color='#2563eb')
|
| 655 |
+
ax.set_xlim(0, 100)
|
| 656 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 657 |
+
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
|
| 658 |
|
| 659 |
+
for bar, pct in zip(bars, percentages):
|
| 660 |
+
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 661 |
+
f'{pct:.1f}%', va='center', fontsize=10)
|
| 662 |
+
|
| 663 |
+
elif classify_mode == "Ensemble":
|
| 664 |
+
# Ensemble: use category_1_consensus, category_2_consensus, etc.
|
| 665 |
+
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
|
| 666 |
+
|
| 667 |
+
dist_data = []
|
| 668 |
+
for i, cat in enumerate(categories, 1):
|
| 669 |
+
col_name = f"category_{i}_consensus"
|
| 670 |
+
if col_name in result_df.columns:
|
| 671 |
+
count = int(result_df[col_name].sum())
|
| 672 |
+
pct = (count / total_rows) * 100
|
| 673 |
+
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 674 |
+
|
| 675 |
+
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 676 |
+
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 677 |
+
|
| 678 |
+
bars = ax.barh(categories_list, percentages, color='#16a34a')
|
| 679 |
+
ax.set_xlim(0, 100)
|
| 680 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 681 |
+
ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold')
|
| 682 |
+
|
| 683 |
+
for bar, pct in zip(bars, percentages):
|
| 684 |
+
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 685 |
+
f'{pct:.1f}%', va='center', fontsize=10)
|
| 686 |
+
|
| 687 |
+
else: # Model Comparison
|
| 688 |
+
# Model Comparison: grouped bars for each model
|
| 689 |
+
if not models_list:
|
| 690 |
+
models_list = []
|
| 691 |
+
|
| 692 |
+
sanitized_names = [sanitize_model_name(m) for m in models_list]
|
| 693 |
+
n_models = len(sanitized_names)
|
| 694 |
+
n_categories = len(categories)
|
| 695 |
+
|
| 696 |
+
fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2)))
|
| 697 |
+
|
| 698 |
+
# Gather data for each model
|
| 699 |
+
bar_height = 0.8 / n_models
|
| 700 |
+
y_positions = np.arange(n_categories)
|
| 701 |
+
|
| 702 |
+
for model_idx, (model_name, sanitized) in enumerate(zip(models_list, sanitized_names)):
|
| 703 |
+
model_pcts = []
|
| 704 |
+
for i in range(1, n_categories + 1):
|
| 705 |
+
col_name = f"category_{i}_{sanitized}"
|
| 706 |
+
if col_name in result_df.columns:
|
| 707 |
+
count = int(result_df[col_name].sum())
|
| 708 |
+
pct = (count / total_rows) * 100
|
| 709 |
+
else:
|
| 710 |
+
pct = 0
|
| 711 |
+
model_pcts.append(pct)
|
| 712 |
+
|
| 713 |
+
# Reverse for horizontal bar chart
|
| 714 |
+
model_pcts = model_pcts[::-1]
|
| 715 |
+
offset = (model_idx - n_models / 2 + 0.5) * bar_height
|
| 716 |
+
color = model_colors[model_idx % len(model_colors)]
|
| 717 |
+
|
| 718 |
+
# Use shorter display name
|
| 719 |
+
display_name = model_name.split('/')[-1].split(':')[0][:20]
|
| 720 |
+
bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9,
|
| 721 |
+
label=display_name, color=color, alpha=0.85)
|
| 722 |
+
|
| 723 |
+
ax.set_yticks(y_positions)
|
| 724 |
+
ax.set_yticklabels(categories[::-1])
|
| 725 |
+
ax.set_xlim(0, 100)
|
| 726 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 727 |
+
ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold')
|
| 728 |
+
ax.legend(loc='lower right', fontsize=9)
|
| 729 |
|
| 730 |
plt.tight_layout()
|
| 731 |
return fig
|
|
|
|
| 1352 |
'pdf_path': pdf_path,
|
| 1353 |
'code': code,
|
| 1354 |
'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
|
| 1355 |
+
'categories': categories_entered,
|
| 1356 |
+
'classify_mode': classify_mode,
|
| 1357 |
+
'models_list': models_list,
|
| 1358 |
}
|
| 1359 |
st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
|
| 1360 |
st.rerun()
|
|
|
|
| 1368 |
results = st.session_state.results
|
| 1369 |
|
| 1370 |
# Distribution chart
|
| 1371 |
+
fig = create_distribution_chart(
|
| 1372 |
+
results['df'],
|
| 1373 |
+
results['categories'],
|
| 1374 |
+
classify_mode=results.get('classify_mode', 'Single Model'),
|
| 1375 |
+
models_list=results.get('models_list', [])
|
| 1376 |
+
)
|
| 1377 |
st.pyplot(fig)
|
| 1378 |
st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
|
| 1379 |
|