Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -344,11 +344,10 @@ def create_example_outputs(template):
|
|
| 344 |
return outputs
|
| 345 |
|
| 346 |
|
| 347 |
-
# Add this function after generate_categorical_permutations function
|
| 348 |
def calculate_cartesian_product_size(categorical_vars):
|
| 349 |
"""Calculate the size of the Cartesian product based on selected options."""
|
| 350 |
if not categorical_vars:
|
| 351 |
-
return 0
|
| 352 |
|
| 353 |
# Calculate the product size
|
| 354 |
product_size = 1
|
|
@@ -356,6 +355,7 @@ def calculate_cartesian_product_size(categorical_vars):
|
|
| 356 |
|
| 357 |
for var in categorical_vars:
|
| 358 |
options = var.get("options", [])
|
|
|
|
| 359 |
selected_options = var.get("selected_options", options)
|
| 360 |
min_sel = var.get("min", 1)
|
| 361 |
max_sel = var.get("max", 1)
|
|
@@ -448,6 +448,9 @@ def parse_template_file(uploaded_template):
|
|
| 448 |
template_content = uploaded_template.getvalue().decode("utf-8")
|
| 449 |
template_spec = json.loads(template_content)
|
| 450 |
|
|
|
|
|
|
|
|
|
|
| 451 |
# Validate the template structure
|
| 452 |
required_keys = [
|
| 453 |
"name",
|
|
@@ -491,6 +494,44 @@ def parse_template_file(uploaded_template):
|
|
| 491 |
return None, f"Error parsing template file: {str(e)}"
|
| 492 |
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
# LLM call function
|
| 495 |
def call_llm(prompt, model="gpt-3.5-turbo"):
|
| 496 |
"""Call the LLM API to generate text based on the prompt."""
|
|
@@ -772,6 +813,8 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
|
|
| 772 |
]
|
| 773 |
non_categorical_vars = [var for var in input_vars if var not in categorical_vars]
|
| 774 |
|
|
|
|
|
|
|
| 775 |
# Process in batches and show progress
|
| 776 |
with st.spinner(f"Generating {num_samples} synthetic inputs..."):
|
| 777 |
progress_bar = st.progress(0)
|
|
@@ -794,9 +837,18 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
|
|
| 794 |
|
| 795 |
# Create a complete row by adding non-categorical values
|
| 796 |
row = perm.copy()
|
| 797 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
non_cat_values = generate_non_categorical_values(
|
| 799 |
-
|
| 800 |
)
|
| 801 |
row.update(non_cat_values)
|
| 802 |
|
|
@@ -1573,6 +1625,8 @@ with tab1:
|
|
| 1573 |
if error:
|
| 1574 |
st.error(error)
|
| 1575 |
else:
|
|
|
|
|
|
|
| 1576 |
st.success(f"Successfully loaded template: {template_spec['name']}")
|
| 1577 |
|
| 1578 |
# Show template preview
|
|
@@ -1879,6 +1933,62 @@ with tab2:
|
|
| 1879 |
st.success("Knowledge base updated")
|
| 1880 |
st.rerun()
|
| 1881 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1882 |
# Knowledge Base Analysis Section
|
| 1883 |
if st.session_state.knowledge_base:
|
| 1884 |
with st.expander("Knowledge Base Analysis", expanded=False):
|
|
@@ -2004,9 +2114,45 @@ with tab2:
|
|
| 2004 |
with col1:
|
| 2005 |
# Create the appropriate input field based on variable type
|
| 2006 |
if var_type == "string":
|
| 2007 |
-
|
| 2008 |
-
|
| 2009 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2010 |
elif var_type == "int":
|
| 2011 |
st.session_state.user_inputs[var_name] = (
|
| 2012 |
st.number_input(
|
|
@@ -2616,6 +2762,10 @@ with tab3:
|
|
| 2616 |
template_spec_copy = st.session_state.template_spec.copy()
|
| 2617 |
template_spec_copy["input"] = st.session_state.template_spec["input"].copy()
|
| 2618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2619 |
# For each categorical variable, allow selecting options
|
| 2620 |
for i, var in enumerate(
|
| 2621 |
[
|
|
@@ -2624,37 +2774,42 @@ with tab3:
|
|
| 2624 |
if v["type"] == "categorical" and v.get("options")
|
| 2625 |
]
|
| 2626 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2627 |
with st.expander(
|
| 2628 |
f"{var['name']} - {var['description']}", expanded=False
|
| 2629 |
):
|
| 2630 |
options = var.get("options", [])
|
| 2631 |
|
| 2632 |
-
#
|
| 2633 |
-
|
| 2634 |
-
# First time initialization
|
| 2635 |
-
var["selected_options"] = options.copy()
|
| 2636 |
-
else:
|
| 2637 |
-
# Filter selected_options to only include valid options
|
| 2638 |
-
var["selected_options"] = [
|
| 2639 |
-
opt
|
| 2640 |
-
for opt in var.get("selected_options", [])
|
| 2641 |
-
if opt in options
|
| 2642 |
-
]
|
| 2643 |
|
| 2644 |
-
|
| 2645 |
-
|
|
|
|
|
|
|
| 2646 |
|
| 2647 |
-
|
| 2648 |
-
|
| 2649 |
-
|
| 2650 |
-
|
|
|
|
|
|
|
|
|
|
| 2651 |
|
| 2652 |
-
|
| 2653 |
-
|
| 2654 |
-
|
| 2655 |
|
| 2656 |
# Store current options for future comparison
|
| 2657 |
-
|
| 2658 |
|
| 2659 |
# Add "Select All" and "Clear All" buttons
|
| 2660 |
col1, col2 = st.columns([1, 1])
|
|
@@ -2663,33 +2818,36 @@ with tab3:
|
|
| 2663 |
f"Select All Options for {var['name']}",
|
| 2664 |
key=f"select_all_{i}",
|
| 2665 |
):
|
| 2666 |
-
|
| 2667 |
with col2:
|
| 2668 |
if st.button(
|
| 2669 |
f"Clear All Options for {var['name']}", key=f"clear_all_{i}"
|
| 2670 |
):
|
| 2671 |
-
|
| 2672 |
|
| 2673 |
# Create multiselect for options
|
| 2674 |
-
|
| 2675 |
f"Select options to include for {var['name']}",
|
| 2676 |
options=options,
|
| 2677 |
-
default=
|
| 2678 |
-
"selected_options", []
|
| 2679 |
-
), # Use empty list as fallback
|
| 2680 |
key=f"options_select_{i}",
|
| 2681 |
)
|
| 2682 |
|
| 2683 |
# Show selected count
|
| 2684 |
st.write(
|
| 2685 |
-
f"Selected {len(
|
| 2686 |
)
|
| 2687 |
|
| 2688 |
-
|
| 2689 |
-
|
| 2690 |
-
|
| 2691 |
-
|
| 2692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2693 |
|
| 2694 |
# Calculate and display Cartesian product size
|
| 2695 |
product_size, var_counts = calculate_cartesian_product_size(
|
|
@@ -2723,17 +2881,38 @@ with tab3:
|
|
| 2723 |
)
|
| 2724 |
else:
|
| 2725 |
with st.spinner(f"Generating {num_samples} synthetic input samples..."):
|
| 2726 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2727 |
if categorical_vars:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2728 |
st.session_state.synthetic_inputs = (
|
| 2729 |
generate_synthetic_inputs_hybrid(
|
| 2730 |
-
|
| 2731 |
)
|
| 2732 |
)
|
| 2733 |
else:
|
| 2734 |
st.session_state.synthetic_inputs = (
|
| 2735 |
generate_synthetic_inputs_hybrid(
|
| 2736 |
-
|
| 2737 |
)
|
| 2738 |
)
|
| 2739 |
|
|
|
|
| 344 |
return outputs
|
| 345 |
|
| 346 |
|
|
|
|
| 347 |
def calculate_cartesian_product_size(categorical_vars):
|
| 348 |
"""Calculate the size of the Cartesian product based on selected options."""
|
| 349 |
if not categorical_vars:
|
| 350 |
+
return 0, []
|
| 351 |
|
| 352 |
# Calculate the product size
|
| 353 |
product_size = 1
|
|
|
|
| 355 |
|
| 356 |
for var in categorical_vars:
|
| 357 |
options = var.get("options", [])
|
| 358 |
+
# Use selected_options if available, otherwise use all options
|
| 359 |
selected_options = var.get("selected_options", options)
|
| 360 |
min_sel = var.get("min", 1)
|
| 361 |
max_sel = var.get("max", 1)
|
|
|
|
| 448 |
template_content = uploaded_template.getvalue().decode("utf-8")
|
| 449 |
template_spec = json.loads(template_content)
|
| 450 |
|
| 451 |
+
# Sanitize the template to remove UI-specific keys
|
| 452 |
+
template_spec = sanitize_template_spec(template_spec)
|
| 453 |
+
|
| 454 |
# Validate the template structure
|
| 455 |
required_keys = [
|
| 456 |
"name",
|
|
|
|
| 494 |
return None, f"Error parsing template file: {str(e)}"
|
| 495 |
|
| 496 |
|
| 497 |
+
def sanitize_template_spec(template_spec):
|
| 498 |
+
"""
|
| 499 |
+
Remove UI-specific keys from template specification that shouldn't be part of the template.
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
template_spec (dict): The template specification to sanitize
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
dict: Sanitized template specification
|
| 506 |
+
"""
|
| 507 |
+
if not template_spec:
|
| 508 |
+
return template_spec
|
| 509 |
+
|
| 510 |
+
# Create a deep copy to avoid modifying the original
|
| 511 |
+
sanitized_spec = template_spec.copy()
|
| 512 |
+
|
| 513 |
+
# List of UI-specific keys that should be removed
|
| 514 |
+
ui_specific_keys = ["previous_options", "selected_options"]
|
| 515 |
+
|
| 516 |
+
# Clean input variables
|
| 517 |
+
if "input" in sanitized_spec and isinstance(sanitized_spec["input"], list):
|
| 518 |
+
for i, var in enumerate(sanitized_spec["input"]):
|
| 519 |
+
# Remove UI-specific keys from each variable
|
| 520 |
+
sanitized_spec["input"][i] = {
|
| 521 |
+
k: v for k, v in var.items() if k not in ui_specific_keys
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
# Clean output variables
|
| 525 |
+
if "output" in sanitized_spec and isinstance(sanitized_spec["output"], list):
|
| 526 |
+
for i, var in enumerate(sanitized_spec["output"]):
|
| 527 |
+
# Remove UI-specific keys from each variable
|
| 528 |
+
sanitized_spec["output"][i] = {
|
| 529 |
+
k: v for k, v in var.items() if k not in ui_specific_keys
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
return sanitized_spec
|
| 533 |
+
|
| 534 |
+
|
| 535 |
# LLM call function
|
| 536 |
def call_llm(prompt, model="gpt-3.5-turbo"):
|
| 537 |
"""Call the LLM API to generate text based on the prompt."""
|
|
|
|
| 813 |
]
|
| 814 |
non_categorical_vars = [var for var in input_vars if var not in categorical_vars]
|
| 815 |
|
| 816 |
+
default_value_vars = [var for var in input_vars if "default_value" in var]
|
| 817 |
+
|
| 818 |
# Process in batches and show progress
|
| 819 |
with st.spinner(f"Generating {num_samples} synthetic inputs..."):
|
| 820 |
progress_bar = st.progress(0)
|
|
|
|
| 837 |
|
| 838 |
# Create a complete row by adding non-categorical values
|
| 839 |
row = perm.copy()
|
| 840 |
+
|
| 841 |
+
# Add default values first
|
| 842 |
+
for var in default_value_vars:
|
| 843 |
+
row[var["name"]] = var["default_value"]
|
| 844 |
+
|
| 845 |
+
# Generate values for remaining non-categorical variables
|
| 846 |
+
remaining_non_cat_vars = [
|
| 847 |
+
var for var in non_categorical_vars if var not in default_value_vars
|
| 848 |
+
]
|
| 849 |
+
if remaining_non_cat_vars:
|
| 850 |
non_cat_values = generate_non_categorical_values(
|
| 851 |
+
remaining_non_cat_vars, perm, max_retries
|
| 852 |
)
|
| 853 |
row.update(non_cat_values)
|
| 854 |
|
|
|
|
| 1625 |
if error:
|
| 1626 |
st.error(error)
|
| 1627 |
else:
|
| 1628 |
+
# Sanitize the template to remove UI-specific keys
|
| 1629 |
+
template_spec = sanitize_template_spec(template_spec)
|
| 1630 |
st.success(f"Successfully loaded template: {template_spec['name']}")
|
| 1631 |
|
| 1632 |
# Show template preview
|
|
|
|
| 1933 |
st.success("Knowledge base updated")
|
| 1934 |
st.rerun()
|
| 1935 |
|
| 1936 |
+
# Add knowledge base as input variable option
|
| 1937 |
+
if st.session_state.knowledge_base:
|
| 1938 |
+
kb_var_option = st.checkbox(
|
| 1939 |
+
"Create input variable from knowledge base"
|
| 1940 |
+
)
|
| 1941 |
+
|
| 1942 |
+
if kb_var_option:
|
| 1943 |
+
# Allow editing the content to include as variable
|
| 1944 |
+
kb_content = st.text_area(
|
| 1945 |
+
"Edit knowledge base content for input variable",
|
| 1946 |
+
value=st.session_state.knowledge_base,
|
| 1947 |
+
height=300,
|
| 1948 |
+
)
|
| 1949 |
+
|
| 1950 |
+
# Create input variable name
|
| 1951 |
+
kb_var_name = st.text_input(
|
| 1952 |
+
"Input variable name", value="kb_content"
|
| 1953 |
+
)
|
| 1954 |
+
|
| 1955 |
+
# Add button to create the input variable
|
| 1956 |
+
if st.button("Add as input variable"):
|
| 1957 |
+
# Check if variable already exists
|
| 1958 |
+
var_exists = False
|
| 1959 |
+
for var in st.session_state.template_spec["input"]:
|
| 1960 |
+
if var["name"] == kb_var_name:
|
| 1961 |
+
var_exists = True
|
| 1962 |
+
var["description"] = "Knowledge base content"
|
| 1963 |
+
var["type"] = "string"
|
| 1964 |
+
var["default_value"] = kb_content
|
| 1965 |
+
st.success(
|
| 1966 |
+
f"Updated existing input variable '{kb_var_name}'"
|
| 1967 |
+
)
|
| 1968 |
+
break
|
| 1969 |
+
|
| 1970 |
+
if not var_exists:
|
| 1971 |
+
# Create new input variable
|
| 1972 |
+
new_var = {
|
| 1973 |
+
"name": kb_var_name,
|
| 1974 |
+
"description": "Knowledge base content",
|
| 1975 |
+
"type": "string",
|
| 1976 |
+
"min": len(kb_content),
|
| 1977 |
+
"max": len(kb_content) * 2,
|
| 1978 |
+
"default_value": kb_content,
|
| 1979 |
+
}
|
| 1980 |
+
st.session_state.template_spec["input"].append(
|
| 1981 |
+
new_var
|
| 1982 |
+
)
|
| 1983 |
+
st.success(
|
| 1984 |
+
f"Added new input variable '{kb_var_name}'"
|
| 1985 |
+
)
|
| 1986 |
+
|
| 1987 |
+
# Remind user to update prompt template
|
| 1988 |
+
st.info(
|
| 1989 |
+
f"Remember to use {{{kb_var_name}}} in your prompt template"
|
| 1990 |
+
)
|
| 1991 |
+
|
| 1992 |
# Knowledge Base Analysis Section
|
| 1993 |
if st.session_state.knowledge_base:
|
| 1994 |
with st.expander("Knowledge Base Analysis", expanded=False):
|
|
|
|
| 2114 |
with col1:
|
| 2115 |
# Create the appropriate input field based on variable type
|
| 2116 |
if var_type == "string":
|
| 2117 |
+
# Check if this is a knowledge base variable with default value
|
| 2118 |
+
if "default_value" in input_var:
|
| 2119 |
+
use_default = st.checkbox(
|
| 2120 |
+
f"Use default value for {var_name}",
|
| 2121 |
+
value=True,
|
| 2122 |
+
key=f"use_default_{var_name}",
|
| 2123 |
+
)
|
| 2124 |
+
if use_default:
|
| 2125 |
+
st.session_state.user_inputs[var_name] = (
|
| 2126 |
+
input_var["default_value"]
|
| 2127 |
+
)
|
| 2128 |
+
st.text_area(
|
| 2129 |
+
f"Default value for {var_name}",
|
| 2130 |
+
value=input_var["default_value"][:500]
|
| 2131 |
+
+ (
|
| 2132 |
+
"..."
|
| 2133 |
+
if len(input_var["default_value"]) > 500
|
| 2134 |
+
else ""
|
| 2135 |
+
),
|
| 2136 |
+
height=150,
|
| 2137 |
+
disabled=True,
|
| 2138 |
+
key=f"preview_{var_name}",
|
| 2139 |
+
)
|
| 2140 |
+
else:
|
| 2141 |
+
st.session_state.user_inputs[var_name] = (
|
| 2142 |
+
st.text_area(
|
| 2143 |
+
f"Enter value for {var_name}",
|
| 2144 |
+
value=input_var["default_value"],
|
| 2145 |
+
height=150,
|
| 2146 |
+
key=f"use_{var_name}",
|
| 2147 |
+
)
|
| 2148 |
+
)
|
| 2149 |
+
else:
|
| 2150 |
+
st.session_state.user_inputs[var_name] = (
|
| 2151 |
+
st.text_input(
|
| 2152 |
+
f"Enter value for {var_name}",
|
| 2153 |
+
key=f"use_{var_name}",
|
| 2154 |
+
)
|
| 2155 |
+
)
|
| 2156 |
elif var_type == "int":
|
| 2157 |
st.session_state.user_inputs[var_name] = (
|
| 2158 |
st.number_input(
|
|
|
|
| 2762 |
template_spec_copy = st.session_state.template_spec.copy()
|
| 2763 |
template_spec_copy["input"] = st.session_state.template_spec["input"].copy()
|
| 2764 |
|
| 2765 |
+
# Initialize UI state for categorical variables if not present
|
| 2766 |
+
if "categorical_ui_state" not in st.session_state:
|
| 2767 |
+
st.session_state.categorical_ui_state = {}
|
| 2768 |
+
|
| 2769 |
# For each categorical variable, allow selecting options
|
| 2770 |
for i, var in enumerate(
|
| 2771 |
[
|
|
|
|
| 2774 |
if v["type"] == "categorical" and v.get("options")
|
| 2775 |
]
|
| 2776 |
):
|
| 2777 |
+
var_name = var["name"]
|
| 2778 |
+
|
| 2779 |
+
# Initialize UI state for this variable if not present
|
| 2780 |
+
if var_name not in st.session_state.categorical_ui_state:
|
| 2781 |
+
st.session_state.categorical_ui_state[var_name] = {
|
| 2782 |
+
"selected_options": var.get("options", []).copy(),
|
| 2783 |
+
"previous_options": var.get("options", []).copy(),
|
| 2784 |
+
}
|
| 2785 |
+
|
| 2786 |
with st.expander(
|
| 2787 |
f"{var['name']} - {var['description']}", expanded=False
|
| 2788 |
):
|
| 2789 |
options = var.get("options", [])
|
| 2790 |
|
| 2791 |
+
# Get UI state for this variable
|
| 2792 |
+
ui_state = st.session_state.categorical_ui_state[var_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2793 |
|
| 2794 |
+
# Filter selected_options to only include valid options
|
| 2795 |
+
ui_state["selected_options"] = [
|
| 2796 |
+
opt for opt in ui_state["selected_options"] if opt in options
|
| 2797 |
+
]
|
| 2798 |
|
| 2799 |
+
# Check for new options that need to be automatically selected
|
| 2800 |
+
previous_options = ui_state["previous_options"]
|
| 2801 |
+
|
| 2802 |
+
# Find new options that weren't in the previous options list
|
| 2803 |
+
new_options = [
|
| 2804 |
+
opt for opt in options if opt not in previous_options
|
| 2805 |
+
]
|
| 2806 |
|
| 2807 |
+
# Add new options to selected_options
|
| 2808 |
+
if new_options:
|
| 2809 |
+
ui_state["selected_options"].extend(new_options)
|
| 2810 |
|
| 2811 |
# Store current options for future comparison
|
| 2812 |
+
ui_state["previous_options"] = options.copy()
|
| 2813 |
|
| 2814 |
# Add "Select All" and "Clear All" buttons
|
| 2815 |
col1, col2 = st.columns([1, 1])
|
|
|
|
| 2818 |
f"Select All Options for {var['name']}",
|
| 2819 |
key=f"select_all_{i}",
|
| 2820 |
):
|
| 2821 |
+
ui_state["selected_options"] = options.copy()
|
| 2822 |
with col2:
|
| 2823 |
if st.button(
|
| 2824 |
f"Clear All Options for {var['name']}", key=f"clear_all_{i}"
|
| 2825 |
):
|
| 2826 |
+
ui_state["selected_options"] = []
|
| 2827 |
|
| 2828 |
# Create multiselect for options
|
| 2829 |
+
ui_state["selected_options"] = st.multiselect(
|
| 2830 |
f"Select options to include for {var['name']}",
|
| 2831 |
options=options,
|
| 2832 |
+
default=ui_state["selected_options"],
|
|
|
|
|
|
|
| 2833 |
key=f"options_select_{i}",
|
| 2834 |
)
|
| 2835 |
|
| 2836 |
# Show selected count
|
| 2837 |
st.write(
|
| 2838 |
+
f"Selected {len(ui_state['selected_options'])} out of {len(options)} options"
|
| 2839 |
)
|
| 2840 |
|
| 2841 |
+
# Create a temporary copy of the variable with selected_options for the calculation
|
| 2842 |
+
# but don't modify the actual template
|
| 2843 |
+
var_copy = var.copy()
|
| 2844 |
+
var_copy["selected_options"] = ui_state["selected_options"]
|
| 2845 |
+
|
| 2846 |
+
# Update the template spec copy with the selected options for calculation purposes only
|
| 2847 |
+
for j, input_var in enumerate(template_spec_copy["input"]):
|
| 2848 |
+
if input_var["name"] == var["name"]:
|
| 2849 |
+
template_spec_copy["input"][j] = var_copy
|
| 2850 |
+
break
|
| 2851 |
|
| 2852 |
# Calculate and display Cartesian product size
|
| 2853 |
product_size, var_counts = calculate_cartesian_product_size(
|
|
|
|
| 2881 |
)
|
| 2882 |
else:
|
| 2883 |
with st.spinner(f"Generating {num_samples} synthetic input samples..."):
|
| 2884 |
+
# Create a clean template spec without UI state variables
|
| 2885 |
+
clean_template_spec = st.session_state.template_spec.copy()
|
| 2886 |
+
clean_template_spec["input"] = st.session_state.template_spec[
|
| 2887 |
+
"input"
|
| 2888 |
+
].copy()
|
| 2889 |
+
|
| 2890 |
+
# If we have categorical variables, apply the selected options from UI state
|
| 2891 |
if categorical_vars:
|
| 2892 |
+
for i, var in enumerate(clean_template_spec["input"]):
|
| 2893 |
+
if (
|
| 2894 |
+
var["type"] == "categorical"
|
| 2895 |
+
and var.get("options")
|
| 2896 |
+
and var["name"] in st.session_state.categorical_ui_state
|
| 2897 |
+
):
|
| 2898 |
+
# Create a copy of the variable with selected_options for generation
|
| 2899 |
+
var_copy = var.copy()
|
| 2900 |
+
var_copy["selected_options"] = (
|
| 2901 |
+
st.session_state.categorical_ui_state[var["name"]][
|
| 2902 |
+
"selected_options"
|
| 2903 |
+
]
|
| 2904 |
+
)
|
| 2905 |
+
clean_template_spec["input"][i] = var_copy
|
| 2906 |
+
|
| 2907 |
st.session_state.synthetic_inputs = (
|
| 2908 |
generate_synthetic_inputs_hybrid(
|
| 2909 |
+
clean_template_spec, num_samples=num_samples
|
| 2910 |
)
|
| 2911 |
)
|
| 2912 |
else:
|
| 2913 |
st.session_state.synthetic_inputs = (
|
| 2914 |
generate_synthetic_inputs_hybrid(
|
| 2915 |
+
clean_template_spec, num_samples=num_samples
|
| 2916 |
)
|
| 2917 |
)
|
| 2918 |
|