Spaces:
Running
Running
Update app.py
Browse fileschanged strategy descriptor to download
app.py
CHANGED
|
@@ -214,26 +214,58 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
| 214 |
|
| 215 |
print(f"CSV file '{filename}' created successfully.")
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
# Filter the data
|
| 223 |
-
filtered_data = selected_test_info.iloc[matching_indices]
|
| 224 |
-
# new data contains etalon instead of 0/1 for ER/ME
|
| 225 |
-
filtered_data = filtered_data[filtered_data[8] == task_type] # Ensure test_info[6] matches
|
| 226 |
-
|
| 227 |
-
# Define filename dynamically
|
| 228 |
-
task_type_map = {0: "ER", 1: "ME"}
|
| 229 |
-
label_map = {0: "unsuccessful", 1: "successful"}
|
| 230 |
-
|
| 231 |
-
filename = f"{task_type_map[task_type]}-{label_map[label]}-strategies.csv"
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
# Write to CSV
|
| 235 |
-
process_and_write_csv(filtered_data, filename)
|
| 236 |
-
|
| 237 |
with open("fileHandler/roc_data2.pkl", 'rb') as file:
|
| 238 |
data = pickle.load(file)
|
| 239 |
t_label=data[0]
|
|
@@ -539,7 +571,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
| 539 |
textinfo='percent+label',
|
| 540 |
textposition='auto',
|
| 541 |
marker=dict(colors=colors),
|
| 542 |
-
sort=False
|
|
|
|
| 543 |
|
| 544 |
)])
|
| 545 |
|
|
@@ -577,7 +610,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
| 577 |
textinfo='percent+label',
|
| 578 |
textposition='auto',
|
| 579 |
marker=dict(colors=colors),
|
| 580 |
-
sort=False
|
|
|
|
| 581 |
# pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
|
| 582 |
|
| 583 |
)])
|
|
@@ -1142,31 +1176,82 @@ button, select, .slider-percentage {
|
|
| 1142 |
margin-bottom: 1rem !important;
|
| 1143 |
text-align: center !important;
|
| 1144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1146 |
}
|
| 1147 |
|
| 1148 |
|
| 1149 |
'''
|
|
|
|
| 1150 |
# Define the file directory
|
| 1151 |
FILE_DIR = "fileHandler"
|
| 1152 |
|
| 1153 |
# Function to get list of files
|
| 1154 |
def list_files():
|
| 1155 |
-
return ['Unsuccessful Strategies (ER)', 'Successful Strategies (ER)', 'Unsuccessful Strategies (ME)', 'Successful Strategies (ME)']
|
| 1156 |
label_to_filename = {
|
| 1157 |
-
|
| 1158 |
-
'Successful Strategies (ER)': 'ER-successful
|
| 1159 |
-
'Unsuccessful Strategies (
|
| 1160 |
-
'Successful Strategies (ME)': 'ME-successful
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
}
|
|
|
|
| 1162 |
# Function to provide the selected file path
|
| 1163 |
-
def
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1170 |
|
| 1171 |
|
| 1172 |
with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
|
@@ -1205,31 +1290,53 @@ with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
|
| 1205 |
opt2_pie = gr.Plot(label="ME")
|
| 1206 |
|
| 1207 |
with gr.Row():
|
| 1208 |
-
gr.
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
| 1215 |
-
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
|
| 1220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1221 |
|
| 1222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1223 |
|
| 1224 |
btn.click(
|
| 1225 |
fn=lambda model, increment: (
|
| 1226 |
*process_file(model, increment), # Unpack all outputs from process_file
|
| 1227 |
-
gr.update(value=None),
|
|
|
|
| 1228 |
None, # Clear file output
|
| 1229 |
gr.update(visible=False) # Hide visualize markdown
|
| 1230 |
),
|
| 1231 |
inputs=[model_dropdown, increment_slider],
|
| 1232 |
-
outputs=[output_text, plot_output, opt1_pie, opt2_pie,
|
| 1233 |
)
|
| 1234 |
|
| 1235 |
|
|
|
|
| 214 |
|
| 215 |
print(f"CSV file '{filename}' created successfully.")
|
| 216 |
|
| 217 |
+
task_type_map = {0: "ER", 1: "ME"}
|
| 218 |
+
label_map = {0: "unsuccessful", 1: "successful"}
|
| 219 |
+
|
| 220 |
+
# -------------------------------
|
| 221 |
+
# 1. Where tlb == plb
|
| 222 |
+
# -------------------------------
|
| 223 |
+
for label in [0, 1]:
|
| 224 |
+
# All strategies
|
| 225 |
+
matching_indices = [i for i in range(len(tlb)) if tlb[i] == plb[i] == label]
|
| 226 |
+
filtered_data = selected_test_info.iloc[matching_indices]
|
| 227 |
+
filename = f"allstrategies-match-{label_map[label]}.csv"
|
| 228 |
+
process_and_write_csv(filtered_data, filename)
|
| 229 |
+
|
| 230 |
+
# Per task type
|
| 231 |
+
for task_type in [0, 1]:
|
| 232 |
+
task_data = filtered_data[filtered_data[8] == task_type]
|
| 233 |
+
filename = f"{task_type_map[task_type]}-match-{label_map[label]}.csv"
|
| 234 |
+
process_and_write_csv(task_data, filename)
|
| 235 |
+
|
| 236 |
+
# -------------------------------
|
| 237 |
+
# 2. Where tlb only
|
| 238 |
+
# -------------------------------
|
| 239 |
+
for label in [0, 1]:
|
| 240 |
+
# All strategies
|
| 241 |
+
matching_indices = [i for i in range(len(tlb)) if tlb[i] == label]
|
| 242 |
+
filtered_data = selected_test_info.iloc[matching_indices]
|
| 243 |
+
filename = f"allstrategies-groundtruth-{label_map[label]}.csv"
|
| 244 |
+
process_and_write_csv(filtered_data, filename)
|
| 245 |
+
|
| 246 |
+
# Per task type
|
| 247 |
+
for task_type in [0, 1]:
|
| 248 |
+
task_data = filtered_data[filtered_data[8] == task_type]
|
| 249 |
+
filename = f"{task_type_map[task_type]}-groundtruth-{label_map[label]}.csv"
|
| 250 |
+
process_and_write_csv(task_data, filename)
|
| 251 |
+
|
| 252 |
+
# -------------------------------
|
| 253 |
+
# 3. All data by task type (no label filtering)
|
| 254 |
+
# -------------------------------
|
| 255 |
+
# ER
|
| 256 |
+
task_data = selected_test_info[selected_test_info[8] == 0]
|
| 257 |
+
filename = f"ER-all.csv"
|
| 258 |
+
process_and_write_csv(task_data, filename)
|
| 259 |
+
|
| 260 |
+
# ME
|
| 261 |
+
task_data = selected_test_info[selected_test_info[8] == 1]
|
| 262 |
+
filename = f"ME-all.csv"
|
| 263 |
+
process_and_write_csv(task_data, filename)
|
| 264 |
+
|
| 265 |
+
# All strategies
|
| 266 |
+
filename = "allstrategies-all.csv"
|
| 267 |
+
process_and_write_csv(selected_test_info, filename)
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
with open("fileHandler/roc_data2.pkl", 'rb') as file:
|
| 270 |
data = pickle.load(file)
|
| 271 |
t_label=data[0]
|
|
|
|
| 571 |
textinfo='percent+label',
|
| 572 |
textposition='auto',
|
| 573 |
marker=dict(colors=colors),
|
| 574 |
+
sort=False,
|
| 575 |
+
hole=0.4
|
| 576 |
|
| 577 |
)])
|
| 578 |
|
|
|
|
| 610 |
textinfo='percent+label',
|
| 611 |
textposition='auto',
|
| 612 |
marker=dict(colors=colors),
|
| 613 |
+
sort=False,
|
| 614 |
+
hole=0.4
|
| 615 |
# pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
|
| 616 |
|
| 617 |
)])
|
|
|
|
| 1176 |
margin-bottom: 1rem !important;
|
| 1177 |
text-align: center !important;
|
| 1178 |
|
| 1179 |
+
#file-box {
|
| 1180 |
+
border: 1px solid #ccc;
|
| 1181 |
+
border-radius: 6px;
|
| 1182 |
+
padding: 10px;
|
| 1183 |
+
margin-top: 12px;
|
| 1184 |
+
background-color: #f9f9f9;
|
| 1185 |
+
}
|
| 1186 |
|
| 1187 |
+
.file-download {
|
| 1188 |
+
margin-bottom: 5px !important;
|
| 1189 |
+
padding: 4px !important;
|
| 1190 |
+
height: 10px;
|
| 1191 |
+
}
|
| 1192 |
}
|
| 1193 |
|
| 1194 |
|
| 1195 |
'''
|
| 1196 |
+
|
| 1197 |
# Define the file directory
|
| 1198 |
FILE_DIR = "fileHandler"
|
| 1199 |
|
| 1200 |
# Function to get list of files
|
| 1201 |
def list_files():
|
| 1202 |
+
return ['Unsuccessful Strategies (ER)', 'Successful Strategies (ER)', 'Unsuccessful Strategies (ME)', 'Successful Strategies (ME)','Ground Truth Unsuccessful Strategies (ER)','Ground Truth Successful Strategies (ER)','Ground Truth Unsuccessful Strategies (ME)','Ground Truth Successful Strategies (ME)']
|
| 1203 |
label_to_filename = {
|
| 1204 |
+
# Predicted (tlb == plb)
|
| 1205 |
+
'Predicted Successful Strategies (ER)': 'ER-match-successful.csv',
|
| 1206 |
+
'Predicted Unsuccessful Strategies (ER)': 'ER-match-unsuccessful.csv',
|
| 1207 |
+
'Predicted Successful Strategies (ME)': 'ME-match-successful.csv',
|
| 1208 |
+
'Predicted Unsuccessful Strategies (ME)': 'ME-match-unsuccessful.csv',
|
| 1209 |
+
'Predicted Successful Strategies (All)': 'allstrategies-match-successful.csv',
|
| 1210 |
+
'Predicted Unsuccessful Strategies (All)': 'allstrategies-match-unsuccessful.csv',
|
| 1211 |
+
|
| 1212 |
+
# Ground Truth (tlb only)
|
| 1213 |
+
'Ground Truth Successful Strategies (ER)': 'ER-groundtruth-successful.csv',
|
| 1214 |
+
'Ground Truth Unsuccessful Strategies (ER)': 'ER-groundtruth-unsuccessful.csv',
|
| 1215 |
+
'Ground Truth Successful Strategies (ME)': 'ME-groundtruth-successful.csv',
|
| 1216 |
+
'Ground Truth Unsuccessful Strategies (ME)': 'ME-groundtruth-unsuccessful.csv',
|
| 1217 |
+
'Ground Truth Successful Strategies (All)': 'allstrategies-groundtruth-successful.csv',
|
| 1218 |
+
'Ground Truth Unsuccessful Strategies (All)': 'allstrategies-groundtruth-unsuccessful.csv',
|
| 1219 |
+
|
| 1220 |
+
# All data
|
| 1221 |
+
'All Strategies (ER)': 'ER-all.csv',
|
| 1222 |
+
'All Strategies (ME)': 'ME-all.csv',
|
| 1223 |
+
'All Strategies (All)': 'allstrategies-all.csv'
|
| 1224 |
}
|
| 1225 |
+
|
| 1226 |
# Function to provide the selected file path
|
| 1227 |
+
def provide_file_paths(task_type, source):
|
| 1228 |
+
if not task_type or not source:
|
| 1229 |
+
return None, None, gr.update(visible=False)
|
| 1230 |
+
|
| 1231 |
+
# Handle "All" case for combined strategies
|
| 1232 |
+
if source == "All":
|
| 1233 |
+
label_success = f"All Strategies ({task_type})"
|
| 1234 |
+
label_unsuccess = f"All Strategies ({task_type})"
|
| 1235 |
+
else:
|
| 1236 |
+
label_success = f"{source} Successful Strategies ({task_type})"
|
| 1237 |
+
label_unsuccess = f"{source} Unsuccessful Strategies ({task_type})"
|
| 1238 |
+
label_all=f"All Strategies ({task_type})"
|
| 1239 |
+
|
| 1240 |
+
file_success = label_to_filename.get(label_success)
|
| 1241 |
+
file_unsuccess = label_to_filename.get(label_unsuccess)
|
| 1242 |
+
file_all=label_to_filename.get(label_all)
|
| 1243 |
+
|
| 1244 |
+
file_success_path = f"{FILE_DIR}/{file_success}" if file_success else None
|
| 1245 |
+
file_unsuccess_path = f"{FILE_DIR}/{file_unsuccess}" if file_unsuccess else None
|
| 1246 |
+
file_all_path = f"{FILE_DIR}/{file_all}" if file_all else None
|
| 1247 |
+
|
| 1248 |
+
dynamic_text = "🔍 [Visualize the strategies](https://path-analysis.vercel.app/)"
|
| 1249 |
+
if file_success and file_unsuccess and file_all:
|
| 1250 |
+
return file_success_path, file_unsuccess_path,file_all_path, gr.update(value=dynamic_text, visible=True)
|
| 1251 |
+
|
| 1252 |
+
return None, None,None, gr.update(visible=False)
|
| 1253 |
+
|
| 1254 |
+
|
| 1255 |
|
| 1256 |
|
| 1257 |
with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
|
|
|
| 1290 |
opt2_pie = gr.Plot(label="ME")
|
| 1291 |
|
| 1292 |
with gr.Row():
|
| 1293 |
+
with gr.Column():
|
| 1294 |
+
# gr.Markdown("Select strategy filters and click Generate")
|
| 1295 |
+
task_type_radio = gr.Dropdown(
|
| 1296 |
+
choices=["ER", "ME", "All"],
|
| 1297 |
+
label="Filter by Problem Type",
|
| 1298 |
+
interactive=True
|
| 1299 |
+
)
|
| 1300 |
+
source_radio = gr.Checkbox(
|
| 1301 |
+
label="Predicted Labels",
|
| 1302 |
+
value=True
|
| 1303 |
+
)
|
| 1304 |
+
generate_button = gr.Button("Generate Strategies")
|
| 1305 |
+
|
| 1306 |
+
# with gr.Row():
|
| 1307 |
+
with gr.Column():
|
| 1308 |
+
with gr.Group(visible=False) as file_output_group:
|
| 1309 |
+
|
| 1310 |
+
gr.Markdown("**Download strategy descriptor files**")
|
| 1311 |
+
file_output_success = gr.File(label=" ")
|
| 1312 |
+
file_output_unsuccess = gr.File(label=" ")
|
| 1313 |
+
file_output_all = gr.File(label=" ")
|
| 1314 |
+
visualize_markdown = gr.Markdown(visible=False)
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
def handle_generate(task_type_dropdown, use_predicted):
|
| 1318 |
+
label_source = "Predicted" if use_predicted else "Ground Truth"
|
| 1319 |
+
file_success_path, file_unsuccess_path,file_all_path, viz_link = provide_file_paths(task_type_dropdown, label_source)
|
| 1320 |
+
|
| 1321 |
+
return file_success_path, file_unsuccess_path,file_all_path, viz_link,gr.update(visible=True)
|
| 1322 |
|
| 1323 |
|
| 1324 |
+
generate_button.click(
|
| 1325 |
+
fn=handle_generate,
|
| 1326 |
+
inputs=[task_type_radio, source_radio],
|
| 1327 |
+
outputs=[file_output_success, file_output_unsuccess,file_output_all, visualize_markdown,file_output_group]
|
| 1328 |
+
)
|
| 1329 |
|
| 1330 |
btn.click(
|
| 1331 |
fn=lambda model, increment: (
|
| 1332 |
*process_file(model, increment), # Unpack all outputs from process_file
|
| 1333 |
+
gr.update(value=None), # update outcome_radio
|
| 1334 |
+
gr.update(value=None), # Reset dropdown to first item
|
| 1335 |
None, # Clear file output
|
| 1336 |
gr.update(visible=False) # Hide visualize markdown
|
| 1337 |
),
|
| 1338 |
inputs=[model_dropdown, increment_slider],
|
| 1339 |
+
outputs=[output_text, plot_output, opt1_pie, opt2_pie, task_type_radio,source_radio,file_output_success,file_output_unsuccess, visualize_markdown]
|
| 1340 |
)
|
| 1341 |
|
| 1342 |
|