lisabdunlap commited on
Commit
dce46c6
·
verified ·
1 Parent(s): a350ac9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+
6
+ def create_cluster_browser_app():
7
+ """
8
+ Create a simple Gradio app for browsing prompts by cluster from uploaded CSV file.
9
+ """
10
+
11
+ def load_and_validate_csv(file) -> Tuple[Optional[pd.DataFrame], str, List[str], str]:
12
+ """
13
+ Load and validate the uploaded CSV file.
14
+
15
+ Args:
16
+ file: Uploaded file object from Gradio
17
+
18
+ Returns:
19
+ Tuple of (dataframe, status_message, cluster_options, cluster_stats)
20
+ """
21
+ if file is None:
22
+ return None, "Please upload a CSV file with 'prompt' and 'cluster' columns.", ["(No data loaded)"], ""
23
+
24
+ try:
25
+ df = pd.read_csv(file.name)
26
+
27
+ # Validate required columns
28
+ required_cols = ['prompt', 'cluster']
29
+ missing_cols = [col for col in required_cols if col not in df.columns]
30
+
31
+ if missing_cols:
32
+ return None, f"Missing required columns: {missing_cols}. Please ensure your CSV has 'prompt' and 'cluster' columns.", ["(No data loaded)"], ""
33
+
34
+ # Validate data types
35
+ if not pd.api.types.is_numeric_dtype(df['cluster']):
36
+ return None, "The 'cluster' column must contain numeric values.", ["(No data loaded)"], ""
37
+
38
+ # Get cluster options
39
+ unique_clusters = sorted(df['cluster'].unique())
40
+ cluster_options = ["(All Clusters)"] + [f"Cluster {c}" for c in unique_clusters]
41
+
42
+ # Get cluster statistics
43
+ stats = []
44
+ for cluster_num in unique_clusters:
45
+ count = len(df[df['cluster'] == cluster_num])
46
+ stats.append(f"Cluster {cluster_num}: {count} prompts")
47
+
48
+ total_prompts = len(df)
49
+ stats_text = f"**Total Prompts:** {total_prompts}\n\n**Cluster Distribution:**\n" + "\n".join(stats)
50
+
51
+ return df, f"✅ Successfully loaded {len(df)} prompts with {len(unique_clusters)} clusters.", cluster_options, stats_text
52
+
53
+ except Exception as e:
54
+ return None, f"Error loading CSV file: {str(e)}", ["(No data loaded)"], ""
55
+
56
+ def filter_by_cluster(df: pd.DataFrame, cluster_sel: str) -> pd.DataFrame:
57
+ """Filter dataframe by selected cluster."""
58
+ if df is None or cluster_sel == "(All Clusters)" or cluster_sel == "(No data loaded)":
59
+ return df if df is not None else pd.DataFrame()
60
+
61
+ cluster_num = int(cluster_sel.split()[-1]) # Extract number from "Cluster X"
62
+ return df[df['cluster'] == cluster_num].reset_index(drop=True)
63
+
64
+ def format_prompt_cell(prompt_text: str) -> str:
65
+ """Format a single prompt in its own cell."""
66
+ return f"""
67
+ <div style="
68
+ background: #f8f9fa;
69
+ border: 1px solid #e9ecef;
70
+ border-radius: 8px;
71
+ padding: 16px;
72
+ margin: 8px 0;
73
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
74
+ ">
75
+ <div style="font-size: 14px; line-height: 1.5; color: #333;">
76
+ {prompt_text}
77
+ </div>
78
+ </div>
79
+ """
80
+
81
+ def format_prompts(df: pd.DataFrame) -> str:
82
+ """Format all prompts in the dataframe as individual cells."""
83
+ if df is None or len(df) == 0:
84
+ return "No prompts to display."
85
+
86
+ formatted_prompts = []
87
+ for idx, row in df.iterrows():
88
+ prompt_text = str(row['prompt']).strip()
89
+ formatted_prompts.append(format_prompt_cell(prompt_text))
90
+
91
+ return "\n".join(formatted_prompts)
92
+
93
+ def on_file_upload(file):
94
+ """Handle file upload and validation."""
95
+ df, status_msg, cluster_options, cluster_stats = load_and_validate_csv(file)
96
+
97
+ if df is not None:
98
+ # Show all prompts initially
99
+ prompts_html = format_prompts(df)
100
+ return df, status_msg, gr.Dropdown(choices=cluster_options, value="(All Clusters)", interactive=True), prompts_html, cluster_stats
101
+ else:
102
+ return None, status_msg, gr.Dropdown(choices=cluster_options, value="(No data loaded)", interactive=False), "No data loaded.", ""
103
+
104
+ def on_cluster_change(df, cluster_sel):
105
+ """Handle cluster selection change."""
106
+ if df is None:
107
+ return "No data loaded."
108
+
109
+ filtered_df = filter_by_cluster(df, cluster_sel)
110
+ return format_prompts(filtered_df)
111
+
112
+ # Create the Gradio interface
113
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
114
+ gr.Markdown("# Prompt Cluster Browser")
115
+
116
+ # Store the loaded dataframe
117
+ df_state = gr.State(None)
118
+
119
+ with gr.Row():
120
+ # Sidebar
121
+ with gr.Column(scale=1):
122
+ # File upload section
123
+ file_upload = gr.File(
124
+ label="Upload Clustered Prompts CSV",
125
+ file_types=[".csv"],
126
+ file_count="single"
127
+ )
128
+
129
+ # Status
130
+ status_md = gr.Markdown("Please upload a CSV file to get started.")
131
+
132
+ # Cluster statistics
133
+ stats_md = gr.Markdown("")
134
+
135
+ # Cluster selection
136
+ cluster_dropdown = gr.Dropdown(
137
+ ["(No data loaded)"],
138
+ label="Select Cluster",
139
+ value="(No data loaded)",
140
+ interactive=False
141
+ )
142
+
143
+ # Main content area
144
+ with gr.Column(scale=3):
145
+ prompts_html = gr.HTML("Upload a CSV file to browse clusters")
146
+
147
+ # Connect event handlers
148
+ file_upload.change(
149
+ on_file_upload,
150
+ [file_upload],
151
+ [df_state, status_md, cluster_dropdown, prompts_html, stats_md]
152
+ )
153
+
154
+ cluster_dropdown.change(
155
+ on_cluster_change,
156
+ [df_state, cluster_dropdown],
157
+ [prompts_html]
158
+ )
159
+
160
+ return demo
161
+
162
+ def launch_cluster_browser():
163
+ """
164
+ Launch the cluster browser app.
165
+ """
166
+ app = create_cluster_browser_app()
167
+ app.launch()
168
+
169
+ if __name__ == "__main__":
170
+ launch_cluster_browser()