DJHumanRPT commited on
Commit
8689bd7
·
verified ·
1 Parent(s): 32fdd42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -84
app.py CHANGED
@@ -5,6 +5,7 @@ from docling.document_converter import DocumentConverter
5
  import re
6
  from io import BytesIO
7
  import openai
 
8
  import pandas as pd
9
  import itertools
10
  import random
@@ -27,6 +28,61 @@ def get_openai_client():
27
  return None
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # @st.cache_resource
31
  def get_document_converter():
32
  """Cache the DocumentConverter to prevent reloading on each interaction"""
@@ -439,11 +495,6 @@ def parse_template_file(uploaded_template):
439
  def call_llm(prompt, model="gpt-3.5-turbo"):
440
  """Call the LLM API to generate text based on the prompt."""
441
  try:
442
- client = get_openai_client()
443
- if not client:
444
- st.error("Please provide an OpenAI API key in the sidebar.")
445
- return "Error: No API key provided."
446
-
447
  # Get output specifications from the template if available
448
  output_specs = ""
449
  if st.session_state.show_template_editor and st.session_state.template_spec:
@@ -461,15 +512,13 @@ def call_llm(prompt, model="gpt-3.5-turbo"):
461
  # Add the output specs to the prompt
462
  prompt = f"{prompt}\n\n{output_specs}\n\nReturn ONLY a JSON object with the output variables, with no additional text or explanation."
463
 
464
- response = client.chat.completions.create(
465
  model=model,
466
- messages=[{"role": "user", "content": prompt}],
467
  max_tokens=1000,
468
  temperature=st.session_state.get("temperature", 0.7),
469
  )
470
 
471
- result = response.choices[0].message.content
472
-
473
  # Try to parse as JSON if the template has output variables
474
  if (
475
  st.session_state.show_template_editor
@@ -513,10 +562,6 @@ def generate_template_from_instructions(instructions, document_content=""):
513
  Use LLM to generate a template specification based on user instructions
514
  and document content.
515
  """
516
- client = get_openai_client()
517
- if not client:
518
- st.error("Please provide an OpenAI API key to generate a template.")
519
- return create_fallback_template(instructions)
520
 
521
  # Prepare the prompt for the LLM
522
  prompt = f"""
@@ -564,15 +609,13 @@ If document content was provided, design the template to effectively use that in
564
 
565
  try:
566
  # Call the LLM to generate the template
567
- response = client.chat.completions.create(
568
  model=st.session_state.model,
569
- messages=[{"role": "user", "content": prompt}],
570
  max_tokens=4096,
571
  temperature=0.7,
572
  )
573
 
574
- template_text = response.choices[0].message.content
575
-
576
  # Extract the JSON part from the response
577
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*{[\s\S]*}\s*$"
578
  json_match = re.search(json_pattern, template_text)
@@ -581,12 +624,12 @@ If document content was provided, design the template to effectively use that in
581
  json_str = json_match.group(1) if json_match.group(1) else template_text
582
  # Clean up any remaining markdown or comments
583
  json_str = re.sub(r"```.*|```", "", json_str).strip()
584
- template_spec = json.loads(json_str)
585
  return template_spec
586
  else:
587
  # If no JSON format found, try to parse the entire response
588
  try:
589
- template_spec = json.loads(template_text)
590
  return template_spec
591
  except:
592
  st.warning("LLM didn't return valid JSON. Using fallback template.")
@@ -604,9 +647,10 @@ def generate_improved_prompt_template(template_spec, knowledge_base=""):
604
  """
605
  Use LLM to generate an improved prompt template based on current template variables.
606
  """
607
- client = get_openai_client()
608
- if not client:
609
- st.error("Please provide an OpenAI API key to rewrite the prompt.")
 
610
  return template_spec["prompt"]
611
 
612
  # Extract template information for context
@@ -661,15 +705,13 @@ Return ONLY the revised prompt template text, with no additional explanations.
661
 
662
  try:
663
  # Call the LLM to generate the improved prompt template
664
- response = client.chat.completions.create(
665
  model=st.session_state.model,
666
- messages=[{"role": "user", "content": prompt}],
667
  max_tokens=4096,
668
  temperature=0.7,
669
  )
670
 
671
- improved_template = response.choices[0].message.content.strip()
672
-
673
  # Remove any markdown code block formatting if present
674
  improved_template = re.sub(r"```.*\n|```", "", improved_template)
675
 
@@ -715,8 +757,9 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
715
  - Use LLM to fill in non-categorical variables
716
  - Process row by row for resilience
717
  """
718
- client = get_openai_client()
719
- if not client:
 
720
  st.error("Please provide an OpenAI API key to generate synthetic data.")
721
  return []
722
 
@@ -753,7 +796,7 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
753
  row = perm.copy()
754
  if non_categorical_vars:
755
  non_cat_values = generate_non_categorical_values(
756
- non_categorical_vars, perm, client, max_retries
757
  )
758
  row.update(non_cat_values)
759
 
@@ -769,14 +812,14 @@ def generate_synthetic_inputs_hybrid(template_spec, num_samples=10, max_retries=
769
  progress_bar.progress(min((i + 1) / num_samples, 1.0))
770
 
771
  # Generate a complete row of values
772
- row = generate_single_row(input_vars, client, max_retries)
773
  if row:
774
  results.append(row)
775
 
776
  # Ensure we have the requested number of samples
777
  while len(results) < num_samples:
778
  # Generate additional rows if needed
779
- row = generate_single_row(input_vars, client, max_retries)
780
  if row:
781
  results.append(row)
782
 
@@ -893,7 +936,7 @@ def generate_categorical_permutations(categorical_vars, target_count):
893
  return all_permutations
894
 
895
 
896
- def generate_non_categorical_values(non_cat_vars, existing_values, client, max_retries):
897
  """Generate values for non-categorical variables given existing categorical values."""
898
  if not non_cat_vars:
899
  return {}
@@ -929,14 +972,14 @@ def generate_non_categorical_values(non_cat_vars, existing_values, client, max_r
929
 
930
  for attempt in range(max_retries):
931
  try:
932
- response = client.chat.completions.create(
933
  model=st.session_state.model,
934
- messages=[{"role": "user", "content": prompt}],
935
  max_tokens=1000,
936
  temperature=st.session_state.temperature,
937
  )
938
 
939
- result = response.choices[0].message.content.strip()
940
 
941
  # Extract JSON
942
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
@@ -946,14 +989,14 @@ def generate_non_categorical_values(non_cat_vars, existing_values, client, max_r
946
  json_str = json_match.group(1) if json_match.group(1) else result
947
  json_str = re.sub(r"```.*|```", "", json_str).strip()
948
  try:
949
- values = json.loads(json_str)
950
  if isinstance(values, dict):
951
  return values
952
  except:
953
  pass
954
  else:
955
  try:
956
- values = json.loads(result)
957
  if isinstance(values, dict):
958
  return values
959
  except:
@@ -967,7 +1010,7 @@ def generate_non_categorical_values(non_cat_vars, existing_values, client, max_r
967
  return {var["name"]: get_default_value(var) for var in non_cat_vars}
968
 
969
 
970
- def generate_single_row(all_vars, client, max_retries):
971
  """Generate a complete row of data for all variables."""
972
  # Format the variables for the prompt
973
  vars_text = "\n".join(
@@ -999,14 +1042,14 @@ def generate_single_row(all_vars, client, max_retries):
999
 
1000
  for attempt in range(max_retries):
1001
  try:
1002
- response = client.chat.completions.create(
1003
  model=st.session_state.model,
1004
  messages=[{"role": "user", "content": prompt}],
1005
  max_tokens=1000,
1006
  temperature=st.session_state.temperature,
1007
  )
1008
 
1009
- result = response.choices[0].message.content.strip()
1010
 
1011
  # Extract JSON
1012
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
@@ -1016,14 +1059,14 @@ def generate_single_row(all_vars, client, max_retries):
1016
  json_str = json_match.group(1) if json_match.group(1) else result
1017
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1018
  try:
1019
- values = json.loads(json_str)
1020
  if isinstance(values, dict):
1021
  return values
1022
  except:
1023
  pass
1024
  else:
1025
  try:
1026
- values = json.loads(result)
1027
  if isinstance(values, dict):
1028
  return values
1029
  except:
@@ -1072,10 +1115,6 @@ def generate_synthetic_outputs(
1072
  template_spec, input_data, knowledge_base="", max_retries=3
1073
  ):
1074
  """Generate synthetic output data based on template and input data with retry logic."""
1075
- client = get_openai_client()
1076
- if not client:
1077
- st.error("Please provide an OpenAI API key to generate synthetic outputs.")
1078
- return []
1079
 
1080
  output_vars = template_spec["output"]
1081
  prompt_template = template_spec["prompt"]
@@ -1141,17 +1180,16 @@ The response must be valid JSON that can be parsed directly.
1141
  """
1142
 
1143
  output_data = None
1144
- print(generation_prompt)
1145
  for attempt in range(max_retries):
1146
  try:
1147
- response = client.chat.completions.create(
1148
  model=st.session_state.model,
1149
- messages=[{"role": "user", "content": generation_prompt}],
1150
  max_tokens=2000,
1151
  temperature=st.session_state.temperature,
1152
  )
1153
 
1154
- result = response.choices[0].message.content.strip()
1155
 
1156
  # Extract JSON from the response
1157
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
@@ -1164,7 +1202,7 @@ The response must be valid JSON that can be parsed directly.
1164
  # Clean up any remaining markdown or comments
1165
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1166
  try:
1167
- output_data = json.loads(json_str)
1168
  # Validate that we got a dictionary
1169
  if isinstance(output_data, dict):
1170
  # Check if all required output variables are present
@@ -1191,7 +1229,7 @@ The response must be valid JSON that can be parsed directly.
1191
  else:
1192
  # Try to parse the entire response as JSON
1193
  try:
1194
- output_data = json.loads(result)
1195
  # Validate that we got a dictionary
1196
  if isinstance(output_data, dict):
1197
  # Check if all required output variables are present
@@ -1249,13 +1287,13 @@ The response must be valid JSON that can be parsed directly.
1249
 
1250
 
1251
  def suggest_variable_values_from_kb(
1252
- variable_name, variable_type, knowledge_base, client, model="gpt-3.5-turbo"
1253
  ):
1254
  """
1255
  Use LLM to suggest possible values for a variable based on the knowledge base content.
1256
  Especially useful for categorical variables to extract options from documents.
1257
  """
1258
- if not knowledge_base or not client:
1259
  return None
1260
 
1261
  # Truncate knowledge base if it's too long
@@ -1289,15 +1327,13 @@ def suggest_variable_values_from_kb(
1289
  """
1290
 
1291
  try:
1292
- response = client.chat.completions.create(
1293
  model=model,
1294
- messages=[{"role": "user", "content": prompt}],
1295
  max_tokens=1000,
1296
  temperature=0.3,
1297
  )
1298
 
1299
- result = response.choices[0].message.content
1300
-
1301
  # Extract JSON from the response
1302
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
1303
  json_match = re.search(json_pattern, result)
@@ -1306,13 +1342,13 @@ def suggest_variable_values_from_kb(
1306
  json_str = json_match.group(1) if json_match.group(1) else result
1307
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1308
  try:
1309
- suggestions = json.loads(json_str)
1310
  return suggestions
1311
  except:
1312
  pass
1313
  else:
1314
  try:
1315
- suggestions = json.loads(result)
1316
  return suggestions
1317
  except:
1318
  pass
@@ -1324,12 +1360,12 @@ def suggest_variable_values_from_kb(
1324
 
1325
 
1326
  @st.cache_data
1327
- def analyze_knowledge_base(knowledge_base, _client, model="gpt-4o-mini"):
1328
  """
1329
  Analyze the knowledge base to extract potential variable names and values.
1330
  This can be used to suggest variables when creating a new template.
1331
  """
1332
- if not knowledge_base or not client:
1333
  return None
1334
 
1335
  # Truncate knowledge base if it's too long
@@ -1365,15 +1401,13 @@ def analyze_knowledge_base(knowledge_base, _client, model="gpt-4o-mini"):
1365
  """
1366
 
1367
  try:
1368
- response = _client.chat.completions.create(
1369
  model=model,
1370
- messages=[{"role": "user", "content": prompt}],
1371
  max_tokens=2000,
1372
  temperature=0.3,
1373
  )
1374
 
1375
- result = response.choices[0].message.content
1376
-
1377
  # Extract JSON from the response
1378
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\[[\s\S]*\]\s*$"
1379
  json_match = re.search(json_pattern, result)
@@ -1382,13 +1416,13 @@ def analyze_knowledge_base(knowledge_base, _client, model="gpt-4o-mini"):
1382
  json_str = json_match.group(1) if json_match.group(1) else result
1383
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1384
  try:
1385
- suggestions = json.loads(json_str)
1386
  return suggestions
1387
  except:
1388
  pass
1389
  else:
1390
  try:
1391
- suggestions = json.loads(result)
1392
  return suggestions
1393
  except:
1394
  pass
@@ -1420,18 +1454,42 @@ with st.sidebar:
1420
  st.title("Template Generator")
1421
  st.write("Create templates for generating content with LLMs.")
1422
 
1423
- # API Key input
 
1424
  api_key = st.text_input("OpenAI API Key", type="password")
1425
  if api_key:
1426
  st.session_state.api_key = api_key
1427
 
 
 
 
 
1428
  # Model selection
1429
- st.session_state.model = st.selectbox(
1430
- "Select LLM Model",
1431
- options=["gpt-4o-mini", "gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-turbo"],
 
1432
  index=0,
1433
  )
1434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435
  # Main application layout
1436
  st.title("Template Generator")
1437
 
@@ -1565,7 +1623,9 @@ with tab1:
1565
 
1566
  # Generate Template button
1567
  if st.button("Generate Template"):
1568
- if not st.session_state.get("api_key"):
 
 
1569
  st.error(
1570
  "Please provide an OpenAI API key in the sidebar before generating a template."
1571
  )
@@ -1838,7 +1898,7 @@ with tab2:
1838
  else:
1839
  with st.spinner("Analyzing knowledge base..."):
1840
  suggested_vars = analyze_knowledge_base(
1841
- st.session_state.knowledge_base, client
1842
  )
1843
  if suggested_vars:
1844
  st.session_state.suggested_variables = (
@@ -2095,7 +2155,6 @@ with tab2:
2095
  input_var["name"],
2096
  "categorical",
2097
  st.session_state.knowledge_base,
2098
- client,
2099
  )
2100
  )
2101
  if (
@@ -2278,7 +2337,6 @@ with tab2:
2278
  output_var["name"],
2279
  "categorical",
2280
  st.session_state.knowledge_base,
2281
- client,
2282
  )
2283
  )
2284
  if suggestions and "options" in suggestions:
@@ -2405,9 +2463,11 @@ with tab2:
2405
  # Generate Output button
2406
  if st.button("Generate Output", key="generate_button"):
2407
  # Check if API key is provided
2408
- if not st.session_state.get("api_key"):
 
 
2409
  st.error(
2410
- "Please provide an OpenAI API key in the sidebar before generating output."
2411
  )
2412
  else:
2413
  # Fill the prompt template with user-provided values
@@ -2655,8 +2715,12 @@ with tab3:
2655
 
2656
  # Generate inputs button
2657
  if st.button("Generate Synthetic Inputs"):
2658
- if not st.session_state.get("api_key"):
2659
- st.error("Please provide an OpenAI API key in the sidebar.")
 
 
 
 
2660
  else:
2661
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
2662
  # Use the modified template spec with selected options
@@ -2849,8 +2913,12 @@ with tab3:
2849
 
2850
  # Generate outputs button
2851
  if st.button("Generate Outputs for Selected Samples"):
2852
- if not st.session_state.get("api_key"):
2853
- st.error("Please provide an OpenAI API key in the sidebar.")
 
 
 
 
2854
  elif not st.session_state.selected_samples:
2855
  st.error("No samples selected for output generation.")
2856
  else:
@@ -2940,6 +3008,44 @@ with tab3:
2940
  if st.session_state.combined_data:
2941
  st.subheader("Complete Dataset (Inputs + Outputs)")
2942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2943
  # Create a function to prepare the dataframe with JSON columns
2944
  def prepare_dataframe_with_json_columns(
2945
  data, template_spec, show_json_columns=False
@@ -3017,8 +3123,10 @@ with tab3:
3017
  try:
3018
  # Create a BytesIO object to hold the Parquet file
3019
  parquet_buffer = BytesIO()
 
 
3020
  # Write the DataFrame to the BytesIO object in Parquet format
3021
- full_df.to_parquet(parquet_buffer, index=False)
3022
  # Reset the buffer's position to the beginning
3023
  parquet_buffer.seek(0)
3024
 
@@ -3036,4 +3144,4 @@ with tab3:
3036
  else:
3037
  st.info(
3038
  "No template has been generated yet. Go to the 'Setup' tab to create one."
3039
- )
 
5
  import re
6
  from io import BytesIO
7
  import openai
8
+ import anthropic # Add import for Anthropic's Claude models
9
  import pandas as pd
10
  import itertools
11
  import random
 
28
  return None
29
 
30
 
31
+ def get_anthropic_client():
32
+ api_key = st.session_state.get("anthropic_api_key", "")
33
+ if api_key:
34
+ return anthropic.Anthropic(api_key=api_key)
35
+ return None
36
+
37
+
38
+ def call_model_api(prompt, model, temperature=0.7, max_tokens=1000):
39
+ """
40
+ Abstraction function to call the appropriate LLM API based on the model name.
41
+
42
+ Args:
43
+ prompt (str): The prompt to send to the model
44
+ model (str): The model name (e.g., "gpt-4", "claude-3-opus-latest")
45
+ temperature (float): Creativity parameter (0.0 to 1.0)
46
+ max_tokens (int): Maximum number of tokens to generate
47
+
48
+ Returns:
49
+ str: The generated text response
50
+ """
51
+ # Check if it's a Claude model
52
+ if model.startswith("claude"):
53
+ client = get_anthropic_client()
54
+ if not client:
55
+ return "Error: No Anthropic API key provided."
56
+
57
+ try:
58
+ response = client.messages.create(
59
+ model=model,
60
+ messages=[{"role": "user", "content": prompt}],
61
+ max_tokens=max_tokens,
62
+ temperature=temperature,
63
+ )
64
+ return response.content[0].text
65
+ except Exception as e:
66
+ return f"Error calling Anthropic API: {str(e)}"
67
+
68
+ # Otherwise, use OpenAI
69
+ else:
70
+ client = get_openai_client()
71
+ if not client:
72
+ return "Error: No OpenAI API key provided."
73
+
74
+ try:
75
+ response = client.chat.completions.create(
76
+ model=model,
77
+ messages=[{"role": "user", "content": prompt}],
78
+ max_tokens=max_tokens,
79
+ temperature=temperature,
80
+ )
81
+ return response.choices[0].message.content
82
+ except Exception as e:
83
+ return f"Error calling OpenAI API: {str(e)}"
84
+
85
+
86
  # @st.cache_resource
87
  def get_document_converter():
88
  """Cache the DocumentConverter to prevent reloading on each interaction"""
 
495
  def call_llm(prompt, model="gpt-3.5-turbo"):
496
  """Call the LLM API to generate text based on the prompt."""
497
  try:
 
 
 
 
 
498
  # Get output specifications from the template if available
499
  output_specs = ""
500
  if st.session_state.show_template_editor and st.session_state.template_spec:
 
512
  # Add the output specs to the prompt
513
  prompt = f"{prompt}\n\n{output_specs}\n\nReturn ONLY a JSON object with the output variables, with no additional text or explanation."
514
 
515
+ result = call_model_api(
516
  model=model,
517
+ prompt=prompt,
518
  max_tokens=1000,
519
  temperature=st.session_state.get("temperature", 0.7),
520
  )
521
 
 
 
522
  # Try to parse as JSON if the template has output variables
523
  if (
524
  st.session_state.show_template_editor
 
562
  Use LLM to generate a template specification based on user instructions
563
  and document content.
564
  """
 
 
 
 
565
 
566
  # Prepare the prompt for the LLM
567
  prompt = f"""
 
609
 
610
  try:
611
  # Call the LLM to generate the template
612
+ template_text = call_model_api(
613
  model=st.session_state.model,
614
+ prompt=prompt,
615
  max_tokens=4096,
616
  temperature=0.7,
617
  )
618
 
 
 
619
  # Extract the JSON part from the response
620
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*{[\s\S]*}\s*$"
621
  json_match = re.search(json_pattern, template_text)
 
624
  json_str = json_match.group(1) if json_match.group(1) else template_text
625
  # Clean up any remaining markdown or comments
626
  json_str = re.sub(r"```.*|```", "", json_str).strip()
627
+ template_spec = json.loads(json_str, strict=False)
628
  return template_spec
629
  else:
630
  # If no JSON format found, try to parse the entire response
631
  try:
632
+ template_spec = json.loads(template_text, strict=False)
633
  return template_spec
634
  except:
635
  st.warning("LLM didn't return valid JSON. Using fallback template.")
 
647
  """
648
  Use LLM to generate an improved prompt template based on current template variables.
649
  """
650
+ if not st.session_state.get("api_key") and not st.session_state.get(
651
+ "anthropic_api_key"
652
+ ):
653
+ st.error("Please provide an OpenAI or Anthropic API key to rewrite the prompt.")
654
  return template_spec["prompt"]
655
 
656
  # Extract template information for context
 
705
 
706
  try:
707
  # Call the LLM to generate the improved prompt template
708
+ improved_template = call_model_api(
709
  model=st.session_state.model,
710
+ prompt=prompt,
711
  max_tokens=4096,
712
  temperature=0.7,
713
  )
714
 
 
 
715
  # Remove any markdown code block formatting if present
716
  improved_template = re.sub(r"```.*\n|```", "", improved_template)
717
 
 
757
  - Use LLM to fill in non-categorical variables
758
  - Process row by row for resilience
759
  """
760
+ if not st.session_state.get("api_key") and not st.session_state.get(
761
+ "anthropic_api_key"
762
+ ):
763
  st.error("Please provide an OpenAI API key to generate synthetic data.")
764
  return []
765
 
 
796
  row = perm.copy()
797
  if non_categorical_vars:
798
  non_cat_values = generate_non_categorical_values(
799
+ non_categorical_vars, perm, max_retries
800
  )
801
  row.update(non_cat_values)
802
 
 
812
  progress_bar.progress(min((i + 1) / num_samples, 1.0))
813
 
814
  # Generate a complete row of values
815
+ row = generate_single_row(input_vars, max_retries)
816
  if row:
817
  results.append(row)
818
 
819
  # Ensure we have the requested number of samples
820
  while len(results) < num_samples:
821
  # Generate additional rows if needed
822
+ row = generate_single_row(input_vars, max_retries)
823
  if row:
824
  results.append(row)
825
 
 
936
  return all_permutations
937
 
938
 
939
+ def generate_non_categorical_values(non_cat_vars, existing_values, max_retries):
940
  """Generate values for non-categorical variables given existing categorical values."""
941
  if not non_cat_vars:
942
  return {}
 
972
 
973
  for attempt in range(max_retries):
974
  try:
975
+ response = call_model_api(
976
  model=st.session_state.model,
977
+ prompt=prompt,
978
  max_tokens=1000,
979
  temperature=st.session_state.temperature,
980
  )
981
 
982
+ result = response.strip()
983
 
984
  # Extract JSON
985
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
 
989
  json_str = json_match.group(1) if json_match.group(1) else result
990
  json_str = re.sub(r"```.*|```", "", json_str).strip()
991
  try:
992
+ values = json.loads(json_str, strict=False)
993
  if isinstance(values, dict):
994
  return values
995
  except:
996
  pass
997
  else:
998
  try:
999
+ values = json.loads(result, strict=False)
1000
  if isinstance(values, dict):
1001
  return values
1002
  except:
 
1010
  return {var["name"]: get_default_value(var) for var in non_cat_vars}
1011
 
1012
 
1013
+ def generate_single_row(all_vars, max_retries):
1014
  """Generate a complete row of data for all variables."""
1015
  # Format the variables for the prompt
1016
  vars_text = "\n".join(
 
1042
 
1043
  for attempt in range(max_retries):
1044
  try:
1045
+ response = call_model_api(
1046
  model=st.session_state.model,
1047
  messages=[{"role": "user", "content": prompt}],
1048
  max_tokens=1000,
1049
  temperature=st.session_state.temperature,
1050
  )
1051
 
1052
+ result = response.strip()
1053
 
1054
  # Extract JSON
1055
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
 
1059
  json_str = json_match.group(1) if json_match.group(1) else result
1060
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1061
  try:
1062
+ values = json.loads(json_str, strict=False)
1063
  if isinstance(values, dict):
1064
  return values
1065
  except:
1066
  pass
1067
  else:
1068
  try:
1069
+ values = json.loads(result, strict=False)
1070
  if isinstance(values, dict):
1071
  return values
1072
  except:
 
1115
  template_spec, input_data, knowledge_base="", max_retries=3
1116
  ):
1117
  """Generate synthetic output data based on template and input data with retry logic."""
 
 
 
 
1118
 
1119
  output_vars = template_spec["output"]
1120
  prompt_template = template_spec["prompt"]
 
1180
  """
1181
 
1182
  output_data = None
 
1183
  for attempt in range(max_retries):
1184
  try:
1185
+ response = call_model_api(
1186
  model=st.session_state.model,
1187
+ prompt=generation_prompt,
1188
  max_tokens=2000,
1189
  temperature=st.session_state.temperature,
1190
  )
1191
 
1192
+ result = response.strip()
1193
 
1194
  # Extract JSON from the response
1195
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
 
1202
  # Clean up any remaining markdown or comments
1203
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1204
  try:
1205
+ output_data = json.loads(json_str, strict=False)
1206
  # Validate that we got a dictionary
1207
  if isinstance(output_data, dict):
1208
  # Check if all required output variables are present
 
1229
  else:
1230
  # Try to parse the entire response as JSON
1231
  try:
1232
+ output_data = json.loads(result, strict=False)
1233
  # Validate that we got a dictionary
1234
  if isinstance(output_data, dict):
1235
  # Check if all required output variables are present
 
1287
 
1288
 
1289
  def suggest_variable_values_from_kb(
1290
+ variable_name, variable_type, knowledge_base, model="gpt-3.5-turbo"
1291
  ):
1292
  """
1293
  Use LLM to suggest possible values for a variable based on the knowledge base content.
1294
  Especially useful for categorical variables to extract options from documents.
1295
  """
1296
+ if not knowledge_base:
1297
  return None
1298
 
1299
  # Truncate knowledge base if it's too long
 
1327
  """
1328
 
1329
  try:
1330
+ result = call_model_api(
1331
  model=model,
1332
+ prompt=prompt,
1333
  max_tokens=1000,
1334
  temperature=0.3,
1335
  )
1336
 
 
 
1337
  # Extract JSON from the response
1338
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
1339
  json_match = re.search(json_pattern, result)
 
1342
  json_str = json_match.group(1) if json_match.group(1) else result
1343
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1344
  try:
1345
+ suggestions = json.loads(json_str, strict=False)
1346
  return suggestions
1347
  except:
1348
  pass
1349
  else:
1350
  try:
1351
+ suggestions = json.loads(result, strict=False)
1352
  return suggestions
1353
  except:
1354
  pass
 
1360
 
1361
 
1362
  @st.cache_data
1363
+ def analyze_knowledge_base(knowledge_base, model="gpt-4o-mini"):
1364
  """
1365
  Analyze the knowledge base to extract potential variable names and values.
1366
  This can be used to suggest variables when creating a new template.
1367
  """
1368
+ if not knowledge_base:
1369
  return None
1370
 
1371
  # Truncate knowledge base if it's too long
 
1401
  """
1402
 
1403
  try:
1404
+ result = call_model_api(
1405
  model=model,
1406
+ prompt=prompt,
1407
  max_tokens=2000,
1408
  temperature=0.3,
1409
  )
1410
 
 
 
1411
  # Extract JSON from the response
1412
  json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\[[\s\S]*\]\s*$"
1413
  json_match = re.search(json_pattern, result)
 
1416
  json_str = json_match.group(1) if json_match.group(1) else result
1417
  json_str = re.sub(r"```.*|```", "", json_str).strip()
1418
  try:
1419
+ suggestions = json.loads(json_str, strict=False)
1420
  return suggestions
1421
  except:
1422
  pass
1423
  else:
1424
  try:
1425
+ suggestions = json.loads(result, strict=False)
1426
  return suggestions
1427
  except:
1428
  pass
 
1454
  st.title("Template Generator")
1455
  st.write("Create templates for generating content with LLMs.")
1456
 
1457
+ # API Key inputs
1458
+ st.subheader("API Keys")
1459
  api_key = st.text_input("OpenAI API Key", type="password")
1460
  if api_key:
1461
  st.session_state.api_key = api_key
1462
 
1463
+ anthropic_api_key = st.text_input("Anthropic API Key", type="password")
1464
+ if anthropic_api_key:
1465
+ st.session_state.anthropic_api_key = anthropic_api_key
1466
+
1467
  # Model selection
1468
+ st.subheader("Model Selection")
1469
+ model_provider = st.radio(
1470
+ "Select Model Provider",
1471
+ options=["OpenAI", "Anthropic"],
1472
  index=0,
1473
  )
1474
 
1475
+ if model_provider == "OpenAI":
1476
+ st.session_state.model = st.selectbox(
1477
+ "Select OpenAI Model",
1478
+ options=["gpt-4o-mini", "gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-turbo"],
1479
+ index=0,
1480
+ )
1481
+ else: # Anthropic
1482
+ st.session_state.model = st.selectbox(
1483
+ "Select Claude Model",
1484
+ options=[
1485
+ "claude-3-7-sonnet-latest",
1486
+ "claude-3-5-haiku-latest",
1487
+ "claude-3-5-sonnet-latest",
1488
+ "claude-3-opus-latest",
1489
+ ],
1490
+ index=1, # Default to Sonnet as a good balance of capability and cost
1491
+ )
1492
+
1493
  # Main application layout
1494
  st.title("Template Generator")
1495
 
 
1623
 
1624
  # Generate Template button
1625
  if st.button("Generate Template"):
1626
+ if not st.session_state.get("api_key") and not st.session_state.get(
1627
+ "anthropic_api_key"
1628
+ ):
1629
  st.error(
1630
  "Please provide an OpenAI API key in the sidebar before generating a template."
1631
  )
 
1898
  else:
1899
  with st.spinner("Analyzing knowledge base..."):
1900
  suggested_vars = analyze_knowledge_base(
1901
+ st.session_state.knowledge_base
1902
  )
1903
  if suggested_vars:
1904
  st.session_state.suggested_variables = (
 
2155
  input_var["name"],
2156
  "categorical",
2157
  st.session_state.knowledge_base,
 
2158
  )
2159
  )
2160
  if (
 
2337
  output_var["name"],
2338
  "categorical",
2339
  st.session_state.knowledge_base,
 
2340
  )
2341
  )
2342
  if suggestions and "options" in suggestions:
 
2463
  # Generate Output button
2464
  if st.button("Generate Output", key="generate_button"):
2465
  # Check if API key is provided
2466
+ if not st.session_state.get("api_key") and not st.session_state.get(
2467
+ "anthropic_api_key"
2468
+ ):
2469
  st.error(
2470
+ "Please provide an OpenAI or Anthropic API key in the sidebar before generating output."
2471
  )
2472
  else:
2473
  # Fill the prompt template with user-provided values
 
2715
 
2716
  # Generate inputs button
2717
  if st.button("Generate Synthetic Inputs"):
2718
+ if not st.session_state.get("api_key") and not st.session_state.get(
2719
+ "anthropic_api_key"
2720
+ ):
2721
+ st.error(
2722
+ "Please provide an OpenAI or Anthropic API key in the sidebar."
2723
+ )
2724
  else:
2725
  with st.spinner(f"Generating {num_samples} synthetic input samples..."):
2726
  # Use the modified template spec with selected options
 
2913
 
2914
  # Generate outputs button
2915
  if st.button("Generate Outputs for Selected Samples"):
2916
+ if not st.session_state.get("api_key") and not st.session_state.get(
2917
+ "anthropic_api_key"
2918
+ ):
2919
+ st.error(
2920
+ "Please provide an OpenAI or Anthropic API key in the sidebar."
2921
+ )
2922
  elif not st.session_state.selected_samples:
2923
  st.error("No samples selected for output generation.")
2924
  else:
 
3008
  if st.session_state.combined_data:
3009
  st.subheader("Complete Dataset (Inputs + Outputs)")
3010
 
3011
+ # Add this function before the prepare_dataframe_with_json_columns function
3012
+
3013
+ def prepare_dataframe_for_parquet(df):
3014
+ """
3015
+ Convert DataFrame columns to types compatible with Parquet format.
3016
+
3017
+ Args:
3018
+ df (pd.DataFrame): Input DataFrame
3019
+
3020
+ Returns:
3021
+ pd.DataFrame: DataFrame with converted types
3022
+ """
3023
+ df_copy = df.copy()
3024
+
3025
+ for col in df_copy.columns:
3026
+ # Check if column contains lists or dictionaries
3027
+ if df_copy[col].apply(lambda x: isinstance(x, (list, dict))).any():
3028
+ # Convert lists and dictionaries to JSON strings
3029
+ df_copy[col] = df_copy[col].apply(
3030
+ lambda x: (
3031
+ json.dumps(x) if isinstance(x, (list, dict)) else x
3032
+ )
3033
+ )
3034
+
3035
+ # Check for mixed types that might cause issues
3036
+ if (
3037
+ df_copy[col]
3038
+ .apply(lambda x: isinstance(x, (bool, int, float, str)))
3039
+ .all()
3040
+ ):
3041
+ # Column has consistent primitive types, leave as is
3042
+ continue
3043
+ else:
3044
+ # Convert any complex or mixed types to strings
3045
+ df_copy[col] = df_copy[col].apply(str)
3046
+
3047
+ return df_copy
3048
+
3049
  # Create a function to prepare the dataframe with JSON columns
3050
  def prepare_dataframe_with_json_columns(
3051
  data, template_spec, show_json_columns=False
 
3123
  try:
3124
  # Create a BytesIO object to hold the Parquet file
3125
  parquet_buffer = BytesIO()
3126
+ # Convert DataFrame to Parquet-compatible types
3127
+ parquet_df = prepare_dataframe_for_parquet(full_df)
3128
  # Write the DataFrame to the BytesIO object in Parquet format
3129
+ parquet_df.to_parquet(parquet_buffer, index=False)
3130
  # Reset the buffer's position to the beginning
3131
  parquet_buffer.seek(0)
3132
 
 
3144
  else:
3145
  st.info(
3146
  "No template has been generated yet. Go to the 'Setup' tab to create one."
3147
+ )