DJHumanRPT commited on
Commit
4e3b069
·
verified ·
1 Parent(s): 2fa84b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -32
app.py CHANGED
@@ -288,6 +288,58 @@ def create_example_outputs(template):
288
  return outputs
289
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  @st.cache_data
292
  def parse_documents(uploaded_files):
293
  """Parse multiple document files and extract their text content."""
@@ -2381,6 +2433,7 @@ with tab3:
2381
  if var["type"] == "categorical" and var.get("options")
2382
  ]
2383
 
 
2384
  if categorical_vars:
2385
  st.subheader("Categorical Variable Options")
2386
  st.info(
@@ -2406,7 +2459,30 @@ with tab3:
2406
 
2407
  # Initialize selected_options if not present
2408
  if "selected_options" not in var:
 
2409
  var["selected_options"] = options.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2410
 
2411
  # Add "Select All" and "Clear All" buttons
2412
  col1, col2 = st.columns([1, 1])
@@ -2426,7 +2502,9 @@ with tab3:
2426
  var["selected_options"] = st.multiselect(
2427
  f"Select options to include for {var['name']}",
2428
  options=options,
2429
- default=var.get("selected_options", options),
 
 
2430
  key=f"options_select_{i}",
2431
  )
2432
 
@@ -2435,11 +2513,33 @@ with tab3:
2435
  f"Selected {len(var['selected_options'])} out of {len(options)} options"
2436
  )
2437
 
2438
- # Update the template spec with the selected options
2439
- for j, input_var in enumerate(template_spec_copy["input"]):
2440
- if input_var["name"] == var["name"]:
2441
- template_spec_copy["input"][j] = var
2442
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2443
 
2444
  # Generate inputs button
2445
  if st.button("Generate Synthetic Inputs"):
@@ -2585,6 +2685,56 @@ with tab3:
2585
  "Filled Prompt", value=filled_prompt, height=300, disabled=True
2586
  )
2587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2588
  # Generate outputs button
2589
  if st.button("Generate Outputs for Selected Samples"):
2590
  if not st.session_state.get("api_key"):
@@ -2604,6 +2754,31 @@ with tab3:
2604
  for i in st.session_state.selected_samples
2605
  ]
2606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2607
  with st.spinner(
2608
  f"Generating outputs for {len(selected_inputs)} samples..."
2609
  ):
@@ -2618,34 +2793,36 @@ with tab3:
2618
  if selection_method == "Generate for all samples":
2619
  st.session_state.combined_data = generated_outputs
2620
  else:
2621
- # If we're generating for specific samples, update only those samples
2622
- # First, ensure combined_data exists and has the right size
2623
- if not st.session_state.combined_data or len(
2624
- st.session_state.combined_data
2625
- ) != len(st.session_state.synthetic_inputs):
2626
- st.session_state.combined_data = [None] * len(
2627
- st.session_state.synthetic_inputs
2628
- )
2629
-
2630
- # Update only the selected samples
2631
- for i, output_idx in enumerate(
2632
- st.session_state.selected_samples
2633
- ):
2634
- if i < len(generated_outputs):
2635
- st.session_state.combined_data[output_idx] = (
2636
- generated_outputs[i]
2637
  )
2638
 
2639
- # Remove any None values (samples that haven't been generated yet)
2640
- st.session_state.combined_data = [
2641
- item
2642
- for item in st.session_state.combined_data
2643
- if item is not None
2644
- ]
 
 
2645
 
2646
- st.success(
2647
- f"Generated outputs for {len(generated_outputs)} samples"
2648
- )
 
 
 
 
 
2649
 
2650
  # Display combined data if available
2651
  if st.session_state.combined_data:
@@ -2747,4 +2924,4 @@ with tab3:
2747
  else:
2748
  st.info(
2749
  "No template has been generated yet. Go to the 'Setup' tab to create one."
2750
- )
 
288
  return outputs
289
 
290
 
291
+ # Add this function after generate_categorical_permutations function
292
+ def calculate_cartesian_product_size(categorical_vars):
293
+ """Calculate the size of the Cartesian product based on selected options."""
294
+ if not categorical_vars:
295
+ return 0
296
+
297
+ # Calculate the product size
298
+ product_size = 1
299
+ var_counts = []
300
+
301
+ for var in categorical_vars:
302
+ options = var.get("options", [])
303
+ selected_options = var.get("selected_options", options)
304
+ min_sel = var.get("min", 1)
305
+ max_sel = var.get("max", 1)
306
+
307
+ # Use only selected options for calculation
308
+ options_to_use = [opt for opt in options if opt in selected_options]
309
+
310
+ # If no options selected, use all options
311
+ if not options_to_use:
312
+ options_to_use = options
313
+
314
+ # Single selection case
315
+ if min_sel == 1 and max_sel == 1:
316
+ count = len(options_to_use)
317
+ else:
318
+ # Multi-selection case - calculate combinations
319
+ count = 0
320
+ # Include min selections
321
+ from math import comb
322
+
323
+ if len(options_to_use) >= min_sel:
324
+ count += comb(len(options_to_use), min_sel)
325
+
326
+ # Include max selections if different from min
327
+ if max_sel != min_sel and len(options_to_use) >= max_sel:
328
+ count += comb(len(options_to_use), max_sel)
329
+
330
+ # Include some intermediate selections if applicable
331
+ for size in range(min_sel + 1, max_sel):
332
+ if len(options_to_use) >= size:
333
+ count += min(
334
+ 3, comb(len(options_to_use), size)
335
+ ) # Take up to 3 samples
336
+
337
+ var_counts.append({"name": var["name"], "count": count})
338
+ product_size *= max(count, 1) # Avoid multiplying by zero
339
+
340
+ return product_size, var_counts
341
+
342
+
343
  @st.cache_data
344
  def parse_documents(uploaded_files):
345
  """Parse multiple document files and extract their text content."""
 
2433
  if var["type"] == "categorical" and var.get("options")
2434
  ]
2435
 
2436
+ # In tab3, modify the categorical variable options section
2437
  if categorical_vars:
2438
  st.subheader("Categorical Variable Options")
2439
  st.info(
 
2459
 
2460
  # Initialize selected_options if not present
2461
  if "selected_options" not in var:
2462
+ # First time initialization
2463
  var["selected_options"] = options.copy()
2464
+ else:
2465
+ # Filter selected_options to only include valid options
2466
+ var["selected_options"] = [
2467
+ opt
2468
+ for opt in var.get("selected_options", [])
2469
+ if opt in options
2470
+ ]
2471
+
2472
+ # Check for new options that need to be automatically selected
2473
+ previous_options = var.get("previous_options", [])
2474
+
2475
+ # Find new options that weren't in the previous options list
2476
+ new_options = [
2477
+ opt for opt in options if opt not in previous_options
2478
+ ]
2479
+
2480
+ # Add new options to selected_options
2481
+ if new_options:
2482
+ var["selected_options"].extend(new_options)
2483
+
2484
+ # Store current options for future comparison
2485
+ var["previous_options"] = options.copy()
2486
 
2487
  # Add "Select All" and "Clear All" buttons
2488
  col1, col2 = st.columns([1, 1])
 
2502
  var["selected_options"] = st.multiselect(
2503
  f"Select options to include for {var['name']}",
2504
  options=options,
2505
+ default=var.get(
2506
+ "selected_options", []
2507
+ ), # Use empty list as fallback
2508
  key=f"options_select_{i}",
2509
  )
2510
 
 
2513
  f"Selected {len(var['selected_options'])} out of {len(options)} options"
2514
  )
2515
 
2516
+ # Update the template spec with the selected options
2517
+ for j, input_var in enumerate(template_spec_copy["input"]):
2518
+ if input_var["name"] == var["name"]:
2519
+ template_spec_copy["input"][j] = var
2520
+ break
2521
+
2522
+ # Calculate and display Cartesian product size
2523
+ product_size, var_counts = calculate_cartesian_product_size(
2524
+ [v for v in template_spec_copy["input"] if v["type"] == "categorical"]
2525
+ )
2526
+
2527
+ st.subheader("Combination Analysis")
2528
+ st.info(f"Total number of possible combinations: {product_size:,}")
2529
+
2530
+ # Display breakdown of combinations
2531
+ st.write("Breakdown by variable:")
2532
+ for var in var_counts:
2533
+ st.write(f"- {var['name']}: {var['count']:,} possible values")
2534
+
2535
+ if product_size > num_samples:
2536
+ st.warning(
2537
+ f"Note: Only {num_samples} samples will be generated from the {product_size:,} possible combinations"
2538
+ )
2539
+ elif product_size < num_samples:
2540
+ st.warning(
2541
+ f"Note: Some combinations will be repeated to reach {num_samples} samples (only {product_size:,} unique combinations possible)"
2542
+ )
2543
 
2544
  # Generate inputs button
2545
  if st.button("Generate Synthetic Inputs"):
 
2685
  "Filled Prompt", value=filled_prompt, height=300, disabled=True
2686
  )
2687
 
2688
+ # Advanced output generation options
2689
+ with st.expander("Advanced Output Generation Options", expanded=False):
2690
+ st.info("Configure options for generating multiple outputs per input")
2691
+
2692
+ # Option to generate multiple outputs for some inputs
2693
+ enable_multiple_outputs = st.checkbox(
2694
+ "Generate multiple outputs for some inputs",
2695
+ help="Enable generating multiple variations of outputs for selected inputs",
2696
+ )
2697
+
2698
+ if enable_multiple_outputs:
2699
+ # Proportion of inputs to duplicate
2700
+ duplicate_proportion = st.slider(
2701
+ "Proportion of inputs to generate multiple outputs for",
2702
+ min_value=0.0,
2703
+ max_value=1.0,
2704
+ value=0.2,
2705
+ step=0.1,
2706
+ help="What fraction of the input samples should have multiple outputs",
2707
+ )
2708
+
2709
+ # Number of outputs per duplicated input
2710
+ outputs_per_input = st.number_input(
2711
+ "Number of outputs per selected input",
2712
+ min_value=2,
2713
+ max_value=5,
2714
+ value=2,
2715
+ help="How many different outputs to generate for each selected input",
2716
+ )
2717
+
2718
+ # Preview the effect
2719
+ if st.session_state.selected_samples:
2720
+ num_selected = len(st.session_state.selected_samples)
2721
+ num_to_duplicate = math.ceil(
2722
+ num_selected * duplicate_proportion
2723
+ )
2724
+ total_outputs = (num_selected - num_to_duplicate) + (
2725
+ num_to_duplicate * outputs_per_input
2726
+ )
2727
+
2728
+ st.write(
2729
+ f"This will result in approximately {total_outputs} total outputs:"
2730
+ )
2731
+ st.write(
2732
+ f"- {num_selected - num_to_duplicate} inputs with 1 output"
2733
+ )
2734
+ st.write(
2735
+ f"- {num_to_duplicate} inputs with {outputs_per_input} outputs each"
2736
+ )
2737
+
2738
  # Generate outputs button
2739
  if st.button("Generate Outputs for Selected Samples"):
2740
  if not st.session_state.get("api_key"):
 
2754
  for i in st.session_state.selected_samples
2755
  ]
2756
 
2757
+ # Handle multiple outputs if enabled
2758
+ if enable_multiple_outputs:
2759
+ # Calculate how many inputs should have multiple outputs
2760
+ num_to_duplicate = math.ceil(
2761
+ len(selected_inputs) * duplicate_proportion
2762
+ )
2763
+
2764
+ # Randomly select inputs for multiple outputs
2765
+ duplicate_indices = random.sample(
2766
+ range(len(selected_inputs)), num_to_duplicate
2767
+ )
2768
+
2769
+ # Create the expanded input list
2770
+ expanded_inputs = []
2771
+ for i, input_data in enumerate(selected_inputs):
2772
+ if i in duplicate_indices:
2773
+ # Add multiple copies for selected inputs
2774
+ expanded_inputs.extend([input_data] * outputs_per_input)
2775
+ else:
2776
+ # Add single copy for other inputs
2777
+ expanded_inputs.append(input_data)
2778
+
2779
+ # Update selected_inputs with the expanded list
2780
+ selected_inputs = expanded_inputs
2781
+
2782
  with st.spinner(
2783
  f"Generating outputs for {len(selected_inputs)} samples..."
2784
  ):
 
2793
  if selection_method == "Generate for all samples":
2794
  st.session_state.combined_data = generated_outputs
2795
  else:
2796
+ # For specific samples, we need to handle the case of multiple outputs
2797
+ if enable_multiple_outputs:
2798
+ # Simply use all generated outputs as the combined data
2799
+ st.session_state.combined_data = generated_outputs
2800
+ else:
2801
+ # Handle single outputs as before
2802
+ if not st.session_state.combined_data or len(
2803
+ st.session_state.combined_data
2804
+ ) != len(st.session_state.synthetic_inputs):
2805
+ st.session_state.combined_data = [None] * len(
2806
+ st.session_state.synthetic_inputs
 
 
 
 
 
2807
  )
2808
 
2809
+ # Update only the selected samples
2810
+ for i, output_idx in enumerate(
2811
+ st.session_state.selected_samples
2812
+ ):
2813
+ if i < len(generated_outputs):
2814
+ st.session_state.combined_data[output_idx] = (
2815
+ generated_outputs[i]
2816
+ )
2817
 
2818
+ # Remove any None values (samples that haven't been generated yet)
2819
+ st.session_state.combined_data = [
2820
+ item
2821
+ for item in st.session_state.combined_data
2822
+ if item is not None
2823
+ ]
2824
+
2825
+ st.success(f"Generated {len(generated_outputs)} outputs")
2826
 
2827
  # Display combined data if available
2828
  if st.session_state.combined_data:
 
2924
  else:
2925
  st.info(
2926
  "No template has been generated yet. Go to the 'Setup' tab to create one."
2927
+ )