Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -288,6 +288,58 @@ def create_example_outputs(template):
|
|
| 288 |
return outputs
|
| 289 |
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
@st.cache_data
|
| 292 |
def parse_documents(uploaded_files):
|
| 293 |
"""Parse multiple document files and extract their text content."""
|
|
@@ -2381,6 +2433,7 @@ with tab3:
|
|
| 2381 |
if var["type"] == "categorical" and var.get("options")
|
| 2382 |
]
|
| 2383 |
|
|
|
|
| 2384 |
if categorical_vars:
|
| 2385 |
st.subheader("Categorical Variable Options")
|
| 2386 |
st.info(
|
|
@@ -2406,7 +2459,30 @@ with tab3:
|
|
| 2406 |
|
| 2407 |
# Initialize selected_options if not present
|
| 2408 |
if "selected_options" not in var:
|
|
|
|
| 2409 |
var["selected_options"] = options.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2410 |
|
| 2411 |
# Add "Select All" and "Clear All" buttons
|
| 2412 |
col1, col2 = st.columns([1, 1])
|
|
@@ -2426,7 +2502,9 @@ with tab3:
|
|
| 2426 |
var["selected_options"] = st.multiselect(
|
| 2427 |
f"Select options to include for {var['name']}",
|
| 2428 |
options=options,
|
| 2429 |
-
default=var.get(
|
|
|
|
|
|
|
| 2430 |
key=f"options_select_{i}",
|
| 2431 |
)
|
| 2432 |
|
|
@@ -2435,11 +2513,33 @@ with tab3:
|
|
| 2435 |
f"Selected {len(var['selected_options'])} out of {len(options)} options"
|
| 2436 |
)
|
| 2437 |
|
| 2438 |
-
|
| 2439 |
-
|
| 2440 |
-
|
| 2441 |
-
|
| 2442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2443 |
|
| 2444 |
# Generate inputs button
|
| 2445 |
if st.button("Generate Synthetic Inputs"):
|
|
@@ -2585,6 +2685,56 @@ with tab3:
|
|
| 2585 |
"Filled Prompt", value=filled_prompt, height=300, disabled=True
|
| 2586 |
)
|
| 2587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2588 |
# Generate outputs button
|
| 2589 |
if st.button("Generate Outputs for Selected Samples"):
|
| 2590 |
if not st.session_state.get("api_key"):
|
|
@@ -2604,6 +2754,31 @@ with tab3:
|
|
| 2604 |
for i in st.session_state.selected_samples
|
| 2605 |
]
|
| 2606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2607 |
with st.spinner(
|
| 2608 |
f"Generating outputs for {len(selected_inputs)} samples..."
|
| 2609 |
):
|
|
@@ -2618,34 +2793,36 @@ with tab3:
|
|
| 2618 |
if selection_method == "Generate for all samples":
|
| 2619 |
st.session_state.combined_data = generated_outputs
|
| 2620 |
else:
|
| 2621 |
-
#
|
| 2622 |
-
|
| 2623 |
-
|
| 2624 |
-
st.session_state.combined_data
|
| 2625 |
-
|
| 2626 |
-
|
| 2627 |
-
|
| 2628 |
-
|
| 2629 |
-
|
| 2630 |
-
|
| 2631 |
-
|
| 2632 |
-
st.session_state.selected_samples
|
| 2633 |
-
):
|
| 2634 |
-
if i < len(generated_outputs):
|
| 2635 |
-
st.session_state.combined_data[output_idx] = (
|
| 2636 |
-
generated_outputs[i]
|
| 2637 |
)
|
| 2638 |
|
| 2639 |
-
|
| 2640 |
-
|
| 2641 |
-
|
| 2642 |
-
|
| 2643 |
-
|
| 2644 |
-
|
|
|
|
|
|
|
| 2645 |
|
| 2646 |
-
|
| 2647 |
-
|
| 2648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2649 |
|
| 2650 |
# Display combined data if available
|
| 2651 |
if st.session_state.combined_data:
|
|
@@ -2747,4 +2924,4 @@ with tab3:
|
|
| 2747 |
else:
|
| 2748 |
st.info(
|
| 2749 |
"No template has been generated yet. Go to the 'Setup' tab to create one."
|
| 2750 |
-
)
|
|
|
|
| 288 |
return outputs
|
| 289 |
|
| 290 |
|
| 291 |
+
# Add this function after generate_categorical_permutations function
|
| 292 |
+
def calculate_cartesian_product_size(categorical_vars):
|
| 293 |
+
"""Calculate the size of the Cartesian product based on selected options."""
|
| 294 |
+
if not categorical_vars:
|
| 295 |
+
return 0
|
| 296 |
+
|
| 297 |
+
# Calculate the product size
|
| 298 |
+
product_size = 1
|
| 299 |
+
var_counts = []
|
| 300 |
+
|
| 301 |
+
for var in categorical_vars:
|
| 302 |
+
options = var.get("options", [])
|
| 303 |
+
selected_options = var.get("selected_options", options)
|
| 304 |
+
min_sel = var.get("min", 1)
|
| 305 |
+
max_sel = var.get("max", 1)
|
| 306 |
+
|
| 307 |
+
# Use only selected options for calculation
|
| 308 |
+
options_to_use = [opt for opt in options if opt in selected_options]
|
| 309 |
+
|
| 310 |
+
# If no options selected, use all options
|
| 311 |
+
if not options_to_use:
|
| 312 |
+
options_to_use = options
|
| 313 |
+
|
| 314 |
+
# Single selection case
|
| 315 |
+
if min_sel == 1 and max_sel == 1:
|
| 316 |
+
count = len(options_to_use)
|
| 317 |
+
else:
|
| 318 |
+
# Multi-selection case - calculate combinations
|
| 319 |
+
count = 0
|
| 320 |
+
# Include min selections
|
| 321 |
+
from math import comb
|
| 322 |
+
|
| 323 |
+
if len(options_to_use) >= min_sel:
|
| 324 |
+
count += comb(len(options_to_use), min_sel)
|
| 325 |
+
|
| 326 |
+
# Include max selections if different from min
|
| 327 |
+
if max_sel != min_sel and len(options_to_use) >= max_sel:
|
| 328 |
+
count += comb(len(options_to_use), max_sel)
|
| 329 |
+
|
| 330 |
+
# Include some intermediate selections if applicable
|
| 331 |
+
for size in range(min_sel + 1, max_sel):
|
| 332 |
+
if len(options_to_use) >= size:
|
| 333 |
+
count += min(
|
| 334 |
+
3, comb(len(options_to_use), size)
|
| 335 |
+
) # Take up to 3 samples
|
| 336 |
+
|
| 337 |
+
var_counts.append({"name": var["name"], "count": count})
|
| 338 |
+
product_size *= max(count, 1) # Avoid multiplying by zero
|
| 339 |
+
|
| 340 |
+
return product_size, var_counts
|
| 341 |
+
|
| 342 |
+
|
| 343 |
@st.cache_data
|
| 344 |
def parse_documents(uploaded_files):
|
| 345 |
"""Parse multiple document files and extract their text content."""
|
|
|
|
| 2433 |
if var["type"] == "categorical" and var.get("options")
|
| 2434 |
]
|
| 2435 |
|
| 2436 |
+
# In tab3, modify the categorical variable options section
|
| 2437 |
if categorical_vars:
|
| 2438 |
st.subheader("Categorical Variable Options")
|
| 2439 |
st.info(
|
|
|
|
| 2459 |
|
| 2460 |
# Initialize selected_options if not present
|
| 2461 |
if "selected_options" not in var:
|
| 2462 |
+
# First time initialization
|
| 2463 |
var["selected_options"] = options.copy()
|
| 2464 |
+
else:
|
| 2465 |
+
# Filter selected_options to only include valid options
|
| 2466 |
+
var["selected_options"] = [
|
| 2467 |
+
opt
|
| 2468 |
+
for opt in var.get("selected_options", [])
|
| 2469 |
+
if opt in options
|
| 2470 |
+
]
|
| 2471 |
+
|
| 2472 |
+
# Check for new options that need to be automatically selected
|
| 2473 |
+
previous_options = var.get("previous_options", [])
|
| 2474 |
+
|
| 2475 |
+
# Find new options that weren't in the previous options list
|
| 2476 |
+
new_options = [
|
| 2477 |
+
opt for opt in options if opt not in previous_options
|
| 2478 |
+
]
|
| 2479 |
+
|
| 2480 |
+
# Add new options to selected_options
|
| 2481 |
+
if new_options:
|
| 2482 |
+
var["selected_options"].extend(new_options)
|
| 2483 |
+
|
| 2484 |
+
# Store current options for future comparison
|
| 2485 |
+
var["previous_options"] = options.copy()
|
| 2486 |
|
| 2487 |
# Add "Select All" and "Clear All" buttons
|
| 2488 |
col1, col2 = st.columns([1, 1])
|
|
|
|
| 2502 |
var["selected_options"] = st.multiselect(
|
| 2503 |
f"Select options to include for {var['name']}",
|
| 2504 |
options=options,
|
| 2505 |
+
default=var.get(
|
| 2506 |
+
"selected_options", []
|
| 2507 |
+
), # Use empty list as fallback
|
| 2508 |
key=f"options_select_{i}",
|
| 2509 |
)
|
| 2510 |
|
|
|
|
| 2513 |
f"Selected {len(var['selected_options'])} out of {len(options)} options"
|
| 2514 |
)
|
| 2515 |
|
| 2516 |
+
# Update the template spec with the selected options
|
| 2517 |
+
for j, input_var in enumerate(template_spec_copy["input"]):
|
| 2518 |
+
if input_var["name"] == var["name"]:
|
| 2519 |
+
template_spec_copy["input"][j] = var
|
| 2520 |
+
break
|
| 2521 |
+
|
| 2522 |
+
# Calculate and display Cartesian product size
|
| 2523 |
+
product_size, var_counts = calculate_cartesian_product_size(
|
| 2524 |
+
[v for v in template_spec_copy["input"] if v["type"] == "categorical"]
|
| 2525 |
+
)
|
| 2526 |
+
|
| 2527 |
+
st.subheader("Combination Analysis")
|
| 2528 |
+
st.info(f"Total number of possible combinations: {product_size:,}")
|
| 2529 |
+
|
| 2530 |
+
# Display breakdown of combinations
|
| 2531 |
+
st.write("Breakdown by variable:")
|
| 2532 |
+
for var in var_counts:
|
| 2533 |
+
st.write(f"- {var['name']}: {var['count']:,} possible values")
|
| 2534 |
+
|
| 2535 |
+
if product_size > num_samples:
|
| 2536 |
+
st.warning(
|
| 2537 |
+
f"Note: Only {num_samples} samples will be generated from the {product_size:,} possible combinations"
|
| 2538 |
+
)
|
| 2539 |
+
elif product_size < num_samples:
|
| 2540 |
+
st.warning(
|
| 2541 |
+
f"Note: Some combinations will be repeated to reach {num_samples} samples (only {product_size:,} unique combinations possible)"
|
| 2542 |
+
)
|
| 2543 |
|
| 2544 |
# Generate inputs button
|
| 2545 |
if st.button("Generate Synthetic Inputs"):
|
|
|
|
| 2685 |
"Filled Prompt", value=filled_prompt, height=300, disabled=True
|
| 2686 |
)
|
| 2687 |
|
| 2688 |
+
# Advanced output generation options
|
| 2689 |
+
with st.expander("Advanced Output Generation Options", expanded=False):
|
| 2690 |
+
st.info("Configure options for generating multiple outputs per input")
|
| 2691 |
+
|
| 2692 |
+
# Option to generate multiple outputs for some inputs
|
| 2693 |
+
enable_multiple_outputs = st.checkbox(
|
| 2694 |
+
"Generate multiple outputs for some inputs",
|
| 2695 |
+
help="Enable generating multiple variations of outputs for selected inputs",
|
| 2696 |
+
)
|
| 2697 |
+
|
| 2698 |
+
if enable_multiple_outputs:
|
| 2699 |
+
# Proportion of inputs to duplicate
|
| 2700 |
+
duplicate_proportion = st.slider(
|
| 2701 |
+
"Proportion of inputs to generate multiple outputs for",
|
| 2702 |
+
min_value=0.0,
|
| 2703 |
+
max_value=1.0,
|
| 2704 |
+
value=0.2,
|
| 2705 |
+
step=0.1,
|
| 2706 |
+
help="What fraction of the input samples should have multiple outputs",
|
| 2707 |
+
)
|
| 2708 |
+
|
| 2709 |
+
# Number of outputs per duplicated input
|
| 2710 |
+
outputs_per_input = st.number_input(
|
| 2711 |
+
"Number of outputs per selected input",
|
| 2712 |
+
min_value=2,
|
| 2713 |
+
max_value=5,
|
| 2714 |
+
value=2,
|
| 2715 |
+
help="How many different outputs to generate for each selected input",
|
| 2716 |
+
)
|
| 2717 |
+
|
| 2718 |
+
# Preview the effect
|
| 2719 |
+
if st.session_state.selected_samples:
|
| 2720 |
+
num_selected = len(st.session_state.selected_samples)
|
| 2721 |
+
num_to_duplicate = math.ceil(
|
| 2722 |
+
num_selected * duplicate_proportion
|
| 2723 |
+
)
|
| 2724 |
+
total_outputs = (num_selected - num_to_duplicate) + (
|
| 2725 |
+
num_to_duplicate * outputs_per_input
|
| 2726 |
+
)
|
| 2727 |
+
|
| 2728 |
+
st.write(
|
| 2729 |
+
f"This will result in approximately {total_outputs} total outputs:"
|
| 2730 |
+
)
|
| 2731 |
+
st.write(
|
| 2732 |
+
f"- {num_selected - num_to_duplicate} inputs with 1 output"
|
| 2733 |
+
)
|
| 2734 |
+
st.write(
|
| 2735 |
+
f"- {num_to_duplicate} inputs with {outputs_per_input} outputs each"
|
| 2736 |
+
)
|
| 2737 |
+
|
| 2738 |
# Generate outputs button
|
| 2739 |
if st.button("Generate Outputs for Selected Samples"):
|
| 2740 |
if not st.session_state.get("api_key"):
|
|
|
|
| 2754 |
for i in st.session_state.selected_samples
|
| 2755 |
]
|
| 2756 |
|
| 2757 |
+
# Handle multiple outputs if enabled
|
| 2758 |
+
if enable_multiple_outputs:
|
| 2759 |
+
# Calculate how many inputs should have multiple outputs
|
| 2760 |
+
num_to_duplicate = math.ceil(
|
| 2761 |
+
len(selected_inputs) * duplicate_proportion
|
| 2762 |
+
)
|
| 2763 |
+
|
| 2764 |
+
# Randomly select inputs for multiple outputs
|
| 2765 |
+
duplicate_indices = random.sample(
|
| 2766 |
+
range(len(selected_inputs)), num_to_duplicate
|
| 2767 |
+
)
|
| 2768 |
+
|
| 2769 |
+
# Create the expanded input list
|
| 2770 |
+
expanded_inputs = []
|
| 2771 |
+
for i, input_data in enumerate(selected_inputs):
|
| 2772 |
+
if i in duplicate_indices:
|
| 2773 |
+
# Add multiple copies for selected inputs
|
| 2774 |
+
expanded_inputs.extend([input_data] * outputs_per_input)
|
| 2775 |
+
else:
|
| 2776 |
+
# Add single copy for other inputs
|
| 2777 |
+
expanded_inputs.append(input_data)
|
| 2778 |
+
|
| 2779 |
+
# Update selected_inputs with the expanded list
|
| 2780 |
+
selected_inputs = expanded_inputs
|
| 2781 |
+
|
| 2782 |
with st.spinner(
|
| 2783 |
f"Generating outputs for {len(selected_inputs)} samples..."
|
| 2784 |
):
|
|
|
|
| 2793 |
if selection_method == "Generate for all samples":
|
| 2794 |
st.session_state.combined_data = generated_outputs
|
| 2795 |
else:
|
| 2796 |
+
# For specific samples, we need to handle the case of multiple outputs
|
| 2797 |
+
if enable_multiple_outputs:
|
| 2798 |
+
# Simply use all generated outputs as the combined data
|
| 2799 |
+
st.session_state.combined_data = generated_outputs
|
| 2800 |
+
else:
|
| 2801 |
+
# Handle single outputs as before
|
| 2802 |
+
if not st.session_state.combined_data or len(
|
| 2803 |
+
st.session_state.combined_data
|
| 2804 |
+
) != len(st.session_state.synthetic_inputs):
|
| 2805 |
+
st.session_state.combined_data = [None] * len(
|
| 2806 |
+
st.session_state.synthetic_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2807 |
)
|
| 2808 |
|
| 2809 |
+
# Update only the selected samples
|
| 2810 |
+
for i, output_idx in enumerate(
|
| 2811 |
+
st.session_state.selected_samples
|
| 2812 |
+
):
|
| 2813 |
+
if i < len(generated_outputs):
|
| 2814 |
+
st.session_state.combined_data[output_idx] = (
|
| 2815 |
+
generated_outputs[i]
|
| 2816 |
+
)
|
| 2817 |
|
| 2818 |
+
# Remove any None values (samples that haven't been generated yet)
|
| 2819 |
+
st.session_state.combined_data = [
|
| 2820 |
+
item
|
| 2821 |
+
for item in st.session_state.combined_data
|
| 2822 |
+
if item is not None
|
| 2823 |
+
]
|
| 2824 |
+
|
| 2825 |
+
st.success(f"Generated {len(generated_outputs)} outputs")
|
| 2826 |
|
| 2827 |
# Display combined data if available
|
| 2828 |
if st.session_state.combined_data:
|
|
|
|
| 2924 |
else:
|
| 2925 |
st.info(
|
| 2926 |
"No template has been generated yet. Go to the 'Setup' tab to create one."
|
| 2927 |
+
)
|