DJHumanRPT commited on
Commit
297b883
·
verified ·
1 Parent(s): e9c8d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -59
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import json
3
  import PyPDF2
 
4
  import re
5
  from io import BytesIO
6
  import openai
@@ -26,36 +27,57 @@ def get_openai_client():
26
  return None
27
 
28
 
29
- # Define helper functions for PDF parsing
30
- def parse_pdf(file):
31
- """Extract text from a PDF file."""
32
- try:
33
- pdf_reader = PyPDF2.PdfReader(file)
34
- text = ""
35
- for page_num in range(len(pdf_reader.pages)):
36
- text += pdf_reader.pages[page_num].extract_text() or ""
37
- return text
38
- except Exception as e:
39
- st.error(f"Error parsing PDF: {str(e)}")
40
- return ""
 
 
41
 
42
 
 
43
  def parse_documents(uploaded_files):
44
  """Parse multiple document files and extract their text content."""
 
 
 
 
 
 
 
45
  content = ""
 
46
  for file in uploaded_files:
47
  try:
48
  file_type = file.name.split(".")[-1].lower()
49
- if file_type == "pdf":
50
- # Create a copy of the file to avoid buffer issues
51
- file_copy = BytesIO(file.getvalue())
52
- content += parse_pdf(file_copy) + "\n\n"
53
- elif file_type == "txt":
54
- content += file.getvalue().decode("utf-8") + "\n\n"
 
 
 
 
 
 
 
 
 
55
  else:
56
  st.warning(f"Unsupported file type: {file.name}")
57
  except Exception as e:
58
  st.error(f"Error processing file {file.name}: {str(e)}")
 
59
  return content
60
 
61
 
@@ -243,7 +265,7 @@ If document content was provided, design the template to effectively use that in
243
  response = client.chat.completions.create(
244
  model=st.session_state.model,
245
  messages=[{"role": "user", "content": prompt}],
246
- max_completion_tokens=4096,
247
  temperature=0.7,
248
  )
249
 
@@ -340,7 +362,7 @@ Return ONLY the revised prompt template text, with no additional explanations.
340
  response = client.chat.completions.create(
341
  model=st.session_state.model,
342
  messages=[{"role": "user", "content": prompt}],
343
- max_completion_tokens=4096,
344
  temperature=0.7,
345
  )
346
 
@@ -473,25 +495,35 @@ def generate_categorical_permutations(categorical_vars, target_count):
473
  min_sel = var.get("min", 1)
474
  max_sel = var.get("max", 1)
475
 
 
 
 
 
 
 
 
 
 
 
476
  # Single selection case
477
  if min_sel == 1 and max_sel == 1:
478
- option_sets.append([(var_name, opt) for opt in options])
479
  else:
480
  # Multi-selection case - generate varied selection sizes
481
  var_options = []
482
 
483
  # Include min selections
484
- for combo in itertools.combinations(options, min_sel):
485
  var_options.append((var_name, list(combo)))
486
 
487
  # Include max selections if different from min
488
  if max_sel != min_sel:
489
- for combo in itertools.combinations(options, max_sel):
490
  var_options.append((var_name, list(combo)))
491
 
492
  # Include some intermediate selections if applicable
493
  for size in range(min_sel + 1, max_sel):
494
- combos = list(itertools.combinations(options, size))
495
  if combos:
496
  sample_size = min(3, len(combos)) # Take up to 3 samples
497
  for combo in random.sample(combos, sample_size):
@@ -519,12 +551,18 @@ def generate_categorical_permutations(categorical_vars, target_count):
519
  var = random.choice(categorical_vars)
520
  var_name = var["name"]
521
  options = var.get("options", [])
 
522
 
523
- if options and len(options) > 1:
 
 
 
 
 
524
  if var.get("min", 1) == 1 and var.get("max", 1) == 1:
525
  # For single selection, choose a different option
526
  current = new_perm[var_name]
527
- other_options = [opt for opt in options if opt != current]
528
  if other_options:
529
  new_perm[var_name] = random.choice(other_options)
530
  else:
@@ -537,7 +575,9 @@ def generate_categorical_permutations(categorical_vars, target_count):
537
  if len(current_selection) < max_sel and random.random() > 0.5:
538
  # Add an item not already in the selection
539
  available = [
540
- opt for opt in options if opt not in current_selection
 
 
541
  ]
542
  if available:
543
  current_selection.append(random.choice(available))
@@ -893,6 +933,157 @@ The response must be valid JSON that can be parsed directly.
893
  return results
894
 
895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  # Initialize session state
897
  if "template_spec" not in st.session_state:
898
  st.session_state.template_spec = None
@@ -977,7 +1168,7 @@ with tab1:
977
  uploaded_files = st.file_uploader(
978
  "Upload documents to use as knowledge base",
979
  accept_multiple_files=True,
980
- type=["pdf", "txt"],
981
  )
982
 
983
  # Rest of your existing code for document processing...
@@ -992,8 +1183,7 @@ with tab1:
992
  with st.expander("Preview extracted content"):
993
  st.text_area(
994
  "Extracted Text",
995
- value=st.session_state.knowledge_base[:10000]
996
- + ("..." if len(st.session_state.knowledge_base) > 1000 else ""),
997
  height=200,
998
  disabled=True,
999
  )
@@ -1079,6 +1269,14 @@ with tab2:
1079
  if st.session_state.show_template_editor and st.session_state.template_spec:
1080
  st.header("Template Editor")
1081
 
 
 
 
 
 
 
 
 
1082
  # Basic template information
1083
  with st.expander("Template Information", expanded=True):
1084
  col1, col2 = st.columns(2)
@@ -1097,6 +1295,87 @@ with tab2:
1097
  height=100,
1098
  )
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  # Prompt Template Section
1101
  with st.expander("Prompt Template", expanded=True):
1102
  st.info("Use {variable_name} to refer to input variables in your template")
@@ -1131,17 +1410,26 @@ with tab2:
1131
  with st.expander("Input Variables", expanded=True):
1132
  st.subheader("Input Variables")
1133
 
1134
- # Add input variable button
1135
- if st.button("Add Input Variable"):
1136
- new_var = {
1137
- "name": f"new_input_{len(st.session_state.template_spec['input']) + 1}",
1138
- "description": "New input variable",
1139
- "type": "string",
1140
- "min": 1,
1141
- "max": 100,
1142
- }
1143
- st.session_state.template_spec["input"].append(new_var)
1144
- st.rerun()
 
 
 
 
 
 
 
 
 
1145
 
1146
  # Display input variables
1147
  for i, input_var in enumerate(st.session_state.template_spec["input"]):
@@ -1192,6 +1480,42 @@ with tab2:
1192
  )
1193
 
1194
  if var_type == "categorical":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1195
  options = input_var.get("options", [])
1196
  options_str = st.text_area(
1197
  "Options (one per line)",
@@ -1224,7 +1548,7 @@ with tab2:
1224
  with col3:
1225
  if st.button("Remove", key=f"remove_input_{i}"):
1226
  st.session_state.template_spec["input"].pop(i)
1227
- st.rerun()
1228
 
1229
  st.divider()
1230
 
@@ -1232,17 +1556,26 @@ with tab2:
1232
  with st.expander("Output Variables", expanded=True):
1233
  st.subheader("Output Variables")
1234
 
1235
- # Add output variable button
1236
- if st.button("Add Output Variable"):
1237
- new_var = {
1238
- "name": f"new_output_{len(st.session_state.template_spec['output']) + 1}",
1239
- "description": "New output variable",
1240
- "type": "string",
1241
- "min": 1,
1242
- "max": 100,
1243
- }
1244
- st.session_state.template_spec["output"].append(new_var)
1245
- st.rerun()
 
 
 
 
 
 
 
 
 
1246
 
1247
  # Display output variables
1248
  for i, output_var in enumerate(st.session_state.template_spec["output"]):
@@ -1293,6 +1626,42 @@ with tab2:
1293
  )
1294
 
1295
  if var_type == "categorical":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1296
  options = output_var.get("options", [])
1297
  options_str = st.text_area(
1298
  "Options (one per line)",
@@ -1325,7 +1694,7 @@ with tab2:
1325
  with col3:
1326
  if st.button("Remove", key=f"remove_output_{i}"):
1327
  st.session_state.template_spec["output"].pop(i)
1328
- st.rerun()
1329
 
1330
  st.divider()
1331
 
@@ -1566,17 +1935,92 @@ with tab4:
1566
  if "selected_samples" not in st.session_state:
1567
  st.session_state.selected_samples = []
1568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1569
  # Generate inputs button
1570
  if st.button("Generate Synthetic Inputs"):
1571
  if not st.session_state.get("api_key"):
1572
  st.error("Please provide an OpenAI API key in the sidebar.")
1573
  else:
1574
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
1575
- st.session_state.synthetic_inputs = (
1576
- generate_synthetic_inputs_hybrid(
1577
- st.session_state.template_spec, num_samples=num_samples
 
 
 
 
 
 
 
 
 
1578
  )
1579
- )
1580
 
1581
  if st.session_state.synthetic_inputs:
1582
  st.success(
@@ -1864,4 +2308,4 @@ with tab4:
1864
  else:
1865
  st.info(
1866
  "No template has been generated yet. Go to the 'Setup' tab to create one."
1867
- )
 
1
  import streamlit as st
2
  import json
3
  import PyPDF2
4
+ from docling.document_converter import DocumentConverter
5
  import re
6
  from io import BytesIO
7
  import openai
 
27
  return None
28
 
29
 
30
+ @st.cache_resource
31
+ def get_document_converter():
32
+ """Cache the DocumentConverter to prevent reloading on each interaction"""
33
+ return None # Return None initially
34
+
35
+
36
+ def get_or_create_document_converter():
37
+ """Get existing converter or create a new one only when needed"""
38
+ converter = get_document_converter()
39
+ if converter is None:
40
+ converter = DocumentConverter()
41
+ # Update the cached value
42
+ get_document_converter._cached_obj = converter
43
+ return converter
44
 
45
 
46
+ @st.cache_data
47
  def parse_documents(uploaded_files):
48
  """Parse multiple document files and extract their text content."""
49
+ if not uploaded_files:
50
+ return ""
51
+
52
+ import tempfile
53
+ import os
54
+
55
+ converter = get_or_create_document_converter()
56
  content = ""
57
+
58
  for file in uploaded_files:
59
  try:
60
  file_type = file.name.split(".")[-1].lower()
61
+ if file_type in ["pdf", "txt", "docx", "html"]:
62
+ # Create a temporary file with the correct extension
63
+ with tempfile.NamedTemporaryFile(
64
+ delete=False, suffix=f".{file_type}"
65
+ ) as tmp_file:
66
+ # Write the uploaded file content to the temp file
67
+ tmp_file.write(file.getvalue())
68
+ tmp_path = tmp_file.name
69
+
70
+ # Convert using the file path instead of the UploadedFile object
71
+ source = converter.convert(tmp_path)
72
+ content += source.document.export_to_markdown()
73
+
74
+ # Clean up the temporary file
75
+ os.unlink(tmp_path)
76
  else:
77
  st.warning(f"Unsupported file type: {file.name}")
78
  except Exception as e:
79
  st.error(f"Error processing file {file.name}: {str(e)}")
80
+
81
  return content
82
 
83
 
 
265
  response = client.chat.completions.create(
266
  model=st.session_state.model,
267
  messages=[{"role": "user", "content": prompt}],
268
+ max_tokens=4096,
269
  temperature=0.7,
270
  )
271
 
 
362
  response = client.chat.completions.create(
363
  model=st.session_state.model,
364
  messages=[{"role": "user", "content": prompt}],
365
+ max_tokens=4096,
366
  temperature=0.7,
367
  )
368
 
 
495
  min_sel = var.get("min", 1)
496
  max_sel = var.get("max", 1)
497
 
498
+ # Get selected options if they exist
499
+ selected_options = var.get("selected_options", options)
500
+
501
+ # Use only selected options for permutation
502
+ options_to_use = [opt for opt in options if opt in selected_options]
503
+
504
+ # If no options selected, use all options
505
+ if not options_to_use:
506
+ options_to_use = options
507
+
508
  # Single selection case
509
  if min_sel == 1 and max_sel == 1:
510
+ option_sets.append([(var_name, opt) for opt in options_to_use])
511
  else:
512
  # Multi-selection case - generate varied selection sizes
513
  var_options = []
514
 
515
  # Include min selections
516
+ for combo in itertools.combinations(options_to_use, min_sel):
517
  var_options.append((var_name, list(combo)))
518
 
519
  # Include max selections if different from min
520
  if max_sel != min_sel:
521
+ for combo in itertools.combinations(options_to_use, max_sel):
522
  var_options.append((var_name, list(combo)))
523
 
524
  # Include some intermediate selections if applicable
525
  for size in range(min_sel + 1, max_sel):
526
+ combos = list(itertools.combinations(options_to_use, size))
527
  if combos:
528
  sample_size = min(3, len(combos)) # Take up to 3 samples
529
  for combo in random.sample(combos, sample_size):
 
551
  var = random.choice(categorical_vars)
552
  var_name = var["name"]
553
  options = var.get("options", [])
554
+ selected_options = var.get("selected_options", options)
555
 
556
+ # Use only selected options for variation
557
+ options_to_use = [opt for opt in options if opt in selected_options]
558
+ if not options_to_use:
559
+ options_to_use = options
560
+
561
+ if options_to_use and len(options_to_use) > 1:
562
  if var.get("min", 1) == 1 and var.get("max", 1) == 1:
563
  # For single selection, choose a different option
564
  current = new_perm[var_name]
565
+ other_options = [opt for opt in options_to_use if opt != current]
566
  if other_options:
567
  new_perm[var_name] = random.choice(other_options)
568
  else:
 
575
  if len(current_selection) < max_sel and random.random() > 0.5:
576
  # Add an item not already in the selection
577
  available = [
578
+ opt
579
+ for opt in options_to_use
580
+ if opt not in current_selection
581
  ]
582
  if available:
583
  current_selection.append(random.choice(available))
 
933
  return results
934
 
935
 
936
+ def suggest_variable_values_from_kb(
937
+ variable_name, variable_type, knowledge_base, client, model="gpt-3.5-turbo"
938
+ ):
939
+ """
940
+ Use LLM to suggest possible values for a variable based on the knowledge base content.
941
+ Especially useful for categorical variables to extract options from documents.
942
+ """
943
+ if not knowledge_base or not client:
944
+ return None
945
+
946
+ # Truncate knowledge base if it's too long
947
+ kb_excerpt = (
948
+ knowledge_base[:100000] + "..."
949
+ if len(knowledge_base) > 100000
950
+ else knowledge_base
951
+ )
952
+
953
+ prompt = f"""
954
+ Based on the following knowledge base content, suggest appropriate values for a variable named "{variable_name}" of type "{variable_type}".
955
+
956
+ KNOWLEDGE BASE EXCERPT:
957
+ {kb_excerpt}
958
+
959
+ TASK:
960
+ Extract or suggest appropriate values for this variable from the knowledge base.
961
+
962
+ If the variable type is "categorical", return a list of possible options found in the knowledge base.
963
+ If the variable type is "string", suggest a few example values.
964
+ If the variable type is "int" or "float", suggest appropriate min/max ranges.
965
+ If the variable type is "bool", suggest appropriate true/false conditions.
966
+
967
+ Return your response as a JSON object with the following structure:
968
+ For categorical: {{"options": ["option1", "option2", ...]}}
969
+ For string: {{"examples": ["example1", "example2", ...], "min": min_length, "max": max_length}}
970
+ For int/float: {{"min": minimum_value, "max": maximum_value, "examples": [value1, value2, ...]}}
971
+ For bool: {{"examples": ["condition for true", "condition for false"]}}
972
+
973
+ Only include values that are actually present or strongly implied in the knowledge base.
974
+ """
975
+
976
+ try:
977
+ response = client.chat.completions.create(
978
+ model=model,
979
+ messages=[{"role": "user", "content": prompt}],
980
+ max_tokens=1000,
981
+ temperature=0.3,
982
+ )
983
+
984
+ result = response.choices[0].message.content
985
+
986
+ # Extract JSON from the response
987
+ json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
988
+ json_match = re.search(json_pattern, result)
989
+
990
+ if json_match:
991
+ json_str = json_match.group(1) if json_match.group(1) else result
992
+ json_str = re.sub(r"```.*|```", "", json_str).strip()
993
+ try:
994
+ suggestions = json.loads(json_str)
995
+ return suggestions
996
+ except:
997
+ pass
998
+ else:
999
+ try:
1000
+ suggestions = json.loads(result)
1001
+ return suggestions
1002
+ except:
1003
+ pass
1004
+
1005
+ return None
1006
+ except Exception as e:
1007
+ print(f"Error suggesting variable values: {str(e)}")
1008
+ return None
1009
+
1010
+
1011
+ @st.cache_data
1012
+ def analyze_knowledge_base(knowledge_base, _client, model="gpt-4o-mini"):
1013
+ """
1014
+ Analyze the knowledge base to extract potential variable names and values.
1015
+ This can be used to suggest variables when creating a new template.
1016
+ """
1017
+ if not knowledge_base or not client:
1018
+ return None
1019
+
1020
+ # Truncate knowledge base if it's too long
1021
+ kb_excerpt = (
1022
+ knowledge_base[:100000] + "..."
1023
+ if len(knowledge_base) > 100000
1024
+ else knowledge_base
1025
+ )
1026
+
1027
+ prompt = f"""
1028
+ Analyze the following knowledge base content and identify potential variables that could be used in a template.
1029
+
1030
+ KNOWLEDGE BASE EXCERPT:
1031
+ {kb_excerpt}
1032
+
1033
+ TASK:
1034
+ 1. Identify key entities, attributes, or concepts that could be used as variables
1035
+ 2. For each variable, suggest an appropriate type (string, int, float, bool, categorical)
1036
+ 3. For categorical variables, suggest possible options
1037
+
1038
+ Return your analysis as a JSON array with the following structure:
1039
+ [
1040
+ {{
1041
+ "name": "variable_name",
1042
+ "description": "what this variable represents",
1043
+ "type": "string/int/float/bool/categorical",
1044
+ "options": ["option1", "option2", ...] (only for categorical type)
1045
+ }},
1046
+ ...
1047
+ ]
1048
+
1049
+ Focus on extracting variables that appear frequently or seem important in the knowledge base.
1050
+ """
1051
+
1052
+ try:
1053
+ response = _client.chat.completions.create(
1054
+ model=model,
1055
+ messages=[{"role": "user", "content": prompt}],
1056
+ max_tokens=2000,
1057
+ temperature=0.3,
1058
+ )
1059
+
1060
+ result = response.choices[0].message.content
1061
+
1062
+ # Extract JSON from the response
1063
+ json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\[[\s\S]*\]\s*$"
1064
+ json_match = re.search(json_pattern, result)
1065
+
1066
+ if json_match:
1067
+ json_str = json_match.group(1) if json_match.group(1) else result
1068
+ json_str = re.sub(r"```.*|```", "", json_str).strip()
1069
+ try:
1070
+ suggestions = json.loads(json_str)
1071
+ return suggestions
1072
+ except:
1073
+ pass
1074
+ else:
1075
+ try:
1076
+ suggestions = json.loads(result)
1077
+ return suggestions
1078
+ except:
1079
+ pass
1080
+
1081
+ return None
1082
+ except Exception as e:
1083
+ print(f"Error analyzing knowledge base: {str(e)}")
1084
+ return None
1085
+
1086
+
1087
  # Initialize session state
1088
  if "template_spec" not in st.session_state:
1089
  st.session_state.template_spec = None
 
1168
  uploaded_files = st.file_uploader(
1169
  "Upload documents to use as knowledge base",
1170
  accept_multiple_files=True,
1171
+ type=["pdf", "txt", "html"],
1172
  )
1173
 
1174
  # Rest of your existing code for document processing...
 
1183
  with st.expander("Preview extracted content"):
1184
  st.text_area(
1185
  "Extracted Text",
1186
+ value=st.session_state.knowledge_base,
 
1187
  height=200,
1188
  disabled=True,
1189
  )
 
1269
  if st.session_state.show_template_editor and st.session_state.template_spec:
1270
  st.header("Template Editor")
1271
 
1272
+ # Initialize suggested variables in session state if not present
1273
+ if "suggested_variables" not in st.session_state:
1274
+ st.session_state.suggested_variables = []
1275
+
1276
+ # Initialize a tracking variable for added suggestions
1277
+ if "added_suggestions" not in st.session_state:
1278
+ st.session_state.added_suggestions = set()
1279
+
1280
  # Basic template information
1281
  with st.expander("Template Information", expanded=True):
1282
  col1, col2 = st.columns(2)
 
1295
  height=100,
1296
  )
1297
 
1298
+ # Knowledge Base Analysis
1299
+ with st.expander("Knowledge Base Analysis", expanded=True):
1300
+ if st.session_state.knowledge_base:
1301
+ st.info("Analyze the knowledge base to suggest variables and values")
1302
+
1303
+ if st.button(
1304
+ "Analyze Knowledge Base for Variables", key="analyze_kb_button"
1305
+ ):
1306
+ client = get_openai_client()
1307
+ if not client:
1308
+ st.error(
1309
+ "Please provide an OpenAI API key to analyze the knowledge base."
1310
+ )
1311
+ else:
1312
+ with st.spinner("Analyzing knowledge base..."):
1313
+ suggested_vars = analyze_knowledge_base(
1314
+ st.session_state.knowledge_base, client
1315
+ )
1316
+ if suggested_vars:
1317
+ st.session_state.suggested_variables = suggested_vars
1318
+ st.success(
1319
+ f"Found {len(suggested_vars)} potential variables in the knowledge base"
1320
+ )
1321
+ else:
1322
+ st.warning(
1323
+ "Could not extract variables from the knowledge base"
1324
+ )
1325
+
1326
+ # Display suggested variables if they exist
1327
+ if st.session_state.suggested_variables:
1328
+ st.subheader("Suggested Variables")
1329
+
1330
+ # Create a container for the variables
1331
+ for i, var in enumerate(st.session_state.suggested_variables):
1332
+ # Generate a unique ID for this variable
1333
+ var_id = f"{var['name']}_{i}"
1334
+
1335
+ # Check if this variable has already been added
1336
+ if var_id in st.session_state.added_suggestions:
1337
+ continue
1338
+
1339
+ with st.container():
1340
+ col1, col2 = st.columns([3, 1])
1341
+ with col1:
1342
+ st.markdown(
1343
+ f"**{var['name']}** ({var['type']}): {var['description']}"
1344
+ )
1345
+ if var.get("options"):
1346
+ st.markdown(f"Options: {', '.join(var['options'])}")
1347
+ with col2:
1348
+ # Use a unique key for each button
1349
+ if st.button("Add", key=f"add_suggested_{var_id}"):
1350
+ # Add this variable to the template
1351
+ new_var = {
1352
+ "name": var["name"],
1353
+ "description": var["description"],
1354
+ "type": var["type"],
1355
+ }
1356
+ if var.get("options"):
1357
+ new_var["options"] = var["options"]
1358
+ if var["type"] in ["string", "int", "float"]:
1359
+ new_var["min"] = 1
1360
+ new_var["max"] = 100
1361
+
1362
+ # Add to input variables
1363
+ st.session_state.template_spec["input"].append(
1364
+ new_var
1365
+ )
1366
+
1367
+ # Mark this variable as added
1368
+ st.session_state.added_suggestions.add(var_id)
1369
+
1370
+ # Show success message
1371
+ st.success(
1372
+ f"Added {var['name']} to input variables!"
1373
+ )
1374
+ else:
1375
+ st.warning(
1376
+ "No knowledge base available. Please upload documents in the Setup tab first."
1377
+ )
1378
+
1379
  # Prompt Template Section
1380
  with st.expander("Prompt Template", expanded=True):
1381
  st.info("Use {variable_name} to refer to input variables in your template")
 
1410
  with st.expander("Input Variables", expanded=True):
1411
  st.subheader("Input Variables")
1412
 
1413
+ # Add input variable button with smart functionality
1414
+ col1, col2 = st.columns([3, 1])
1415
+ with col1:
1416
+ new_input_name = st.text_input(
1417
+ "New input variable name", key="new_input_name"
1418
+ )
1419
+ with col2:
1420
+ if st.button("Add Input Variable"):
1421
+ new_var = {
1422
+ "name": (
1423
+ new_input_name
1424
+ if new_input_name
1425
+ else f"new_input_{len(st.session_state.template_spec['input']) + 1}"
1426
+ ),
1427
+ "description": "New input variable",
1428
+ "type": "string",
1429
+ "min": 1,
1430
+ "max": 100,
1431
+ }
1432
+ st.session_state.template_spec["input"].append(new_var)
1433
 
1434
  # Display input variables
1435
  for i, input_var in enumerate(st.session_state.template_spec["input"]):
 
1480
  )
1481
 
1482
  if var_type == "categorical":
1483
+ # Add a button to suggest options from knowledge base
1484
+ kb_button_key = f"suggest_input_{i}_{input_var['name']}"
1485
+ if st.button("Suggest Options from KB", key=kb_button_key):
1486
+ client = get_openai_client()
1487
+ if not client:
1488
+ st.error(
1489
+ "Please provide an OpenAI API key to suggest options."
1490
+ )
1491
+ elif not st.session_state.knowledge_base:
1492
+ st.warning(
1493
+ "No knowledge base available. Please upload documents first."
1494
+ )
1495
+ else:
1496
+ with st.spinner(
1497
+ f"Suggesting options for {input_var['name']}..."
1498
+ ):
1499
+ suggestions = suggest_variable_values_from_kb(
1500
+ input_var["name"],
1501
+ "categorical",
1502
+ st.session_state.knowledge_base,
1503
+ client,
1504
+ )
1505
+ if suggestions and "options" in suggestions:
1506
+ # Update the options
1507
+ input_var["options"] = suggestions[
1508
+ "options"
1509
+ ]
1510
+ st.success(
1511
+ f"Found {len(suggestions['options'])} options"
1512
+ )
1513
+ else:
1514
+ st.warning(
1515
+ "Could not find suitable options in the knowledge base"
1516
+ )
1517
+
1518
+ # Display and edit options
1519
  options = input_var.get("options", [])
1520
  options_str = st.text_area(
1521
  "Options (one per line)",
 
1548
  with col3:
1549
  if st.button("Remove", key=f"remove_input_{i}"):
1550
  st.session_state.template_spec["input"].pop(i)
1551
+ st.rerun() # Only use rerun for removal
1552
 
1553
  st.divider()
1554
 
 
1556
  with st.expander("Output Variables", expanded=True):
1557
  st.subheader("Output Variables")
1558
 
1559
+ # Add output variable button with smart functionality
1560
+ col1, col2 = st.columns([3, 1])
1561
+ with col1:
1562
+ new_output_name = st.text_input(
1563
+ "New output variable name", key="new_output_name"
1564
+ )
1565
+ with col2:
1566
+ if st.button("Add Output Variable"):
1567
+ new_var = {
1568
+ "name": (
1569
+ new_output_name
1570
+ if new_output_name
1571
+ else f"new_output_{len(st.session_state.template_spec['output']) + 1}"
1572
+ ),
1573
+ "description": "New output variable",
1574
+ "type": "string",
1575
+ "min": 1,
1576
+ "max": 100,
1577
+ }
1578
+ st.session_state.template_spec["output"].append(new_var)
1579
 
1580
  # Display output variables
1581
  for i, output_var in enumerate(st.session_state.template_spec["output"]):
 
1626
  )
1627
 
1628
  if var_type == "categorical":
1629
+ # Add a button to suggest options from knowledge base
1630
+ kb_button_key = f"suggest_output_{i}_{output_var['name']}"
1631
+ if st.button("Suggest Options from KB", key=kb_button_key):
1632
+ client = get_openai_client()
1633
+ if not client:
1634
+ st.error(
1635
+ "Please provide an OpenAI API key to suggest options."
1636
+ )
1637
+ elif not st.session_state.knowledge_base:
1638
+ st.warning(
1639
+ "No knowledge base available. Please upload documents first."
1640
+ )
1641
+ else:
1642
+ with st.spinner(
1643
+ f"Suggesting options for {output_var['name']}..."
1644
+ ):
1645
+ suggestions = suggest_variable_values_from_kb(
1646
+ output_var["name"],
1647
+ "categorical",
1648
+ st.session_state.knowledge_base,
1649
+ client,
1650
+ )
1651
+ if suggestions and "options" in suggestions:
1652
+ # Update the options
1653
+ output_var["options"] = suggestions[
1654
+ "options"
1655
+ ]
1656
+ st.success(
1657
+ f"Found {len(suggestions['options'])} options"
1658
+ )
1659
+ else:
1660
+ st.warning(
1661
+ "Could not find suitable options in the knowledge base"
1662
+ )
1663
+
1664
+ # Display and edit options
1665
  options = output_var.get("options", [])
1666
  options_str = st.text_area(
1667
  "Options (one per line)",
 
1694
  with col3:
1695
  if st.button("Remove", key=f"remove_output_{i}"):
1696
  st.session_state.template_spec["output"].pop(i)
1697
+ st.rerun() # Only use rerun for removal
1698
 
1699
  st.divider()
1700
 
 
1935
  if "selected_samples" not in st.session_state:
1936
  st.session_state.selected_samples = []
1937
 
1938
+ # Add option selection for categorical variables
1939
+ categorical_vars = [
1940
+ var
1941
+ for var in st.session_state.template_spec["input"]
1942
+ if var["type"] == "categorical" and var.get("options")
1943
+ ]
1944
+
1945
+ if categorical_vars:
1946
+ st.subheader("Categorical Variable Options")
1947
+ st.info(
1948
+ "Select which options to include in the permutations for each categorical variable."
1949
+ )
1950
+
1951
+ # Create a copy of the template spec for modification
1952
+ template_spec_copy = st.session_state.template_spec.copy()
1953
+ template_spec_copy["input"] = st.session_state.template_spec["input"].copy()
1954
+
1955
+ # For each categorical variable, allow selecting options
1956
+ for i, var in enumerate(
1957
+ [
1958
+ v
1959
+ for v in template_spec_copy["input"]
1960
+ if v["type"] == "categorical" and v.get("options")
1961
+ ]
1962
+ ):
1963
+ with st.expander(
1964
+ f"{var['name']} - {var['description']}", expanded=False
1965
+ ):
1966
+ options = var.get("options", [])
1967
+
1968
+ # Initialize selected_options if not present
1969
+ if "selected_options" not in var:
1970
+ var["selected_options"] = options.copy()
1971
+
1972
+ # Add "Select All" and "Clear All" buttons
1973
+ col1, col2 = st.columns([1, 1])
1974
+ with col1:
1975
+ if st.button(
1976
+ f"Select All Options for {var['name']}",
1977
+ key=f"select_all_{i}",
1978
+ ):
1979
+ var["selected_options"] = options.copy()
1980
+ with col2:
1981
+ if st.button(
1982
+ f"Clear All Options for {var['name']}", key=f"clear_all_{i}"
1983
+ ):
1984
+ var["selected_options"] = []
1985
+
1986
+ # Create multiselect for options
1987
+ var["selected_options"] = st.multiselect(
1988
+ f"Select options to include for {var['name']}",
1989
+ options=options,
1990
+ default=var.get("selected_options", options),
1991
+ key=f"options_select_{i}",
1992
+ )
1993
+
1994
+ # Show selected count
1995
+ st.write(
1996
+ f"Selected {len(var['selected_options'])} out of {len(options)} options"
1997
+ )
1998
+
1999
+ # Update the template spec with the selected options
2000
+ for j, input_var in enumerate(template_spec_copy["input"]):
2001
+ if input_var["name"] == var["name"]:
2002
+ template_spec_copy["input"][j] = var
2003
+ break
2004
+
2005
  # Generate inputs button
2006
  if st.button("Generate Synthetic Inputs"):
2007
  if not st.session_state.get("api_key"):
2008
  st.error("Please provide an OpenAI API key in the sidebar.")
2009
  else:
2010
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
2011
+ # Use the modified template spec with selected options
2012
+ if categorical_vars:
2013
+ st.session_state.synthetic_inputs = (
2014
+ generate_synthetic_inputs_hybrid(
2015
+ template_spec_copy, num_samples=num_samples
2016
+ )
2017
+ )
2018
+ else:
2019
+ st.session_state.synthetic_inputs = (
2020
+ generate_synthetic_inputs_hybrid(
2021
+ st.session_state.template_spec, num_samples=num_samples
2022
+ )
2023
  )
 
2024
 
2025
  if st.session_state.synthetic_inputs:
2026
  st.success(
 
2308
  else:
2309
  st.info(
2310
  "No template has been generated yet. Go to the 'Setup' tab to create one."
2311
+ )