DJHumanRPT commited on
Commit
e16fb8a
·
verified ·
1 Parent(s): 7cc7a46

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1265 -0
app.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import PyPDF2
4
+ import re
5
+ from io import BytesIO
6
+ import openai
7
+ import pandas as pd
8
+
9
+ # Setup page config
10
+ st.set_page_config(
11
+ page_title="Template Generator",
12
+ layout="wide",
13
+ initial_sidebar_state="expanded",
14
+ )
15
+
16
+
17
+ # Initialize OpenAI client (you'll need to provide your API key)
18
+ def get_openai_client():
19
+ api_key = st.session_state.get("api_key", "")
20
+ if api_key:
21
+ return openai.OpenAI(api_key=api_key)
22
+ return None
23
+
24
+
25
+ # Define helper functions for PDF parsing
26
+ def parse_pdf(file):
27
+ """Extract text from a PDF file."""
28
+ try:
29
+ pdf_reader = PyPDF2.PdfReader(file)
30
+ text = ""
31
+ for page_num in range(len(pdf_reader.pages)):
32
+ text += pdf_reader.pages[page_num].extract_text() or ""
33
+ return text
34
+ except Exception as e:
35
+ st.error(f"Error parsing PDF: {str(e)}")
36
+ return ""
37
+
38
+
39
+ def parse_documents(uploaded_files):
40
+ """Parse multiple document files and extract their text content."""
41
+ content = ""
42
+ for file in uploaded_files:
43
+ try:
44
+ file_type = file.name.split(".")[-1].lower()
45
+ if file_type == "pdf":
46
+ # Create a copy of the file to avoid buffer issues
47
+ file_copy = BytesIO(file.getvalue())
48
+ content += parse_pdf(file_copy) + "\n\n"
49
+ elif file_type == "txt":
50
+ content += file.getvalue().decode("utf-8") + "\n\n"
51
+ else:
52
+ st.warning(f"Unsupported file type: {file.name}")
53
+ except Exception as e:
54
+ st.error(f"Error processing file {file.name}: {str(e)}")
55
+ return content
56
+
57
+
58
+ # Add this function after parse_documents function
59
+ def parse_template_file(uploaded_template):
60
+ """Parse an uploaded template JSON file and validate its structure."""
61
+ try:
62
+ # Read the file content
63
+ if uploaded_template.name.endswith(".json"):
64
+ template_content = uploaded_template.getvalue().decode("utf-8")
65
+ template_spec = json.loads(template_content)
66
+
67
+ # Validate the template structure
68
+ required_keys = [
69
+ "name",
70
+ "version",
71
+ "description",
72
+ "input",
73
+ "output",
74
+ "prompt",
75
+ ]
76
+ for key in required_keys:
77
+ if key not in template_spec:
78
+ return None, f"Invalid template: Missing '{key}' field"
79
+
80
+ # Validate input and output arrays
81
+ if not isinstance(template_spec["input"], list):
82
+ return None, "Invalid template: 'input' must be an array"
83
+ if not isinstance(template_spec["output"], list):
84
+ return None, "Invalid template: 'output' must be an array"
85
+
86
+ # Check that each input and output has required fields
87
+ for i, input_var in enumerate(template_spec["input"]):
88
+ if not all(k in input_var for k in ["name", "description", "type"]):
89
+ return (
90
+ None,
91
+ f"Invalid template: Input variable at index {i} is missing required fields",
92
+ )
93
+
94
+ for i, output_var in enumerate(template_spec["output"]):
95
+ if not all(k in output_var for k in ["name", "description", "type"]):
96
+ return (
97
+ None,
98
+ f"Invalid template: Output variable at index {i} is missing required fields",
99
+ )
100
+
101
+ return template_spec, None
102
+ else:
103
+ return None, "Uploaded file must be a JSON file"
104
+ except json.JSONDecodeError:
105
+ return None, "Invalid JSON format in the uploaded template file"
106
+ except Exception as e:
107
+ return None, f"Error parsing template file: {str(e)}"
108
+
109
+
110
+ # LLM call function
111
+ def call_llm(prompt, model="gpt-3.5-turbo"):
112
+ """Call the LLM API to generate text based on the prompt."""
113
+ try:
114
+ client = get_openai_client()
115
+ if not client:
116
+ st.error("Please provide an OpenAI API key in the sidebar.")
117
+ return "Error: No API key provided."
118
+
119
+ # Get output specifications from the template if available
120
+ output_specs = ""
121
+ if st.session_state.show_template_editor and st.session_state.template_spec:
122
+ output_vars = st.session_state.template_spec.get("output", [])
123
+ if output_vars:
124
+ output_specs = (
125
+ "Please generate output with the following specifications:\n"
126
+ )
127
+ for var in output_vars:
128
+ output_specs += (
129
+ f"- {var['name']}: {var['description']} (Type: {var['type']})"
130
+ )
131
+ if var.get("options"):
132
+ output_specs += f", Options: {var['options']}"
133
+ output_specs += "\n"
134
+
135
+ # Add the output specs to the prompt
136
+ prompt = f"{prompt}\n\n{output_specs}"
137
+
138
+ response = client.chat.completions.create(
139
+ model=model,
140
+ messages=[{"role": "user", "content": prompt}],
141
+ max_tokens=1000,
142
+ temperature=0.7,
143
+ )
144
+ return response.choices[0].message.content
145
+ except Exception as e:
146
+ st.error(f"Error calling LLM API: {str(e)}")
147
+ return f"Error: {str(e)}"
148
+
149
+
150
+ # Function to generate a template based on instructions and documents
151
+ def generate_template_from_instructions(instructions, document_content=""):
152
+ """
153
+ Use LLM to generate a template specification based on user instructions
154
+ and document content.
155
+ """
156
+ client = get_openai_client()
157
+ if not client:
158
+ st.error("Please provide an OpenAI API key to generate a template.")
159
+ return create_fallback_template(instructions)
160
+
161
+ # Prepare the prompt for the LLM
162
+ prompt = f"""
163
+ You are a template designer for an LLM-powered content generation system.
164
+ Create a template specification based on the following instructions:
165
+
166
+ INSTRUCTIONS:
167
+ {instructions}
168
+
169
+ {"DOCUMENT CONTENT (EXCERPT):" + document_content[:2000] + "..." if document_content else "NO DOCUMENTS PROVIDED"}
170
+
171
+ Generate a JSON template specification with the following structure:
172
+ {{
173
+ "name": "A descriptive name for the template",
174
+ "version": "1.0.0",
175
+ "description": "A brief description of what this template does",
176
+ "input": [
177
+ {{
178
+ "name": "variable_name",
179
+ "description": "What this variable represents",
180
+ "type": "string/int/float/bool/categorical",
181
+ "min": minimum_value_or_length,
182
+ "max": maximum_value_or_length,
183
+ "options": ["option1", "option2"] (only for categorical type)
184
+ }},
185
+ ... more input variables
186
+ ],
187
+ "output": [
188
+ {{
189
+ "name": "output_variable_name",
190
+ "description": "What this output represents",
191
+ "type": "string/int/float/bool/categorical"
192
+ }},
193
+ ... more output variables
194
+ ],
195
+ "prompt": "A template string with {{variable_name}} placeholders that will be replaced with actual values"
196
+ }}
197
+
198
+ Make sure the prompt includes all input variables and is designed to produce the expected outputs.
199
+ If a 'lore' or 'knowledge_base' should be incorporated, include {{lore}} in the prompt template.
200
+ If document content was provided, design the template to effectively use that information.
201
+ """
202
+
203
+ try:
204
+ # Call the LLM to generate the template
205
+ response = client.chat.completions.create(
206
+ model=st.session_state.model,
207
+ messages=[{"role": "user", "content": prompt}],
208
+ max_tokens=2000,
209
+ temperature=0.7,
210
+ )
211
+
212
+ template_text = response.choices[0].message.content
213
+
214
+ # Extract the JSON part from the response
215
+ json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*{[\s\S]*}\s*$"
216
+ json_match = re.search(json_pattern, template_text)
217
+
218
+ if json_match:
219
+ json_str = json_match.group(1) if json_match.group(1) else template_text
220
+ # Clean up any remaining markdown or comments
221
+ json_str = re.sub(r"```.*|```", "", json_str).strip()
222
+ template_spec = json.loads(json_str)
223
+ return template_spec
224
+ else:
225
+ # If no JSON format found, try to parse the entire response
226
+ try:
227
+ template_spec = json.loads(template_text)
228
+ return template_spec
229
+ except:
230
+ st.warning("LLM didn't return valid JSON. Using fallback template.")
231
+ return create_fallback_template(instructions)
232
+
233
+ except Exception as e:
234
+ st.error(f"Error generating template: {str(e)}")
235
+ return create_fallback_template(instructions)
236
+
237
+
238
+ # Add these functions after the generate_template_from_instructions function
239
+
240
+
241
+ def generate_improved_prompt_template(template_spec, knowledge_base=""):
242
+ """
243
+ Use LLM to generate an improved prompt template based on current template variables.
244
+ """
245
+ client = get_openai_client()
246
+ if not client:
247
+ st.error("Please provide an OpenAI API key to rewrite the prompt.")
248
+ return template_spec["prompt"]
249
+
250
+ # Extract template information for context
251
+ input_vars = template_spec["input"]
252
+ output_vars = template_spec["output"]
253
+ template_description = template_spec["description"]
254
+
255
+ # Format variable information for the prompt
256
+ input_vars_text = "\n".join(
257
+ [
258
+ f"- {var['name']}: {var['description']} (Type: {var['type']})"
259
+ + (f", Options: {var['options']}" if var.get("options") else "")
260
+ for var in input_vars
261
+ ]
262
+ )
263
+
264
+ output_vars_text = "\n".join(
265
+ [
266
+ f"- {var['name']}: {var['description']} (Type: {var['type']})"
267
+ for var in output_vars
268
+ ]
269
+ )
270
+
271
+ # Prepare the prompt for the LLM
272
+ prompt = f"""
273
+ You are an expert at designing effective prompts for LLMs. Rewrite the prompt template based on the following details:
274
+
275
+ TEMPLATE PURPOSE:
276
+ {template_description}
277
+
278
+ INPUT VARIABLES:
279
+ {input_vars_text}
280
+
281
+ OUTPUT VARIABLES:
282
+ {output_vars_text}
283
+
284
+ {"KNOWLEDGE BASE AVAILABLE:" if knowledge_base else "NO KNOWLEDGE BASE AVAILABLE."}
285
+ {knowledge_base[:500] + "..." if len(knowledge_base) > 500 else knowledge_base if knowledge_base else ""}
286
+
287
+ Current prompt template:
288
+ {template_spec["prompt"]}
289
+
290
+ Please create an improved prompt template that:
291
+ 1. Uses all input variables (in curly braces like {{variable_name}})
292
+ 2. Is designed to generate the specified outputs
293
+ 3. Includes {{lore}} where background information or context should be inserted
294
+ 4. Is clear, specific, and well-structured
295
+ 5. Provides enough guidance to the LLM to generate high-quality results
296
+
297
+ Return ONLY the revised prompt template text, with no additional explanations.
298
+ """
299
+
300
+ try:
301
+ # Call the LLM to generate the improved prompt template
302
+ response = client.chat.completions.create(
303
+ model=st.session_state.model,
304
+ messages=[{"role": "user", "content": prompt}],
305
+ max_tokens=1000,
306
+ temperature=0.7,
307
+ )
308
+
309
+ improved_template = response.choices[0].message.content.strip()
310
+
311
+ # Remove any markdown code block formatting if present
312
+ improved_template = re.sub(r"```.*\n|```", "", improved_template)
313
+
314
+ return improved_template
315
+ except Exception as e:
316
+ st.error(f"Error generating improved prompt: {str(e)}")
317
+ return template_spec["prompt"]
318
+
319
+
320
+ # Fallback template if generation fails
321
+ def create_fallback_template(instructions=""):
322
+ """Create a basic template to use as fallback."""
323
+ return {
324
+ "name": "Generated Template",
325
+ "version": "1.0.0",
326
+ "description": instructions,
327
+ "input": [
328
+ {
329
+ "name": "input_1",
330
+ "description": "First input variable",
331
+ "type": "string",
332
+ "min": 1,
333
+ "max": 100,
334
+ }
335
+ ],
336
+ "output": [
337
+ {
338
+ "name": "output_1",
339
+ "description": "Generated output",
340
+ "type": "string",
341
+ "min": 10,
342
+ "max": 1000,
343
+ }
344
+ ],
345
+ "prompt": "Based on the following information:\n{input_1}\n\nAnd considering this additional context:\n{lore}\n\nGenerate the following output.",
346
+ }
347
+
348
+
349
+ def generate_synthetic_inputs(template_spec, num_samples=1, max_retries=3):
350
+ """Generate synthetic input data based on template specifications with retry logic."""
351
+ client = get_openai_client()
352
+ if not client:
353
+ st.error("Please provide an OpenAI API key to generate synthetic data.")
354
+ return []
355
+
356
+ input_vars = template_spec["input"]
357
+
358
+ # Format variable information for the prompt
359
+ input_vars_text = "\n".join(
360
+ [
361
+ f"- {var['name']}: {var['description']} (Type: {var['type']})"
362
+ + (
363
+ f", Min: {var.get('min', 'N/A')}, Max: {var.get('max', 'N/A')}"
364
+ if var["type"] in ["string", "int", "float"]
365
+ else ""
366
+ )
367
+ + (f", Options: {var['options']}" if var.get("options") else "")
368
+ for var in input_vars
369
+ ]
370
+ )
371
+
372
+ prompt = f"""
373
+ You are a synthetic data generator. Generate {num_samples} realistic sample(s) for the following input variables:
374
+
375
+ {input_vars_text}
376
+
377
+ Return the data as a JSON array of objects, where each object contains values for all input variables.
378
+ Each object should follow this structure:
379
+ {{
380
+ "variable_name_1": value1,
381
+ "variable_name_2": value2,
382
+ ...
383
+ }}
384
+
385
+ Make sure to:
386
+ 1. Use appropriate data types (strings, numbers, booleans)
387
+ 2. Stay within min/max constraints
388
+ 3. Only use provided options for categorical variables
389
+ 4. Generate realistic and diverse values
390
+ 5. Return ONLY the JSON array with no additional text or explanation
391
+ 6. The response must be valid JSON that can be parsed directly
392
+ """
393
+
394
+ for attempt in range(max_retries):
395
+ try:
396
+ response = client.chat.completions.create(
397
+ model=st.session_state.model,
398
+ messages=[{"role": "user", "content": prompt}],
399
+ max_tokens=2000,
400
+ temperature=0.8,
401
+ )
402
+
403
+ result = response.choices[0].message.content.strip()
404
+
405
+ # Extract JSON from the response
406
+ json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\[[\s\S]*\]\s*$"
407
+ json_match = re.search(json_pattern, result)
408
+
409
+ if json_match:
410
+ json_str = json_match.group(1) if json_match.group(1) else result
411
+ # Clean up any remaining markdown or comments
412
+ json_str = re.sub(r"```.*|```", "", json_str).strip()
413
+ try:
414
+ synthetic_inputs = json.loads(json_str)
415
+ # Validate that we got a list of dictionaries
416
+ if isinstance(synthetic_inputs, list) and all(
417
+ isinstance(item, dict) for item in synthetic_inputs
418
+ ):
419
+ return synthetic_inputs
420
+ else:
421
+ st.warning(
422
+ f"Attempt {attempt+1}: Generated data is not in the expected format. Retrying..."
423
+ )
424
+ continue
425
+ except json.JSONDecodeError:
426
+ st.warning(
427
+ f"Attempt {attempt+1}: Failed to parse JSON. Retrying..."
428
+ )
429
+ continue
430
+ else:
431
+ # Try to parse the entire response as JSON
432
+ try:
433
+ synthetic_inputs = json.loads(result)
434
+ # Validate that we got a list of dictionaries
435
+ if isinstance(synthetic_inputs, list) and all(
436
+ isinstance(item, dict) for item in synthetic_inputs
437
+ ):
438
+ return synthetic_inputs
439
+ else:
440
+ st.warning(
441
+ f"Attempt {attempt+1}: Generated data is not in the expected format. Retrying..."
442
+ )
443
+ continue
444
+ except json.JSONDecodeError:
445
+ st.warning(
446
+ f"Attempt {attempt+1}: Failed to parse JSON. Retrying..."
447
+ )
448
+ continue
449
+
450
+ except Exception as e:
451
+ st.warning(
452
+ f"Attempt {attempt+1}: Error generating synthetic inputs: {str(e)}. Retrying..."
453
+ )
454
+ if attempt == max_retries - 1:
455
+ st.error(
456
+ f"Failed to generate synthetic inputs after {max_retries} attempts: {str(e)}"
457
+ )
458
+ return []
459
+
460
+ st.error(f"Failed to generate valid synthetic inputs after {max_retries} attempts.")
461
+ return []
462
+
463
+
464
+ def generate_synthetic_outputs(
465
+ template_spec, input_data, knowledge_base="", max_retries=3
466
+ ):
467
+ """Generate synthetic output data based on template and input data with retry logic."""
468
+ client = get_openai_client()
469
+ if not client:
470
+ st.error("Please provide an OpenAI API key to generate synthetic outputs.")
471
+ return []
472
+
473
+ output_vars = template_spec["output"]
474
+ prompt_template = template_spec["prompt"]
475
+
476
+ # Format output variable information for the prompt
477
+ output_vars_text = "\n".join(
478
+ [
479
+ f"- {var['name']}: {var['description']} (Type: {var['type']}) {'Options: '+str(var['options']) if var.get('options') else ''}"
480
+ for var in output_vars
481
+ ]
482
+ )
483
+
484
+ results = []
485
+
486
+ # Create a progress bar
487
+ progress_bar = st.progress(0)
488
+
489
+ try:
490
+ for i, input_item in enumerate(input_data):
491
+ # Fill the prompt template with input values
492
+ filled_prompt = prompt_template
493
+ for var_name, var_value in input_item.items():
494
+ filled_prompt = filled_prompt.replace(f"{{{var_name}}}", str(var_value))
495
+
496
+ # Replace {lore} with knowledge base if present
497
+ if "{lore}" in filled_prompt:
498
+ filled_prompt = filled_prompt.replace("{lore}", knowledge_base)
499
+
500
+ # Create a prompt for generating synthetic output
501
+ generation_prompt = f"""
502
+ You are generating synthetic output data based on the following input:
503
+
504
+ INPUT DATA:
505
+ {json.dumps(input_item, indent=2)}
506
+
507
+ PROMPT USED:
508
+ {filled_prompt}
509
+
510
+ REQUIRED OUTPUT VARIABLES:
511
+ {output_vars_text}
512
+
513
+ Generate realistic output data for these variables. Return ONLY a JSON object with the output variables:
514
+ {{
515
+ "output_variable_1": value1,
516
+ "output_variable_2": value2,
517
+ ...
518
+ }}
519
+
520
+ Use appropriate data types for each variable. Return ONLY the JSON object with no additional text or explanation.
521
+ The response must be valid JSON that can be parsed directly.
522
+ """
523
+
524
+ output_data = None
525
+ for attempt in range(max_retries):
526
+ try:
527
+ response = client.chat.completions.create(
528
+ model=st.session_state.model,
529
+ messages=[{"role": "user", "content": generation_prompt}],
530
+ max_tokens=2000,
531
+ temperature=0.7,
532
+ )
533
+
534
+ result = response.choices[0].message.content.strip()
535
+
536
+ # Extract JSON from the response
537
+ json_pattern = r"```json\s*([\s\S]*?)\s*```|^\s*\{[\s\S]*\}\s*$"
538
+ json_match = re.search(json_pattern, result)
539
+
540
+ if json_match:
541
+ json_str = (
542
+ json_match.group(1) if json_match.group(1) else result
543
+ )
544
+ # Clean up any remaining markdown or comments
545
+ json_str = re.sub(r"```.*|```", "", json_str).strip()
546
+ try:
547
+ output_data = json.loads(json_str)
548
+ # Validate that we got a dictionary
549
+ if isinstance(output_data, dict):
550
+ # Check if all required output variables are present
551
+ required_vars = [var["name"] for var in output_vars]
552
+ if all(var in output_data for var in required_vars):
553
+ break # Valid output, exit retry loop
554
+ else:
555
+ missing_vars = [
556
+ var
557
+ for var in required_vars
558
+ if var not in output_data
559
+ ]
560
+ st.warning(
561
+ f"Attempt {attempt+1} for input {i+1}: Missing output variables: {missing_vars}. Retrying..."
562
+ )
563
+ else:
564
+ st.warning(
565
+ f"Attempt {attempt+1} for input {i+1}: Generated output is not a dictionary. Retrying..."
566
+ )
567
+ except json.JSONDecodeError:
568
+ st.warning(
569
+ f"Attempt {attempt+1} for input {i+1}: Failed to parse JSON. Retrying..."
570
+ )
571
+ else:
572
+ # Try to parse the entire response as JSON
573
+ try:
574
+ output_data = json.loads(result)
575
+ # Validate that we got a dictionary
576
+ if isinstance(output_data, dict):
577
+ # Check if all required output variables are present
578
+ required_vars = [var["name"] for var in output_vars]
579
+ if all(var in output_data for var in required_vars):
580
+ break # Valid output, exit retry loop
581
+ else:
582
+ missing_vars = [
583
+ var
584
+ for var in required_vars
585
+ if var not in output_data
586
+ ]
587
+ st.warning(
588
+ f"Attempt {attempt+1} for input {i+1}: Missing output variables: {missing_vars}. Retrying..."
589
+ )
590
+ else:
591
+ st.warning(
592
+ f"Attempt {attempt+1} for input {i+1}: Generated output is not a dictionary. Retrying..."
593
+ )
594
+ except json.JSONDecodeError:
595
+ st.warning(
596
+ f"Attempt {attempt+1} for input {i+1}: Failed to parse JSON. Retrying..."
597
+ )
598
+
599
+ except Exception as e:
600
+ st.warning(
601
+ f"Attempt {attempt+1} for input {i+1}: Error generating output: {str(e)}. Retrying..."
602
+ )
603
+
604
+ # If we've reached the max retries, log the error
605
+ if attempt == max_retries - 1:
606
+ st.error(
607
+ f"Failed to generate valid output for input {i+1} after {max_retries} attempts."
608
+ )
609
+ output_data = {
610
+ "error": f"Failed to generate valid output after {max_retries} attempts"
611
+ }
612
+
613
+ # Combine input and output data
614
+ if output_data:
615
+ combined_data = {**input_item, **output_data}
616
+ results.append(combined_data)
617
+ else:
618
+ results.append({**input_item, "error": "Failed to generate output"})
619
+
620
+ # Update progress bar
621
+ progress_bar.progress((i + 1) / len(input_data))
622
+
623
+ finally:
624
+ # Ensure progress bar reaches 100% when done
625
+ if len(input_data) > 0:
626
+ progress_bar.progress(1.0)
627
+
628
+ return results
629
+
630
+
631
+ # Initialize session state
632
+ if "template_spec" not in st.session_state:
633
+ st.session_state.template_spec = None
634
+ if "knowledge_base" not in st.session_state:
635
+ st.session_state.knowledge_base = ""
636
+ if "show_template_editor" not in st.session_state:
637
+ st.session_state.show_template_editor = False
638
+ if "user_inputs" not in st.session_state:
639
+ st.session_state.user_inputs = {}
640
+ if "generated_output" not in st.session_state:
641
+ st.session_state.generated_output = ""
642
+
643
+ # Sidebar setup
644
+ with st.sidebar:
645
+ st.title("Template Generator")
646
+ st.write("Create templates for generating content with LLMs.")
647
+
648
+ # API Key input
649
+ api_key = st.text_input("OpenAI API Key", type="password")
650
+ if api_key:
651
+ st.session_state.api_key = api_key
652
+
653
+ # Model selection
654
+ st.session_state.model = st.selectbox(
655
+ "Select LLM Model",
656
+ options=["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4-turbo"],
657
+ index=0,
658
+ )
659
+
660
+ # Main application layout
661
+ st.title("Template Generator")
662
+
663
+ # Create tabs for workflow
664
+ tab1, tab2, tab3, tab4 = st.tabs(
665
+ ["Setup", "Edit Template", "Use Template", "Generate Data"]
666
+ )
667
+
668
+ with tab1:
669
+ st.header("Project Setup")
670
+
671
+ # Add option to either upload a template or create a new one
672
+ setup_option = st.radio(
673
+ "Choose how to start your project",
674
+ options=["Upload existing template", "Create new template from documents"],
675
+ index=1,
676
+ )
677
+
678
+ if setup_option == "Upload existing template":
679
+ st.subheader("Upload Template File")
680
+ uploaded_template = st.file_uploader(
681
+ "Upload a template JSON file",
682
+ type=["json"],
683
+ help="Upload a previously created template file (.json)",
684
+ )
685
+
686
+ if uploaded_template:
687
+ template_spec, error = parse_template_file(uploaded_template)
688
+ if error:
689
+ st.error(error)
690
+ else:
691
+ st.success(f"Successfully loaded template: {template_spec['name']}")
692
+
693
+ # Show template preview
694
+ with st.expander("Template Preview", expanded=True):
695
+ st.json(template_spec)
696
+
697
+ # Button to use this template
698
+ if st.button("Use This Template"):
699
+ st.session_state.template_spec = template_spec
700
+ st.session_state.show_template_editor = True
701
+ st.success(
702
+ "Template loaded! Go to the 'Edit Template' tab to customize it."
703
+ )
704
+
705
+ if (
706
+ setup_option == "Create new template from documents"
707
+ or setup_option == "Upload existing template"
708
+ and not uploaded_template
709
+ ):
710
+ # Step 1: Upload Knowledge Base (existing code)
711
+ st.subheader("Step 1: Upload Knowledge Base")
712
+ uploaded_files = st.file_uploader(
713
+ "Upload documents to use as knowledge base",
714
+ accept_multiple_files=True,
715
+ type=["pdf", "txt"],
716
+ )
717
+
718
+ # Rest of your existing code for document processing...
719
+ if uploaded_files:
720
+ # Track filenames for UI feedback
721
+ st.session_state.uploaded_filenames = [file.name for file in uploaded_files]
722
+
723
+ with st.spinner("Processing documents..."):
724
+ st.session_state.knowledge_base = parse_documents(uploaded_files)
725
+ st.success(f"Processed {len(uploaded_files)} documents")
726
+
727
+ with st.expander("Preview extracted content"):
728
+ st.text_area(
729
+ "Extracted Text",
730
+ value=st.session_state.knowledge_base[:10000]
731
+ + ("..." if len(st.session_state.knowledge_base) > 1000 else ""),
732
+ height=200,
733
+ disabled=True,
734
+ )
735
+
736
+ # Step 2: Provide Instructions
737
+ st.subheader("Step 2: Provide Instructions")
738
+ instructions = st.text_area(
739
+ "Describe what you want to create",
740
+ placeholder="Describe what you want to create (e.g., 'Create a character background generator with name, faction, and race as inputs...')",
741
+ height=150,
742
+ )
743
+
744
+ # Generate Template button
745
+ if st.button("Generate Template"):
746
+ if not st.session_state.get("api_key"):
747
+ st.error(
748
+ "Please provide an OpenAI API key in the sidebar before generating a template."
749
+ )
750
+ elif instructions:
751
+ with st.spinner("Analyzing instructions and generating template..."):
752
+ # Generate template based on instructions and document content
753
+ st.session_state.template_spec = (
754
+ generate_template_from_instructions(
755
+ instructions, st.session_state.knowledge_base
756
+ )
757
+ )
758
+ st.session_state.show_template_editor = True
759
+ st.success(
760
+ "Template generated! Go to the 'Edit Template' tab to customize it."
761
+ )
762
+ else:
763
+ st.warning("Please provide instructions first")
764
+
765
+ with tab2:
766
+ if st.session_state.show_template_editor and st.session_state.template_spec:
767
+ st.header("Template Editor")
768
+
769
+ # Basic template information
770
+ with st.expander("Template Information", expanded=True):
771
+ col1, col2 = st.columns(2)
772
+ with col1:
773
+ st.session_state.template_spec["name"] = st.text_input(
774
+ "Template Name", value=st.session_state.template_spec["name"]
775
+ )
776
+ with col2:
777
+ st.session_state.template_spec["version"] = st.text_input(
778
+ "Version", value=st.session_state.template_spec["version"]
779
+ )
780
+
781
+ st.session_state.template_spec["description"] = st.text_area(
782
+ "Description",
783
+ value=st.session_state.template_spec["description"],
784
+ height=100,
785
+ )
786
+
787
+ # Prompt Template Section
788
+ with st.expander("Prompt Template", expanded=True):
789
+ st.info("Use {variable_name} to refer to input variables in your template")
790
+
791
+ # Add buttons for prompt management
792
+ col1, col2 = st.columns([1, 1])
793
+ with col1:
794
+ rewrite_prompt = st.button("AI Rewrite Prompt")
795
+ with col2:
796
+ reroll_prompt = st.button("Reroll Prompt Variation")
797
+
798
+ # Handle prompt rewriting
799
+ if rewrite_prompt or reroll_prompt:
800
+ with st.spinner("Generating improved prompt template..."):
801
+ improved_template = generate_improved_prompt_template(
802
+ st.session_state.template_spec, st.session_state.knowledge_base
803
+ )
804
+ # Only update if we got a valid result back
805
+ if improved_template and len(improved_template) > 10:
806
+ st.session_state.template_spec["prompt"] = improved_template
807
+ st.success("Prompt template updated!")
808
+
809
+ # Display the prompt template
810
+ prompt_template = st.text_area(
811
+ "Edit the prompt template",
812
+ value=st.session_state.template_spec["prompt"],
813
+ height=200,
814
+ )
815
+ st.session_state.template_spec["prompt"] = prompt_template
816
+
817
+ # Input Variables Editor
818
+ with st.expander("Input Variables", expanded=True):
819
+ st.subheader("Input Variables")
820
+
821
+ # Add input variable button
822
+ if st.button("Add Input Variable"):
823
+ new_var = {
824
+ "name": f"new_input_{len(st.session_state.template_spec['input']) + 1}",
825
+ "description": "New input variable",
826
+ "type": "string",
827
+ "min": 0,
828
+ "max": 100,
829
+ }
830
+ st.session_state.template_spec["input"].append(new_var)
831
+ st.rerun()
832
+
833
+ # Display input variables
834
+ for i, input_var in enumerate(st.session_state.template_spec["input"]):
835
+ with st.container():
836
+ st.markdown(f"##### {input_var['name']}")
837
+
838
+ col1, col2, col3 = st.columns([2, 2, 1])
839
+
840
+ with col1:
841
+ input_var["name"] = st.text_input(
842
+ "Name", value=input_var["name"], key=f"input_name_{i}"
843
+ )
844
+
845
+ input_var["description"] = st.text_input(
846
+ "Description",
847
+ value=input_var["description"],
848
+ key=f"input_desc_{i}",
849
+ )
850
+
851
+ with col2:
852
+ var_type = st.selectbox(
853
+ "Type",
854
+ options=["string", "int", "float", "bool", "categorical"],
855
+ index=[
856
+ "string",
857
+ "int",
858
+ "float",
859
+ "bool",
860
+ "categorical",
861
+ ].index(input_var["type"]),
862
+ key=f"input_type_{i}",
863
+ )
864
+ input_var["type"] = var_type
865
+
866
+ if var_type in ["string", "int", "float"]:
867
+ col_min, col_max = st.columns(2)
868
+ with col_min:
869
+ input_var["min"] = st.number_input(
870
+ "Min",
871
+ value=int(input_var.get("min", 0)),
872
+ key=f"input_min_{i}",
873
+ )
874
+ with col_max:
875
+ input_var["max"] = st.number_input(
876
+ "Max",
877
+ value=int(input_var.get("max", 100)),
878
+ key=f"input_max_{i}",
879
+ )
880
+
881
+ if var_type == "categorical":
882
+ options = input_var.get("options", [])
883
+ options_str = st.text_area(
884
+ "Options (one per line)",
885
+ value="\n".join(options),
886
+ key=f"input_options_{i}",
887
+ )
888
+ input_var["options"] = [
889
+ opt.strip()
890
+ for opt in options_str.split("\n")
891
+ if opt.strip()
892
+ ]
893
+
894
+ with col3:
895
+ if st.button("Remove", key=f"remove_input_{i}"):
896
+ st.session_state.template_spec["input"].pop(i)
897
+ st.rerun()
898
+
899
+ st.divider()
900
+
901
+ # Output Variables Editor
902
+ with st.expander("Output Variables", expanded=True):
903
+ st.subheader("Output Variables")
904
+
905
+ # Add output variable button
906
+ if st.button("Add Output Variable"):
907
+ new_var = {
908
+ "name": f"new_output_{len(st.session_state.template_spec['output']) + 1}",
909
+ "description": "New output variable",
910
+ "type": "string",
911
+ "min": 0,
912
+ "max": 100,
913
+ }
914
+ st.session_state.template_spec["output"].append(new_var)
915
+ st.rerun()
916
+
917
+ # Display output variables
918
+ for i, output_var in enumerate(st.session_state.template_spec["output"]):
919
+ with st.container():
920
+ st.markdown(f"##### {output_var['name']}")
921
+
922
+ col1, col2, col3 = st.columns([2, 2, 1])
923
+
924
+ with col1:
925
+ output_var["name"] = st.text_input(
926
+ "Name", value=output_var["name"], key=f"output_name_{i}"
927
+ )
928
+
929
+ output_var["description"] = st.text_input(
930
+ "Description",
931
+ value=output_var["description"],
932
+ key=f"output_desc_{i}",
933
+ )
934
+
935
+ with col2:
936
+ var_type = st.selectbox(
937
+ "Type",
938
+ options=["string", "int", "float", "bool", "categorical"],
939
+ index=[
940
+ "string",
941
+ "int",
942
+ "float",
943
+ "bool",
944
+ "categorical",
945
+ ].index(output_var["type"]),
946
+ key=f"output_type_{i}",
947
+ )
948
+ output_var["type"] = var_type
949
+
950
+ if var_type in ["string", "int", "float"]:
951
+ col_min, col_max = st.columns(2)
952
+ with col_min:
953
+ output_var["min"] = st.number_input(
954
+ "Min",
955
+ value=int(output_var.get("min", 0)),
956
+ key=f"output_min_{i}",
957
+ )
958
+ with col_max:
959
+ output_var["max"] = st.number_input(
960
+ "Max",
961
+ value=int(output_var.get("max", 100)),
962
+ key=f"output_max_{i}",
963
+ )
964
+
965
+ if var_type == "categorical":
966
+ options = output_var.get("options", [])
967
+ options_str = st.text_area(
968
+ "Options (one per line)",
969
+ value="\n".join(options),
970
+ key=f"output_options_{i}",
971
+ )
972
+ output_var["options"] = [
973
+ opt.strip()
974
+ for opt in options_str.split("\n")
975
+ if opt.strip()
976
+ ]
977
+
978
+ with col3:
979
+ if st.button("Remove", key=f"remove_output_{i}"):
980
+ st.session_state.template_spec["output"].pop(i)
981
+ st.rerun()
982
+
983
+ st.divider()
984
+
985
+ # Template Specification and Download Section
986
+ with st.expander("Template JSON", expanded=False):
987
+ st.json(st.session_state.template_spec)
988
+
989
+ # Download button
990
+ template_json = json.dumps(st.session_state.template_spec, indent=2)
991
+ st.download_button(
992
+ label="Download Template JSON",
993
+ data=template_json,
994
+ file_name="template_spec.json",
995
+ mime="application/json",
996
+ )
997
+ else:
998
+ st.info(
999
+ "No template has been generated yet. Go to the 'Setup' tab to create one."
1000
+ )
1001
+
1002
+ with tab3:
1003
+ if st.session_state.show_template_editor and st.session_state.template_spec:
1004
+ st.header("Use Template")
1005
+
1006
+ # Reset user inputs when template changes
1007
+ if (
1008
+ "last_template" not in st.session_state
1009
+ or st.session_state.last_template != st.session_state.template_spec
1010
+ ):
1011
+ st.session_state.user_inputs = {}
1012
+ st.session_state.last_template = st.session_state.template_spec
1013
+
1014
+ # Create input fields based on the template specification
1015
+ for input_var in st.session_state.template_spec["input"]:
1016
+ var_name = input_var["name"]
1017
+ var_type = input_var["type"]
1018
+ var_desc = input_var["description"]
1019
+
1020
+ st.markdown(f"##### {var_desc}")
1021
+
1022
+ if var_type == "string":
1023
+ st.session_state.user_inputs[var_name] = st.text_input(
1024
+ f"Enter value for {var_name}", key=f"use_{var_name}"
1025
+ )
1026
+
1027
+ elif var_type == "int":
1028
+ st.session_state.user_inputs[var_name] = st.number_input(
1029
+ f"Enter value for {var_name}",
1030
+ min_value=input_var.get("min", None),
1031
+ max_value=input_var.get("max", None),
1032
+ step=1,
1033
+ key=f"use_{var_name}",
1034
+ )
1035
+
1036
+ elif var_type == "float":
1037
+ st.session_state.user_inputs[var_name] = st.number_input(
1038
+ f"Enter value for {var_name}",
1039
+ min_value=float(input_var.get("min", 0)),
1040
+ max_value=float(input_var.get("max", 100)),
1041
+ key=f"use_{var_name}",
1042
+ )
1043
+
1044
+ elif var_type == "bool":
1045
+ st.session_state.user_inputs[var_name] = st.checkbox(
1046
+ f"Select value for {var_name}", key=f"use_{var_name}"
1047
+ )
1048
+
1049
+ elif var_type == "categorical":
1050
+ options = input_var.get("options", [])
1051
+ if options:
1052
+ st.session_state.user_inputs[var_name] = st.selectbox(
1053
+ f"Select value for {var_name}",
1054
+ options=options,
1055
+ key=f"use_{var_name}",
1056
+ )
1057
+ else:
1058
+ st.warning(f"No options defined for {var_name}")
1059
+
1060
+ # Handle the lore/knowledge base as a special variable
1061
+ prompt_template = st.session_state.template_spec["prompt"]
1062
+ if "{lore}" in prompt_template:
1063
+ st.markdown("##### Document Knowledge Base")
1064
+
1065
+ # Display info about the knowledge base
1066
+ if st.session_state.knowledge_base:
1067
+ st.success(
1068
+ f"Using content from {len(st.session_state.uploaded_filenames) if 'uploaded_filenames' in st.session_state else 'uploaded'} documents as knowledge base"
1069
+ )
1070
+
1071
+ with st.expander("View knowledge base content"):
1072
+ st.text_area(
1073
+ "Knowledge base content",
1074
+ value=st.session_state.knowledge_base[:2000]
1075
+ + (
1076
+ "..." if len(st.session_state.knowledge_base) > 2000 else ""
1077
+ ),
1078
+ height=200,
1079
+ disabled=True,
1080
+ )
1081
+
1082
+ # Add option to edit if needed
1083
+ use_edited_lore = st.checkbox("Edit knowledge base content")
1084
+ if use_edited_lore:
1085
+ st.session_state.user_inputs["lore"] = st.text_area(
1086
+ "Edit knowledge base for this generation",
1087
+ value=st.session_state.knowledge_base,
1088
+ height=300,
1089
+ )
1090
+ else:
1091
+ st.session_state.user_inputs["lore"] = (
1092
+ st.session_state.knowledge_base
1093
+ )
1094
+ else:
1095
+ st.warning("No documents uploaded. You can provide custom lore below.")
1096
+ st.session_state.user_inputs["lore"] = st.text_area(
1097
+ "Enter background information or context",
1098
+ placeholder="Enter custom lore or background information here...",
1099
+ height=150,
1100
+ )
1101
+
1102
+ # Generate Output button
1103
+ if st.button("Generate Output", key="generate_button"):
1104
+ # Check if API key is provided
1105
+ if not st.session_state.get("api_key"):
1106
+ st.error(
1107
+ "Please provide an OpenAI API key in the sidebar before generating output."
1108
+ )
1109
+ else:
1110
+ # Fill the prompt template with user-provided values
1111
+ filled_prompt = prompt_template
1112
+ for var_name, var_value in st.session_state.user_inputs.items():
1113
+ filled_prompt = filled_prompt.replace(
1114
+ f"{{{var_name}}}", str(var_value)
1115
+ )
1116
+
1117
+ # Show the filled prompt
1118
+ with st.expander("View populated prompt"):
1119
+ st.text_area(
1120
+ "Prompt sent to LLM",
1121
+ value=filled_prompt,
1122
+ height=200,
1123
+ disabled=True,
1124
+ )
1125
+
1126
+ # Call LLM with the filled prompt
1127
+ with st.spinner("Generating output..."):
1128
+ model_selected = st.session_state.model
1129
+ generated_output = call_llm(filled_prompt, model=model_selected)
1130
+ st.session_state.generated_output = generated_output
1131
+
1132
+ # Display generated output
1133
+ if st.session_state.generated_output:
1134
+ st.header("Generated Output")
1135
+ st.markdown("### Result")
1136
+ st.write(st.session_state.generated_output)
1137
+
1138
+ # Option to save the output
1139
+ st.download_button(
1140
+ label="Download Output",
1141
+ data=st.session_state.generated_output,
1142
+ file_name="generated_output.txt",
1143
+ mime="text/plain",
1144
+ )
1145
+ else:
1146
+ st.info(
1147
+ "No template has been generated yet. Go to the 'Setup' tab to create one."
1148
+ )
1149
+
1150
+ with tab4:
1151
+ if st.session_state.show_template_editor and st.session_state.template_spec:
1152
+ st.header("Generate Synthetic Data")
1153
+
1154
+ with st.expander("Template Information", expanded=False):
1155
+ st.json(st.session_state.template_spec)
1156
+
1157
+ # Data generation controls
1158
+ st.subheader("Generation Settings")
1159
+
1160
+ col1, col2 = st.columns(2)
1161
+ with col1:
1162
+ num_samples = st.number_input(
1163
+ "Number of samples to generate", min_value=1, max_value=100, value=5
1164
+ )
1165
+ with col2:
1166
+ temperature = st.slider(
1167
+ "Temperature (creativity)",
1168
+ min_value=0.1,
1169
+ max_value=1.0,
1170
+ value=0.7,
1171
+ step=0.1,
1172
+ )
1173
+ st.session_state.temperature = temperature
1174
+
1175
+ # Initialize containers for generated data
1176
+ if "synthetic_inputs" not in st.session_state:
1177
+ st.session_state.synthetic_inputs = []
1178
+ if "synthetic_outputs" not in st.session_state:
1179
+ st.session_state.synthetic_outputs = []
1180
+ if "combined_data" not in st.session_state:
1181
+ st.session_state.combined_data = []
1182
+
1183
+ # Generate inputs button
1184
+ if st.button("Generate Synthetic Inputs"):
1185
+ if not st.session_state.get("api_key"):
1186
+ st.error("Please provide an OpenAI API key in the sidebar.")
1187
+ else:
1188
+ with st.spinner(f"Generating {num_samples} synthetic input samples..."):
1189
+ st.session_state.synthetic_inputs = generate_synthetic_inputs(
1190
+ st.session_state.template_spec, num_samples=num_samples
1191
+ )
1192
+
1193
+ if st.session_state.synthetic_inputs:
1194
+ st.success(
1195
+ f"Generated {len(st.session_state.synthetic_inputs)} input samples"
1196
+ )
1197
+
1198
+ # Display generated inputs if available
1199
+ if st.session_state.synthetic_inputs:
1200
+ st.subheader("Generated Input Data")
1201
+
1202
+ # Show data in a table
1203
+ input_df = pd.DataFrame(st.session_state.synthetic_inputs)
1204
+ st.dataframe(input_df)
1205
+
1206
+ # Download button for inputs
1207
+ input_csv = input_df.to_csv(index=False)
1208
+ st.download_button(
1209
+ label="Download Input Data (CSV)",
1210
+ data=input_csv,
1211
+ file_name="synthetic_inputs.csv",
1212
+ mime="text/csv",
1213
+ )
1214
+
1215
+ # Generate outputs button
1216
+ if st.button("Generate Outputs for These Inputs"):
1217
+ if not st.session_state.get("api_key"):
1218
+ st.error("Please provide an OpenAI API key in the sidebar.")
1219
+ else:
1220
+ with st.spinner("Generating outputs for each input..."):
1221
+ st.session_state.combined_data = generate_synthetic_outputs(
1222
+ st.session_state.template_spec,
1223
+ st.session_state.synthetic_inputs,
1224
+ st.session_state.knowledge_base,
1225
+ )
1226
+
1227
+ if st.session_state.combined_data:
1228
+ st.success(
1229
+ f"Generated outputs for {len(st.session_state.combined_data)} inputs"
1230
+ )
1231
+
1232
+ # Display combined data if available
1233
+ if st.session_state.combined_data:
1234
+ st.subheader("Complete Dataset (Inputs + Outputs)")
1235
+
1236
+ # Show data in a table
1237
+ combined_df = pd.DataFrame(st.session_state.combined_data)
1238
+ st.dataframe(combined_df)
1239
+
1240
+ # Download buttons for different formats
1241
+ col1, col2 = st.columns(2)
1242
+
1243
+ with col1:
1244
+ # CSV download
1245
+ combined_csv = combined_df.to_csv(index=False)
1246
+ st.download_button(
1247
+ label="Download Complete Dataset (CSV)",
1248
+ data=combined_csv,
1249
+ file_name="synthetic_dataset.csv",
1250
+ mime="text/csv",
1251
+ )
1252
+
1253
+ with col2:
1254
+ # JSON download
1255
+ combined_json = json.dumps(st.session_state.combined_data, indent=2)
1256
+ st.download_button(
1257
+ label="Download Complete Dataset (JSON)",
1258
+ data=combined_json,
1259
+ file_name="synthetic_dataset.json",
1260
+ mime="application/json",
1261
+ )
1262
+ else:
1263
+ st.info(
1264
+ "No template has been generated yet. Go to the 'Setup' tab to create one."
1265
+ )