Spaces:
Sleeping
Sleeping
File size: 6,434 Bytes
dce46c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
def create_cluster_browser_app():
"""
Create a simple Gradio app for browsing prompts by cluster from uploaded CSV file.
"""
def load_and_validate_csv(file) -> Tuple[Optional[pd.DataFrame], str, List[str], str]:
"""
Load and validate the uploaded CSV file.
Args:
file: Uploaded file object from Gradio
Returns:
Tuple of (dataframe, status_message, cluster_options, cluster_stats)
"""
if file is None:
return None, "Please upload a CSV file with 'prompt' and 'cluster' columns.", ["(No data loaded)"], ""
try:
df = pd.read_csv(file.name)
# Validate required columns
required_cols = ['prompt', 'cluster']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
return None, f"Missing required columns: {missing_cols}. Please ensure your CSV has 'prompt' and 'cluster' columns.", ["(No data loaded)"], ""
# Validate data types
if not pd.api.types.is_numeric_dtype(df['cluster']):
return None, "The 'cluster' column must contain numeric values.", ["(No data loaded)"], ""
# Get cluster options
unique_clusters = sorted(df['cluster'].unique())
cluster_options = ["(All Clusters)"] + [f"Cluster {c}" for c in unique_clusters]
# Get cluster statistics
stats = []
for cluster_num in unique_clusters:
count = len(df[df['cluster'] == cluster_num])
stats.append(f"Cluster {cluster_num}: {count} prompts")
total_prompts = len(df)
stats_text = f"**Total Prompts:** {total_prompts}\n\n**Cluster Distribution:**\n" + "\n".join(stats)
return df, f"✅ Successfully loaded {len(df)} prompts with {len(unique_clusters)} clusters.", cluster_options, stats_text
except Exception as e:
return None, f"Error loading CSV file: {str(e)}", ["(No data loaded)"], ""
def filter_by_cluster(df: pd.DataFrame, cluster_sel: str) -> pd.DataFrame:
"""Filter dataframe by selected cluster."""
if df is None or cluster_sel == "(All Clusters)" or cluster_sel == "(No data loaded)":
return df if df is not None else pd.DataFrame()
cluster_num = int(cluster_sel.split()[-1]) # Extract number from "Cluster X"
return df[df['cluster'] == cluster_num].reset_index(drop=True)
def format_prompt_cell(prompt_text: str) -> str:
"""Format a single prompt in its own cell."""
return f"""
<div style="
background: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 8px;
padding: 16px;
margin: 8px 0;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
">
<div style="font-size: 14px; line-height: 1.5; color: #333;">
{prompt_text}
</div>
</div>
"""
def format_prompts(df: pd.DataFrame) -> str:
"""Format all prompts in the dataframe as individual cells."""
if df is None or len(df) == 0:
return "No prompts to display."
formatted_prompts = []
for idx, row in df.iterrows():
prompt_text = str(row['prompt']).strip()
formatted_prompts.append(format_prompt_cell(prompt_text))
return "\n".join(formatted_prompts)
def on_file_upload(file):
"""Handle file upload and validation."""
df, status_msg, cluster_options, cluster_stats = load_and_validate_csv(file)
if df is not None:
# Show all prompts initially
prompts_html = format_prompts(df)
return df, status_msg, gr.Dropdown(choices=cluster_options, value="(All Clusters)", interactive=True), prompts_html, cluster_stats
else:
return None, status_msg, gr.Dropdown(choices=cluster_options, value="(No data loaded)", interactive=False), "No data loaded.", ""
def on_cluster_change(df, cluster_sel):
"""Handle cluster selection change."""
if df is None:
return "No data loaded."
filtered_df = filter_by_cluster(df, cluster_sel)
return format_prompts(filtered_df)
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("# Prompt Cluster Browser")
# Store the loaded dataframe
df_state = gr.State(None)
with gr.Row():
# Sidebar
with gr.Column(scale=1):
# File upload section
file_upload = gr.File(
label="Upload Clustered Prompts CSV",
file_types=[".csv"],
file_count="single"
)
# Status
status_md = gr.Markdown("Please upload a CSV file to get started.")
# Cluster statistics
stats_md = gr.Markdown("")
# Cluster selection
cluster_dropdown = gr.Dropdown(
["(No data loaded)"],
label="Select Cluster",
value="(No data loaded)",
interactive=False
)
# Main content area
with gr.Column(scale=3):
prompts_html = gr.HTML("Upload a CSV file to browse clusters")
# Connect event handlers
file_upload.change(
on_file_upload,
[file_upload],
[df_state, status_md, cluster_dropdown, prompts_html, stats_md]
)
cluster_dropdown.change(
on_cluster_change,
[df_state, cluster_dropdown],
[prompts_html]
)
return demo
def launch_cluster_browser():
"""
Launch the cluster browser app.
"""
app = create_cluster_browser_app()
app.launch()
if __name__ == "__main__":
launch_cluster_browser()
|