DJHumanRPT commited on
Commit
843fdb1
·
verified ·
1 Parent(s): 8689bd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -43
app.py CHANGED
@@ -344,11 +344,10 @@ def create_example_outputs(template):
344
  return outputs
345
 
346
 
347
- # Add this function after generate_categorical_permutations function
348
  def calculate_cartesian_product_size(categorical_vars):
349
  """Calculate the size of the Cartesian product based on selected options."""
350
  if not categorical_vars:
351
- return 0
352
 
353
  # Calculate the product size
354
  product_size = 1
@@ -356,6 +355,7 @@ def calculate_cartesian_product_size(categorical_vars):
356
 
357
  for var in categorical_vars:
358
  options = var.get("options", [])
 
359
  selected_options = var.get("selected_options", options)
360
  min_sel = var.get("min", 1)
361
  max_sel = var.get("max", 1)
@@ -448,6 +448,9 @@ def parse_template_file(uploaded_template):
448
  template_content = uploaded_template.getvalue().decode("utf-8")
449
  template_spec = json.loads(template_content)
450
 
 
 
 
451
  # Validate the template structure
452
  required_keys = [
453
  "name",
@@ -491,6 +494,44 @@ def parse_template_file(uploaded_template):
491
  return None, f"Error parsing template file: {str(e)}"
492
 
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  # LLM call function
495
  def call_llm(prompt, model="gpt-3.5-turbo"):
496
  """Call the LLM API to generate text based on the prompt."""
@@ -772,6 +813,8 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
772
  ]
773
  non_categorical_vars = [var for var in input_vars if var not in categorical_vars]
774
 
 
 
775
  # Process in batches and show progress
776
  with st.spinner(f"Generating {num_samples} synthetic inputs..."):
777
  progress_bar = st.progress(0)
@@ -794,9 +837,18 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
794
 
795
  # Create a complete row by adding non-categorical values
796
  row = perm.copy()
797
- if non_categorical_vars:
 
 
 
 
 
 
 
 
 
798
  non_cat_values = generate_non_categorical_values(
799
- non_categorical_vars, perm, max_retries
800
  )
801
  row.update(non_cat_values)
802
 
@@ -1573,6 +1625,8 @@ with tab1:
1573
  if error:
1574
  st.error(error)
1575
  else:
 
 
1576
  st.success(f"Successfully loaded template: {template_spec['name']}")
1577
 
1578
  # Show template preview
@@ -1879,6 +1933,62 @@ with tab2:
1879
  st.success("Knowledge base updated")
1880
  st.rerun()
1881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1882
  # Knowledge Base Analysis Section
1883
  if st.session_state.knowledge_base:
1884
  with st.expander("Knowledge Base Analysis", expanded=False):
@@ -2004,9 +2114,45 @@ with tab2:
2004
  with col1:
2005
  # Create the appropriate input field based on variable type
2006
  if var_type == "string":
2007
- st.session_state.user_inputs[var_name] = st.text_input(
2008
- f"Enter value for {var_name}", key=f"use_{var_name}"
2009
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2010
  elif var_type == "int":
2011
  st.session_state.user_inputs[var_name] = (
2012
  st.number_input(
@@ -2616,6 +2762,10 @@ with tab3:
2616
  template_spec_copy = st.session_state.template_spec.copy()
2617
  template_spec_copy["input"] = st.session_state.template_spec["input"].copy()
2618
 
 
 
 
 
2619
  # For each categorical variable, allow selecting options
2620
  for i, var in enumerate(
2621
  [
@@ -2624,37 +2774,42 @@ with tab3:
2624
  if v["type"] == "categorical" and v.get("options")
2625
  ]
2626
  ):
 
 
 
 
 
 
 
 
 
2627
  with st.expander(
2628
  f"{var['name']} - {var['description']}", expanded=False
2629
  ):
2630
  options = var.get("options", [])
2631
 
2632
- # Initialize selected_options if not present
2633
- if "selected_options" not in var:
2634
- # First time initialization
2635
- var["selected_options"] = options.copy()
2636
- else:
2637
- # Filter selected_options to only include valid options
2638
- var["selected_options"] = [
2639
- opt
2640
- for opt in var.get("selected_options", [])
2641
- if opt in options
2642
- ]
2643
 
2644
- # Check for new options that need to be automatically selected
2645
- previous_options = var.get("previous_options", [])
 
 
2646
 
2647
- # Find new options that weren't in the previous options list
2648
- new_options = [
2649
- opt for opt in options if opt not in previous_options
2650
- ]
 
 
 
2651
 
2652
- # Add new options to selected_options
2653
- if new_options:
2654
- var["selected_options"].extend(new_options)
2655
 
2656
  # Store current options for future comparison
2657
- var["previous_options"] = options.copy()
2658
 
2659
  # Add "Select All" and "Clear All" buttons
2660
  col1, col2 = st.columns([1, 1])
@@ -2663,33 +2818,36 @@ with tab3:
2663
  f"Select All Options for {var['name']}",
2664
  key=f"select_all_{i}",
2665
  ):
2666
- var["selected_options"] = options.copy()
2667
  with col2:
2668
  if st.button(
2669
  f"Clear All Options for {var['name']}", key=f"clear_all_{i}"
2670
  ):
2671
- var["selected_options"] = []
2672
 
2673
  # Create multiselect for options
2674
- var["selected_options"] = st.multiselect(
2675
  f"Select options to include for {var['name']}",
2676
  options=options,
2677
- default=var.get(
2678
- "selected_options", []
2679
- ), # Use empty list as fallback
2680
  key=f"options_select_{i}",
2681
  )
2682
 
2683
  # Show selected count
2684
  st.write(
2685
- f"Selected {len(var['selected_options'])} out of {len(options)} options"
2686
  )
2687
 
2688
- # Update the template spec with the selected options
2689
- for j, input_var in enumerate(template_spec_copy["input"]):
2690
- if input_var["name"] == var["name"]:
2691
- template_spec_copy["input"][j] = var
2692
- break
 
 
 
 
 
2693
 
2694
  # Calculate and display Cartesian product size
2695
  product_size, var_counts = calculate_cartesian_product_size(
@@ -2723,17 +2881,38 @@ with tab3:
2723
  )
2724
  else:
2725
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
2726
- # Use the modified template spec with selected options
 
 
 
 
 
 
2727
  if categorical_vars:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2728
  st.session_state.synthetic_inputs = (
2729
  generate_synthetic_inputs_hybrid(
2730
- template_spec_copy, num_samples=num_samples
2731
  )
2732
  )
2733
  else:
2734
  st.session_state.synthetic_inputs = (
2735
  generate_synthetic_inputs_hybrid(
2736
- st.session_state.template_spec, num_samples=num_samples
2737
  )
2738
  )
2739
 
 
344
  return outputs
345
 
346
 
 
347
  def calculate_cartesian_product_size(categorical_vars):
348
  """Calculate the size of the Cartesian product based on selected options."""
349
  if not categorical_vars:
350
+ return 0, []
351
 
352
  # Calculate the product size
353
  product_size = 1
 
355
 
356
  for var in categorical_vars:
357
  options = var.get("options", [])
358
+ # Use selected_options if available, otherwise use all options
359
  selected_options = var.get("selected_options", options)
360
  min_sel = var.get("min", 1)
361
  max_sel = var.get("max", 1)
 
448
  template_content = uploaded_template.getvalue().decode("utf-8")
449
  template_spec = json.loads(template_content)
450
 
451
+ # Sanitize the template to remove UI-specific keys
452
+ template_spec = sanitize_template_spec(template_spec)
453
+
454
  # Validate the template structure
455
  required_keys = [
456
  "name",
 
494
  return None, f"Error parsing template file: {str(e)}"
495
 
496
 
497
+ def sanitize_template_spec(template_spec):
498
+ """
499
+ Remove UI-specific keys from template specification that shouldn't be part of the template.
500
+
501
+ Args:
502
+ template_spec (dict): The template specification to sanitize
503
+
504
+ Returns:
505
+ dict: Sanitized template specification
506
+ """
507
+ if not template_spec:
508
+ return template_spec
509
+
510
+ # Create a deep copy to avoid modifying the original
511
+ sanitized_spec = template_spec.copy()
512
+
513
+ # List of UI-specific keys that should be removed
514
+ ui_specific_keys = ["previous_options", "selected_options"]
515
+
516
+ # Clean input variables
517
+ if "input" in sanitized_spec and isinstance(sanitized_spec["input"], list):
518
+ for i, var in enumerate(sanitized_spec["input"]):
519
+ # Remove UI-specific keys from each variable
520
+ sanitized_spec["input"][i] = {
521
+ k: v for k, v in var.items() if k not in ui_specific_keys
522
+ }
523
+
524
+ # Clean output variables
525
+ if "output" in sanitized_spec and isinstance(sanitized_spec["output"], list):
526
+ for i, var in enumerate(sanitized_spec["output"]):
527
+ # Remove UI-specific keys from each variable
528
+ sanitized_spec["output"][i] = {
529
+ k: v for k, v in var.items() if k not in ui_specific_keys
530
+ }
531
+
532
+ return sanitized_spec
533
+
534
+
535
  # LLM call function
536
  def call_llm(prompt, model="gpt-3.5-turbo"):
537
  """Call the LLM API to generate text based on the prompt."""
 
813
  ]
814
  non_categorical_vars = [var for var in input_vars if var not in categorical_vars]
815
 
816
+ default_value_vars = [var for var in input_vars if "default_value" in var]
817
+
818
  # Process in batches and show progress
819
  with st.spinner(f"Generating {num_samples} synthetic inputs..."):
820
  progress_bar = st.progress(0)
 
837
 
838
  # Create a complete row by adding non-categorical values
839
  row = perm.copy()
840
+
841
+ # Add default values first
842
+ for var in default_value_vars:
843
+ row[var["name"]] = var["default_value"]
844
+
845
+ # Generate values for remaining non-categorical variables
846
+ remaining_non_cat_vars = [
847
+ var for var in non_categorical_vars if var not in default_value_vars
848
+ ]
849
+ if remaining_non_cat_vars:
850
  non_cat_values = generate_non_categorical_values(
851
+ remaining_non_cat_vars, perm, max_retries
852
  )
853
  row.update(non_cat_values)
854
 
 
1625
  if error:
1626
  st.error(error)
1627
  else:
1628
+ # Sanitize the template to remove UI-specific keys
1629
+ template_spec = sanitize_template_spec(template_spec)
1630
  st.success(f"Successfully loaded template: {template_spec['name']}")
1631
 
1632
  # Show template preview
 
1933
  st.success("Knowledge base updated")
1934
  st.rerun()
1935
 
1936
+ # Add knowledge base as input variable option
1937
+ if st.session_state.knowledge_base:
1938
+ kb_var_option = st.checkbox(
1939
+ "Create input variable from knowledge base"
1940
+ )
1941
+
1942
+ if kb_var_option:
1943
+ # Allow editing the content to include as variable
1944
+ kb_content = st.text_area(
1945
+ "Edit knowledge base content for input variable",
1946
+ value=st.session_state.knowledge_base,
1947
+ height=300,
1948
+ )
1949
+
1950
+ # Create input variable name
1951
+ kb_var_name = st.text_input(
1952
+ "Input variable name", value="kb_content"
1953
+ )
1954
+
1955
+ # Add button to create the input variable
1956
+ if st.button("Add as input variable"):
1957
+ # Check if variable already exists
1958
+ var_exists = False
1959
+ for var in st.session_state.template_spec["input"]:
1960
+ if var["name"] == kb_var_name:
1961
+ var_exists = True
1962
+ var["description"] = "Knowledge base content"
1963
+ var["type"] = "string"
1964
+ var["default_value"] = kb_content
1965
+ st.success(
1966
+ f"Updated existing input variable '{kb_var_name}'"
1967
+ )
1968
+ break
1969
+
1970
+ if not var_exists:
1971
+ # Create new input variable
1972
+ new_var = {
1973
+ "name": kb_var_name,
1974
+ "description": "Knowledge base content",
1975
+ "type": "string",
1976
+ "min": len(kb_content),
1977
+ "max": len(kb_content) * 2,
1978
+ "default_value": kb_content,
1979
+ }
1980
+ st.session_state.template_spec["input"].append(
1981
+ new_var
1982
+ )
1983
+ st.success(
1984
+ f"Added new input variable '{kb_var_name}'"
1985
+ )
1986
+
1987
+ # Remind user to update prompt template
1988
+ st.info(
1989
+ f"Remember to use {{{kb_var_name}}} in your prompt template"
1990
+ )
1991
+
1992
  # Knowledge Base Analysis Section
1993
  if st.session_state.knowledge_base:
1994
  with st.expander("Knowledge Base Analysis", expanded=False):
 
2114
  with col1:
2115
  # Create the appropriate input field based on variable type
2116
  if var_type == "string":
2117
+ # Check if this is a knowledge base variable with default value
2118
+ if "default_value" in input_var:
2119
+ use_default = st.checkbox(
2120
+ f"Use default value for {var_name}",
2121
+ value=True,
2122
+ key=f"use_default_{var_name}",
2123
+ )
2124
+ if use_default:
2125
+ st.session_state.user_inputs[var_name] = (
2126
+ input_var["default_value"]
2127
+ )
2128
+ st.text_area(
2129
+ f"Default value for {var_name}",
2130
+ value=input_var["default_value"][:500]
2131
+ + (
2132
+ "..."
2133
+ if len(input_var["default_value"]) > 500
2134
+ else ""
2135
+ ),
2136
+ height=150,
2137
+ disabled=True,
2138
+ key=f"preview_{var_name}",
2139
+ )
2140
+ else:
2141
+ st.session_state.user_inputs[var_name] = (
2142
+ st.text_area(
2143
+ f"Enter value for {var_name}",
2144
+ value=input_var["default_value"],
2145
+ height=150,
2146
+ key=f"use_{var_name}",
2147
+ )
2148
+ )
2149
+ else:
2150
+ st.session_state.user_inputs[var_name] = (
2151
+ st.text_input(
2152
+ f"Enter value for {var_name}",
2153
+ key=f"use_{var_name}",
2154
+ )
2155
+ )
2156
  elif var_type == "int":
2157
  st.session_state.user_inputs[var_name] = (
2158
  st.number_input(
 
2762
  template_spec_copy = st.session_state.template_spec.copy()
2763
  template_spec_copy["input"] = st.session_state.template_spec["input"].copy()
2764
 
2765
+ # Initialize UI state for categorical variables if not present
2766
+ if "categorical_ui_state" not in st.session_state:
2767
+ st.session_state.categorical_ui_state = {}
2768
+
2769
  # For each categorical variable, allow selecting options
2770
  for i, var in enumerate(
2771
  [
 
2774
  if v["type"] == "categorical" and v.get("options")
2775
  ]
2776
  ):
2777
+ var_name = var["name"]
2778
+
2779
+ # Initialize UI state for this variable if not present
2780
+ if var_name not in st.session_state.categorical_ui_state:
2781
+ st.session_state.categorical_ui_state[var_name] = {
2782
+ "selected_options": var.get("options", []).copy(),
2783
+ "previous_options": var.get("options", []).copy(),
2784
+ }
2785
+
2786
  with st.expander(
2787
  f"{var['name']} - {var['description']}", expanded=False
2788
  ):
2789
  options = var.get("options", [])
2790
 
2791
+ # Get UI state for this variable
2792
+ ui_state = st.session_state.categorical_ui_state[var_name]
 
 
 
 
 
 
 
 
 
2793
 
2794
+ # Filter selected_options to only include valid options
2795
+ ui_state["selected_options"] = [
2796
+ opt for opt in ui_state["selected_options"] if opt in options
2797
+ ]
2798
 
2799
+ # Check for new options that need to be automatically selected
2800
+ previous_options = ui_state["previous_options"]
2801
+
2802
+ # Find new options that weren't in the previous options list
2803
+ new_options = [
2804
+ opt for opt in options if opt not in previous_options
2805
+ ]
2806
 
2807
+ # Add new options to selected_options
2808
+ if new_options:
2809
+ ui_state["selected_options"].extend(new_options)
2810
 
2811
  # Store current options for future comparison
2812
+ ui_state["previous_options"] = options.copy()
2813
 
2814
  # Add "Select All" and "Clear All" buttons
2815
  col1, col2 = st.columns([1, 1])
 
2818
  f"Select All Options for {var['name']}",
2819
  key=f"select_all_{i}",
2820
  ):
2821
+ ui_state["selected_options"] = options.copy()
2822
  with col2:
2823
  if st.button(
2824
  f"Clear All Options for {var['name']}", key=f"clear_all_{i}"
2825
  ):
2826
+ ui_state["selected_options"] = []
2827
 
2828
  # Create multiselect for options
2829
+ ui_state["selected_options"] = st.multiselect(
2830
  f"Select options to include for {var['name']}",
2831
  options=options,
2832
+ default=ui_state["selected_options"],
 
 
2833
  key=f"options_select_{i}",
2834
  )
2835
 
2836
  # Show selected count
2837
  st.write(
2838
+ f"Selected {len(ui_state['selected_options'])} out of {len(options)} options"
2839
  )
2840
 
2841
+ # Create a temporary copy of the variable with selected_options for the calculation
2842
+ # but don't modify the actual template
2843
+ var_copy = var.copy()
2844
+ var_copy["selected_options"] = ui_state["selected_options"]
2845
+
2846
+ # Update the template spec copy with the selected options for calculation purposes only
2847
+ for j, input_var in enumerate(template_spec_copy["input"]):
2848
+ if input_var["name"] == var["name"]:
2849
+ template_spec_copy["input"][j] = var_copy
2850
+ break
2851
 
2852
  # Calculate and display Cartesian product size
2853
  product_size, var_counts = calculate_cartesian_product_size(
 
2881
  )
2882
  else:
2883
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
2884
+ # Create a clean template spec without UI state variables
2885
+ clean_template_spec = st.session_state.template_spec.copy()
2886
+ clean_template_spec["input"] = st.session_state.template_spec[
2887
+ "input"
2888
+ ].copy()
2889
+
2890
+ # If we have categorical variables, apply the selected options from UI state
2891
  if categorical_vars:
2892
+ for i, var in enumerate(clean_template_spec["input"]):
2893
+ if (
2894
+ var["type"] == "categorical"
2895
+ and var.get("options")
2896
+ and var["name"] in st.session_state.categorical_ui_state
2897
+ ):
2898
+ # Create a copy of the variable with selected_options for generation
2899
+ var_copy = var.copy()
2900
+ var_copy["selected_options"] = (
2901
+ st.session_state.categorical_ui_state[var["name"]][
2902
+ "selected_options"
2903
+ ]
2904
+ )
2905
+ clean_template_spec["input"][i] = var_copy
2906
+
2907
  st.session_state.synthetic_inputs = (
2908
  generate_synthetic_inputs_hybrid(
2909
+ clean_template_spec, num_samples=num_samples
2910
  )
2911
  )
2912
  else:
2913
  st.session_state.synthetic_inputs = (
2914
  generate_synthetic_inputs_hybrid(
2915
+ clean_template_spec, num_samples=num_samples
2916
  )
2917
  )
2918