Spaces:
Running
Running
Anna's changes to numeric logic
Browse files
app.py
CHANGED
|
@@ -992,144 +992,194 @@ def generate_non_categorical_values(non_cat_vars, existing_values, max_retries):
|
|
| 992 |
"""Generate values for non-categorical variables given existing categorical values."""
|
| 993 |
if not non_cat_vars:
|
| 994 |
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
|
| 996 |
-
# Format the variables for the prompt
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
f"- {var['name']}: {var['description']} (Type:
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
for var in non_cat_vars
|
| 1006 |
-
]
|
| 1007 |
-
)
|
| 1008 |
-
|
| 1009 |
-
# Create prompt with existing categorical values as context
|
| 1010 |
-
prompt = f"""
|
| 1011 |
-
As a synthetic data generator, create values for these variables:
|
| 1012 |
-
|
| 1013 |
-
{vars_text}
|
| 1014 |
-
|
| 1015 |
-
These values should be coherent with the existing categorical values:
|
| 1016 |
-
{json.dumps(existing_values, indent=2)}
|
| 1017 |
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
"variable_name_2": value2
|
| 1022 |
-
}}
|
| 1023 |
-
"""
|
| 1024 |
|
| 1025 |
-
|
| 1026 |
-
try:
|
| 1027 |
-
response = call_model_api(
|
| 1028 |
-
model=st.session_state.model,
|
| 1029 |
-
prompt=prompt,
|
| 1030 |
-
max_tokens=1000,
|
| 1031 |
-
temperature=st.session_state.temperature,
|
| 1032 |
-
)
|
| 1033 |
|
| 1034 |
-
|
|
|
|
| 1035 |
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
pass
|
| 1049 |
-
else:
|
| 1050 |
-
try:
|
| 1051 |
-
values = json.loads(result, strict=False)
|
| 1052 |
-
if isinstance(values, dict):
|
| 1053 |
-
return values
|
| 1054 |
-
except:
|
| 1055 |
-
pass
|
| 1056 |
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
|
|
|
|
|
|
| 1063 |
|
| 1064 |
|
| 1065 |
def generate_single_row(all_vars, max_retries):
|
| 1066 |
-
"""Generate a complete row of data
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
|
| 1084 |
-
|
| 1085 |
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
"variable_name_1": value1,
|
| 1089 |
-
"variable_name_2": value2
|
| 1090 |
-
}}
|
| 1091 |
|
| 1092 |
-
|
| 1093 |
-
|
|
|
|
|
|
|
|
|
|
| 1094 |
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
model=st.session_state.model,
|
| 1099 |
-
messages=[{"role": "user", "content": prompt}],
|
| 1100 |
-
max_tokens=1000,
|
| 1101 |
-
temperature=st.session_state.temperature,
|
| 1102 |
-
)
|
| 1103 |
|
| 1104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
try:
|
| 1114 |
values = json.loads(json_str, strict=False)
|
| 1115 |
if isinstance(values, dict):
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
else:
|
| 1120 |
-
try:
|
| 1121 |
values = json.loads(result, strict=False)
|
| 1122 |
if isinstance(values, dict):
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
pass
|
| 1126 |
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
|
| 1131 |
-
|
| 1132 |
-
return None
|
| 1133 |
|
| 1134 |
|
| 1135 |
def get_default_value(var):
|
|
@@ -1182,7 +1232,7 @@ def generate_synthetic_outputs(
|
|
| 1182 |
input_vars = template_spec["input"]
|
| 1183 |
input_vars_text = "\n".join(
|
| 1184 |
[
|
| 1185 |
-
f"- {var['name']}: {var['description']} (Type: {var['type']})
|
| 1186 |
for var in input_vars
|
| 1187 |
]
|
| 1188 |
)
|
|
@@ -1198,7 +1248,11 @@ def generate_synthetic_outputs(
|
|
| 1198 |
progress_bar = st.progress(0)
|
| 1199 |
|
| 1200 |
try:
|
|
|
|
|
|
|
| 1201 |
for i, input_item in enumerate(input_data):
|
|
|
|
|
|
|
| 1202 |
# Fill the prompt template with input values
|
| 1203 |
filled_prompt = prompt_template
|
| 1204 |
for var_name, var_value in input_item.items():
|
|
@@ -1230,7 +1284,9 @@ Generate realistic output data for these variables. Return ONLY a JSON object wi
|
|
| 1230 |
Use appropriate data types for each variable. Return ONLY the JSON object with no additional text or explanation.
|
| 1231 |
The response must be valid JSON that can be parsed directly.
|
| 1232 |
"""
|
| 1233 |
-
|
|
|
|
|
|
|
| 1234 |
output_data = None
|
| 1235 |
for attempt in range(max_retries):
|
| 1236 |
try:
|
|
@@ -1527,8 +1583,14 @@ with st.sidebar:
|
|
| 1527 |
if model_provider == "OpenAI":
|
| 1528 |
st.session_state.model = st.selectbox(
|
| 1529 |
"Select OpenAI Model",
|
| 1530 |
-
options=[
|
| 1531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1532 |
)
|
| 1533 |
else: # Anthropic
|
| 1534 |
st.session_state.model = st.selectbox(
|
|
@@ -2605,7 +2667,10 @@ with tab2:
|
|
| 2605 |
placeholder="Enter custom lore or background information here...",
|
| 2606 |
height=150,
|
| 2607 |
)
|
| 2608 |
-
|
|
|
|
|
|
|
|
|
|
| 2609 |
# Generate Output button
|
| 2610 |
if st.button("Generate Output", key="generate_button"):
|
| 2611 |
# Check if API key is provided
|
|
@@ -3186,6 +3251,15 @@ with tab3:
|
|
| 3186 |
# Display combined data if available
|
| 3187 |
if st.session_state.combined_data:
|
| 3188 |
st.subheader("Complete Dataset (Inputs + Outputs)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3189 |
|
| 3190 |
# Add this function before the prepare_dataframe_with_json_columns function
|
| 3191 |
|
|
@@ -3227,12 +3301,23 @@ with tab3:
|
|
| 3227 |
|
| 3228 |
# Create a function to prepare the dataframe with JSON columns
|
| 3229 |
def prepare_dataframe_with_json_columns(
|
| 3230 |
-
data, template_spec, show_json_columns=False
|
| 3231 |
):
|
| 3232 |
df = pd.DataFrame(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3233 |
|
| 3234 |
# Create input and output JSON columns
|
| 3235 |
-
input_vars = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3236 |
output_vars = [var["name"] for var in template_spec["output"]]
|
| 3237 |
|
| 3238 |
# Create input JSON column
|
|
@@ -3254,9 +3339,11 @@ with tab3:
|
|
| 3254 |
# If not showing JSON columns in UI, remove them for display only
|
| 3255 |
if not show_json_columns:
|
| 3256 |
display_df = df.drop(columns=["input", "output"])
|
| 3257 |
-
|
|
|
|
| 3258 |
|
| 3259 |
-
|
|
|
|
| 3260 |
|
| 3261 |
# Toggle for showing JSON columns
|
| 3262 |
st.session_state.show_json_columns = st.checkbox(
|
|
@@ -3269,6 +3356,7 @@ with tab3:
|
|
| 3269 |
st.session_state.combined_data,
|
| 3270 |
st.session_state.template_spec,
|
| 3271 |
st.session_state.show_json_columns,
|
|
|
|
| 3272 |
)
|
| 3273 |
|
| 3274 |
# Show data in a table
|
|
@@ -3288,8 +3376,11 @@ with tab3:
|
|
| 3288 |
)
|
| 3289 |
|
| 3290 |
with col2:
|
| 3291 |
-
# JSON download
|
| 3292 |
-
|
|
|
|
|
|
|
|
|
|
| 3293 |
st.download_button(
|
| 3294 |
label="Download Dataset (JSON)",
|
| 3295 |
data=combined_json,
|
|
@@ -3323,4 +3414,4 @@ with tab3:
|
|
| 3323 |
else:
|
| 3324 |
st.info(
|
| 3325 |
"No template has been generated yet. Go to the 'Setup' tab to create one."
|
| 3326 |
-
)
|
|
|
|
| 992 |
"""Generate values for non-categorical variables given existing categorical values."""
|
| 993 |
if not non_cat_vars:
|
| 994 |
return {}
|
| 995 |
+
|
| 996 |
+
# Separate string and numeric variables
|
| 997 |
+
llm_vars = [var for var in non_cat_vars if var["type"] == "string"]
|
| 998 |
+
numeric_vars = [var for var in non_cat_vars if var["type"] in ["int", "float"]]
|
| 999 |
+
|
| 1000 |
+
# Sample numeric values within the specified range
|
| 1001 |
+
result_values = {}
|
| 1002 |
+
# result_values_descr = {} # Uncomment to include the var description, i.e. units so the LLM understands the numerical values
|
| 1003 |
+
# Otherwise, good practice is to include units in numerical vars names (e.g. price_in_euros instead of price)
|
| 1004 |
+
for var in numeric_vars:
|
| 1005 |
+
name = var["name"]
|
| 1006 |
+
var_min = var.get("min")
|
| 1007 |
+
var_max = var.get("max")
|
| 1008 |
+
# description = var.get("description")
|
| 1009 |
+
|
| 1010 |
+
if var_min is None or var_max is None:
|
| 1011 |
+
result_values[name] = get_default_value(var)
|
| 1012 |
+
# result_values_descr[name] = get_default_value(var)
|
| 1013 |
+
else:
|
| 1014 |
+
try:
|
| 1015 |
+
if var["type"] == "int":
|
| 1016 |
+
result_values[name] = random.randint(int(var_min), int(var_max))
|
| 1017 |
+
# result_values_descr[name] = [result_values[name], description]
|
| 1018 |
+
elif var["type"] == "float":
|
| 1019 |
+
result_values[name] = round(random.uniform(float(var_min), float(var_max)), 2)
|
| 1020 |
+
# result_values_descr[name] = [result_values[name], description]
|
| 1021 |
+
except:
|
| 1022 |
+
result_values[name] = get_default_value(var)
|
| 1023 |
+
# result_values_descr[name] = get_default_value(var)
|
| 1024 |
|
| 1025 |
+
# Format the string variables for the prompt
|
| 1026 |
+
if llm_vars:
|
| 1027 |
+
vars_text = "\n".join(
|
| 1028 |
+
[f"- {var['name']}: {var['description']} (Type: string)" for var in llm_vars]
|
| 1029 |
+
)
|
| 1030 |
+
# Combine categorical and numeric values for LLM context
|
| 1031 |
+
# context_values = {**existing_values, **result_values_descr}
|
| 1032 |
+
context_values = {**existing_values, **result_values}
|
| 1033 |
+
print(context_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
|
| 1035 |
+
# Create prompt with existing categorical and numerical values as context
|
| 1036 |
+
prompt = f"""
|
| 1037 |
+
As a synthetic data generator, create values for these variables:
|
|
|
|
|
|
|
|
|
|
| 1038 |
|
| 1039 |
+
{vars_text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
|
| 1041 |
+
These values should be coherent with the existing categorical and/or numerical values:
|
| 1042 |
+
{json.dumps(context_values, indent=2)}
|
| 1043 |
|
| 1044 |
+
Return ONLY a JSON object with the new variable values:
|
| 1045 |
+
{{
|
| 1046 |
+
"variable_name_1": value1,
|
| 1047 |
+
"variable_name_2": value2
|
| 1048 |
+
}}
|
| 1049 |
+
"""
|
| 1050 |
+
# print("*************** PROMPT FOR STR VAR:", prompt)
|
| 1051 |
|
| 1052 |
+
for attempt in range(max_retries):
|
| 1053 |
+
try:
|
| 1054 |
+
response = call_model_api(
|
| 1055 |
+
model=st.session_state.model,
|
| 1056 |
+
prompt=prompt,
|
| 1057 |
+
max_tokens=1000,
|
| 1058 |
+
temperature=st.session_state.temperature,
|
| 1059 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
|
| 1061 |
+
result = response.strip()
|
| 1062 |
+
|
| 1063 |
+
# Extract JSON
|
| 1064 |
+
json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
|
| 1065 |
+
json_match = re.search(json_pattern, result)
|
| 1066 |
+
|
| 1067 |
+
if json_match:
|
| 1068 |
+
json_str = json_match.group(1) if json_match.group(1) else result
|
| 1069 |
+
json_str = re.sub(r"```.*|```", "", json_str).strip()
|
| 1070 |
+
try:
|
| 1071 |
+
values = json.loads(json_str, strict=False)
|
| 1072 |
+
if isinstance(values, dict):
|
| 1073 |
+
result_values.update(values)
|
| 1074 |
+
return result_values
|
| 1075 |
+
except:
|
| 1076 |
+
pass
|
| 1077 |
+
else:
|
| 1078 |
+
try:
|
| 1079 |
+
values = json.loads(result, strict=False)
|
| 1080 |
+
if isinstance(values, dict):
|
| 1081 |
+
result_values.update(values)
|
| 1082 |
+
return result_values
|
| 1083 |
+
except:
|
| 1084 |
+
pass
|
| 1085 |
+
|
| 1086 |
+
except Exception as e:
|
| 1087 |
+
if attempt == max_retries - 1:
|
| 1088 |
+
st.warning(f"Failed to generate string values: {str(e)}")
|
| 1089 |
|
| 1090 |
+
# Fallback: generate empty values for all string variables
|
| 1091 |
+
for var in llm_vars:
|
| 1092 |
+
result_values[var["name"]] = get_default_value(var)
|
| 1093 |
+
return result_values
|
| 1094 |
|
| 1095 |
|
| 1096 |
def generate_single_row(all_vars, max_retries):
|
| 1097 |
+
"""Generate a complete row of data using hybrid logic:
|
| 1098 |
+
- Use LLM for string/categorical vars
|
| 1099 |
+
- Sample int/float within range
|
| 1100 |
+
"""
|
| 1101 |
+
numeric_vars = [var for var in all_vars if var["type"] in ["int", "float"]]
|
| 1102 |
+
llm_vars = [var for var in all_vars if var["type"] in ["string", "categorical"]]
|
| 1103 |
+
|
| 1104 |
+
row = {}
|
| 1105 |
+
|
| 1106 |
+
# Sample numeric vars
|
| 1107 |
+
for var in numeric_vars:
|
| 1108 |
+
name = var["name"]
|
| 1109 |
+
var_min = var.get("min")
|
| 1110 |
+
var_max = var.get("max")
|
| 1111 |
+
if var_min is None or var_max is None:
|
| 1112 |
+
row[name] = get_default_value(var)
|
| 1113 |
+
else:
|
| 1114 |
+
try:
|
| 1115 |
+
if var["type"] == "int":
|
| 1116 |
+
row[name] = random.randint(int(var_min), int(var_max))
|
| 1117 |
+
elif var["type"] == "float":
|
| 1118 |
+
row[name] = round(random.uniform(float(var_min), float(var_max)), 2)
|
| 1119 |
+
except:
|
| 1120 |
+
row[name] = get_default_value(var)
|
| 1121 |
+
|
| 1122 |
+
# Generate string and categorical via LLM
|
| 1123 |
+
if llm_vars:
|
| 1124 |
+
vars_text = "\n".join(
|
| 1125 |
+
[
|
| 1126 |
+
f"- {var['name']}: {var['description']} (Type: {var['type']})"
|
| 1127 |
+
+ (
|
| 1128 |
+
f", Options: {var['options']}" if var["type"] == "categorical" and var.get("options") else ""
|
| 1129 |
+
)
|
| 1130 |
+
for var in llm_vars
|
| 1131 |
+
]
|
| 1132 |
+
)
|
| 1133 |
|
| 1134 |
+
prompt = f"""
|
| 1135 |
+
You are a synthetic data generator. Generate values for the following variables:
|
| 1136 |
|
| 1137 |
+
{vars_text}
|
| 1138 |
|
| 1139 |
+
Based on this partial row:
|
| 1140 |
+
{json.dumps(row, indent=2)}
|
|
|
|
|
|
|
|
|
|
| 1141 |
|
| 1142 |
+
Return ONLY a JSON object with the new values:
|
| 1143 |
+
{{
|
| 1144 |
+
"var_name_1": value1,
|
| 1145 |
+
"var_name_2": value2
|
| 1146 |
+
}}
|
| 1147 |
|
| 1148 |
+
For categorical variables that allow multiple selections, return a list of values.
|
| 1149 |
+
"""
|
| 1150 |
+
# print("*************** PROMPT FOR STR,CAT VAR:", prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1151 |
|
| 1152 |
+
for attempt in range(max_retries):
|
| 1153 |
+
try:
|
| 1154 |
+
response = call_model_api(
|
| 1155 |
+
model=st.session_state.model,
|
| 1156 |
+
prompt=prompt,
|
| 1157 |
+
max_tokens=1000,
|
| 1158 |
+
temperature=st.session_state.temperature,
|
| 1159 |
+
)
|
| 1160 |
|
| 1161 |
+
result = response.strip()
|
| 1162 |
+
json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
|
| 1163 |
+
json_match = re.search(json_pattern, result)
|
| 1164 |
|
| 1165 |
+
if json_match:
|
| 1166 |
+
json_str = json_match.group(1) if json_match.group(1) else result
|
| 1167 |
+
json_str = re.sub(r"```.*|```", "", json_str).strip()
|
|
|
|
| 1168 |
values = json.loads(json_str, strict=False)
|
| 1169 |
if isinstance(values, dict):
|
| 1170 |
+
row.update(values)
|
| 1171 |
+
break
|
| 1172 |
+
else:
|
|
|
|
|
|
|
| 1173 |
values = json.loads(result, strict=False)
|
| 1174 |
if isinstance(values, dict):
|
| 1175 |
+
row.update(values)
|
| 1176 |
+
break
|
|
|
|
| 1177 |
|
| 1178 |
+
except Exception as e:
|
| 1179 |
+
if attempt == max_retries - 1:
|
| 1180 |
+
st.warning(f"Failed to generate string/categorical values: {str(e)}")
|
| 1181 |
|
| 1182 |
+
return row if row else None
|
|
|
|
| 1183 |
|
| 1184 |
|
| 1185 |
def get_default_value(var):
|
|
|
|
| 1232 |
input_vars = template_spec["input"]
|
| 1233 |
input_vars_text = "\n".join(
|
| 1234 |
[
|
| 1235 |
+
f"- {var['name']}: {var['description']} (Type: {var['type']})"
|
| 1236 |
for var in input_vars
|
| 1237 |
]
|
| 1238 |
)
|
|
|
|
| 1248 |
progress_bar = st.progress(0)
|
| 1249 |
|
| 1250 |
try:
|
| 1251 |
+
input_var_names = [var["name"] for var in template_spec["input"]]
|
| 1252 |
+
|
| 1253 |
for i, input_item in enumerate(input_data):
|
| 1254 |
+
# Filter out variables not defined in the template spec
|
| 1255 |
+
input_item = {k: v for k, v in input_item.items() if k in input_var_names}
|
| 1256 |
# Fill the prompt template with input values
|
| 1257 |
filled_prompt = prompt_template
|
| 1258 |
for var_name, var_value in input_item.items():
|
|
|
|
| 1284 |
Use appropriate data types for each variable. Return ONLY the JSON object with no additional text or explanation.
|
| 1285 |
The response must be valid JSON that can be parsed directly.
|
| 1286 |
"""
|
| 1287 |
+
# debug logs:
|
| 1288 |
+
# print("*************Filtered Input:", input_item)
|
| 1289 |
+
# print("*************Generated Prompt:", generation_prompt)
|
| 1290 |
output_data = None
|
| 1291 |
for attempt in range(max_retries):
|
| 1292 |
try:
|
|
|
|
| 1583 |
if model_provider == "OpenAI":
|
| 1584 |
st.session_state.model = st.selectbox(
|
| 1585 |
"Select OpenAI Model",
|
| 1586 |
+
options=[
|
| 1587 |
+
"gpt-4o-mini",
|
| 1588 |
+
"gpt-4.1-mini",
|
| 1589 |
+
"gpt-4.1",
|
| 1590 |
+
"gpt-4o",
|
| 1591 |
+
"gpt-4.1-nano",
|
| 1592 |
+
],
|
| 1593 |
+
index=1,
|
| 1594 |
)
|
| 1595 |
else: # Anthropic
|
| 1596 |
st.session_state.model = st.selectbox(
|
|
|
|
| 2667 |
placeholder="Enter custom lore or background information here...",
|
| 2668 |
height=150,
|
| 2669 |
)
|
| 2670 |
+
# Temperature control slider
|
| 2671 |
+
st.session_state.temperature = st.slider(
|
| 2672 |
+
"Temperature (creativity level)", min_value=0.0, max_value=1.0, value=0.7, step=0.05
|
| 2673 |
+
)
|
| 2674 |
# Generate Output button
|
| 2675 |
if st.button("Generate Output", key="generate_button"):
|
| 2676 |
# Check if API key is provided
|
|
|
|
| 3251 |
# Display combined data if available
|
| 3252 |
if st.session_state.combined_data:
|
| 3253 |
st.subheader("Complete Dataset (Inputs + Outputs)")
|
| 3254 |
+
# Get all available column names from the data
|
| 3255 |
+
all_columns = pd.DataFrame(st.session_state.combined_data).columns.tolist()
|
| 3256 |
+
|
| 3257 |
+
# Let the user select columns to exclude from input JSON
|
| 3258 |
+
st.session_state.columns_to_drop = st.multiselect(
|
| 3259 |
+
"Select input variables to exclude:",
|
| 3260 |
+
options=all_columns,
|
| 3261 |
+
default=st.session_state.get("columns_to_drop", []),
|
| 3262 |
+
)
|
| 3263 |
|
| 3264 |
# Add this function before the prepare_dataframe_with_json_columns function
|
| 3265 |
|
|
|
|
| 3301 |
|
| 3302 |
# Create a function to prepare the dataframe with JSON columns
|
| 3303 |
def prepare_dataframe_with_json_columns(
|
| 3304 |
+
data, template_spec, show_json_columns=False, columns_to_drop=None
|
| 3305 |
):
|
| 3306 |
df = pd.DataFrame(data)
|
| 3307 |
+
# Drop specified columns from the dataframe
|
| 3308 |
+
if columns_to_drop:
|
| 3309 |
+
df = df.drop(
|
| 3310 |
+
columns=[col for col in columns_to_drop if col in df.columns]
|
| 3311 |
+
)
|
| 3312 |
+
else:
|
| 3313 |
+
columns_to_drop = []
|
| 3314 |
|
| 3315 |
# Create input and output JSON columns
|
| 3316 |
+
input_vars = [
|
| 3317 |
+
var["name"]
|
| 3318 |
+
for var in template_spec["input"]
|
| 3319 |
+
if var["name"] not in columns_to_drop
|
| 3320 |
+
]
|
| 3321 |
output_vars = [var["name"] for var in template_spec["output"]]
|
| 3322 |
|
| 3323 |
# Create input JSON column
|
|
|
|
| 3339 |
# If not showing JSON columns in UI, remove them for display only
|
| 3340 |
if not show_json_columns:
|
| 3341 |
display_df = df.drop(columns=["input", "output"])
|
| 3342 |
+
else:
|
| 3343 |
+
display_df = df
|
| 3344 |
|
| 3345 |
+
# Return the same filtered df for export (full_df)
|
| 3346 |
+
return df, display_df
|
| 3347 |
|
| 3348 |
# Toggle for showing JSON columns
|
| 3349 |
st.session_state.show_json_columns = st.checkbox(
|
|
|
|
| 3356 |
st.session_state.combined_data,
|
| 3357 |
st.session_state.template_spec,
|
| 3358 |
st.session_state.show_json_columns,
|
| 3359 |
+
columns_to_drop=st.session_state.columns_to_drop,
|
| 3360 |
)
|
| 3361 |
|
| 3362 |
# Show data in a table
|
|
|
|
| 3376 |
)
|
| 3377 |
|
| 3378 |
with col2:
|
| 3379 |
+
# JSON download using cleaned dataframe
|
| 3380 |
+
json_ready_df = full_df.drop(columns=["input", "output"])
|
| 3381 |
+
combined_json = json.dumps(
|
| 3382 |
+
json_ready_df.to_dict(orient="records"), indent=2
|
| 3383 |
+
)
|
| 3384 |
st.download_button(
|
| 3385 |
label="Download Dataset (JSON)",
|
| 3386 |
data=combined_json,
|
|
|
|
| 3414 |
else:
|
| 3415 |
st.info(
|
| 3416 |
"No template has been generated yet. Go to the 'Setup' tab to create one."
|
| 3417 |
+
)
|