Suhasdev commited on
Commit
cacd4d0
·
0 Parent(s):

Deploy Universal Prompt Optimizer to HF Spaces (clean)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +27 -0
  3. README.md +44 -0
  4. app.py +1563 -0
  5. requirements.txt +23 -0
  6. src/gepa_optimizer.egg-info/PKG-INFO +439 -0
  7. src/gepa_optimizer.egg-info/SOURCES.txt +65 -0
  8. src/gepa_optimizer.egg-info/dependency_links.txt +1 -0
  9. src/gepa_optimizer.egg-info/entry_points.txt +2 -0
  10. src/gepa_optimizer.egg-info/requires.txt +29 -0
  11. src/gepa_optimizer.egg-info/top_level.txt +1 -0
  12. src/gepa_optimizer/__init__.py +295 -0
  13. src/gepa_optimizer/cli.py +239 -0
  14. src/gepa_optimizer/core/__init__.py +8 -0
  15. src/gepa_optimizer/core/base_adapter.py +85 -0
  16. src/gepa_optimizer/core/custom_adapter.py +389 -0
  17. src/gepa_optimizer/core/optimizer.py +1279 -0
  18. src/gepa_optimizer/core/result.py +180 -0
  19. src/gepa_optimizer/core/universal_adapter.py +0 -0
  20. src/gepa_optimizer/data/__init__.py +27 -0
  21. src/gepa_optimizer/data/converters.py +265 -0
  22. src/gepa_optimizer/data/index_caching_loader.py +278 -0
  23. src/gepa_optimizer/data/loaders.py +237 -0
  24. src/gepa_optimizer/data/scroll_dataset_loader.py +334 -0
  25. src/gepa_optimizer/data/validation_dataset_loader.py +376 -0
  26. src/gepa_optimizer/data/validators.py +207 -0
  27. src/gepa_optimizer/evaluation/__init__.py +28 -0
  28. src/gepa_optimizer/evaluation/base_evaluator.py +51 -0
  29. src/gepa_optimizer/evaluation/index_caching_evaluator.py +357 -0
  30. src/gepa_optimizer/evaluation/scroll_evaluator.py +251 -0
  31. src/gepa_optimizer/evaluation/ui_evaluator.py +297 -0
  32. src/gepa_optimizer/evaluation/universal_evaluator.py +911 -0
  33. src/gepa_optimizer/evaluation/validation_evaluator.py +495 -0
  34. src/gepa_optimizer/infrastructure/__init__.py +15 -0
  35. src/gepa_optimizer/infrastructure/logging/__init__.py +43 -0
  36. src/gepa_optimizer/infrastructure/logging/context.py +257 -0
  37. src/gepa_optimizer/infrastructure/logging/formatters.py +259 -0
  38. src/gepa_optimizer/infrastructure/logging/logger.py +260 -0
  39. src/gepa_optimizer/llms/__init__.py +10 -0
  40. src/gepa_optimizer/llms/base_llm.py +56 -0
  41. src/gepa_optimizer/llms/batch_llm.py +712 -0
  42. src/gepa_optimizer/llms/llego_enhanced_llm.py +1625 -0
  43. src/gepa_optimizer/llms/vision_llm.py +813 -0
  44. src/gepa_optimizer/models/__init__.py +15 -0
  45. src/gepa_optimizer/models/config.py +488 -0
  46. src/gepa_optimizer/models/dataset.py +89 -0
  47. src/gepa_optimizer/models/result.py +204 -0
  48. src/gepa_optimizer/operators/__init__.py +45 -0
  49. src/gepa_optimizer/operators/base_operator.py +107 -0
  50. src/gepa_optimizer/operators/crossover.py +120 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+
8
+ # Virtual environments
9
+ venv/
10
+ env/
11
+ ENV/
12
+
13
+ # IDE
14
+ .vscode/
15
+ .idea/
16
+ *.swp
17
+ *.swo
18
+
19
+ # OS
20
+ .DS_Store
21
+ Thumbs.db
22
+
23
+ # Build artifacts
24
+ *.egg-info/
25
+ dist/
26
+ build/
27
+
README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ---
4
+ title: Universal Prompt Optimizer
5
+ emoji: 🧬
6
+ colorFrom: blue
7
+ colorTo: cyan
8
+ sdk: gradio
9
+ sdk_version: 4.0.0
10
+ app_file: app.py
11
+ pinned: false
12
+ license: mit
13
+ ---
14
+ # Universal Prompt Optimizer
15
+
16
+ A powerful genetic evolutionary prompt optimization tool built with GEPA (Genetic Evolutionary Prompt Agent). Optimize your prompts using genetic algorithms with optional LLEGO crossover for faster convergence.
17
+
18
+ ## Features
19
+
20
+ - 🧬 **Genetic Algorithm Optimization**: Evolve prompts through multiple iterations
21
+ - 🎯 **Multi-Model Support**: Works with OpenAI, Anthropic, Google, and custom models
22
+ - 📊 **Real-time Metrics**: Track optimization progress and improvements
23
+ - 🖼️ **Multi-modal Support**: Include images in your training examples
24
+ - ⚡ **LLEGO Crossover**: Advanced genetic operations for faster convergence
25
+
26
+ ## How to Use
27
+
28
+ 1. **Select Model**: Choose your target LLM (GPT-4, Claude, Gemini, or custom)
29
+ 2. **Enter Seed Prompt**: Describe your task, constraints, and desired output format
30
+ 3. **Add Training Examples**: Provide input/output pairs (images optional)
31
+ 4. **Configure Optimization**: Set evolution rounds, batch size, and enable LLEGO
32
+ 5. **Start Optimization**: Watch as the genetic algorithm evolves your prompt
33
+
34
+ ## API Keys
35
+
36
+ API keys are stored in-session only and never logged. You can provide them in the UI or set them as environment variables:
37
+
38
+ - `OPENAI_API_KEY`
39
+ - `ANTHROPIC_API_KEY`
40
+ - `GOOGLE_API_KEY`
41
+
42
+ ## License
43
+
44
+ MIT License
app.py ADDED
@@ -0,0 +1,1563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🚀 Universal Prompt Optimizer - Enhanced Production UI v8.0
3
+ Principal Engineer Edition: Linear/Vercel-style Dark Mode with Premium UX
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ # Add src directory to Python path for gepa_optimizer imports
9
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
10
+
11
+ import gradio as gr
12
+ import json
13
+ import base64
14
+ import io
15
+ import os
16
+ import logging
17
+ import traceback
18
+ import html
19
+ import numpy as np
20
+ from PIL import Image as PILImage
21
+ from typing import List, Dict, Optional, Any, Tuple
22
+ import threading
23
+ from collections import deque
24
+
25
+ # Optional import for URL image downloads
26
+ try:
27
+ import requests
28
+ REQUESTS_AVAILABLE = True
29
+ except ImportError:
30
+ REQUESTS_AVAILABLE = False
31
+
32
+ # ==========================================
33
+ # 0. LOGGING & BACKEND UTILS
34
+ # ==========================================
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s - %(levelname)s - %(message)s"
38
+ )
39
+ logger = logging.getLogger(__name__)
40
+
41
+ # Global Candidates Store (Thread-safe)
42
+ _candidates_store = {
43
+ 'candidates': deque(maxlen=100),
44
+ 'lock': threading.Lock(),
45
+ 'iteration': 0
46
+ }
47
+
48
+ def add_candidate_to_store(candidate: Dict[str, Any]):
49
+ with _candidates_store['lock']:
50
+ _candidates_store['candidates'].append({
51
+ 'iteration': _candidates_store['iteration'],
52
+ 'source': candidate.get('source', 'unknown'),
53
+ 'prompt': candidate.get('prompt', ''),
54
+ 'timestamp': candidate.get('timestamp', ''),
55
+ 'index': len(_candidates_store['candidates']) + 1
56
+ })
57
+
58
+ def get_candidates_from_store() -> List[Dict[str, Any]]:
59
+ with _candidates_store['lock']:
60
+ return list(_candidates_store['candidates'])
61
+
62
+ def clear_candidates_store():
63
+ with _candidates_store['lock']:
64
+ _candidates_store['candidates'].clear()
65
+ _candidates_store['iteration'] = 0
66
+
67
+ def increment_iteration():
68
+ with _candidates_store['lock']:
69
+ _candidates_store['iteration'] += 1
70
+
71
+ # ==========================================
72
+ # 1. MOCK BACKEND (Kept as provided)
73
+ # ==========================================
74
+ try:
75
+ from gepa_optimizer import quick_optimize_sync, OptimizedResult
76
+ BACKEND_AVAILABLE = True
77
+ except ImportError:
78
+ BACKEND_AVAILABLE = False
79
+ from dataclasses import dataclass
80
+
81
+ @dataclass
82
+ class OptimizedResult:
83
+ optimized_prompt: str
84
+ improvement_metrics: dict
85
+ iteration_history: list
86
+
87
+ def quick_optimize_sync(seed_prompt, dataset, model, **kwargs):
88
+ import time
89
+ iterations = kwargs.get('max_iterations', 5)
90
+ batch_size = kwargs.get('batch_size', 4)
91
+ use_llego = kwargs.get('use_llego', True)
92
+
93
+ # Simulate processing time based on iterations
94
+ time.sleep(0.5 * iterations)
95
+
96
+ llego_note = "with LLEGO crossover" if use_llego else "standard mutation only"
97
+
98
+ return OptimizedResult(
99
+ optimized_prompt=f"""# OPTIMIZED PROMPT FOR {model}
100
+ # ----------------------------------------
101
+ # Optimization: {iterations} iterations, batch size {batch_size}, {llego_note}
102
+
103
+ ## Task Context
104
+ {seed_prompt}
105
+
106
+ ## Refined Instructions
107
+ 1. Analyse the input constraints strictly.
108
+ 2. Verify output format against expected schema.
109
+ 3. Apply chain-of-thought reasoning before answering.
110
+ 4. Cross-reference with provided examples for consistency.
111
+
112
+ ## Safety & Edge Cases
113
+ - If input is ambiguous, ask for clarification.
114
+ - Maintain a professional, neutral tone.
115
+ - Handle edge cases gracefully with informative responses.""",
116
+ improvement_metrics={
117
+ "baseline_score": 0.45,
118
+ "final_score": 0.92,
119
+ "improvement": "+104.4%",
120
+ "iterations_run": iterations,
121
+ "candidates_evaluated": iterations * batch_size,
122
+ },
123
+ iteration_history=[
124
+ f"Iter 1: Baseline evaluation - Score: 0.45",
125
+ f"Iter 2: Added Chain-of-Thought constraints - Score: 0.62",
126
+ f"Iter 3: Refined output formatting rules - Score: 0.78",
127
+ f"Iter 4: {'LLEGO crossover applied' if use_llego else 'Mutation applied'} - Score: 0.88",
128
+ f"Iter 5: Final refinement - Score: 0.92",
129
+ ][:iterations],
130
+ )
131
+
132
+ # ==========================================
133
+ # 2. HELPER FUNCTIONS
134
+ # ==========================================
135
+ def gradio_image_to_base64(image_input) -> Optional[str]:
136
+ """Convert Gradio image input to base64 string with comprehensive error handling."""
137
+ if image_input is None:
138
+ return None
139
+
140
+ try:
141
+ pil_image = None
142
+
143
+ if isinstance(image_input, np.ndarray):
144
+ try:
145
+ # Validate array shape and dtype
146
+ if image_input.size == 0:
147
+ logger.warning("Empty image array provided")
148
+ return None
149
+ pil_image = PILImage.fromarray(image_input)
150
+ except (ValueError, TypeError) as e:
151
+ logger.error(f"Failed to convert numpy array to PIL Image: {str(e)}")
152
+ return None
153
+ elif isinstance(image_input, PILImage.Image):
154
+ pil_image = image_input
155
+ elif isinstance(image_input, str):
156
+ if not os.path.exists(image_input):
157
+ logger.warning(f"Image file not found: {image_input}")
158
+ return None
159
+ try:
160
+ pil_image = PILImage.open(image_input)
161
+ except (IOError, OSError) as e:
162
+ logger.error(f"Failed to open image file: {str(e)}")
163
+ return None
164
+ else:
165
+ logger.warning(f"Unsupported image input type: {type(image_input)}")
166
+ return None
167
+
168
+ if pil_image is None:
169
+ return None
170
+
171
+ try:
172
+ # Validate image before encoding
173
+ pil_image.verify()
174
+ # Reopen after verify (verify closes the image)
175
+ pil_image = PILImage.open(io.BytesIO(pil_image.tobytes()))
176
+ except Exception:
177
+ # If verify fails, try to proceed anyway
178
+ pass
179
+
180
+ try:
181
+ buffered = io.BytesIO()
182
+ pil_image.save(buffered, format="PNG")
183
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
184
+ return f"data:image/png;base64,{img_str}"
185
+ except (IOError, OSError, ValueError) as e:
186
+ logger.error(f"Failed to encode image to base64: {str(e)}")
187
+ return None
188
+ except Exception as e:
189
+ logger.error(f"Unexpected error in image conversion: {str(e)}\n{traceback.format_exc()}")
190
+ return None
191
+
192
+ def validate_dataset(dataset: List[Dict]) -> Tuple[bool, str]:
193
+ """Validate dataset structure and content with detailed error messages."""
194
+ if not isinstance(dataset, list):
195
+ return False, "Dataset must be a list of examples."
196
+
197
+ if len(dataset) == 0:
198
+ return False, "Dataset is empty. Add at least one example."
199
+
200
+ # Validate each item in the dataset
201
+ for i, item in enumerate(dataset):
202
+ if not isinstance(item, dict):
203
+ return False, f"Dataset item {i+1} must be a dictionary with 'input' and 'output' keys."
204
+
205
+ if "input" not in item or "output" not in item:
206
+ return False, f"Dataset item {i+1} is missing required 'input' or 'output' field."
207
+
208
+ if not isinstance(item.get("input"), str) or not isinstance(item.get("output"), str):
209
+ return False, f"Dataset item {i+1} has invalid 'input' or 'output' type (must be strings)."
210
+
211
+ if not item.get("input", "").strip() or not item.get("output", "").strip():
212
+ return False, f"Dataset item {i+1} has empty 'input' or 'output' field."
213
+
214
+ return True, ""
215
+
216
+ def validate_model(model: str, custom_model: str) -> Tuple[bool, str]:
217
+ """Validate model selection and custom model format."""
218
+ if not model:
219
+ return False, "Please select a foundation model."
220
+
221
+ if model == "custom":
222
+ if not custom_model or not custom_model.strip():
223
+ return False, "Custom model selected but no model ID provided."
224
+
225
+ # Validate custom model format (provider/model_name)
226
+ parts = custom_model.strip().split("/")
227
+ if len(parts) != 2:
228
+ return False, "Custom model ID must be in format 'provider/model_name' (e.g., 'openai/gpt-4')."
229
+
230
+ if not parts[0].strip() or not parts[1].strip():
231
+ return False, "Custom model ID provider and model name cannot be empty."
232
+
233
+ return True, ""
234
+
235
+ def validate_api_keys(model: str, api_keys: Dict[str, str]) -> Tuple[bool, str]:
236
+ """Validate that required API keys are provided for the selected model."""
237
+ if not api_keys:
238
+ return True, "" # Keys are optional if already set in environment
239
+
240
+ model_provider = model.split("/")[0] if "/" in model else model.lower()
241
+
242
+ # Check if model requires a specific provider key
243
+ required_providers = {
244
+ "openai": "openai",
245
+ "anthropic": "anthropic",
246
+ "google": "google"
247
+ }
248
+
249
+ if model_provider in required_providers:
250
+ provider = required_providers[model_provider]
251
+ key_value = api_keys.get(provider, "").strip() if api_keys.get(provider) else ""
252
+
253
+ # Check environment variable as fallback
254
+ env_vars = {
255
+ "openai": "OPENAI_API_KEY",
256
+ "anthropic": "ANTHROPIC_API_KEY",
257
+ "google": "GOOGLE_API_KEY"
258
+ }
259
+
260
+ if not key_value and not os.environ.get(env_vars.get(provider, "")):
261
+ return False, f"API key for {provider.capitalize()} is required for model '{model}' but not provided."
262
+
263
+ return True, ""
264
+
265
+ def safe_optimize(seed_prompt, dataset, model, custom_model="", max_iterations=5, max_metric_calls=50, batch_size=4, use_llego=True, api_keys=None):
266
+ """Safely run optimization with comprehensive error handling."""
267
+ try:
268
+ # Validate seed prompt
269
+ if not seed_prompt or not isinstance(seed_prompt, str):
270
+ return False, "Seed prompt is required and must be a string.", None
271
+
272
+ if not seed_prompt.strip():
273
+ return False, "Seed prompt cannot be empty.", None
274
+
275
+ # Validate dataset
276
+ is_valid, msg = validate_dataset(dataset)
277
+ if not is_valid:
278
+ return False, msg, None
279
+
280
+ # Determine final model
281
+ final_model = custom_model.strip() if custom_model and custom_model.strip() else model
282
+
283
+ # Validate model
284
+ model_valid, model_msg = validate_model(model, custom_model)
285
+ if not model_valid:
286
+ return False, model_msg, None
287
+
288
+ # Validate API keys
289
+ api_valid, api_msg = validate_api_keys(final_model, api_keys or {})
290
+ if not api_valid:
291
+ return False, api_msg, None
292
+
293
+ # Validate optimization parameters
294
+ if not isinstance(max_iterations, int) or max_iterations < 1 or max_iterations > 50:
295
+ return False, "Max iterations must be between 1 and 50.", None
296
+
297
+ if not isinstance(max_metric_calls, int) or max_metric_calls < 10 or max_metric_calls > 500:
298
+ return False, "Max metric calls must be between 10 and 500.", None
299
+
300
+ if not isinstance(batch_size, int) or batch_size < 1 or batch_size > 20:
301
+ return False, "Batch size must be between 1 and 20.", None
302
+
303
+ # Check backend availability
304
+ if not BACKEND_AVAILABLE:
305
+ logger.warning("Backend not available, using mock optimizer")
306
+
307
+ # Set API keys from UI if provided
308
+ if api_keys:
309
+ try:
310
+ key_mapping = {
311
+ "openai": "OPENAI_API_KEY",
312
+ "google": "GOOGLE_API_KEY",
313
+ "anthropic": "ANTHROPIC_API_KEY",
314
+ }
315
+ for provider, env_var in key_mapping.items():
316
+ if api_keys.get(provider) and api_keys[provider].strip():
317
+ os.environ[env_var] = api_keys[provider].strip()
318
+ logger.info(f"Set {provider} API key from UI")
319
+ except Exception as e:
320
+ logger.error(f"Failed to set API keys: {str(e)}")
321
+ return False, f"Failed to configure API keys: {str(e)}", None
322
+
323
+ # Run optimization
324
+ try:
325
+ result = quick_optimize_sync(
326
+ seed_prompt=seed_prompt,
327
+ dataset=dataset,
328
+ model=final_model,
329
+ max_iterations=max_iterations,
330
+ max_metric_calls=max_metric_calls,
331
+ batch_size=batch_size,
332
+ use_llego=use_llego,
333
+ verbose=True,
334
+ )
335
+
336
+ # Validate result structure
337
+ if not result:
338
+ return False, "Optimization returned no result.", None
339
+
340
+ if not hasattr(result, 'optimized_prompt'):
341
+ return False, "Optimization result is missing required fields.", None
342
+
343
+ return True, "Success", result
344
+
345
+ except KeyboardInterrupt:
346
+ logger.warning("Optimization interrupted by user")
347
+ return False, "Optimization was interrupted.", None
348
+ except TimeoutError:
349
+ logger.error("Optimization timed out")
350
+ return False, "Optimization timed out. Try reducing max_iterations or max_metric_calls.", None
351
+ except ConnectionError as e:
352
+ logger.error(f"Connection error during optimization: {str(e)}")
353
+ return False, f"Connection error: {str(e)}. Check your internet connection and API keys.", None
354
+ except ValueError as e:
355
+ logger.error(f"Invalid parameter in optimization: {str(e)}")
356
+ return False, f"Invalid configuration: {str(e)}", None
357
+ except Exception as e:
358
+ error_msg = str(e)
359
+ logger.error(f"Optimization failed: {error_msg}\n{traceback.format_exc()}")
360
+ # Provide user-friendly error messages
361
+ if "api" in error_msg.lower() or "key" in error_msg.lower():
362
+ return False, f"API error: {error_msg}. Please check your API keys.", None
363
+ elif "rate limit" in error_msg.lower():
364
+ return False, "Rate limit exceeded. Please wait a moment and try again.", None
365
+ elif "quota" in error_msg.lower():
366
+ return False, "API quota exceeded. Please check your account limits.", None
367
+ else:
368
+ return False, f"Optimization failed: {error_msg}", None
369
+
370
+ except Exception as e:
371
+ logger.error(f"Unexpected error in safe_optimize: {str(e)}\n{traceback.format_exc()}")
372
+ return False, f"Unexpected error: {str(e)}", None
373
+
374
+ # ==========================================
375
+ # 3. UI LOGIC
376
+ # ==========================================
377
+ def add_example(input_text, output_text, image_input, current_dataset):
378
+ """Add an example to the dataset with comprehensive error handling."""
379
+ try:
380
+ # Validate inputs
381
+ if not input_text:
382
+ raise gr.Error("Input text is required.")
383
+
384
+ if not output_text:
385
+ raise gr.Error("Output text is required.")
386
+
387
+ if not isinstance(input_text, str) or not isinstance(output_text, str):
388
+ raise gr.Error("Input and Output must be text strings.")
389
+
390
+ input_text = input_text.strip()
391
+ output_text = output_text.strip()
392
+
393
+ if not input_text:
394
+ raise gr.Error("Input text cannot be empty.")
395
+
396
+ if not output_text:
397
+ raise gr.Error("Output text cannot be empty.")
398
+
399
+ # Validate dataset state
400
+ if not isinstance(current_dataset, list):
401
+ raise gr.Error("Dataset state is invalid. Please refresh the page.")
402
+
403
+ # Process image with error handling
404
+ img_b64 = None
405
+ try:
406
+ img_b64 = gradio_image_to_base64(image_input)
407
+ except Exception as e:
408
+ logger.warning(f"Image processing failed, continuing without image: {str(e)}")
409
+ # Continue without image - it's optional
410
+
411
+ # Create new item
412
+ try:
413
+ new_item = {
414
+ "input": input_text,
415
+ "output": output_text,
416
+ "image": img_b64,
417
+ "image_preview": "🖼️ Image" if img_b64 else "-"
418
+ }
419
+
420
+ # Validate item structure
421
+ if not isinstance(new_item["input"], str) or not isinstance(new_item["output"], str):
422
+ raise gr.Error("Failed to create dataset item: invalid data types.")
423
+
424
+ current_dataset.append(new_item)
425
+
426
+ return current_dataset, "", "", None
427
+
428
+ except Exception as e:
429
+ logger.error(f"Failed to add example to dataset: {str(e)}")
430
+ raise gr.Error(f"Failed to add example: {str(e)}")
431
+
432
+ except gr.Error:
433
+ # Re-raise Gradio errors as-is
434
+ raise
435
+ except Exception as e:
436
+ logger.error(f"Unexpected error in add_example: {str(e)}\n{traceback.format_exc()}")
437
+ raise gr.Error(f"Unexpected error: {str(e)}")
438
+
439
+ def update_table(dataset):
440
+ """Update the dataset table display with error handling."""
441
+ try:
442
+ if not dataset:
443
+ return []
444
+
445
+ if not isinstance(dataset, list):
446
+ logger.error(f"Invalid dataset type: {type(dataset)}")
447
+ return []
448
+
449
+ table_data = []
450
+ for i, item in enumerate(dataset):
451
+ try:
452
+ if not isinstance(item, dict):
453
+ logger.warning(f"Skipping invalid dataset item {i+1}: not a dictionary")
454
+ continue
455
+
456
+ input_text = str(item.get("input", ""))[:50] if item.get("input") else ""
457
+ output_text = str(item.get("output", ""))[:50] if item.get("output") else ""
458
+ image_preview = str(item.get("image_preview", "-"))
459
+
460
+ table_data.append([i+1, input_text, output_text, image_preview])
461
+ except Exception as e:
462
+ logger.warning(f"Error processing dataset item {i+1}: {str(e)}")
463
+ continue
464
+
465
+ return table_data
466
+
467
+ except Exception as e:
468
+ logger.error(f"Error updating table: {str(e)}\n{traceback.format_exc()}")
469
+ return []
470
+
471
+ def clear_dataset():
472
+ """Clear the dataset with error handling."""
473
+ try:
474
+ return [], []
475
+ except Exception as e:
476
+ logger.error(f"Error clearing dataset: {str(e)}")
477
+ return [], []
478
+
479
+ def get_candidates_display():
480
+ """Generate HTML display for candidates with error handling."""
481
+ try:
482
+ candidates = get_candidates_from_store()
483
+
484
+ if not candidates:
485
+ return "<div style='padding: 2rem; text-align: center; color: #6b7280;'><div style='font-size: 3rem; opacity: 0.3; margin-bottom: 1rem;'>🧬</div><p>Waiting for optimization to start...</p></div>"
486
+
487
+ if not isinstance(candidates, list):
488
+ logger.error(f"Invalid candidates type: {type(candidates)}")
489
+ return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>"
490
+
491
+ html_output = "<div style='display: flex; flex-direction: column; gap: 12px;'>"
492
+
493
+ # Show last 10 candidates
494
+ candidates_to_show = list(candidates)[-10:]
495
+ for c in reversed(candidates_to_show):
496
+ try:
497
+ if not isinstance(c, dict):
498
+ continue
499
+
500
+ iteration = str(c.get('iteration', '?'))
501
+ source = str(c.get('source', 'unknown')).upper()
502
+ prompt = str(c.get('prompt', ''))[:200]
503
+
504
+ # Escape HTML to prevent XSS
505
+ iteration = html.escape(iteration)
506
+ source = html.escape(source)
507
+ prompt = html.escape(prompt)
508
+
509
+ html_output += f"""
510
+ <div style='background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); border: 1px solid #334155; border-radius: 8px; padding: 16px; position: relative; overflow: hidden;'>
511
+ <div style='position: absolute; top: 0; left: 0; width: 100%; height: 2px; background: linear-gradient(90deg, #06b6d4, #3b82f6);'></div>
512
+ <div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px;'>
513
+ <span style='font-family: "JetBrains Mono", monospace; font-size: 0.75rem; color: #06b6d4; font-weight: 600;'>ITERATION {iteration}</span>
514
+ <span style='background: #1e293b; border: 1px solid #334155; padding: 2px 8px; border-radius: 4px; font-size: 0.7rem; color: #94a3b8;'>{source}</span>
515
+ </div>
516
+ <div style='font-family: "JetBrains Mono", monospace; font-size: 0.85rem; color: #cbd5e1; line-height: 1.6;'>{prompt}...</div>
517
+ </div>
518
+ """
519
+ except Exception as e:
520
+ logger.warning(f"Error rendering candidate: {str(e)}")
521
+ continue
522
+
523
+ html_output += "</div>"
524
+ return html_output
525
+
526
+ except Exception as e:
527
+ logger.error(f"Error generating candidates display: {str(e)}\n{traceback.format_exc()}")
528
+ return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates display.</div>"
529
+
530
+ def run_optimization_flow(seed, dataset, model, custom_model, iter_count, call_count, batch, llego, k_openai, k_google, k_anthropic, progress=gr.Progress()):
531
+ """Run the optimization flow with comprehensive error handling."""
532
+ import time
533
+
534
+ try:
535
+ # Validate inputs
536
+ if not seed:
537
+ raise gr.Error("Seed prompt is required.")
538
+
539
+ if not dataset:
540
+ raise gr.Error("Dataset is required. Add at least one example.")
541
+
542
+ if not model:
543
+ raise gr.Error("Model selection is required.")
544
+
545
+ # Validate numeric parameters
546
+ try:
547
+ iter_count = int(iter_count) if iter_count else 5
548
+ call_count = int(call_count) if call_count else 50
549
+ batch = int(batch) if batch else 4
550
+ except (ValueError, TypeError) as e:
551
+ raise gr.Error(f"Invalid optimization parameters: {str(e)}")
552
+
553
+ # Determine final model
554
+ try:
555
+ final_model = custom_model.strip() if custom_model and custom_model.strip() else model
556
+ except Exception as e:
557
+ logger.warning(f"Error processing custom model: {str(e)}")
558
+ final_model = model
559
+
560
+ # Clear candidates store
561
+ try:
562
+ clear_candidates_store()
563
+ except Exception as e:
564
+ logger.warning(f"Error clearing candidates store: {str(e)}")
565
+
566
+ # Prepare API keys
567
+ api_keys = {}
568
+ try:
569
+ api_keys = {
570
+ "openai": k_openai if k_openai else "",
571
+ "google": k_google if k_google else "",
572
+ "anthropic": k_anthropic if k_anthropic else ""
573
+ }
574
+ except Exception as e:
575
+ logger.warning(f"Error processing API keys: {str(e)}")
576
+
577
+ # Initial state
578
+ try:
579
+ yield (
580
+ gr.update(visible=True),
581
+ gr.update(visible=False),
582
+ gr.update(visible=False),
583
+ "🚀 Initializing Genetic Algorithm...",
584
+ "", {}, "", ""
585
+ )
586
+ time.sleep(0.5) # Brief pause for UI update
587
+ except Exception as e:
588
+ logger.error(f"Error in initial UI update: {str(e)}")
589
+ raise gr.Error(f"Failed to initialize UI: {str(e)}")
590
+
591
+ # Evolution loop (visual progress - actual work happens in safe_optimize)
592
+ try:
593
+ for i in range(1, iter_count + 1):
594
+ try:
595
+ increment_iteration()
596
+ add_candidate_to_store({
597
+ "source": "evolution_step",
598
+ "prompt": f"Candidate {i}: Optimizing instruction clarity and task alignment...",
599
+ "timestamp": "now"
600
+ })
601
+
602
+ progress(i/iter_count, desc=f"Evolution Round {i}/{iter_count}")
603
+ yield (
604
+ gr.update(), gr.update(), gr.update(),
605
+ f"🧬 **Evolution Round {i}/{iter_count}**\n\n• Generating {batch} prompt mutations\n• Evaluating fitness scores\n• Selecting top candidates",
606
+ "", {}, "", get_candidates_display()
607
+ )
608
+ time.sleep(0.3) # Pause to show progress
609
+ except Exception as e:
610
+ logger.warning(f"Error in evolution step {i}: {str(e)}")
611
+ # Continue with next iteration
612
+ continue
613
+ except Exception as e:
614
+ logger.error(f"Error in evolution loop: {str(e)}")
615
+ # Continue to optimization attempt
616
+
617
+ # Final optimization
618
+ try:
619
+ success, msg, result = safe_optimize(
620
+ seed_prompt=seed,
621
+ dataset=dataset,
622
+ model=model,
623
+ custom_model=custom_model,
624
+ max_iterations=iter_count,
625
+ max_metric_calls=call_count,
626
+ batch_size=batch,
627
+ use_llego=llego,
628
+ api_keys=api_keys
629
+ )
630
+
631
+ if not success:
632
+ # Show error state
633
+ yield (
634
+ gr.update(visible=True),
635
+ gr.update(visible=False),
636
+ gr.update(visible=False),
637
+ f"❌ **Optimization Failed**\n\n{msg}",
638
+ "", {}, "", get_candidates_display()
639
+ )
640
+ raise gr.Error(msg)
641
+
642
+ # Validate result before displaying
643
+ if not result:
644
+ raise gr.Error("Optimization completed but returned no result.")
645
+
646
+ if not hasattr(result, 'optimized_prompt'):
647
+ raise gr.Error("Optimization result is missing required fields.")
648
+
649
+ # Show results
650
+ try:
651
+ optimized_prompt = result.optimized_prompt if result.optimized_prompt else ""
652
+ improvement_metrics = result.improvement_metrics if hasattr(result, 'improvement_metrics') else {}
653
+ iteration_history = result.iteration_history if hasattr(result, 'iteration_history') else []
654
+
655
+ history_text = "\n".join(iteration_history) if isinstance(iteration_history, list) else str(iteration_history)
656
+
657
+ yield (
658
+ gr.update(visible=False),
659
+ gr.update(visible=False),
660
+ gr.update(visible=True),
661
+ "✅ Optimization Complete",
662
+ optimized_prompt,
663
+ improvement_metrics,
664
+ history_text,
665
+ get_candidates_display()
666
+ )
667
+ except Exception as e:
668
+ logger.error(f"Error displaying results: {str(e)}")
669
+ raise gr.Error(f"Failed to display results: {str(e)}")
670
+
671
+ except gr.Error:
672
+ # Re-raise Gradio errors
673
+ raise
674
+ except Exception as e:
675
+ logger.error(f"Error in optimization: {str(e)}\n{traceback.format_exc()}")
676
+ raise gr.Error(f"Optimization error: {str(e)}")
677
+
678
+ except gr.Error:
679
+ # Re-raise Gradio errors as-is
680
+ raise
681
+ except KeyboardInterrupt:
682
+ logger.warning("Optimization interrupted by user")
683
+ raise gr.Error("Optimization was interrupted.")
684
+ except Exception as e:
685
+ logger.error(f"Unexpected error in optimization flow: {str(e)}\n{traceback.format_exc()}")
686
+ raise gr.Error(f"Unexpected error: {str(e)}")
687
+
688
+ # ==========================================
689
+ # 4. ENHANCED CSS (Linear/Vercel-style)
690
+ # ==========================================
691
+ CUSTOM_CSS = """
692
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap');
693
+
694
+ :root {
695
+ --bg0: #070A0F;
696
+ --bg1: #0B1020;
697
+ --bg2: rgba(255,255,255,0.04);
698
+ --bg3: rgba(255,255,255,0.06);
699
+
700
+ --stroke0: rgba(148,163,184,0.14);
701
+ --stroke1: rgba(148,163,184,0.22);
702
+
703
+ --text0: #EAF0FF;
704
+ --text1: rgba(234,240,255,0.74);
705
+ --text2: rgba(234,240,255,0.56);
706
+
707
+ --teal: #06B6D4;
708
+ --blue: #3B82F6;
709
+
710
+ --ok: #10B981;
711
+ --okGlow: rgba(16,185,129,0.18);
712
+
713
+ --bad: #EF4444;
714
+
715
+ --shadow: 0 12px 40px rgba(0,0,0,0.45);
716
+ --shadowSoft: 0 10px 24px rgba(0,0,0,0.32);
717
+
718
+ --radius: 14px;
719
+ --radiusSm: 10px;
720
+ }
721
+
722
+ html, body {
723
+ background: radial-gradient(1200px 700px at 20% -10%, rgba(6,182,212,0.13), transparent 55%),
724
+ radial-gradient(1000px 650px at 90% 0%, rgba(59,130,246,0.10), transparent 60%),
725
+ linear-gradient(180deg, var(--bg0) 0%, var(--bg1) 100%);
726
+ color: var(--text0);
727
+ font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
728
+ }
729
+
730
+ .gradio-container {
731
+ max-width: 1520px !important;
732
+ padding: 12px 18px !important;
733
+ margin: 0 auto !important;
734
+ }
735
+
736
+ /* --- App shell --- */
737
+ .app-shell { min-height: auto !important; }
738
+ .topbar {
739
+ padding: 12px 14px 12px 14px;
740
+ margin-bottom: 4px;
741
+ border: 1px solid var(--stroke0);
742
+ border-radius: var(--radius);
743
+ background: linear-gradient(180deg, rgba(255,255,255,0.04) 0%, rgba(255,255,255,0.02) 100%);
744
+ box-shadow: var(--shadowSoft);
745
+ }
746
+ .topbar-wrap { margin-bottom: 0 !important; }
747
+
748
+ .brand-row { display: flex; align-items: center; justify-content: space-between; gap: 16px; }
749
+ .brand-left { display: flex; align-items: center; gap: 14px; }
750
+ .brand-mark {
751
+ width: 44px; height: 44px; border-radius: 12px;
752
+ background: linear-gradient(135deg, rgba(6,182,212,0.26), rgba(59,130,246,0.20));
753
+ border: 1px solid rgba(6,182,212,0.30);
754
+ box-shadow: 0 0 0 4px rgba(6,182,212,0.10);
755
+ display: flex; align-items: center; justify-content: center;
756
+ font-weight: 800;
757
+ }
758
+ .h1 {
759
+ font-size: 22px; font-weight: 800; letter-spacing: -0.02em;
760
+ margin: 0; line-height: 1.2;
761
+ }
762
+ .subtitle { margin-top: 4px; color: var(--text1); font-weight: 500; font-size: 13px; }
763
+
764
+ .status-pill {
765
+ display: inline-flex; align-items: center; gap: 10px;
766
+ padding: 10px 12px; border-radius: 999px;
767
+ background: rgba(255,255,255,0.03);
768
+ border: 1px solid var(--stroke0);
769
+ color: var(--text1);
770
+ font-size: 12px; font-weight: 700; letter-spacing: 0.08em;
771
+ text-transform: uppercase;
772
+ }
773
+ .dot {
774
+ width: 10px; height: 10px; border-radius: 999px;
775
+ background: var(--ok);
776
+ box-shadow: 0 0 16px rgba(16,185,129,0.40);
777
+ animation: pulse 1.8s ease-in-out infinite;
778
+ }
779
+ @keyframes pulse { 0%, 100% { transform: scale(1); opacity: 0.95; } 50% { transform: scale(1.18); opacity: 0.70; } }
780
+
781
+ /* --- Two-column layout helpers --- */
782
+ .left-col, .right-col { min-width: 280px; }
783
+
784
+ /* --- Cards / Sections --- */
785
+ .card {
786
+ border-radius: var(--radius);
787
+ background: linear-gradient(180deg, rgba(255,255,255,0.045) 0%, rgba(255,255,255,0.022) 100%);
788
+ border: 1px solid var(--stroke0);
789
+ box-shadow: var(--shadowSoft);
790
+ padding: 16px;
791
+ }
792
+ .card + .card { margin-top: 14px; }
793
+
794
+ .card-head {
795
+ display: flex; align-items: center; justify-content: space-between;
796
+ gap: 12px;
797
+ padding-bottom: 12px;
798
+ margin-bottom: 12px;
799
+ border-bottom: 1px solid var(--stroke0);
800
+ }
801
+ .card-title {
802
+ display: flex; align-items: center; gap: 10px;
803
+ font-size: 13px; font-weight: 800; letter-spacing: 0.12em;
804
+ text-transform: uppercase; color: var(--text1);
805
+ }
806
+ .step {
807
+ width: 30px; height: 30px; border-radius: 10px;
808
+ background: linear-gradient(135deg, rgba(6,182,212,0.95), rgba(59,130,246,0.95));
809
+ box-shadow: 0 10px 20px rgba(6,182,212,0.18);
810
+ display: flex; align-items: center; justify-content: center;
811
+ color: white; font-weight: 900; font-size: 13px;
812
+ }
813
+ .hint { color: var(--text2); font-size: 12px; line-height: 1.4; }
814
+
815
+ .ds-count span {
816
+ display: inline-flex;
817
+ align-items: center;
818
+ padding: 7px 10px;
819
+ border-radius: 999px;
820
+ border: 1px solid var(--stroke0);
821
+ background: rgba(255,255,255,0.02);
822
+ color: var(--text1) !important;
823
+ font-weight: 700;
824
+ font-size: 12px;
825
+ }
826
+
827
+ /* --- Inputs --- */
828
+ label { color: var(--text1) !important; font-weight: 650 !important; font-size: 12px !important; }
829
+
830
+ textarea, input, select {
831
+ background: rgba(255,255,255,0.03) !important;
832
+ border: 1px solid var(--stroke0) !important;
833
+ border-radius: 12px !important;
834
+ color: var(--text0) !important;
835
+ transition: border-color 0.15s ease, box-shadow 0.15s ease, transform 0.15s ease;
836
+ }
837
+
838
+ textarea:focus, input:focus, select:focus {
839
+ outline: none !important;
840
+ border-color: rgba(6,182,212,0.55) !important;
841
+ box-shadow: 0 0 0 4px rgba(6,182,212,0.14) !important;
842
+ }
843
+
844
+ .keybox input { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; }
845
+
846
+ .seed textarea { min-height: 160px !important; }
847
+ .mono textarea { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; font-size: 12.5px !important; }
848
+
849
+ /* --- Buttons --- */
850
+ .cta button {
851
+ width: 100% !important;
852
+ border: 0 !important;
853
+ border-radius: 14px !important;
854
+ padding: 14px 16px !important;
855
+ font-size: 13px !important;
856
+ font-weight: 900 !important;
857
+ letter-spacing: 0.12em !important;
858
+ text-transform: uppercase !important;
859
+ color: white !important;
860
+ background: linear-gradient(135deg, rgba(6,182,212,1) 0%, rgba(59,130,246,1) 100%) !important;
861
+ box-shadow: 0 18px 48px rgba(6,182,212,0.22) !important;
862
+ position: relative !important;
863
+ overflow: hidden !important;
864
+ }
865
+ .cta button::after {
866
+ content: "";
867
+ position: absolute; inset: -120px;
868
+ background: radial-gradient(closest-side, rgba(255,255,255,0.18), transparent 60%);
869
+ transform: translateX(-40%);
870
+ transition: transform 0.45s ease;
871
+ }
872
+ .cta button:hover { transform: translateY(-1px); }
873
+ .cta button:hover::after { transform: translateX(40%); }
874
+ .cta button:active { transform: translateY(0px); }
875
+
876
+ .btn-secondary button {
877
+ border-radius: 12px !important;
878
+ border: 1px solid var(--stroke1) !important;
879
+ background: rgba(255,255,255,0.03) !important;
880
+ color: var(--text0) !important;
881
+ font-weight: 800 !important;
882
+ }
883
+ .btn-secondary button:hover { border-color: rgba(6,182,212,0.55) !important; }
884
+
885
+ .btn-danger button {
886
+ border-radius: 12px !important;
887
+ border: 1px solid rgba(239,68,68,0.55) !important;
888
+ background: rgba(239,68,68,0.06) !important;
889
+ color: rgba(255,170,170,1) !important;
890
+ font-weight: 900 !important;
891
+ }
892
+
893
+ /* --- Dataframe --- */
894
+ .dataframe {
895
+ border-radius: 14px !important;
896
+ border: 1px solid var(--stroke0) !important;
897
+ background: rgba(255,255,255,0.02) !important;
898
+ overflow: hidden !important;
899
+ }
900
+ .dataframe thead th {
901
+ background: rgba(255,255,255,0.04) !important;
902
+ color: var(--text1) !important;
903
+ font-weight: 900 !important;
904
+ font-size: 11px !important;
905
+ letter-spacing: 0.10em !important;
906
+ text-transform: uppercase !important;
907
+ border-bottom: 1px solid var(--stroke0) !important;
908
+ }
909
+ .dataframe tbody td {
910
+ color: var(--text0) !important;
911
+ font-size: 12px !important;
912
+ border-bottom: 1px solid rgba(148,163,184,0.10) !important;
913
+ }
914
+ .dataframe tbody tr:hover { background: rgba(255,255,255,0.03) !important; }
915
+
916
+ /* --- Status / Results --- */
917
+ .panel {
918
+ border-radius: var(--radius);
919
+ border: 1px solid var(--stroke0);
920
+ background: linear-gradient(180deg, rgba(255,255,255,0.045), rgba(255,255,255,0.020));
921
+ box-shadow: var(--shadowSoft);
922
+ padding: 16px;
923
+ }
924
+ .panel-title {
925
+ display: flex; align-items: center; justify-content: space-between;
926
+ gap: 10px;
927
+ padding-bottom: 12px; margin-bottom: 12px;
928
+ border-bottom: 1px solid var(--stroke0);
929
+ }
930
+ .panel-title h3 { margin: 0; font-size: 13px; letter-spacing: 0.12em; text-transform: uppercase; color: var(--text1); }
931
+ .running-pill {
932
+ display: inline-flex; align-items: center; gap: 10px;
933
+ padding: 8px 10px; border-radius: 999px;
934
+ border: 1px solid rgba(6,182,212,0.38);
935
+ background: rgba(6,182,212,0.08);
936
+ color: rgba(153,246,228,0.95);
937
+ font-weight: 900; font-size: 11px; letter-spacing: 0.10em; text-transform: uppercase;
938
+ }
939
+ .running-dot { width: 9px; height: 9px; border-radius: 99px; background: var(--teal); box-shadow: 0 0 18px rgba(6,182,212,0.45); animation: pulse 1.8s ease-in-out infinite; }
940
+
941
+ .empty {
942
+ border-radius: var(--radius);
943
+ border: 1px dashed rgba(148,163,184,0.26);
944
+ background: rgba(255,255,255,0.02);
945
+ padding: 28px;
946
+ text-align: center;
947
+ color: var(--text2);
948
+ }
949
+ .empty .big { font-size: 40px; opacity: 0.22; margin-bottom: 10px; }
950
+ .empty .t { color: var(--text1); font-weight: 800; margin-bottom: 6px; }
951
+ .empty .s { font-size: 12px; }
952
+
953
+ .results {
954
+ border-radius: var(--radius);
955
+ border: 1px solid rgba(16,185,129,0.55);
956
+ background: linear-gradient(180deg, rgba(16,185,129,0.12), rgba(255,255,255,0.02));
957
+ box-shadow: 0 0 0 4px rgba(16,185,129,0.10), 0 20px 60px rgba(0,0,0,0.42);
958
+ padding: 16px;
959
+ }
960
+ .results-banner {
961
+ display: flex; align-items: center; justify-content: space-between;
962
+ gap: 12px;
963
+ padding-bottom: 12px; margin-bottom: 12px;
964
+ border-bottom: 1px solid rgba(16,185,129,0.28);
965
+ }
966
+ .results-banner .k { display: flex; align-items: center; gap: 10px; }
967
+ .results-banner .k .icon {
968
+ width: 36px; height: 36px; border-radius: 12px;
969
+ background: rgba(16,185,129,0.18);
970
+ border: 1px solid rgba(16,185,129,0.45);
971
+ display: flex; align-items: center; justify-content: center;
972
+ }
973
+ .results-banner .k .title { font-weight: 900; color: rgba(189,255,225,0.98); letter-spacing: 0.06em; text-transform: uppercase; font-size: 12px; }
974
+ .results-banner .k .sub { margin-top: 2px; color: rgba(189,255,225,0.70); font-size: 12px; }
975
+
976
+ .tabs { background: transparent !important; }
977
+ .tab-nav button {
978
+ background: transparent !important;
979
+ border: 0 !important;
980
+ border-bottom: 2px solid transparent !important;
981
+ color: var(--text2) !important;
982
+ font-weight: 800 !important;
983
+ padding: 10px 12px !important;
984
+ }
985
+ .tab-nav button[aria-selected="true"] {
986
+ color: rgba(153,246,228,0.98) !important;
987
+ border-bottom-color: rgba(6,182,212,0.75) !important;
988
+ }
989
+ .tab-nav button:hover { color: var(--text0) !important; }
990
+
991
+ .small-note { color: var(--text2); font-size: 12px; }
992
+
993
+ /* --- Candidates stream --- */
994
+ .cand-empty { padding: 28px; text-align: center; color: var(--text2); }
995
+ .cand-empty-icon { font-size: 40px; opacity: 0.25; margin-bottom: 10px; }
996
+ .cand-empty-title { color: var(--text1); font-weight: 900; margin-bottom: 4px; }
997
+ .cand-empty-sub { font-size: 12px; }
998
+
999
+ .cand-stream { display: flex; flex-direction: column; gap: 10px; }
1000
+ .cand-card {
1001
+ border-radius: 14px;
1002
+ border: 1px solid rgba(148,163,184,0.18);
1003
+ background: linear-gradient(135deg, rgba(15,23,42,0.85), rgba(2,6,23,0.45));
1004
+ overflow: hidden;
1005
+ }
1006
+ .cand-topbar { height: 2px; background: linear-gradient(90deg, var(--teal), var(--blue)); }
1007
+ .cand-header {
1008
+ display: flex; align-items: center; justify-content: space-between;
1009
+ gap: 10px;
1010
+ padding: 10px 12px 0 12px;
1011
+ }
1012
+ .cand-iter { font-family: "JetBrains Mono", ui-monospace; font-size: 11px; color: rgba(153,246,228,0.92); font-weight: 800; letter-spacing: 0.08em; }
1013
+ .cand-pill {
1014
+ font-size: 10px; font-weight: 900; letter-spacing: 0.10em;
1015
+ padding: 5px 8px; border-radius: 999px;
1016
+ border: 1px solid rgba(148,163,184,0.20);
1017
+ background: rgba(255,255,255,0.03);
1018
+ color: var(--text2);
1019
+ }
1020
+ .cand-body {
1021
+ padding: 10px 12px 12px 12px;
1022
+ font-family: "JetBrains Mono", ui-monospace;
1023
+ font-size: 12px;
1024
+ line-height: 1.6;
1025
+ color: rgba(234,240,255,0.75);
1026
+ }
1027
+
1028
+ /* --- Responsive --- */
1029
+ @media (max-width: 980px) {
1030
+ .gradio-container { padding: 16px 12px !important; }
1031
+ .brand-row { flex-direction: column; align-items: flex-start; }
1032
+ .status-pill { align-self: stretch; justify-content: center; }
1033
+ }
1034
+ """
1035
+
1036
+ FORCE_DARK_JS = """
1037
+ function forceDarkTheme() {
1038
+ try {
1039
+ const url = new URL(window.location.href);
1040
+ if (url.searchParams.get("__theme") !== "dark") {
1041
+ url.searchParams.set("__theme", "dark");
1042
+ window.location.replace(url.toString());
1043
+ }
1044
+ } catch (e) {
1045
+ // no-op
1046
+ }
1047
+ }
1048
+ forceDarkTheme();
1049
+ """
1050
+
1051
+ # ==========================================
1052
+ # 5. UI CONSTRUCTION (Redesigned)
1053
+ # ==========================================
1054
+ APP_TITLE = "Universal Prompt Optimizer"
1055
+ APP_SUBTITLE = "Genetic Evolutionary Prompt Agent (GEPA)"
1056
+ STATUS_READY = "System Ready"
1057
+
1058
+ with gr.Blocks(
1059
+ title="Universal Prompt Optimizer",
1060
+ theme=gr.themes.Base()
1061
+ ) as app:
1062
+ dataset_state = gr.State([])
1063
+
1064
+ # TOP BAR
1065
+ gr.HTML(
1066
+ f"""
1067
+ <div class="topbar">
1068
+ <div class="brand-row">
1069
+ <div class="brand-left">
1070
+ <div class="brand-mark">GE</div>
1071
+ <div>
1072
+ <div class="h1">{APP_TITLE}</div>
1073
+ <div class="subtitle">{APP_SUBTITLE}</div>
1074
+ </div>
1075
+ </div>
1076
+ <div class="status-pill"><span class="dot"></span> {STATUS_READY}</div>
1077
+ </div>
1078
+ </div>
1079
+ """,
1080
+ elem_classes=["topbar-wrap"]
1081
+ )
1082
+
1083
+ # MAIN LAYOUT
1084
+ with gr.Row():
1085
+
1086
+ # LEFT COLUMN: Configuration
1087
+ with gr.Column(scale=5):
1088
+
1089
+ # Step 1
1090
+ with gr.Group(elem_classes=["card"]):
1091
+ gr.HTML(
1092
+ """
1093
+ <div class="card-head">
1094
+ <div class="card-title"><div class="step">1</div> Model & Credentials</div>
1095
+ <div class="hint">Select a target model, then provide keys (stored in-session only).</div>
1096
+ </div>
1097
+ """
1098
+ )
1099
+
1100
+ with gr.Row():
1101
+ model_select = gr.Dropdown(
1102
+ label="Foundation Model",
1103
+ choices=[
1104
+ "openai/gpt-4o",
1105
+ "openai/gpt-4-turbo",
1106
+ "anthropic/claude-3-5-sonnet",
1107
+ "google/gemini-1.5-pro",
1108
+ "custom"
1109
+ ],
1110
+ value="openai/gpt-4o",
1111
+ scale=2
1112
+ )
1113
+ custom_model_input = gr.Textbox(
1114
+ label="Custom Model ID",
1115
+ placeholder="provider/model_name",
1116
+ scale=1
1117
+ )
1118
+
1119
+ gr.HTML('<div class="subsection-title">API Access Keys</div>')
1120
+ gr.Markdown("*Keys are stored in-session only and never logged*", elem_classes=["text-xs"])
1121
+
1122
+ with gr.Row():
1123
+ key_openai = gr.Textbox(
1124
+ label="OpenAI API Key",
1125
+ type="password",
1126
+ placeholder="sk-...",
1127
+ scale=1
1128
+ )
1129
+ key_google = gr.Textbox(
1130
+ label="Google API Key",
1131
+ type="password",
1132
+ placeholder="AIza...",
1133
+ scale=1
1134
+ )
1135
+ key_anthropic = gr.Textbox(
1136
+ label="Anthropic API Key",
1137
+ type="password",
1138
+ placeholder="sk-ant...",
1139
+ scale=1
1140
+ )
1141
+
1142
+ # Step 2
1143
+ with gr.Group(elem_classes=["card"]):
1144
+ gr.HTML(
1145
+ """
1146
+ <div class="card-head">
1147
+ <div class="card-title"><div class="step">2</div> Seed Prompt</div>
1148
+ <div class="hint">Describe the task, constraints, output format, and tone.</div>
1149
+ </div>
1150
+ """
1151
+ )
1152
+ seed_input = gr.Textbox(
1153
+ label="Task Description",
1154
+ placeholder="Example: You are a code reviewer that identifies security vulnerabilities in Python code. Return a JSON report with severity and fixes...",
1155
+ lines=7,
1156
+ max_lines=14,
1157
+ elem_classes=["seed", "mono"]
1158
+ )
1159
+
1160
+ # Step 3
1161
+ with gr.Group(elem_classes=["card"]):
1162
+ gr.HTML(
1163
+ """
1164
+ <div class="card-head">
1165
+ <div class="card-title"><div class="step">3</div> Training Examples</div>
1166
+ <div class="hint">Add a few high-quality I/O pairs (images optional) to shape the optimizer.</div>
1167
+ </div>
1168
+ """
1169
+ )
1170
+
1171
+ with gr.Tabs():
1172
+ with gr.Tab("Manual Entry"):
1173
+ with gr.Row():
1174
+ with gr.Column(scale=2):
1175
+ d_in = gr.Textbox(
1176
+ label="Input / User Prompt",
1177
+ placeholder="Example user input...",
1178
+ lines=3
1179
+ )
1180
+ d_out = gr.Textbox(
1181
+ label="Ideal Output",
1182
+ placeholder="Expected AI response...",
1183
+ lines=3
1184
+ )
1185
+ with gr.Column(scale=1):
1186
+ d_img = gr.Image(
1187
+ label="Attach Image (Optional)",
1188
+ type="numpy",
1189
+ height=170
1190
+ )
1191
+
1192
+ btn_add = gr.Button(
1193
+ "Add Example",
1194
+ elem_classes=["btn-secondary"]
1195
+ )
1196
+
1197
+ with gr.Tab("Bulk Import (JSON)"):
1198
+ gr.Markdown(
1199
+ "Paste a JSON array like: `[{\"input\": \"...\", \"output\": \"...\"}]`",
1200
+ elem_classes=["small-note"]
1201
+ )
1202
+ bulk_json = gr.Textbox(
1203
+ show_label=False,
1204
+ placeholder='[{"input": "...", "output": "..."}]',
1205
+ lines=6
1206
+ )
1207
+ btn_import = gr.Button(
1208
+ "Import JSON",
1209
+ elem_classes=["btn-secondary"]
1210
+ )
1211
+
1212
+ with gr.Row():
1213
+ gr.HTML("<div class='hint'>Current dataset</div>")
1214
+ ds_count = gr.HTML(
1215
+ "<span style='color: var(--text-secondary);'>0 examples loaded</span>",
1216
+ elem_classes=["ds-count"]
1217
+ )
1218
+
1219
+ ds_table = gr.Dataframe(
1220
+ headers=["ID", "Input", "Output", "Media"],
1221
+ datatype=["number", "str", "str", "str"],
1222
+ row_count=6,
1223
+ column_count=(4, "fixed"),
1224
+ interactive=False
1225
+ )
1226
+
1227
+ with gr.Row():
1228
+ btn_clear = gr.Button(
1229
+ "Clear All",
1230
+ elem_classes=["btn-danger"],
1231
+ size="sm"
1232
+ )
1233
+
1234
+ # Step 4 (Prominent, not buried)
1235
+ with gr.Group(elem_classes=["card"]):
1236
+ gr.HTML(
1237
+ """
1238
+ <div class="card-head">
1239
+ <div class="card-title"><div class="step">4</div> Optimization Controls</div>
1240
+ <div class="hint">Tune evolution budget. Defaults are safe for quick runs.</div>
1241
+ </div>
1242
+ """
1243
+ )
1244
+
1245
+ with gr.Row():
1246
+ slider_iter = gr.Slider(
1247
+ minimum=1,
1248
+ maximum=20,
1249
+ value=5,
1250
+ step=1,
1251
+ label="Evolution Rounds",
1252
+ info="Number of genetic iterations"
1253
+ )
1254
+ slider_calls = gr.Slider(
1255
+ minimum=10,
1256
+ maximum=200,
1257
+ value=50,
1258
+ step=10,
1259
+ label="Max LLM Calls",
1260
+ info="Total API call budget"
1261
+ )
1262
+
1263
+ with gr.Row():
1264
+ slider_batch = gr.Slider(
1265
+ minimum=1,
1266
+ maximum=10,
1267
+ value=4,
1268
+ step=1,
1269
+ label="Batch Size",
1270
+ info="Candidates per iteration"
1271
+ )
1272
+ check_llego = gr.Checkbox(
1273
+ value=True,
1274
+ label="Enable LLEGO Crossover",
1275
+ info="Use advanced genetic operations"
1276
+ )
1277
+
1278
+ btn_optimize = gr.Button(
1279
+ "Start Optimization",
1280
+ elem_classes=["cta", "mt-6"]
1281
+ )
1282
+
1283
+ # RIGHT: STATUS + RESULTS
1284
+ with gr.Column(scale=5, elem_classes=["right-col"]):
1285
+ # STATUS PANEL (Hidden by default)
1286
+ status_panel = gr.Group(visible=False, elem_classes=["panel"])
1287
+ with status_panel:
1288
+ gr.HTML(
1289
+ """
1290
+ <div class="panel-title">
1291
+ <h3>Optimization status</h3>
1292
+ <div class="running-pill"><span class="running-dot"></span> Running</div>
1293
+ </div>
1294
+ """
1295
+ )
1296
+ txt_status = gr.Markdown("Initializing genetic algorithm...")
1297
+
1298
+ # EMPTY STATE
1299
+ empty_state = gr.HTML(
1300
+ """
1301
+ <div class="empty">
1302
+ <div class="big">🧬</div>
1303
+ <div class="t">Ready to optimize</div>
1304
+ <div class="s">Fill Steps 1–3, then click <b>Start Optimization</b> to begin prompt evolution.</div>
1305
+ </div>
1306
+ """,
1307
+ visible=True
1308
+ )
1309
+
1310
+ # RESULTS PANEL (Hidden by default)
1311
+ results_panel = gr.Group(visible=False, elem_classes=["results"])
1312
+ with results_panel:
1313
+ gr.HTML(
1314
+ """
1315
+ <div class="results-banner">
1316
+ <div class="k">
1317
+ <div class="icon">✓</div>
1318
+ <div>
1319
+ <div class="title">Optimization successful</div>
1320
+ <div class="sub">Review the optimized prompt, metrics, and evolution traces.</div>
1321
+ </div>
1322
+ </div>
1323
+ </div>
1324
+ """
1325
+ )
1326
+
1327
+ with gr.Tabs():
1328
+ with gr.Tab("Optimized Prompt"):
1329
+ res_prompt = gr.Textbox(
1330
+ label="Optimized Prompt",
1331
+ lines=18,
1332
+ max_lines=28,
1333
+ interactive=False,
1334
+ show_label=True,
1335
+ elem_classes=["mono"]
1336
+ )
1337
+
1338
+ with gr.Tab("Metrics & Log"):
1339
+ res_metrics = gr.JSON(label="Performance Gains")
1340
+ res_history = gr.TextArea(
1341
+ label="Evolution Log",
1342
+ interactive=False,
1343
+ lines=10
1344
+ )
1345
+
1346
+ with gr.Tab("🧬 Live Candidates"):
1347
+ gr.Markdown("Real-time stream of generated prompt candidates during optimization:")
1348
+ live_candidates = gr.HTML()
1349
+ btn_refresh_cand = gr.Button(
1350
+ "🔄 Refresh Stream",
1351
+ elem_classes=["secondary-btn"],
1352
+ size="sm"
1353
+ )
1354
+
1355
+ # ==========================================
1356
+ # 6. EVENT HANDLERS
1357
+ # ==========================================
1358
+
1359
+ # Dataset Management
1360
+ def update_dataset_count(dataset):
1361
+ """Update dataset count display with error handling."""
1362
+ try:
1363
+ if not isinstance(dataset, list):
1364
+ return "<span style='color: var(--text-secondary);'>0 examples loaded</span>"
1365
+ count = len(dataset)
1366
+ return f"<span style='color: var(--text-secondary);'>{count} example{'s' if count != 1 else ''} loaded</span>"
1367
+ except Exception as e:
1368
+ logger.error(f"Error updating dataset count: {str(e)}")
1369
+ return "<span style='color: var(--text-secondary);'>Error</span>"
1370
+
1371
+ # Wrap event handlers with error handling
1372
+ def safe_add_example(*args):
1373
+ """Wrapper for add_example with error handling."""
1374
+ try:
1375
+ return add_example(*args)
1376
+ except gr.Error:
1377
+ raise
1378
+ except Exception as e:
1379
+ logger.error(f"Unexpected error in add_example: {str(e)}")
1380
+ raise gr.Error(f"Failed to add example: {str(e)}")
1381
+
1382
+ def safe_update_table(dataset):
1383
+ """Wrapper for update_table with error handling."""
1384
+ try:
1385
+ return update_table(dataset)
1386
+ except Exception as e:
1387
+ logger.error(f"Error updating table: {str(e)}")
1388
+ return []
1389
+
1390
+ def safe_clear_dataset():
1391
+ """Wrapper for clear_dataset with error handling."""
1392
+ try:
1393
+ return clear_dataset()
1394
+ except Exception as e:
1395
+ logger.error(f"Error clearing dataset: {str(e)}")
1396
+ return [], []
1397
+
1398
+ btn_add.click(
1399
+ safe_add_example,
1400
+ inputs=[d_in, d_out, d_img, dataset_state],
1401
+ outputs=[dataset_state, d_in, d_out, d_img]
1402
+ ).then(
1403
+ safe_update_table,
1404
+ inputs=[dataset_state],
1405
+ outputs=[ds_table]
1406
+ ).then(
1407
+ update_dataset_count,
1408
+ inputs=[dataset_state],
1409
+ outputs=[ds_count]
1410
+ )
1411
+
1412
+ btn_clear.click(
1413
+ safe_clear_dataset,
1414
+ outputs=[dataset_state, ds_table]
1415
+ ).then(
1416
+ lambda: "<span style='color: var(--text-secondary);'>0 examples loaded</span>",
1417
+ outputs=[ds_count]
1418
+ )
1419
+
1420
+ # Bulk Import
1421
+ def import_bulk_json(json_text, current_dataset):
1422
+ """Import examples from JSON with comprehensive error handling."""
1423
+ try:
1424
+ # Validate inputs
1425
+ if not json_text or not json_text.strip():
1426
+ raise gr.Error("JSON input is empty. Please provide a JSON array.")
1427
+
1428
+ if not isinstance(current_dataset, list):
1429
+ raise gr.Error("Dataset state is invalid. Please refresh the page.")
1430
+
1431
+ # Parse JSON
1432
+ try:
1433
+ data = json.loads(json_text.strip())
1434
+ except json.JSONDecodeError as e:
1435
+ raise gr.Error(f"Invalid JSON format: {str(e)}. Please check your JSON syntax.")
1436
+
1437
+ # Validate structure
1438
+ if not isinstance(data, list):
1439
+ raise gr.Error("JSON must be an array of objects. Example: [{\"input\": \"...\", \"output\": \"...\"}]")
1440
+
1441
+ if len(data) == 0:
1442
+ raise gr.Error("JSON array is empty. Add at least one example object.")
1443
+
1444
+ # Validate and import items
1445
+ imported_count = 0
1446
+ errors = []
1447
+
1448
+ for i, item in enumerate(data):
1449
+ try:
1450
+ if not isinstance(item, dict):
1451
+ errors.append(f"Item {i+1}: not a dictionary")
1452
+ continue
1453
+
1454
+ if "input" not in item or "output" not in item:
1455
+ errors.append(f"Item {i+1}: missing 'input' or 'output' field")
1456
+ continue
1457
+
1458
+ input_val = item["input"]
1459
+ output_val = item["output"]
1460
+
1461
+ if not isinstance(input_val, str) or not isinstance(output_val, str):
1462
+ errors.append(f"Item {i+1}: 'input' and 'output' must be strings")
1463
+ continue
1464
+
1465
+ if not input_val.strip() or not output_val.strip():
1466
+ errors.append(f"Item {i+1}: 'input' and 'output' cannot be empty")
1467
+ continue
1468
+
1469
+ # Add valid item
1470
+ current_dataset.append({
1471
+ "input": input_val.strip(),
1472
+ "output": output_val.strip(),
1473
+ "image": item.get("image"), # Optional
1474
+ "image_preview": "🖼️ Image" if item.get("image") else "-"
1475
+ })
1476
+ imported_count += 1
1477
+
1478
+ except Exception as e:
1479
+ errors.append(f"Item {i+1}: {str(e)}")
1480
+ logger.warning(f"Error importing item {i+1}: {str(e)}")
1481
+ continue
1482
+
1483
+ # Report results
1484
+ if imported_count == 0:
1485
+ error_msg = "No valid examples imported. "
1486
+ if errors:
1487
+ error_msg += "Errors: " + "; ".join(errors[:3])
1488
+ if len(errors) > 3:
1489
+ error_msg += f" (and {len(errors) - 3} more)"
1490
+ raise gr.Error(error_msg)
1491
+
1492
+ if errors:
1493
+ warning_msg = f"Imported {imported_count} example(s). "
1494
+ if len(errors) <= 3:
1495
+ warning_msg += f"Warnings: {'; '.join(errors)}"
1496
+ else:
1497
+ warning_msg += f"{len(errors)} items had errors."
1498
+ logger.warning(warning_msg)
1499
+
1500
+ return current_dataset, ""
1501
+
1502
+ except gr.Error:
1503
+ # Re-raise Gradio errors
1504
+ raise
1505
+ except Exception as e:
1506
+ logger.error(f"Unexpected error in import_bulk_json: {str(e)}\n{traceback.format_exc()}")
1507
+ raise gr.Error(f"Failed to import JSON: {str(e)}")
1508
+
1509
+ btn_import.click(
1510
+ import_bulk_json,
1511
+ inputs=[bulk_json, dataset_state],
1512
+ outputs=[dataset_state, bulk_json]
1513
+ ).then(
1514
+ safe_update_table,
1515
+ inputs=[dataset_state],
1516
+ outputs=[ds_table]
1517
+ ).then(
1518
+ update_dataset_count,
1519
+ inputs=[dataset_state],
1520
+ outputs=[ds_count]
1521
+ )
1522
+
1523
+ # Main Optimization Flow
1524
+ btn_optimize.click(
1525
+ run_optimization_flow,
1526
+ inputs=[
1527
+ seed_input, dataset_state, model_select, custom_model_input,
1528
+ slider_iter, slider_calls, slider_batch, check_llego,
1529
+ key_openai, key_google, key_anthropic
1530
+ ],
1531
+ outputs=[
1532
+ status_panel, empty_state, results_panel,
1533
+ txt_status, res_prompt, res_metrics, res_history, live_candidates
1534
+ ]
1535
+ )
1536
+
1537
+ # Refresh Candidates
1538
+ def safe_get_candidates_display():
1539
+ """Wrapper for get_candidates_display with error handling."""
1540
+ try:
1541
+ return get_candidates_display()
1542
+ except Exception as e:
1543
+ logger.error(f"Error refreshing candidates: {str(e)}")
1544
+ return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>"
1545
+
1546
+ btn_refresh_cand.click(
1547
+ safe_get_candidates_display,
1548
+ outputs=[live_candidates]
1549
+ )
1550
+
1551
+ # ==========================================
1552
+ # 7. LAUNCH
1553
+ # ==========================================
1554
+ if __name__ == "__main__":
1555
+ app.queue().launch(
1556
+ server_name="0.0.0.0",
1557
+ server_port=7860,
1558
+ share=False, # Set to False for HF Spaces
1559
+ show_error=True,
1560
+ css=CUSTOM_CSS,
1561
+ js=FORCE_DARK_JS
1562
+ )
1563
+
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies - gepa from git
2
+ git+https://github.com/gepa-ai/gepa.git
3
+ numpy>=1.21.0
4
+ pandas>=1.5.0
5
+ pydantic>=2.0.0
6
+ python-dotenv>=1.0.0
7
+
8
+ # HTTP/API clients
9
+ requests>=2.31.0
10
+ aiohttp>=3.8.0
11
+ asyncio-throttle>=1.0.0
12
+
13
+ # LLM Provider SDKs
14
+ openai>=1.0.0
15
+ anthropic>=0.18.0
16
+ google-generativeai>=0.3.0
17
+ google-genai>=0.2.0
18
+
19
+ # Image processing
20
+ Pillow>=9.0.0
21
+
22
+ # Gradio UI (version will be set by README.md sdk_version)
23
+ gradio>=4.0.0
src/gepa_optimizer.egg-info/PKG-INFO ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: gepa-optimizer
3
+ Version: 0.1.0
4
+ Summary: Universal prompt optimization framework based on GEPA
5
+ Home-page: https://github.com/suhasb-dev/Prompt-Optimizer
6
+ Author: Suhas
7
+ Author-email: Suhas <s8hasgrylls@gmail.com>
8
+ License: MIT
9
+ Project-URL: Homepage, https://github.com/suhasb-dev/Prompt-Optimizer
10
+ Project-URL: Repository, https://github.com/suhasb-dev/Prompt-Optimizer
11
+ Project-URL: Documentation, https://suhasb-dev.gitbook.io/gepa-universal-prompt-optimizer/
12
+ Project-URL: Bug Reports, https://github.com/suhasb-dev/Prompt-Optimizer/issues
13
+ Keywords: prompt-optimization,llm,gepa,ai,machine-learning,ui-tree-extraction
14
+ Classifier: Development Status :: 3 - Alpha
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Intended Audience :: Science/Research
17
+ Classifier: License :: OSI Approved :: MIT License
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.8
20
+ Classifier: Programming Language :: Python :: 3.9
21
+ Classifier: Programming Language :: Python :: 3.10
22
+ Classifier: Programming Language :: Python :: 3.11
23
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
25
+ Requires-Python: >=3.8
26
+ Description-Content-Type: text/markdown
27
+ License-File: LICENSE
28
+ Requires-Dist: gepa>=0.0.12
29
+ Requires-Dist: pandas>=1.5.0
30
+ Requires-Dist: pydantic>=2.0.0
31
+ Requires-Dist: python-dotenv>=1.0.0
32
+ Requires-Dist: requests>=2.31.0
33
+ Requires-Dist: aiohttp>=3.8.0
34
+ Requires-Dist: asyncio-throttle>=1.0.0
35
+ Requires-Dist: google-generativeai>=0.3.0
36
+ Requires-Dist: Pillow>=9.0.0
37
+ Provides-Extra: dev
38
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
39
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
40
+ Requires-Dist: black>=23.0.0; extra == "dev"
41
+ Requires-Dist: flake8>=6.0.0; extra == "dev"
42
+ Requires-Dist: mypy>=1.0.0; extra == "dev"
43
+ Provides-Extra: docs
44
+ Requires-Dist: sphinx>=5.0.0; extra == "docs"
45
+ Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == "docs"
46
+ Provides-Extra: all
47
+ Requires-Dist: pytest>=7.0.0; extra == "all"
48
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "all"
49
+ Requires-Dist: black>=23.0.0; extra == "all"
50
+ Requires-Dist: flake8>=6.0.0; extra == "all"
51
+ Requires-Dist: mypy>=1.0.0; extra == "all"
52
+ Requires-Dist: sphinx>=5.0.0; extra == "all"
53
+ Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == "all"
54
+ Dynamic: author
55
+ Dynamic: home-page
56
+ Dynamic: license-file
57
+ Dynamic: requires-python
58
+
59
+ # GEPA Optimizer
60
+
61
+ [![PyPI version](https://badge.fury.io/py/gepa-optimizer.svg)](https://badge.fury.io/py/gepa-optimizer)
62
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
63
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
64
+
65
+ A universal prompt optimization framework built on [GEPA](https://arxiv.org/abs/2507.19457) with optional [LLEGO](https://arxiv.org/abs/2503.14217) genetic operators for accelerated convergence.
66
+
67
+ ## Overview
68
+
69
+ GEPA Optimizer provides a modular architecture for optimizing prompts through reflective evolution. It requires custom evaluators and LLM clients, enabling domain-specific optimization for any use case.
70
+
71
+ **Key capabilities:**
72
+ - Multi-modal support (text + vision models)
73
+ - Hybrid GEPA + LLEGO optimization modes
74
+ - Configurable train/val/test data splitting
75
+ - Batch API support for cost reduction
76
+ - Async-first architecture
77
+
78
+ ## Installation
79
+
80
+ ```bash
81
+ pip install gepa-optimizer
82
+ ```
83
+
84
+ **From source:**
85
+ ```bash
86
+ git clone https://github.com/suhasb-dev/Prompt-Optimizer.git
87
+ cd Prompt-Optimizer
88
+ pip install -e .
89
+ ```
90
+
91
+ ## Quick Start
92
+
93
+ ```python
94
+ import asyncio
95
+ from gepa_optimizer import (
96
+ GepaOptimizer,
97
+ OptimizationConfig,
98
+ BaseEvaluator,
99
+ BaseLLMClient
100
+ )
101
+
102
+ # Define custom evaluator
103
+ class MyEvaluator(BaseEvaluator):
104
+ def evaluate(self, predicted: str, expected: str) -> dict:
105
+ score = 1.0 if predicted.strip() == expected.strip() else 0.0
106
+ return {"accuracy": score, "composite_score": score}
107
+
108
+ # Define custom LLM client
109
+ class MyLLMClient(BaseLLMClient):
110
+ def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> dict:
111
+ # Your LLM integration here
112
+ return {"content": "response"}
113
+
114
+ async def main():
115
+ config = OptimizationConfig(
116
+ model="openai/gpt-4o",
117
+ reflection_model="openai/gpt-4o",
118
+ max_iterations=5,
119
+ max_metric_calls=50,
120
+ batch_size=8
121
+ )
122
+
123
+ optimizer = GepaOptimizer(
124
+ config=config,
125
+ llm_client=MyLLMClient("openai", "gpt-4o"),
126
+ evaluator=MyEvaluator()
127
+ )
128
+
129
+ result = await optimizer.train(
130
+ seed_prompt="Your initial prompt",
131
+ dataset=your_dataset
132
+ )
133
+
134
+ print(f"Optimized: {result.prompt}")
135
+ print(f"Score: {result.improvement_data}")
136
+
137
+ asyncio.run(main())
138
+ ```
139
+
140
+ ## Project Structure
141
+
142
+ ```
143
+ src/gepa_optimizer/
144
+ ├── core/ # Core optimization logic
145
+ │ ├── optimizer.py # GepaOptimizer main class
146
+ �� ├── base_adapter.py # BaseGepaAdapter interface
147
+ │ └── universal_adapter.py
148
+ ├── evaluation/ # Evaluator implementations
149
+ │ ├── base_evaluator.py # BaseEvaluator abstract class
150
+ │ ├── scroll_evaluator.py
151
+ │ ├── validation_evaluator.py
152
+ │ └── index_caching_evaluator.py
153
+ ├── llms/ # LLM client implementations
154
+ │ ├── base_llm.py # BaseLLMClient abstract class
155
+ │ ├── vision_llm.py # VisionLLMClient (OpenAI, Google, Anthropic)
156
+ │ └── batch_llm.py # BatchLLMClient (50% cost savings)
157
+ ├── operators/ # LLEGO genetic operators
158
+ │ └── llego_operators.py # FitnessGuidedCrossover, DiversityGuidedMutation
159
+ ├── data/ # Dataset loaders and converters
160
+ ├── models/ # Configuration and result models
161
+ └── utils/ # Utilities and helpers
162
+ ```
163
+
164
+ ## Configuration
165
+
166
+ ### Basic Configuration
167
+
168
+ ```python
169
+ from gepa_optimizer import OptimizationConfig, ModelConfig
170
+
171
+ config = OptimizationConfig(
172
+ # Required parameters
173
+ model="openai/gpt-4o", # or ModelConfig instance
174
+ reflection_model="openai/gpt-4o",
175
+ max_iterations=10,
176
+ max_metric_calls=100,
177
+ batch_size=8,
178
+
179
+ # Data splitting (train/val/test)
180
+ data_split=DataSplitConfig(
181
+ train_ratio=0.6,
182
+ val_ratio=0.2,
183
+ test_ratio=0.2
184
+ ),
185
+
186
+ # Optional settings
187
+ reflection_examples=3, # Examples per reflection (2-5 recommended)
188
+ evaluate_on_test=True, # Final evaluation on held-out test set
189
+ log_level="INFO" # DEBUG, INFO, WARNING, ERROR
190
+ )
191
+ ```
192
+
193
+ ### LLEGO Genetic Operators
194
+
195
+ Enable LLEGO for faster convergence through fitness-guided crossover and diversity-guided mutation:
196
+
197
+ ```python
198
+ config = OptimizationConfig(
199
+ model="openai/gpt-4o",
200
+ reflection_model="openai/gpt-4o",
201
+ max_iterations=5,
202
+ max_metric_calls=50,
203
+ batch_size=8,
204
+
205
+ # Enable LLEGO
206
+ use_llego_operators=True,
207
+ alpha=0.15, # Fitness extrapolation factor
208
+ tau=10.0, # Diversity temperature
209
+ nu=4, # Parent arity
210
+ n_crossover=2, # Crossover offspring per iteration
211
+ n_mutation=3, # Mutation offspring per iteration
212
+ population_size=15
213
+ )
214
+ ```
215
+
216
+ ### Hybrid Mode (GEPA + LLEGO)
217
+
218
+ Combine GEPA's semantic reflection with LLEGO's structural diversity:
219
+
220
+ ```python
221
+ config = OptimizationConfig(
222
+ model="openai/gpt-4o",
223
+ reflection_model="openai/gpt-4o",
224
+ max_iterations=6,
225
+ max_metric_calls=200,
226
+ batch_size=10,
227
+
228
+ # Hybrid mode
229
+ use_llego_operators=True,
230
+ enable_gepa_reflection_with_llego=True,
231
+ num_gepa_reflection_candidates=3,
232
+ n_crossover=3,
233
+ n_mutation=3
234
+ # Total: 9 candidates per iteration (3 GEPA + 3 crossover + 3 mutation)
235
+ )
236
+ ```
237
+
238
+ ### Batch API (Cost Optimization)
239
+
240
+ Use batch processing for 50% cost reduction:
241
+
242
+ ```python
243
+ from gepa_optimizer.llms import BatchLLMClient
244
+
245
+ llm_client = BatchLLMClient(
246
+ provider="google",
247
+ model_name="gemini-2.5-flash",
248
+ batch_size=20,
249
+ polling_interval=30
250
+ )
251
+
252
+ optimizer = GepaOptimizer(
253
+ config=config,
254
+ llm_client=llm_client,
255
+ evaluator=evaluator
256
+ )
257
+ ```
258
+
259
+ ## Built-in Components
260
+
261
+ ### LLM Clients
262
+
263
+ | Client | Description | Use Case |
264
+ |--------|-------------|----------|
265
+ | `VisionLLMClient` | Multi-modal client for OpenAI, Google, Anthropic | Real-time requests |
266
+ | `BatchLLMClient` | Batch processing client | Cost-sensitive workloads |
267
+
268
+ ### Evaluators
269
+
270
+ | Evaluator | Description |
271
+ |-----------|-------------|
272
+ | `ScrollElementEvaluator` | UI element detection scoring |
273
+ | `ValidationEvaluator` | Screen validation tasks |
274
+ | `IndexCachingEvaluator` | Index-based element selection |
275
+ | `UITreeEvaluator` | UI tree extraction |
276
+
277
+ ### Dataset Loaders
278
+
279
+ | Loader | Description |
280
+ |--------|-------------|
281
+ | `load_scroll_dataset()` | Load scroll detection datasets |
282
+ | `load_validation_split()` | Load validation datasets with splits |
283
+ | `load_index_caching_split()` | Load index caching datasets |
284
+
285
+ ## Creating Custom Components
286
+
287
+ ### Custom Evaluator
288
+
289
+ ```python
290
+ from gepa_optimizer import BaseEvaluator
291
+
292
+ class CustomEvaluator(BaseEvaluator):
293
+ def __init__(self):
294
+ super().__init__(metric_weights={
295
+ "accuracy": 0.5,
296
+ "completeness": 0.3,
297
+ "format": 0.2
298
+ })
299
+
300
+ def evaluate(self, predicted: str, expected: str) -> dict:
301
+ accuracy = self._compute_accuracy(predicted, expected)
302
+ completeness = self._compute_completeness(predicted, expected)
303
+ format_score = self._compute_format(predicted)
304
+
305
+ composite = (
306
+ accuracy * 0.5 +
307
+ completeness * 0.3 +
308
+ format_score * 0.2
309
+ )
310
+
311
+ return {
312
+ "accuracy": accuracy,
313
+ "completeness": completeness,
314
+ "format": format_score,
315
+ "composite_score": composite # Required key
316
+ }
317
+ ```
318
+
319
+ ### Custom LLM Client
320
+
321
+ ```python
322
+ from gepa_optimizer import BaseLLMClient
323
+
324
+ class CustomLLMClient(BaseLLMClient):
325
+ def __init__(self, api_key: str):
326
+ super().__init__(provider="custom", model_name="my-model")
327
+ self.api_key = api_key
328
+
329
+ def generate(
330
+ self,
331
+ system_prompt: str,
332
+ user_prompt: str,
333
+ image_base64: str = None,
334
+ **kwargs
335
+ ) -> dict:
336
+ # Your API call here
337
+ response = call_your_api(system_prompt, user_prompt, image_base64)
338
+ return {"content": response}
339
+ ```
340
+
341
+ ## Examples
342
+
343
+ | File | Description |
344
+ |------|-------------|
345
+ | [`examples/basic_usage.py`](examples/basic_usage.py) | Basic optimization workflow |
346
+ | [`examples/advanced_usage.py`](examples/advanced_usage.py) | Advanced configuration |
347
+ | [`examples/batch_api_example.py`](examples/batch_api_example.py) | Batch API usage |
348
+ | [`examples/gemini_usage.py`](examples/gemini_usage.py) | Google Gemini integration |
349
+
350
+ **Run examples:**
351
+ ```bash
352
+ python examples/basic_usage.py
353
+ ```
354
+
355
+ ## Testing
356
+
357
+ ```bash
358
+ # Run all tests
359
+ pytest tests/
360
+
361
+ # Run unit tests only
362
+ pytest tests/unit/
363
+
364
+ # Run integration tests
365
+ pytest tests/integration/
366
+ ```
367
+
368
+ ## API Reference
369
+
370
+ ### GepaOptimizer
371
+
372
+ ```python
373
+ class GepaOptimizer:
374
+ def __init__(
375
+ self,
376
+ config: OptimizationConfig,
377
+ llm_client: BaseLLMClient,
378
+ evaluator: BaseEvaluator,
379
+ adapter_type: str = "universal"
380
+ )
381
+
382
+ async def train(
383
+ self,
384
+ seed_prompt: str,
385
+ dataset: Union[List, Dict],
386
+ **kwargs
387
+ ) -> OptimizedResult
388
+ ```
389
+
390
+ ### OptimizationConfig
391
+
392
+ | Parameter | Type | Default | Description |
393
+ |-----------|------|---------|-------------|
394
+ | `model` | `str \| ModelConfig` | Required | Target model |
395
+ | `reflection_model` | `str \| ModelConfig` | Required | Reflection model |
396
+ | `max_iterations` | `int` | Required | Maximum optimization iterations |
397
+ | `max_metric_calls` | `int` | Required | Maximum evaluation calls |
398
+ | `batch_size` | `int` | Required | Samples per evaluation batch |
399
+ | `use_llego_operators` | `bool` | `False` | Enable LLEGO genetic operators |
400
+ | `enable_gepa_reflection_with_llego` | `bool` | `False` | Enable hybrid mode |
401
+ | `use_llm_as_judge` | `bool` | `True` | Enable LLM-as-Judge feedback |
402
+ | `log_level` | `str` | `"INFO"` | Logging verbosity |
403
+
404
+ ### OptimizedResult
405
+
406
+ | Attribute | Type | Description |
407
+ |-----------|------|-------------|
408
+ | `prompt` | `str` | Optimized prompt |
409
+ | `original_prompt` | `str` | Initial seed prompt |
410
+ | `improvement_data` | `dict` | Score improvements |
411
+ | `optimization_time` | `float` | Total time in seconds |
412
+ | `is_successful` | `bool` | Optimization success status |
413
+
414
+ ## Environment Variables
415
+
416
+ | Variable | Description |
417
+ |----------|-------------|
418
+ | `OPENAI_API_KEY` | OpenAI API key |
419
+ | `ANTHROPIC_API_KEY` | Anthropic API key |
420
+ | `GOOGLE_API_KEY` | Google AI API key |
421
+
422
+ ## References
423
+
424
+ - **GEPA Paper:** [Reflective Prompt Evolution Can Outperform Reinforcement Learning](https://arxiv.org/abs/2507.19457)
425
+ - **LLEGO Paper:** [Decision Tree Induction Through LLMs via Semantically-Aware Evolution](https://arxiv.org/abs/2503.14217)
426
+ - **GEPA Library:** [github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa)
427
+
428
+ ## License
429
+
430
+ MIT License - see [LICENSE](LICENSE) for details.
431
+
432
+ ## Contributing
433
+
434
+ Contributions welcome. Please open an issue or submit a pull request.
435
+
436
+ ## Support
437
+
438
+ - **Issues:** [GitHub Issues](https://github.com/suhasb-dev/Prompt-Optimizer/issues)
439
+ - **Documentation:** [GitBook](https://suhasb-dev.gitbook.io/gepa-universal-prompt-optimizer/)
src/gepa_optimizer.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ src/gepa_optimizer/__init__.py
6
+ src/gepa_optimizer/cli.py
7
+ src/gepa_optimizer/types.py
8
+ src/gepa_optimizer/version.py
9
+ src/gepa_optimizer.egg-info/PKG-INFO
10
+ src/gepa_optimizer.egg-info/SOURCES.txt
11
+ src/gepa_optimizer.egg-info/dependency_links.txt
12
+ src/gepa_optimizer.egg-info/entry_points.txt
13
+ src/gepa_optimizer.egg-info/requires.txt
14
+ src/gepa_optimizer.egg-info/top_level.txt
15
+ src/gepa_optimizer/core/__init__.py
16
+ src/gepa_optimizer/core/base_adapter.py
17
+ src/gepa_optimizer/core/custom_adapter.py
18
+ src/gepa_optimizer/core/optimizer.py
19
+ src/gepa_optimizer/core/result.py
20
+ src/gepa_optimizer/core/universal_adapter.py
21
+ src/gepa_optimizer/data/__init__.py
22
+ src/gepa_optimizer/data/converters.py
23
+ src/gepa_optimizer/data/index_caching_loader.py
24
+ src/gepa_optimizer/data/loaders.py
25
+ src/gepa_optimizer/data/scroll_dataset_loader.py
26
+ src/gepa_optimizer/data/validation_dataset_loader.py
27
+ src/gepa_optimizer/data/validators.py
28
+ src/gepa_optimizer/evaluation/__init__.py
29
+ src/gepa_optimizer/evaluation/base_evaluator.py
30
+ src/gepa_optimizer/evaluation/index_caching_evaluator.py
31
+ src/gepa_optimizer/evaluation/scroll_evaluator.py
32
+ src/gepa_optimizer/evaluation/ui_evaluator.py
33
+ src/gepa_optimizer/evaluation/universal_evaluator.py
34
+ src/gepa_optimizer/evaluation/validation_evaluator.py
35
+ src/gepa_optimizer/infrastructure/__init__.py
36
+ src/gepa_optimizer/infrastructure/logging/__init__.py
37
+ src/gepa_optimizer/infrastructure/logging/context.py
38
+ src/gepa_optimizer/infrastructure/logging/formatters.py
39
+ src/gepa_optimizer/infrastructure/logging/logger.py
40
+ src/gepa_optimizer/llms/__init__.py
41
+ src/gepa_optimizer/llms/base_llm.py
42
+ src/gepa_optimizer/llms/batch_llm.py
43
+ src/gepa_optimizer/llms/llego_enhanced_llm.py
44
+ src/gepa_optimizer/llms/vision_llm.py
45
+ src/gepa_optimizer/models/__init__.py
46
+ src/gepa_optimizer/models/config.py
47
+ src/gepa_optimizer/models/dataset.py
48
+ src/gepa_optimizer/models/result.py
49
+ src/gepa_optimizer/operators/__init__.py
50
+ src/gepa_optimizer/operators/base_operator.py
51
+ src/gepa_optimizer/operators/crossover.py
52
+ src/gepa_optimizer/operators/llego_operators.py
53
+ src/gepa_optimizer/operators/models.py
54
+ src/gepa_optimizer/operators/mutation.py
55
+ src/gepa_optimizer/utils/__init__.py
56
+ src/gepa_optimizer/utils/api_keys.py
57
+ src/gepa_optimizer/utils/candidate_collector.py
58
+ src/gepa_optimizer/utils/clean_logger.py
59
+ src/gepa_optimizer/utils/exceptions.py
60
+ src/gepa_optimizer/utils/helpers.py
61
+ src/gepa_optimizer/utils/llm_judge_prompt.py
62
+ src/gepa_optimizer/utils/log_parser.py
63
+ src/gepa_optimizer/utils/logging.py
64
+ src/gepa_optimizer/utils/metrics.py
65
+ src/gepa_optimizer/utils/pareto_logger.py
src/gepa_optimizer.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/gepa_optimizer.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ gepa-optimize = gepa_optimizer.cli:main
src/gepa_optimizer.egg-info/requires.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gepa>=0.0.12
2
+ pandas>=1.5.0
3
+ pydantic>=2.0.0
4
+ python-dotenv>=1.0.0
5
+ requests>=2.31.0
6
+ aiohttp>=3.8.0
7
+ asyncio-throttle>=1.0.0
8
+ google-generativeai>=0.3.0
9
+ Pillow>=9.0.0
10
+
11
+ [all]
12
+ pytest>=7.0.0
13
+ pytest-asyncio>=0.21.0
14
+ black>=23.0.0
15
+ flake8>=6.0.0
16
+ mypy>=1.0.0
17
+ sphinx>=5.0.0
18
+ sphinx-rtd-theme>=1.2.0
19
+
20
+ [dev]
21
+ pytest>=7.0.0
22
+ pytest-asyncio>=0.21.0
23
+ black>=23.0.0
24
+ flake8>=6.0.0
25
+ mypy>=1.0.0
26
+
27
+ [docs]
28
+ sphinx>=5.0.0
29
+ sphinx-rtd-theme>=1.2.0
src/gepa_optimizer.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gepa_optimizer
src/gepa_optimizer/__init__.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GEPA Universal Prompt Optimizer
3
+
4
+ A modern, modular Python library for universal prompt optimization powered by GEPA.
5
+
6
+ Quick Start (No custom evaluator needed!):
7
+
8
+ from gepa_optimizer import quick_optimize
9
+
10
+ result = await quick_optimize(
11
+ seed_prompt="Your initial prompt",
12
+ dataset=[
13
+ {"input": "task1", "output": "expected1"},
14
+ {"input": "task2", "output": "expected2"},
15
+ ],
16
+ model="openai/gpt-4o" # or any: "google/gemini-1.5-pro", "anthropic/claude-3-5-sonnet-20241022"
17
+ )
18
+ print(result.optimized_prompt)
19
+ """
20
+
21
+ # Core functionality
22
+ from .core import GepaOptimizer
23
+ from .core.base_adapter import BaseGepaAdapter
24
+ from .core.universal_adapter import UniversalGepaAdapter
25
+
26
+ # Configuration and models
27
+ from .models import OptimizationConfig, OptimizationResult, OptimizedResult, ModelConfig
28
+
29
+ # Data processing
30
+ from .data import UniversalConverter, DataLoader, DataValidator
31
+ from .data.scroll_dataset_loader import ScrollDatasetLoader, load_scroll_dataset
32
+ from .data.validation_dataset_loader import ValidationDatasetLoader, load_validation_dataset, load_validation_split
33
+ from .data.index_caching_loader import IndexCachingDatasetLoader, load_index_caching_dataset, load_index_caching_split
34
+
35
+ # LLM clients
36
+ from .llms import VisionLLMClient
37
+ from .llms.base_llm import BaseLLMClient
38
+ from .llms.batch_llm import BatchLLMClient
39
+
40
+ # Evaluators - including Universal Semantic Evaluator (works for ANY task!)
41
+ from .evaluation import (
42
+ BaseEvaluator,
43
+ UniversalSemanticEvaluator,
44
+ create_universal_evaluator,
45
+ UITreeEvaluator,
46
+ ScrollElementEvaluator,
47
+ ValidationEvaluator,
48
+ IndexCachingEvaluator
49
+ )
50
+
51
+ # LLEGO Genetic Operators
52
+ from .operators import (
53
+ # Base interfaces
54
+ BaseGeneticOperator,
55
+ BaseCrossoverOperator,
56
+ BaseMutationOperator,
57
+ # Concrete operators
58
+ FitnessGuidedCrossover,
59
+ DiversityGuidedMutation,
60
+ LLEGOIntegrationLayer,
61
+ # Data models
62
+ PromptCandidate,
63
+ PromptMetadata
64
+ )
65
+
66
+ # Utilities
67
+ from .utils import setup_logging, calculate_metrics, sanitize_prompt, APIKeyManager
68
+ from .utils.exceptions import GepaOptimizerError, GepaDependencyError, InvalidInputError, DatasetError
69
+
70
+ # Logging infrastructure
71
+ from .infrastructure.logging import get_logger, configure_logging, LogContext
72
+
73
+ # Type definitions (for type hints in user code)
74
+ from .types import (
75
+ DatasetItem,
76
+ EvaluationResult,
77
+ LLMResponse,
78
+ CandidateDict,
79
+ LLMClientProtocol,
80
+ EvaluatorProtocol,
81
+ )
82
+
83
+ __version__ = "0.1.0"
84
+
85
+
86
+ # ═══════════════════════════════════════════════════════════════════════════════
87
+ # CONVENIENCE FUNCTION: quick_optimize
88
+ # No evaluator needed - uses Universal Semantic Evaluator automatically
89
+ # ═══════════════════════════════════════════════════════════════════════════════
90
+
91
+ async def quick_optimize(
92
+ seed_prompt: str,
93
+ dataset: list,
94
+ model: str,
95
+ max_iterations: int = 5,
96
+ max_metric_calls: int = 50,
97
+ batch_size: int = 4,
98
+ use_llego: bool = True,
99
+ verbose: bool = True
100
+ ) -> OptimizedResult:
101
+ """
102
+ 🚀 Quick prompt optimization - no custom evaluator needed!
103
+
104
+ Uses Universal Semantic Evaluator that works for ANY task.
105
+
106
+ Args:
107
+ seed_prompt: Your initial prompt to optimize
108
+ dataset: List of dicts with 'input' and 'output' (expected) keys
109
+ Can also include 'image' key for multi-modal tasks
110
+ model: LLM model to use in format "provider/model-name" (REQUIRED)
111
+ Examples:
112
+ - "google/gemini-1.5-pro"
113
+ - "google/gemini-2.5-flash-preview-05-20"
114
+ - "openai/gpt-4o"
115
+ - "openai/gpt-4-turbo"
116
+ - "anthropic/claude-3-5-sonnet-20241022"
117
+ max_iterations: Maximum optimization iterations (default: 5)
118
+ max_metric_calls: Maximum evaluation calls (default: 50)
119
+ batch_size: Samples per evaluation batch (default: 4)
120
+ use_llego: Enable LLEGO genetic operators (default: True)
121
+ verbose: Show progress logs (default: True)
122
+
123
+ Returns:
124
+ OptimizedResult with optimized prompt and improvement metrics
125
+
126
+ Example:
127
+ >>> result = await quick_optimize(
128
+ ... seed_prompt="Count the objects in the image",
129
+ ... dataset=[
130
+ ... {"input": "image1.jpg", "output": "5 objects", "image": "base64..."},
131
+ ... {"input": "image2.jpg", "output": "3 objects", "image": "base64..."},
132
+ ... ],
133
+ ... model="openai/gpt-4o", # or "google/gemini-1.5-pro", etc.
134
+ ... max_iterations=3
135
+ ... )
136
+ >>> print(result.optimized_prompt)
137
+ """
138
+ import logging
139
+
140
+ if verbose:
141
+ logging.basicConfig(level=logging.INFO)
142
+
143
+ # Create LLM client
144
+ llm_client = VisionLLMClient.from_model_string(model)
145
+
146
+ # Create Universal Semantic Evaluator (uses same LLM for analysis)
147
+ evaluator = UniversalSemanticEvaluator(
148
+ llm_client=llm_client,
149
+ use_llm_analysis=True
150
+ )
151
+
152
+ # Create configuration
153
+ config = OptimizationConfig(
154
+ model=model,
155
+ reflection_model=model,
156
+ max_iterations=max_iterations,
157
+ max_metric_calls=max_metric_calls,
158
+ batch_size=batch_size,
159
+ use_llego_operators=use_llego,
160
+ enable_gepa_reflection_with_llego=use_llego,
161
+ num_gepa_reflection_candidates=3,
162
+ n_crossover=2,
163
+ n_mutation=2,
164
+ verbose=verbose
165
+ )
166
+
167
+ # Create optimizer
168
+ optimizer = GepaOptimizer(
169
+ config=config,
170
+ llm_client=llm_client,
171
+ evaluator=evaluator
172
+ )
173
+
174
+ # Run optimization
175
+ result = await optimizer.train(
176
+ seed_prompt=seed_prompt,
177
+ dataset=dataset
178
+ )
179
+
180
+ return result
181
+
182
+
183
+ def quick_optimize_sync(
184
+ seed_prompt: str,
185
+ dataset: list,
186
+ model: str,
187
+ max_iterations: int = 5,
188
+ max_metric_calls: int = 50,
189
+ batch_size: int = 4,
190
+ use_llego: bool = True,
191
+ verbose: bool = True
192
+ ) -> OptimizedResult:
193
+ """
194
+ 🚀 Synchronous version of quick_optimize.
195
+
196
+ Same as quick_optimize but runs synchronously (blocks until complete).
197
+
198
+ Args:
199
+ model: LLM model to use in format "provider/model-name" (REQUIRED)
200
+ Examples: "openai/gpt-4o", "google/gemini-1.5-pro", "anthropic/claude-3-5-sonnet-20241022"
201
+
202
+ See quick_optimize for full documentation.
203
+ """
204
+ import asyncio
205
+ return asyncio.run(quick_optimize(
206
+ seed_prompt=seed_prompt,
207
+ dataset=dataset,
208
+ model=model,
209
+ max_iterations=max_iterations,
210
+ max_metric_calls=max_metric_calls,
211
+ batch_size=batch_size,
212
+ use_llego=use_llego,
213
+ verbose=verbose
214
+ ))
215
+
216
+
217
+ __all__ = [
218
+ # 🚀 Quick Start (recommended for new users)
219
+ "quick_optimize",
220
+ "quick_optimize_sync",
221
+
222
+ # Core functionality
223
+ "GepaOptimizer",
224
+ "BaseGepaAdapter",
225
+ "UniversalGepaAdapter",
226
+
227
+ # Configuration
228
+ "OptimizationConfig",
229
+ "OptimizationResult",
230
+ "OptimizedResult",
231
+ "ModelConfig",
232
+
233
+ # Data processing
234
+ "UniversalConverter",
235
+ "DataLoader",
236
+ "DataValidator",
237
+
238
+ # Dataset loaders
239
+ "ScrollDatasetLoader",
240
+ "load_scroll_dataset",
241
+ "ValidationDatasetLoader",
242
+ "load_validation_dataset",
243
+ "load_validation_split",
244
+ "IndexCachingDatasetLoader",
245
+ "load_index_caching_dataset",
246
+ "load_index_caching_split",
247
+
248
+ # LLM clients
249
+ "VisionLLMClient",
250
+ "BaseLLMClient",
251
+ "BatchLLMClient",
252
+
253
+ # Evaluators (Universal recommended for general use)
254
+ "UniversalSemanticEvaluator",
255
+ "create_universal_evaluator",
256
+ "BaseEvaluator",
257
+ "UITreeEvaluator",
258
+ "ScrollElementEvaluator",
259
+ "ValidationEvaluator",
260
+ "IndexCachingEvaluator",
261
+
262
+ # LLEGO Genetic Operators - Base interfaces
263
+ "BaseGeneticOperator",
264
+ "BaseCrossoverOperator",
265
+ "BaseMutationOperator",
266
+ # LLEGO Genetic Operators - Concrete implementations
267
+ "FitnessGuidedCrossover",
268
+ "DiversityGuidedMutation",
269
+ "LLEGOIntegrationLayer",
270
+ "PromptCandidate",
271
+ "PromptMetadata",
272
+
273
+ # Utilities
274
+ "APIKeyManager",
275
+ "GepaOptimizerError",
276
+ "GepaDependencyError",
277
+ "InvalidInputError",
278
+ "DatasetError",
279
+ "setup_logging",
280
+ "calculate_metrics",
281
+ "sanitize_prompt",
282
+
283
+ # Logging infrastructure
284
+ "get_logger",
285
+ "configure_logging",
286
+ "LogContext",
287
+
288
+ # Type definitions
289
+ "DatasetItem",
290
+ "EvaluationResult",
291
+ "LLMResponse",
292
+ "CandidateDict",
293
+ "LLMClientProtocol",
294
+ "EvaluatorProtocol",
295
+ ]
src/gepa_optimizer/cli.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Command Line Interface for GEPA Optimizer
3
+ """
4
+
5
+ import argparse
6
+ import sys
7
+ import json
8
+ import asyncio
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from .core import GepaOptimizer
13
+ from .models import OptimizationConfig, ModelConfig
14
+ from .utils import setup_logging, APIKeyManager
15
+
16
+
17
+ def main():
18
+ """Main CLI entry point"""
19
+ parser = argparse.ArgumentParser(
20
+ description="GEPA Universal Prompt Optimizer CLI",
21
+ formatter_class=argparse.RawDescriptionHelpFormatter,
22
+ epilog="""
23
+ Examples:
24
+ gepa-optimize --model openai/gpt-4-turbo --prompt "Extract UI elements" --dataset data.json
25
+ gepa-optimize --config config.json --prompt "Analyze interface" --dataset images/
26
+ """
27
+ )
28
+
29
+ # Required arguments
30
+ parser.add_argument(
31
+ "--prompt",
32
+ required=True,
33
+ help="Initial seed prompt to optimize"
34
+ )
35
+ parser.add_argument(
36
+ "--dataset",
37
+ required=True,
38
+ help="Path to dataset file or directory"
39
+ )
40
+
41
+ # Model configuration
42
+ parser.add_argument(
43
+ "--model",
44
+ help="Model specification (e.g., 'openai/gpt-4-turbo')"
45
+ )
46
+ parser.add_argument(
47
+ "--reflection-model",
48
+ help="Reflection model specification"
49
+ )
50
+ parser.add_argument(
51
+ "--config",
52
+ help="Path to configuration JSON file"
53
+ )
54
+
55
+ # Optimization parameters
56
+ parser.add_argument(
57
+ "--max-iterations",
58
+ type=int,
59
+ default=10,
60
+ help="Maximum optimization iterations (default: 10)"
61
+ )
62
+ parser.add_argument(
63
+ "--max-metric-calls",
64
+ type=int,
65
+ default=100,
66
+ help="Maximum metric evaluation calls (default: 100)"
67
+ )
68
+ parser.add_argument(
69
+ "--batch-size",
70
+ type=int,
71
+ default=4,
72
+ help="Batch size for evaluation (default: 4)"
73
+ )
74
+
75
+ # GEPA-specific parameters
76
+ parser.add_argument(
77
+ "--candidate-selection-strategy",
78
+ type=str,
79
+ default="pareto",
80
+ choices=["pareto", "best"],
81
+ help="Strategy for selecting candidates (default: pareto)"
82
+ )
83
+ parser.add_argument(
84
+ "--skip-perfect-score",
85
+ action="store_true",
86
+ help="Skip updating candidates with perfect scores"
87
+ )
88
+ parser.add_argument(
89
+ "--reflection-minibatch-size",
90
+ type=int,
91
+ default=None,
92
+ help="Number of examples to use for reflection (default: use batch_size)"
93
+ )
94
+ parser.add_argument(
95
+ "--perfect-score",
96
+ type=float,
97
+ default=1.0,
98
+ help="Perfect score threshold (default: 1.0)"
99
+ )
100
+ parser.add_argument(
101
+ "--module-selector",
102
+ type=str,
103
+ default="round_robin",
104
+ choices=["round_robin", "all"],
105
+ help="Component selection strategy (default: round_robin)"
106
+ )
107
+
108
+ # Output options
109
+ parser.add_argument(
110
+ "--output",
111
+ help="Output file path for results (default: stdout)"
112
+ )
113
+ parser.add_argument(
114
+ "--verbose", "-v",
115
+ action="store_true",
116
+ help="Enable verbose logging"
117
+ )
118
+
119
+ args = parser.parse_args()
120
+
121
+ # Setup logging
122
+ setup_logging(level="DEBUG" if args.verbose else "INFO")
123
+
124
+ try:
125
+ # Load configuration
126
+ if args.config:
127
+ config = load_config_from_file(args.config)
128
+ else:
129
+ config = create_config_from_args(args)
130
+
131
+ # Validate API keys
132
+ validate_api_keys(config)
133
+
134
+ # Create optimizer
135
+ optimizer = GepaOptimizer(config=config)
136
+
137
+ # Run optimization (async)
138
+ print(f"🚀 Starting optimization with model: {config.model.model_name}")
139
+ result = asyncio.run(optimizer.train(
140
+ seed_prompt=args.prompt,
141
+ dataset=args.dataset
142
+ ))
143
+
144
+ # Output results
145
+ output_results(result, args.output)
146
+
147
+ print("✅ Optimization completed successfully!")
148
+
149
+ except Exception as e:
150
+ print(f"❌ Error: {str(e)}", file=sys.stderr)
151
+ sys.exit(1)
152
+
153
+
154
+ def load_config_from_file(config_path: str) -> OptimizationConfig:
155
+ """Load configuration from JSON file"""
156
+ path = Path(config_path)
157
+ if not path.exists():
158
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
159
+
160
+ with open(path, 'r') as f:
161
+ config_data = json.load(f)
162
+
163
+ # Convert model configs
164
+ if 'model' in config_data and isinstance(config_data['model'], dict):
165
+ config_data['model'] = ModelConfig(**config_data['model'])
166
+
167
+ if 'reflection_model' in config_data and isinstance(config_data['reflection_model'], dict):
168
+ config_data['reflection_model'] = ModelConfig(**config_data['reflection_model'])
169
+
170
+ return OptimizationConfig(**config_data)
171
+
172
+
173
+ def create_config_from_args(args) -> OptimizationConfig:
174
+ """Create configuration from command line arguments"""
175
+ if not args.model:
176
+ raise ValueError("Either --model or --config must be specified")
177
+
178
+ # Parse model specification
179
+ model_config = ModelConfig.from_string(args.model)
180
+
181
+ reflection_model_config = None
182
+ if args.reflection_model:
183
+ reflection_model_config = ModelConfig.from_string(args.reflection_model)
184
+
185
+ return OptimizationConfig(
186
+ model=model_config,
187
+ reflection_model=reflection_model_config,
188
+ max_iterations=args.max_iterations,
189
+ max_metric_calls=args.max_metric_calls,
190
+ batch_size=args.batch_size
191
+ )
192
+
193
+
194
+ def validate_api_keys(config: OptimizationConfig):
195
+ """Validate that required API keys are available"""
196
+ api_manager = APIKeyManager()
197
+
198
+ providers = [config.model.provider]
199
+ if config.reflection_model:
200
+ providers.append(config.reflection_model.provider)
201
+
202
+ missing_keys = api_manager.get_missing_keys(providers)
203
+
204
+ if missing_keys:
205
+ print("❌ Missing API keys for the following providers:")
206
+ for provider in missing_keys:
207
+ print(f" - {provider.upper()}_API_KEY")
208
+ print("\nPlease set the required environment variables or use a .env file")
209
+ sys.exit(1)
210
+
211
+ def output_results(result, output_path: Optional[str]):
212
+ """Output optimization results"""
213
+ output_data = {
214
+ "optimized_prompt": result.prompt,
215
+ "original_prompt": result.original_prompt,
216
+ "improvement_metrics": result.improvement_data,
217
+ "optimization_time": result.optimization_time,
218
+ "status": result.status,
219
+ "session_id": result.session_id
220
+ }
221
+
222
+ if output_path:
223
+ with open(output_path, 'w') as f:
224
+ json.dump(output_data, f, indent=2)
225
+ print(f"📄 Results saved to: {output_path}")
226
+ else:
227
+ print("\n📊 Optimization Results:")
228
+ print(f"Session ID: {result.session_id}")
229
+ print(f"Status: {result.status}")
230
+ print(f"Time: {result.optimization_time:.2f}s")
231
+ print(f"\nOriginal Prompt:\n{result.original_prompt}")
232
+ print(f"\nOptimized Prompt:\n{result.prompt}")
233
+
234
+ if 'improvement_percent' in result.improvement_data:
235
+ print(f"\nImprovement: {result.improvement_data['improvement_percent']:.2f}%")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()
src/gepa_optimizer/core/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core functionality for GEPA Universal Prompt Optimizer
3
+ """
4
+
5
+ from .optimizer import GepaOptimizer
6
+ from .result import ResultProcessor
7
+
8
+ __all__ = ["GepaOptimizer", "ResultProcessor"]
src/gepa_optimizer/core/base_adapter.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base adapter class for all GEPA adapters.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, List, Optional
7
+ import logging
8
+ from gepa.core.adapter import GEPAAdapter, EvaluationBatch
9
+
10
+ from ..llms.base_llm import BaseLLMClient
11
+ from ..evaluation.base_evaluator import BaseEvaluator
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class BaseGepaAdapter(GEPAAdapter, ABC):
16
+ """
17
+ Abstract base class for GEPA adapters.
18
+
19
+ Provides the foundation for creating task-specific adapters while
20
+ maintaining compatibility with the GEPA framework.
21
+ """
22
+
23
+ def __init__(self, llm_client: BaseLLMClient, evaluator: BaseEvaluator):
24
+ """
25
+ Initialize adapter with LLM client and evaluator.
26
+
27
+ Args:
28
+ llm_client: LLM client for generating responses
29
+ evaluator: Evaluator for scoring predictions
30
+ """
31
+ if not isinstance(llm_client, BaseLLMClient):
32
+ raise TypeError("llm_client must be an instance of BaseLLMClient")
33
+ if not isinstance(evaluator, BaseEvaluator):
34
+ raise TypeError("evaluator must be an instance of BaseEvaluator")
35
+
36
+ self.llm_client = llm_client
37
+ self.evaluator = evaluator
38
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
39
+
40
+ # Performance tracking
41
+ self._evaluation_count = 0
42
+ self._best_score = 0.0
43
+ self._best_candidate = None
44
+
45
+ @abstractmethod
46
+ def evaluate(self, batch: List[Dict[str, Any]], candidate: Dict[str, str],
47
+ capture_traces: bool = False) -> EvaluationBatch:
48
+ """
49
+ Evaluate candidate on a batch of data.
50
+
51
+ Args:
52
+ batch: List of data items to evaluate
53
+ candidate: Prompt candidate to evaluate
54
+ capture_traces: Whether to capture detailed traces
55
+
56
+ Returns:
57
+ EvaluationBatch with outputs, scores, and optional trajectories
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def make_reflective_dataset(self, candidate: Dict[str, str],
63
+ eval_batch: EvaluationBatch,
64
+ components_to_update: List[str]) -> Dict[str, List[Dict[str, Any]]]:
65
+ """
66
+ Create reflective dataset for GEPA's reflection process.
67
+
68
+ Args:
69
+ candidate: Current prompt candidate
70
+ eval_batch: Results from evaluation
71
+ components_to_update: List of components to update
72
+
73
+ Returns:
74
+ Dictionary mapping components to reflection data
75
+ """
76
+ pass
77
+
78
+ def get_performance_stats(self) -> Dict[str, Any]:
79
+ """Get performance statistics for monitoring"""
80
+ return {
81
+ 'evaluation_count': self._evaluation_count,
82
+ 'best_score': self._best_score,
83
+ 'model_info': self.llm_client.get_model_info(),
84
+ 'evaluator_class': self.evaluator.__class__.__name__
85
+ }
src/gepa_optimizer/core/custom_adapter.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom GEPA Adapter for the GEPA Universal Prompt Optimizer
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import re
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ # Import ModelConfig
11
+ from ..models import ModelConfig
12
+
13
+ from gepa.core.adapter import GEPAAdapter, EvaluationBatch
14
+ from ..llms.vision_llm import VisionLLMClient
15
+ from ..evaluation.ui_evaluator import UITreeEvaluator
16
+ from .base_adapter import BaseGepaAdapter
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class CustomGepaAdapter(BaseGepaAdapter):
21
+ """
22
+ Custom adapter for the GEPA Universal Prompt Optimizer.
23
+ """
24
+
25
+ def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None):
26
+ """Initialize the custom GEPA adapter with model configuration."""
27
+ # Convert string model to ModelConfig if needed
28
+ if not isinstance(model_config, ModelConfig):
29
+ model_config = ModelConfig(
30
+ provider='openai',
31
+ model_name=str(model_config),
32
+ api_key=None
33
+ )
34
+
35
+ # Initialize components
36
+ llm_client = VisionLLMClient(
37
+ provider=model_config.provider,
38
+ model_name=model_config.model_name,
39
+ api_key=model_config.api_key,
40
+ base_url=model_config.base_url,
41
+ temperature=model_config.temperature,
42
+ max_tokens=model_config.max_tokens,
43
+ top_p=model_config.top_p,
44
+ frequency_penalty=model_config.frequency_penalty,
45
+ presence_penalty=model_config.presence_penalty
46
+ )
47
+
48
+ evaluator = UITreeEvaluator(metric_weights=metric_weights)
49
+
50
+ # Initialize parent class
51
+ super().__init__(llm_client, evaluator)
52
+
53
+ # Track candidates for logging
54
+ self._last_candidate = None
55
+ self._evaluation_count = 0
56
+
57
+ self.logger.info(f"🚀 Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}")
58
+
59
+ def _parse_json_safely(self, json_str: str) -> Dict[str, Any]:
60
+ """Safely parse JSON string to dictionary with enhanced parsing and repair."""
61
+ if not json_str or not isinstance(json_str, str):
62
+ return {}
63
+
64
+ # Try direct parsing first
65
+ try:
66
+ return json.loads(json_str)
67
+ except json.JSONDecodeError:
68
+ pass
69
+
70
+ # Try to extract JSON from markdown code blocks
71
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL)
72
+ if json_match:
73
+ try:
74
+ return json.loads(json_match.group(1))
75
+ except json.JSONDecodeError:
76
+ pass
77
+
78
+ # Try to find JSON object in the string
79
+ json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
80
+ if json_match:
81
+ try:
82
+ return json.loads(json_match.group(0))
83
+ except json.JSONDecodeError:
84
+ pass
85
+
86
+ # Try repair and parse
87
+ repaired_json = self._repair_json(json_str)
88
+ if repaired_json:
89
+ try:
90
+ return json.loads(repaired_json)
91
+ except json.JSONDecodeError:
92
+ pass
93
+
94
+ self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...")
95
+ return {}
96
+
97
+ def _repair_json(self, json_str: str) -> str:
98
+ """Attempt to repair common JSON issues."""
99
+ try:
100
+ # Remove markdown formatting
101
+ json_str = re.sub(r'```(?:json)?\s*', '', json_str)
102
+ json_str = re.sub(r'```\s*$', '', json_str)
103
+
104
+ # Remove extra text before/after JSON
105
+ json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
106
+ if json_match:
107
+ json_str = json_match.group(0)
108
+
109
+ # Fix common issues
110
+ json_str = re.sub(r',\s*}', '}', json_str) # Remove trailing commas
111
+ json_str = re.sub(r',\s*]', ']', json_str) # Remove trailing commas in arrays
112
+ json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str) # Quote unquoted keys
113
+
114
+ return json_str
115
+ except Exception as e:
116
+ self.logger.warning(f"🔧 JSON repair failed: {e}")
117
+ return ""
118
+
119
+ def evaluate(
120
+ self,
121
+ batch: List[Dict[str, Any]],
122
+ candidate: Dict[str, str],
123
+ capture_traces: bool = False,
124
+ ) -> EvaluationBatch:
125
+ """Evaluate the candidate on a batch of data."""
126
+ outputs = []
127
+ scores = []
128
+ trajectories = [] if capture_traces else None
129
+
130
+ system_prompt = candidate.get('system_prompt', '')
131
+
132
+ # Check if this is a new candidate (different from last one)
133
+ if self._last_candidate != system_prompt:
134
+ self._evaluation_count += 1
135
+ self.log_proposed_candidate(candidate, self._evaluation_count)
136
+ self._last_candidate = system_prompt
137
+
138
+ self.logger.info(f"📊 Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'")
139
+
140
+ for i, item in enumerate(batch):
141
+ input_text = item.get('input', '')
142
+ image_base64 = item.get('image', '')
143
+ ground_truth_json = item.get('output', '')
144
+
145
+ # Call the LLM client
146
+ llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64)
147
+
148
+ # Extract content from the response dictionary
149
+ if isinstance(llm_response, dict):
150
+ llm_output_json_str = llm_response.get("content", "")
151
+ if not llm_output_json_str:
152
+ llm_output_json_str = str(llm_response)
153
+ else:
154
+ llm_output_json_str = str(llm_response) if llm_response else ""
155
+
156
+ # 🔍 DEBUG: Log essential info only (removed verbose JSON content)
157
+ self.logger.debug(f"🔍 Sample {i+1} - LLM Response Type: {type(llm_response)}")
158
+ self.logger.debug(f"🔍 Sample {i+1} - Response Length: {len(llm_output_json_str)} chars")
159
+
160
+ outputs.append(llm_output_json_str)
161
+
162
+ # Parse JSON strings to dictionaries for evaluation
163
+ llm_output_dict = self._parse_json_safely(llm_output_json_str)
164
+ ground_truth_dict = self._parse_json_safely(ground_truth_json)
165
+
166
+ # Initialize evaluation_results with default values
167
+ evaluation_results = {
168
+ "composite_score": 0.0,
169
+ "element_completeness": 0.0,
170
+ "element_type_accuracy": 0.0,
171
+ "text_content_accuracy": 0.0,
172
+ "hierarchy_accuracy": 0.0,
173
+ "style_accuracy": 0.0
174
+ }
175
+
176
+ # Calculate composite score and evaluation results
177
+ if not llm_output_dict and not ground_truth_dict:
178
+ composite_score = 0.1
179
+ evaluation_results = {k: 0.1 for k in evaluation_results.keys()}
180
+ self.logger.warning(f"⚠️ Sample {i+1}: Empty results - using default score: {composite_score}")
181
+ elif not llm_output_dict or not ground_truth_dict:
182
+ composite_score = 0.05
183
+ evaluation_results = {k: 0.05 for k in evaluation_results.keys()}
184
+ self.logger.warning(f"⚠️ Sample {i+1}: Incomplete results - using low score: {composite_score}")
185
+ else:
186
+ # Calculate score using evaluator with parsed dictionaries
187
+ evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict)
188
+ composite_score = evaluation_results["composite_score"]
189
+
190
+ # Clean, readable logging (removed verbose JSON dumps)
191
+ llm_children = len(llm_output_dict.get('children', []))
192
+ gt_children = len(ground_truth_dict.get('children', []))
193
+
194
+ if composite_score < 0.1:
195
+ self.logger.warning(f"⚠️ Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
196
+ self.logger.debug(f" Score breakdown: {evaluation_results}")
197
+ else:
198
+ self.logger.info(f"✅ Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
199
+
200
+ scores.append(composite_score)
201
+
202
+ if capture_traces:
203
+ trajectories.append({
204
+ 'input_text': input_text,
205
+ 'image_base64': image_base64,
206
+ 'ground_truth_json': ground_truth_json,
207
+ 'llm_output_json': llm_output_json_str,
208
+ 'evaluation_results': evaluation_results
209
+ })
210
+
211
+ avg_score = sum(scores) / len(scores) if scores else 0.0
212
+
213
+ # Update performance tracking (handled by parent class)
214
+ if avg_score > self._best_score:
215
+ self._best_score = avg_score
216
+ self._best_candidate = candidate.copy()
217
+ self.logger.info(f"🎯 New best candidate found with score: {avg_score:.4f}")
218
+
219
+ self.logger.info(f"📈 Batch evaluation complete - Average score: {avg_score:.4f}")
220
+
221
+ return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
222
+
223
+ def make_reflective_dataset(
224
+ self,
225
+ candidate: Dict[str, str],
226
+ eval_batch: EvaluationBatch,
227
+ components_to_update: List[str],
228
+ ) -> Dict[str, List[Dict[str, Any]]]:
229
+ """Create a reflective dataset from the evaluation results."""
230
+ reflective_dataset = {}
231
+ system_prompt = candidate.get('system_prompt', '')
232
+
233
+ # 🎯 NEW: Log the proposed new prompt being evaluated
234
+ self.logger.info(f"📝 Creating reflection dataset for prompt: '{system_prompt[:100]}...'")
235
+
236
+ # Pretty print reflection dataset creation
237
+ self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update)
238
+
239
+ for component in components_to_update:
240
+ reflective_dataset[component] = []
241
+ for i, trace in enumerate(eval_batch.trajectories):
242
+ feedback = self._generate_feedback(trace['evaluation_results'])
243
+ reflective_dataset[component].append({
244
+ "current_prompt": system_prompt,
245
+ "input_text": trace['input_text'],
246
+ "image_base64": trace['image_base64'],
247
+ "generated_json": trace['llm_output_json'],
248
+ "ground_truth_json": trace['ground_truth_json'],
249
+ "score": trace['evaluation_results']["composite_score"],
250
+ "feedback": feedback,
251
+ "detailed_scores": trace['evaluation_results']
252
+ })
253
+
254
+ # 🎯 NEW: Log reflection dataset summary
255
+ total_samples = sum(len(data) for data in reflective_dataset.values())
256
+ avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0
257
+ self.logger.info(f"📝 Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}")
258
+
259
+ return reflective_dataset
260
+
261
+ def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str:
262
+ """Generate textual feedback based on evaluation results."""
263
+ composite_score = evaluation_results.get("composite_score", 0.0)
264
+
265
+ feedback_parts = []
266
+
267
+ # Overall quality assessment
268
+ if composite_score >= 0.8:
269
+ feedback_parts.append("The overall quality is good.")
270
+ elif composite_score >= 0.5:
271
+ feedback_parts.append("The overall quality is moderate.")
272
+ else:
273
+ feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.")
274
+
275
+ # Specific metric feedback
276
+ if evaluation_results.get("element_completeness", 0.0) < 0.7:
277
+ feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.")
278
+
279
+ if evaluation_results.get("element_type_accuracy", 0.0) < 0.7:
280
+ feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).")
281
+
282
+ if evaluation_results.get("text_content_accuracy", 0.0) < 0.7:
283
+ feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.")
284
+
285
+ if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7:
286
+ feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.")
287
+
288
+ if evaluation_results.get("style_accuracy", 0.0) < 0.7:
289
+ feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).")
290
+
291
+ return " ".join(feedback_parts)
292
+
293
+ def get_best_candidate(self) -> Optional[Dict[str, str]]:
294
+ """Get the best candidate found so far."""
295
+ return self._best_candidate
296
+
297
+ def get_best_score(self) -> float:
298
+ """Get the best score found so far."""
299
+ return self._best_score
300
+
301
+ def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0):
302
+ """
303
+ Log the new proposed candidate prompt.
304
+
305
+ Args:
306
+ candidate: The new candidate prompt from GEPA
307
+ iteration: Current optimization iteration
308
+ """
309
+ system_prompt = candidate.get('system_prompt', '')
310
+
311
+ logger.info("="*80)
312
+ logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})")
313
+ logger.info("="*80)
314
+ logger.info(f"PROPOSED PROMPT:")
315
+ logger.info("-" * 40)
316
+ logger.debug(f'"{system_prompt}"')
317
+ logger.info("-" * 40)
318
+ logger.info(f"Prompt Length: {len(system_prompt)} characters")
319
+ logger.info(f"Word Count: {len(system_prompt.split())} words")
320
+ logger.info("="*80)
321
+
322
+ def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch,
323
+ components_to_update: List[str]):
324
+ """
325
+ Log the reflection dataset creation process.
326
+
327
+ Args:
328
+ candidate: Current candidate being evaluated
329
+ eval_batch: Evaluation results
330
+ components_to_update: Components being updated
331
+ """
332
+ system_prompt = candidate.get('system_prompt', '')
333
+
334
+ logger.info("="*80)
335
+ logger.info("REFLECTION DATASET CREATION")
336
+ logger.info("="*80)
337
+
338
+ logger.info(f"CURRENT PROMPT BEING ANALYZED:")
339
+ logger.info("-" * 40)
340
+ logger.debug(f'"{system_prompt}"')
341
+ logger.info("-" * 40)
342
+
343
+ logger.info(f"EVALUATION SUMMARY:")
344
+ logger.info("-" * 40)
345
+ if eval_batch.scores:
346
+ avg_score = sum(eval_batch.scores) / len(eval_batch.scores)
347
+ min_score = min(eval_batch.scores)
348
+ max_score = max(eval_batch.scores)
349
+ logger.info(f" Average Score: {avg_score:.4f}")
350
+ logger.info(f" Min Score: {min_score:.4f}")
351
+ logger.info(f" Max Score: {max_score:.4f}")
352
+ logger.info(f" Total Samples: {len(eval_batch.scores)}")
353
+
354
+ logger.info(f"COMPONENTS TO UPDATE:")
355
+ logger.info("-" * 40)
356
+ for i, component in enumerate(components_to_update, 1):
357
+ logger.info(f" {i}. {component}")
358
+
359
+ if eval_batch.trajectories:
360
+ logger.debug(f"DETAILED ANALYSIS:")
361
+ logger.debug("-" * 40)
362
+ for i, trace in enumerate(eval_batch.trajectories[:3], 1): # Show first 3 samples
363
+ evaluation_results = trace['evaluation_results']
364
+ composite_score = evaluation_results.get("composite_score", 0.0)
365
+
366
+ logger.debug(f" Sample {i} (Score: {composite_score:.4f}):")
367
+
368
+ # Show input data (truncated)
369
+ input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text']
370
+ logger.debug(f" Input: \"{input_text}\"")
371
+
372
+ # Show predicted output (truncated)
373
+ predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json']
374
+ logger.debug(f" Output: \"{predicted_output}\"")
375
+
376
+ # Show detailed scores
377
+ logger.debug(f" Detailed Scores:")
378
+ for metric, score in evaluation_results.items():
379
+ if metric != "composite_score":
380
+ logger.debug(f" {metric.replace('_', ' ').title()}: {score:.4f}")
381
+
382
+ # Show generated feedback
383
+ feedback = self._generate_feedback(evaluation_results)
384
+ logger.debug(f" Feedback: \"{feedback}\"")
385
+
386
+ if len(eval_batch.trajectories) > 3:
387
+ logger.debug(f" ... and {len(eval_batch.trajectories) - 3} more samples")
388
+
389
+ logger.info("="*80)
src/gepa_optimizer/core/optimizer.py ADDED
@@ -0,0 +1,1279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main GepaOptimizer class - the heart of the optimization system
3
+ """
4
+
5
+ import time
6
+ import logging
7
+ from typing import Any, Dict, List, Optional, Union
8
+ import asyncio
9
+ import io
10
+ import sys
11
+ from contextlib import redirect_stdout, redirect_stderr
12
+
13
+ import gepa
14
+ from ..utils.api_keys import APIKeyManager
15
+ from .result import ResultProcessor
16
+ from ..data.converters import UniversalConverter
17
+ from ..models.result import OptimizationResult, OptimizedResult
18
+ from ..models.config import OptimizationConfig, ModelConfig
19
+ from ..utils.helpers import sanitize_prompt
20
+ from ..utils.exceptions import GepaDependencyError, InvalidInputError, DatasetError, GepaOptimizerError
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class GepaOptimizer:
25
+ """
26
+ Main class for prompt optimization using GEPA
27
+
28
+ This is the primary interface that users interact with.
29
+ Provides both simple and advanced optimization capabilities.
30
+ """
31
+
32
+ def __init__(self, config: Optional[OptimizationConfig] = None,
33
+ adapter_type: str = "universal",
34
+ custom_adapter: Optional[Any] = None,
35
+ llm_model_name: Optional[str] = None,
36
+ metric_weights: Optional[Dict[str, float]] = None,
37
+ **kwargs):
38
+ """
39
+ Initialize the optimizer
40
+
41
+ Args:
42
+ config: Optimization configuration (required)
43
+ adapter_type: Type of adapter to use ("universal" only - fully configurable)
44
+ custom_adapter: Custom adapter instance (overrides adapter_type)
45
+ llm_model_name: [Deprecated] Use config.model instead. Will be removed in future versions.
46
+ metric_weights: [Deprecated] Not used - evaluator handles metrics. Will be removed in future versions.
47
+ **kwargs: Additional parameters for universal adapter (llm_client, evaluator, etc.)
48
+
49
+ Raises:
50
+ ValueError: If required configuration is missing
51
+ GepaDependencyError: If GEPA library is not available
52
+ """
53
+ if config is None:
54
+ raise ValueError("config parameter is required. Use OptimizationConfig to configure the optimizer.")
55
+
56
+ # Initialize logger first
57
+ self.logger = logging.getLogger(__name__)
58
+
59
+ self.config = config
60
+ self.converter = UniversalConverter(data_split_config=config.data_split)
61
+ self.api_manager = APIKeyManager()
62
+ self.result_processor = ResultProcessor()
63
+
64
+ # Initialize adapter based on configuration
65
+ if custom_adapter:
66
+ # User provided custom adapter
67
+ from .base_adapter import BaseGepaAdapter
68
+ if not isinstance(custom_adapter, BaseGepaAdapter):
69
+ raise TypeError("custom_adapter must be an instance of BaseGepaAdapter")
70
+ self.adapter = custom_adapter
71
+ self.logger.info("Using user-provided custom adapter")
72
+ elif adapter_type == "universal":
73
+ # Universal adapter requires user to provide components
74
+ llm_client = kwargs.get('llm_client')
75
+ evaluator = kwargs.get('evaluator')
76
+
77
+ if not llm_client or not evaluator:
78
+ raise ValueError(
79
+ "llm_client and evaluator are required for universal adapter. "
80
+ "Example: GepaOptimizer(config=config, adapter_type='universal', "
81
+ "llm_client=llm_client, evaluator=evaluator)"
82
+ )
83
+
84
+ from .universal_adapter import UniversalGepaAdapter
85
+ self.adapter = UniversalGepaAdapter(
86
+ llm_client=llm_client,
87
+ evaluator=evaluator,
88
+ data_converter=kwargs.get('data_converter')
89
+ )
90
+ self.logger.info("Using universal adapter")
91
+ else:
92
+ raise ValueError(
93
+ f"Unknown adapter_type: {adapter_type}. "
94
+ f"Only 'universal' is supported. "
95
+ f"Provide llm_client and evaluator when using universal adapter."
96
+ )
97
+
98
+ # Keep backward compatibility
99
+ self.custom_adapter = self.adapter
100
+
101
+ # Log model configuration
102
+ model_info = self.adapter.get_performance_stats()
103
+ self.logger.info(f"Initialized adapter: {model_info}")
104
+
105
+ # Set up logging
106
+ logging.basicConfig(
107
+ level=logging.INFO,
108
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
109
+ )
110
+
111
+ # Validate GEPA availability
112
+ if gepa is None:
113
+ raise GepaDependencyError("GEPA library is not available. Please install it with: pip install gepa")
114
+
115
+ async def train(self,
116
+ seed_prompt: str,
117
+ dataset: Union[List[Any], str, Dict, Any],
118
+ **kwargs) -> OptimizedResult:
119
+ """
120
+ Main training method for prompt optimization
121
+
122
+ Args:
123
+ seed_prompt: Initial prompt to optimize
124
+ dataset: Training data in any format
125
+ **kwargs: Additional parameters that can override config
126
+
127
+ Returns:
128
+ OptimizedResult: Optimization result with improved prompt
129
+
130
+ Raises:
131
+ InvalidInputError: For invalid input parameters
132
+ DatasetError: For issues with dataset processing
133
+ GepaOptimizerError: For optimization failures
134
+ """
135
+ start_time = time.time()
136
+ session_id = f"opt_{int(start_time)}_{id(self)}"
137
+
138
+ try:
139
+ self.logger.info(f"Starting optimization session: {session_id}")
140
+ self.logger.info(f"Using model: {self.config.model.model_name} (provider: {self.config.model.provider})")
141
+
142
+ # #region agent log
143
+ import json as _json_debug
144
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
145
+ with open(_debug_log_path, "a") as _f:
146
+ _f.write(_json_debug.dumps({"hypothesisId": "E", "location": "optimizer.py:train_start", "message": "Optimization train() started", "data": {"session_id": session_id, "max_iterations": self.config.max_iterations}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
147
+ # #endregion
148
+
149
+ # 🔥 FIX E: Reset Pareto logger at start of each optimization run
150
+ from ..utils.pareto_logger import reset_pareto_logger
151
+ reset_pareto_logger()
152
+ self.logger.info("✅ Reset Pareto logger for new optimization run")
153
+
154
+ # Update config with any overrides from kwargs
155
+ self._update_config_from_kwargs(kwargs)
156
+
157
+ # Step 1: Validate inputs
158
+ self._validate_inputs(seed_prompt)
159
+
160
+ # Step 2: Convert dataset to GEPA format with 3-way split
161
+ # 🔥 FIX: Support pre-split datasets (user-provided train/val/test)
162
+ if isinstance(dataset, dict) and all(k in dataset for k in ['train', 'val', 'test']):
163
+ # User provided pre-split dataset - use it directly
164
+ self.logger.info("✅ Detected pre-split dataset - using user's split (no re-splitting)")
165
+ trainset_raw = dataset.get('train', [])
166
+ valset_raw = dataset.get('val', [])
167
+ testset_raw = dataset.get('test', [])
168
+
169
+ # Still need to standardize the format (convert to GEPA format)
170
+ trainset = self.converter._standardize(trainset_raw)
171
+ valset = self.converter._standardize(valset_raw)
172
+ testset = self.converter._standardize(testset_raw) if testset_raw else []
173
+
174
+ self.logger.info(
175
+ f"Using pre-split dataset: {len(trainset)} train (Dfeedback), "
176
+ f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)"
177
+ )
178
+ else:
179
+ # Standard path: convert and split automatically
180
+ self.logger.info("Converting dataset to GEPA format with 3-way split...")
181
+ trainset, valset, testset = self.converter.convert(
182
+ dataset,
183
+ split_config=self.config.data_split
184
+ )
185
+
186
+ # Log split with adaptive strategy info
187
+ split_strategy = self.config.data_split.small_dataset_strategy
188
+ strategy_note = ""
189
+ if split_strategy == 'adaptive':
190
+ total_size = len(trainset) + len(valset) + len(testset)
191
+ train_ratio, val_ratio, test_ratio = self.config.data_split.get_adaptive_ratios(total_size)
192
+ strategy_note = f" (adaptive: {train_ratio*100:.0f}%/{val_ratio*100:.0f}%/{test_ratio*100:.0f}% ratios)"
193
+ self.logger.info(
194
+ f"Dataset split{strategy_note}: {len(trainset)} train (Dfeedback), "
195
+ f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)"
196
+ )
197
+
198
+ if not trainset:
199
+ raise DatasetError("Dataset appears to be empty after conversion")
200
+
201
+ # Step 3: Create seed candidate
202
+ seed_candidate = self._create_seed_candidate(seed_prompt)
203
+
204
+ # 🔥 CRITICAL: Set valset info in adapter BEFORE baseline evaluation
205
+ # This ensures adapter correctly detects 'dpareto' dataset type
206
+ # Use direct assignment (don't rely on hasattr) to ensure attributes are set
207
+ try:
208
+ self.adapter._valset_size = len(valset) if valset else 0
209
+ self.logger.info(f"✅ Set valset_size in adapter: {len(valset) if valset else 0} for Dpareto detection")
210
+ except AttributeError:
211
+ self.logger.warning("⚠️ Could not set _valset_size in adapter - attribute not supported")
212
+
213
+ try:
214
+ self.adapter._valset = valset
215
+ self.logger.info(f"✅ Stored valset in adapter ({len(valset) if valset else 0} samples)")
216
+ except AttributeError:
217
+ self.logger.warning("⚠️ Could not set _valset in adapter - attribute not supported")
218
+
219
+ # Step 3.5: Calculate baseline score on VALIDATION set (not test set)
220
+ # This ensures fair comparison since optimization uses validation set for Pareto selection
221
+ baseline_val_score = None
222
+ if valset:
223
+ self.logger.info("📊 Evaluating seed prompt on validation set for baseline...")
224
+ # Set baseline flag so adapter knows this is baseline, not optimization
225
+ # Use direct assignment to ensure the flag is set
226
+ try:
227
+ self.adapter._is_baseline_evaluation = True
228
+ self.logger.info("✅ Set baseline evaluation flag in adapter")
229
+ except AttributeError:
230
+ self.logger.warning("⚠️ Could not set _is_baseline_evaluation in adapter")
231
+
232
+ try:
233
+ # Evaluate on validation set (same as what GEPA will use for Pareto selection)
234
+ eval_result = self.adapter.evaluate(
235
+ batch=valset,
236
+ candidate=seed_candidate,
237
+ capture_traces=False
238
+ )
239
+ baseline_val_score = sum(eval_result.scores) / len(eval_result.scores) if eval_result.scores else 0.0
240
+ self.logger.info(f"📊 Baseline validation score: {baseline_val_score:.4f} (on {len(valset)} samples)")
241
+
242
+ # Store baseline in adapter for later use
243
+ if hasattr(self.adapter, '_baseline_score'):
244
+ self.adapter._baseline_score = baseline_val_score
245
+
246
+ # 🔥 CRITICAL FIX: Also set baseline in Pareto logger
247
+ # This ensures candidates can be properly evaluated against baseline
248
+ from ..utils.pareto_logger import get_pareto_logger
249
+ pareto_log = get_pareto_logger()
250
+ pareto_log.set_baseline(baseline_val_score)
251
+ self.logger.info(f"✅ Baseline set in Pareto logger: {baseline_val_score:.4f}")
252
+
253
+ except Exception as e:
254
+ self.logger.warning(f"Baseline evaluation failed: {e}")
255
+ import traceback
256
+ self.logger.debug(f"Baseline evaluation error: {traceback.format_exc()}")
257
+ finally:
258
+ try:
259
+ self.adapter._is_baseline_evaluation = False
260
+ self.logger.debug("✅ Reset baseline evaluation flag - optimization can begin")
261
+ except AttributeError:
262
+ pass # Ignore if attribute not supported
263
+
264
+ # Step 4: Run GEPA optimization
265
+ self.logger.info("Starting GEPA optimization...")
266
+ gepa_result, actual_iterations = await self._run_gepa_optimization(
267
+ adapter=self.adapter,
268
+ seed_candidate=seed_candidate,
269
+ trainset=trainset,
270
+ valset=valset,
271
+ **kwargs
272
+ )
273
+
274
+ # Step 5: Extract best candidate
275
+ best_candidate = self._extract_best_candidate(gepa_result)
276
+
277
+ # 🔥 CRITICAL: Extract optimized prompt from best_candidate
278
+ # This is the actual optimized prompt that GEPA found
279
+ self.logger.info(f"\n{'═'*80}")
280
+ self.logger.info(f"📝 EXTRACTING OPTIMIZED PROMPT FROM GEPA RESULT")
281
+ self.logger.info(f"{'═'*80}")
282
+ self.logger.info(f"best_candidate keys: {list(best_candidate.keys()) if isinstance(best_candidate, dict) else 'N/A'}")
283
+
284
+ optimized_prompt = best_candidate.get('system_prompt', seed_prompt)
285
+ if not optimized_prompt or optimized_prompt.strip() == '':
286
+ # Fallback: try other keys or use seed prompt
287
+ optimized_prompt = best_candidate.get('prompt', best_candidate.get('text', seed_prompt))
288
+
289
+ # Get fitness score if available
290
+ best_fitness = best_candidate.get('fitness') or self.adapter.get_best_score() if hasattr(self.adapter, 'get_best_score') else None
291
+ candidate_source = best_candidate.get('source', 'unknown')
292
+
293
+ self.logger.info(f"\n✅ EXTRACTED OPTIMIZED PROMPT:")
294
+ self.logger.info(f" Source: {candidate_source}")
295
+ if best_fitness is not None:
296
+ self.logger.info(f" Fitness: f={best_fitness:.4f}")
297
+ self.logger.info(f" Length: {len(optimized_prompt)} characters")
298
+ self.logger.info(f" Words: {len(optimized_prompt.split())} words")
299
+ self.logger.info(f"\n📝 FULL OPTIMIZED PROMPT TEXT:")
300
+ self.logger.info(f"{'─'*80}")
301
+ self.logger.info(optimized_prompt)
302
+ self.logger.info(f"{'─'*80}")
303
+
304
+ if optimized_prompt != seed_prompt:
305
+ self.logger.info(f"\n✅ SUCCESS: Prompt WAS OPTIMIZED!")
306
+ self.logger.info(f" Seed length: {len(seed_prompt)} chars")
307
+ self.logger.info(f" Optimized length: {len(optimized_prompt)} chars")
308
+ self.logger.info(f" Difference: {len(optimized_prompt) - len(seed_prompt):+d} chars")
309
+ if best_fitness is not None:
310
+ baseline_fitness = 0.5 # Default baseline, could be improved
311
+ improvement = best_fitness - baseline_fitness
312
+ improvement_pct = (improvement / baseline_fitness * 100) if baseline_fitness > 0 else 0
313
+ self.logger.info(f" Fitness: f={best_fitness:.4f} (improvement: {improvement:+.4f} ({improvement_pct:+.1f}%))")
314
+ else:
315
+ self.logger.warning(f"\n⚠️ WARNING: Optimized prompt is IDENTICAL to seed prompt")
316
+ self.logger.warning(f" This means GEPA didn't modify the prompt during optimization")
317
+ if best_fitness is not None:
318
+ self.logger.warning(f" Best fitness found: f={best_fitness:.4f}")
319
+ self.logger.warning(f" 💡 Check if LLEGO best candidate is being properly extracted")
320
+
321
+ self.logger.info(f"{'═'*80}\n")
322
+
323
+ # Step 5.5: Calculate improvement metrics (validation vs validation)
324
+ optimized_test_score = None
325
+ improvement_data = {}
326
+
327
+ # 🔥 FIX: Calculate improvement based on VALIDATION scores (fair comparison)
328
+ # Compare optimized VALIDATION score vs validation baseline (both on Dpareto)
329
+ # This ensures fair comparison - both evaluated on the same validation set
330
+ optimized_val_score = best_fitness # Best candidate's fitness is from validation set (Dpareto)
331
+
332
+ if baseline_val_score is not None and optimized_val_score is not None:
333
+ absolute_improvement = optimized_val_score - baseline_val_score
334
+ relative_improvement = (
335
+ (absolute_improvement / baseline_val_score * 100)
336
+ if baseline_val_score > 0 else 0
337
+ )
338
+
339
+ improvement_data = {
340
+ 'baseline_val_score': baseline_val_score,
341
+ 'optimized_val_score': optimized_val_score,
342
+ 'absolute_improvement': absolute_improvement,
343
+ 'relative_improvement_percent': relative_improvement
344
+ }
345
+
346
+ self.logger.info(
347
+ f"📈 Validation improvement: {relative_improvement:+.2f}% "
348
+ f"(baseline val: {baseline_val_score:.4f} → optimized val: {optimized_val_score:.4f})"
349
+ )
350
+
351
+ # Step 5.6: Evaluate optimized prompt on test set (if available) for final reporting
352
+ if testset and self.config.evaluate_on_test:
353
+ self.logger.info("📊 Evaluating optimized prompt on test set...")
354
+
355
+ # 🔥 CRITICAL FIX: Clear LLEGO candidate queue before test evaluation
356
+ # This prevents the LLEGO wrapper from intercepting test evaluation calls
357
+ # and returning wrong candidates instead of actually running the optimized prompt
358
+ from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
359
+ if hasattr(self.adapter, 'llm_client') and isinstance(self.adapter.llm_client, LLEGOEnhancedLLMClient):
360
+ if hasattr(self.adapter.llm_client, '_adapter_generated_candidates'):
361
+ self.adapter.llm_client._adapter_generated_candidates = []
362
+ self.logger.info("✅ Cleared LLEGO candidate queue for clean test evaluation")
363
+ if hasattr(self.adapter.llm_client, '_candidate_queue'):
364
+ self.adapter.llm_client._candidate_queue = []
365
+ self.logger.info("✅ Cleared LLEGO hybrid candidate queue for clean test evaluation")
366
+
367
+ # Evaluate on test set for final reporting (but improvement is based on validation)
368
+ try:
369
+ optimized_test_score = self._evaluate_candidate_on_testset(
370
+ best_candidate,
371
+ testset
372
+ )
373
+ self.logger.info(f"📊 Optimized test score: {optimized_test_score:.4f}")
374
+
375
+ # Add test score to improvement_data for reference (but improvement is based on validation)
376
+ improvement_data['optimized_test_score'] = optimized_test_score
377
+
378
+ if baseline_val_score is not None:
379
+ test_vs_baseline = (
380
+ ((optimized_test_score - baseline_val_score) / baseline_val_score * 100)
381
+ if baseline_val_score > 0 else 0
382
+ )
383
+ self.logger.info(
384
+ f"📊 Test set vs validation baseline: {test_vs_baseline:+.2f}% "
385
+ f"(baseline val: {baseline_val_score:.4f} → optimized test: {optimized_test_score:.4f})"
386
+ )
387
+ except Exception as e:
388
+ self.logger.warning(f"Test evaluation failed: {e}")
389
+
390
+ # Step 6: Process results
391
+ optimization_time = time.time() - start_time
392
+
393
+ processed_result = self.result_processor.process_full_result(
394
+ result=gepa_result,
395
+ original_prompt=seed_prompt,
396
+ optimization_time=optimization_time,
397
+ actual_iterations=actual_iterations,
398
+ test_metrics=improvement_data # Add test metrics
399
+ )
400
+
401
+ # Merge improvement data
402
+ final_improvement_data = {**processed_result.get('improvement_data', {}), **improvement_data}
403
+
404
+ # Step 7: Create result objects
405
+ # 🔥 CRITICAL: Use extracted optimized_prompt instead of processed_result
406
+ result = OptimizedResult(
407
+ original_prompt=seed_prompt,
408
+ optimized_prompt=optimized_prompt, # Use extracted prompt, not processed_result!
409
+ improvement_data=final_improvement_data,
410
+ optimization_time=optimization_time,
411
+ dataset_size=len(trainset) + len(valset) + len(testset),
412
+ total_iterations=processed_result.get('total_iterations', 0),
413
+ status=processed_result.get('status', 'completed'),
414
+ error_message=processed_result.get('error_message'),
415
+ detailed_result=OptimizationResult(
416
+ session_id=session_id,
417
+ original_prompt=seed_prompt,
418
+ optimized_prompt=optimized_prompt, # Use extracted prompt!
419
+ improvement_data=final_improvement_data,
420
+ optimization_time=optimization_time,
421
+ dataset_size=len(trainset) + len(valset) + len(testset),
422
+ total_iterations=processed_result.get('total_iterations', 0),
423
+ status=processed_result.get('status', 'completed'),
424
+ error_message=processed_result.get('error_message')
425
+ )
426
+ )
427
+
428
+ self.logger.info(f"✅ Optimization completed in {optimization_time:.2f}s")
429
+ return result
430
+
431
+ except Exception as e:
432
+ optimization_time = time.time() - start_time
433
+ error_msg = f"Optimization failed: {str(e)}"
434
+ self.logger.error(error_msg)
435
+
436
+ # Return failed result
437
+ return OptimizedResult(
438
+ original_prompt=seed_prompt,
439
+ optimized_prompt=seed_prompt, # Return original on failure
440
+ improvement_data={'error': error_msg},
441
+ optimization_time=optimization_time,
442
+ dataset_size=0,
443
+ total_iterations=0,
444
+ status='failed',
445
+ error_message=error_msg
446
+ )
447
+
448
+ def _update_config_from_kwargs(self, kwargs: Dict[str, Any]) -> None:
449
+ """Update configuration with runtime overrides from kwargs."""
450
+ updated_params = []
451
+
452
+ for key, value in kwargs.items():
453
+ if hasattr(self.config, key):
454
+ setattr(self.config, key, value)
455
+ updated_params.append(f"{key}={value}")
456
+ else:
457
+ self.logger.warning(f"Unknown parameter '{key}' ignored")
458
+
459
+ if updated_params:
460
+ self.logger.info(f"Updated config parameters: {', '.join(updated_params)}")
461
+
462
+ def _validate_inputs(self, seed_prompt: str) -> None:
463
+ """
464
+ Validate input parameters for optimization
465
+
466
+ Args:
467
+ seed_prompt: The seed prompt to validate
468
+
469
+ Raises:
470
+ InvalidInputError: If validation fails
471
+ """
472
+ if not seed_prompt or not isinstance(seed_prompt, str):
473
+ raise InvalidInputError("Seed prompt must be a non-empty string")
474
+
475
+ if len(seed_prompt.strip()) < 10:
476
+ raise InvalidInputError("Seed prompt is too short (minimum 10 characters)")
477
+
478
+ # Validate model configuration
479
+ model_config = self.config.model
480
+ if not hasattr(model_config, 'model_name') or not model_config.model_name:
481
+ raise InvalidInputError("Model name is required")
482
+
483
+ reflection_config = self.config.reflection_model
484
+ if not hasattr(reflection_config, 'model_name') or not reflection_config.model_name:
485
+ raise InvalidInputError("Reflection model name is required")
486
+
487
+ def _clean_reflection_prompt(self, prompt: str, max_length: int = 50000) -> str:
488
+ """
489
+ Clean reflection prompt by removing base64 images and truncating if too long.
490
+
491
+ 🔥 CRITICAL: GEPA's reflective dataset includes base64 images which create
492
+ massive prompts (7MB+) that exceed token limits. This function:
493
+ 1. Strips all base64 image data
494
+ 2. Removes excessive detailed_scores entries
495
+ 3. Truncates to reasonable size
496
+ 4. Preserves essential feedback information
497
+
498
+ Args:
499
+ prompt: Original prompt from GEPA (may contain base64)
500
+ max_length: Maximum length after cleaning (default: 50K chars)
501
+
502
+ Returns:
503
+ Cleaned prompt without base64, within size limits
504
+ """
505
+ import re
506
+
507
+ # Step 1: Remove base64 image strings (typically very long alphanumeric strings)
508
+ # Base64 images are usually 50K+ characters of A-Za-z0-9+/= pattern
509
+ # Look for very long base64-like sequences
510
+ base64_pattern = r'[A-Za-z0-9+/=]{5000,}' # Sequences of 5000+ base64 chars
511
+ cleaned = re.sub(base64_pattern, '[IMAGE_DATA_REMOVED]', prompt)
512
+
513
+ # Step 2: Remove detailed_scores sections that might contain base64 references
514
+ # These are usually in markdown format: "### detailed_scores\n...base64..."
515
+ detailed_scores_pattern = r'### detailed_scores[^\n]*\n[^#]*(?:image_base64|base64)[^\n]*(?:\n[^#]*)*'
516
+ cleaned = re.sub(detailed_scores_pattern, '### detailed_scores: [REMOVED_FOR_BREVITY]', cleaned, flags=re.IGNORECASE | re.MULTILINE)
517
+
518
+ # Step 3: Remove any remaining image_base64 references
519
+ cleaned = re.sub(r'image_base64[^\n]*', 'image_base64: [REMOVED]', cleaned, flags=re.IGNORECASE)
520
+ cleaned = re.sub(r'"[A-Za-z0-9+/=]{10000,}"', '[LARGE_DATA_STRING_REMOVED]', cleaned) # Very long strings likely base64
521
+
522
+ # Step 4: Truncate if still too long (keep the beginning which usually has the most important info)
523
+ if len(cleaned) > max_length:
524
+ # Keep first part (usually contains prompt and key feedback)
525
+ # Add truncation notice
526
+ truncated_size = len(cleaned) - max_length
527
+ cleaned = cleaned[:max_length] + f"\n\n[TRUNCATED {truncated_size} characters of detailed evaluation data]"
528
+ self.logger.warning(f"⚠️ Prompt truncated: {len(prompt)} → {len(cleaned)} chars")
529
+
530
+ return cleaned
531
+
532
+ def _validate_models(self, task_lm, reflection_lm):
533
+ """
534
+ Validate if specified models are supported.
535
+
536
+ Note: No hardcoded restrictions - the API provider will validate model existence.
537
+ This method is kept for potential future validation logic but doesn't restrict users.
538
+ """
539
+ # No hardcoded model restrictions - users can specify any model
540
+ # The API provider will handle validation and return errors if model doesn't exist
541
+ self.logger.debug(f"Using task model: {task_lm}, reflection model: {reflection_lm}")
542
+
543
+ def _create_seed_candidate(self, seed_prompt: str) -> Dict[str, str]:
544
+ """Create a seed candidate from the input prompt."""
545
+ sanitized_prompt = sanitize_prompt(seed_prompt)
546
+ return {'system_prompt': sanitized_prompt}
547
+
548
+ async def _run_gepa_optimization(self, adapter, seed_candidate: Any, trainset: List[Any], valset: List[Any], **kwargs) -> tuple: # Return tuple
549
+ """
550
+ Run GEPA optimization with the given adapter and data
551
+
552
+ Args:
553
+ adapter: Custom adapter for GEPA
554
+ seed_candidate: Initial prompt candidate
555
+ trainset: Training dataset
556
+ valset: Validation dataset
557
+ **kwargs: Additional optimization parameters that can override config
558
+
559
+ Returns:
560
+ Dict with optimization results
561
+
562
+ Raises:
563
+ GepaOptimizerError: If optimization fails
564
+
565
+ Note:
566
+ The following parameters are required in the config:
567
+ - max_metric_calls: Maximum number of metric evaluations
568
+ - batch_size: Batch size for evaluation
569
+ - max_iterations: Maximum number of optimization iterations
570
+ """
571
+ try:
572
+ # Get optimization parameters from config (these are required fields)
573
+ max_metric_calls = self.config.max_metric_calls
574
+ batch_size = self.config.batch_size
575
+ max_iterations = self.config.max_iterations
576
+
577
+ # Create reflection model client
578
+ from ..llms.vision_llm import VisionLLMClient
579
+ base_reflection_lm_client = VisionLLMClient(
580
+ provider=self.config.reflection_model.provider,
581
+ model_name=self.config.reflection_model.model_name,
582
+ api_key=self.config.reflection_model.api_key,
583
+ base_url=self.config.reflection_model.base_url,
584
+ temperature=self.config.reflection_model.temperature,
585
+ max_tokens=self.config.reflection_model.max_tokens,
586
+ top_p=self.config.reflection_model.top_p,
587
+ frequency_penalty=self.config.reflection_model.frequency_penalty,
588
+ presence_penalty=self.config.reflection_model.presence_penalty
589
+ )
590
+ # reflection_lm_client will be set below (may be wrapped with LLEGO)
591
+ reflection_lm_client = base_reflection_lm_client
592
+
593
+ # 🆕 LLEGO Integration: Create enhanced reflection callable
594
+ if self.config.use_llego_operators:
595
+ self.logger.info("🧬 LLEGO genetic operators ENABLED")
596
+ self.logger.info(f" α={self.config.alpha}, τ={self.config.tau}, ν={self.config.nu}")
597
+ self.logger.info(f" Crossover offspring: {self.config.n_crossover}, Mutation offspring: {self.config.n_mutation}")
598
+
599
+ # Import LLEGO operators
600
+ from ..operators.llego_operators import LLEGOIntegrationLayer, PromptCandidate
601
+
602
+ # Initialize LLEGO integration layer
603
+ llego = LLEGOIntegrationLayer(
604
+ alpha=self.config.alpha,
605
+ tau=self.config.tau,
606
+ nu=self.config.nu,
607
+ population_size=self.config.population_size,
608
+ n_crossover=self.config.n_crossover,
609
+ n_mutation=self.config.n_mutation
610
+ )
611
+
612
+ # Initialize with seed prompt
613
+ llego.initialize_population(
614
+ seed_prompt=seed_candidate.get('system_prompt', ''),
615
+ initial_fitness=0.5
616
+ )
617
+
618
+ # 🔥 HYBRID MODE FIX: Wrap reflection_lm_client with LLEGO for hybrid mode
619
+ # This ensures reflection calls go through LLEGO wrapper for candidate generation
620
+ if self.config.enable_gepa_reflection_with_llego:
621
+ self.logger.info("🔥 HYBRID MODE: Wrapping reflection_lm_client with LLEGO")
622
+ from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
623
+
624
+ # Wrap reflection_lm_client with LLEGO so hybrid generation is triggered
625
+ reflection_lm_client = LLEGOEnhancedLLMClient(
626
+ base_llm=base_reflection_lm_client,
627
+ llego_layer=llego,
628
+ config=self.config, # Pass config for hybrid mode!
629
+ verbose=True
630
+ )
631
+ self.logger.info("✅ reflection_lm_client wrapped with LLEGO (hybrid mode enabled)")
632
+
633
+ # 🔥 CRITICAL: Store reflection_lm_client reference in adapter so it can set context
634
+ # This allows make_reflective_dataset to set reflection context on BOTH clients
635
+ if hasattr(adapter, 'reflection_lm_client'):
636
+ adapter.reflection_lm_client = reflection_lm_client
637
+ self.logger.info("✅ Stored reflection_lm_client reference in adapter")
638
+ else:
639
+ # Add reflection_lm_client attribute to adapter
640
+ adapter.reflection_lm_client = reflection_lm_client
641
+ self.logger.info("✅ Added reflection_lm_client attribute to adapter")
642
+
643
+ # 🔥 NEW: Also store config and reflection_lm_client for adapter-level generation
644
+ if hasattr(adapter, '_config'):
645
+ adapter._config = self.config
646
+ self.logger.info("✅ Stored config in adapter for hybrid mode")
647
+ else:
648
+ adapter._config = self.config
649
+ self.logger.info("✅ Added _config attribute to adapter")
650
+
651
+ if hasattr(adapter, '_reflection_lm_client'):
652
+ adapter._reflection_lm_client = reflection_lm_client
653
+ self.logger.info("✅ Stored _reflection_lm_client in adapter for hybrid mode")
654
+ else:
655
+ adapter._reflection_lm_client = reflection_lm_client
656
+ self.logger.info("✅ Added _reflection_lm_client attribute to adapter")
657
+
658
+ # 🔥 CRITICAL FIX: Ensure LLEGO layer is stored in adapter
659
+ # Without this, adapter.llego will be None and population updates are skipped!
660
+ if hasattr(adapter, 'llego'):
661
+ if adapter.llego is None:
662
+ adapter.llego = llego
663
+ self.logger.info("✅ CRITICAL: Set LLEGO layer in adapter (was None)")
664
+ else:
665
+ self.logger.debug("✅ LLEGO layer already set in adapter")
666
+ else:
667
+ # Add llego attribute if it doesn't exist
668
+ adapter.llego = llego
669
+ self.logger.info("✅ CRITICAL: Added LLEGO layer to adapter")
670
+
671
+ # 🔥 CRITICAL: Always set _reflection_lm_client in adapter (even without hybrid mode)
672
+ # This is required for propose_new_texts() to work
673
+ if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None:
674
+ adapter._reflection_lm_client = reflection_lm_client
675
+ self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)")
676
+
677
+ # 🔥 HYBRID MODE FIX: Inject config into LLEGO wrapper for hybrid mode
678
+ # The adapter already has LLEGO wrapper, we just need to update its config
679
+ if self.config.enable_gepa_reflection_with_llego:
680
+ # HYBRID MODE: Update the LLEGO wrapper's config
681
+ self.logger.info("🔥 HYBRID MODE: Enabling hybrid candidate generation in LLEGO wrapper")
682
+
683
+ # Get the LLM client (may already be wrapped)
684
+ llm_client = self.adapter.llm_client
685
+ from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
686
+
687
+ if isinstance(llm_client, LLEGOEnhancedLLMClient):
688
+ # Already wrapped, just update config
689
+ llm_client.config = self.config
690
+ self.logger.info("✅ Updated LLEGO wrapper with hybrid mode config")
691
+ else:
692
+ # Not wrapped yet, wrap it now with config
693
+ llego_wrapped_llm = LLEGOEnhancedLLMClient(
694
+ base_llm=llm_client,
695
+ llego_layer=llego,
696
+ config=self.config, # ← Pass config for hybrid mode!
697
+ verbose=True
698
+ )
699
+ # Update adapter's LLM client
700
+ self.adapter.llm_client = llego_wrapped_llm
701
+ self.logger.info("✅ Wrapped LLM client with LLEGO (hybrid mode enabled)")
702
+
703
+ adapter = self.adapter
704
+ else:
705
+ # LLEGO-ONLY MODE: Wrap adapter with LLEGO layer (no hybrid)
706
+ self.logger.info("🧬 LLEGO-ONLY MODE: Recreating adapter with LLEGO integration...")
707
+ if hasattr(self, 'adapter') and self.adapter:
708
+ from .universal_adapter import UniversalGepaAdapter
709
+
710
+ # Get original LLM client and evaluator from current adapter
711
+ original_llm = self.adapter.llm_client
712
+ # If it's already wrapped, unwrap it
713
+ if hasattr(original_llm, 'base_llm'):
714
+ original_llm = original_llm.base_llm
715
+
716
+ evaluator = self.adapter.evaluator
717
+ data_converter = self.adapter.data_converter
718
+
719
+ # Recreate adapter with LLEGO (no hybrid mode config)
720
+ from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
721
+ llego_wrapped_llm = LLEGOEnhancedLLMClient(
722
+ base_llm=original_llm,
723
+ llego_layer=llego,
724
+ config=None, # No hybrid mode
725
+ verbose=True
726
+ )
727
+
728
+ adapter = UniversalGepaAdapter(
729
+ llm_client=llego_wrapped_llm,
730
+ evaluator=evaluator,
731
+ data_converter=data_converter,
732
+ llego_layer=llego
733
+ )
734
+ self.logger.info("✅ Adapter recreated with LLEGO-enhanced LLM client")
735
+ else:
736
+ adapter = self.adapter
737
+
738
+ # Create LLEGO-enhanced reflection callable
739
+ # When hybrid mode is enabled, reflection_lm_client is wrapped with LLEGO
740
+ # The wrapper will automatically generate hybrid candidates when called
741
+ def reflection_lm_callable(prompt: str) -> str:
742
+ """
743
+ Reflection callable that delegates to LLEGO-wrapped client.
744
+ In hybrid mode, the wrapper generates candidates from both GEPA and LLEGO.
745
+
746
+ 🔥 CRITICAL: Clean the prompt to remove base64 images and truncate if too long.
747
+ """
748
+ # 🔥 FIX: Clean prompt to remove base64 images and truncate excessive data
749
+ cleaned_prompt = self._clean_reflection_prompt(prompt)
750
+
751
+ self.logger.info(f"\n{'🔥'*40}")
752
+ self.logger.info(f"🔥 reflection_lm_callable CALLED (delegating to LLEGO wrapper)")
753
+ self.logger.info(f"🔥 Original prompt length: {len(prompt)} chars")
754
+ self.logger.info(f"🔥 Cleaned prompt length: {len(cleaned_prompt)} chars")
755
+ self.logger.info(f"🔥 Truncation: {len(prompt) - len(cleaned_prompt)} chars removed")
756
+ self.logger.info(f"🔥 First 200 chars (cleaned): {cleaned_prompt[:200]}...")
757
+ self.logger.info(f"{'🔥'*40}\n")
758
+
759
+ try:
760
+ # 🔥 CRITICAL: Set reflection context BEFORE generating
761
+ # This signals to the LLEGO wrapper that we're in reflection mode
762
+ if isinstance(reflection_lm_client, LLEGOEnhancedLLMClient):
763
+ reflection_lm_client.set_reflection_context(
764
+ current_prompt=cleaned_prompt, # Use cleaned prompt
765
+ feedback=None,
766
+ in_reflection=True # Enable reflection mode
767
+ )
768
+ self.logger.info("✅ Reflection context set on reflection_lm_client")
769
+
770
+ # 🔥 HYBRID MODE: If reflection_lm_client is wrapped with LLEGO,
771
+ # calling generate() will trigger hybrid candidate generation
772
+ # The wrapper handles queuing and returns candidates one by one
773
+
774
+ # 🔥 CRITICAL: System prompt must instruct LLM to generate improved prompt, not feedback
775
+ optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization.
776
+
777
+ Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues.
778
+
779
+ Core Requirements:
780
+ 1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary)
781
+ 2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening)
782
+ 3. PRESERVE the core task domain and output format requirements
783
+ 4. INTEGRATE improvements from feedback naturally into the prompt structure
784
+ 5. MAINTAIN clarity, specificity, and actionability
785
+
786
+ Quality Standards:
787
+ - Be specific and concrete (avoid vague instructions)
788
+ - Use clear, imperative language for task instructions
789
+ - Include edge case handling if feedback identifies confusion
790
+ - Ensure the prompt is self-contained and unambiguous
791
+
792
+ DO NOT include:
793
+ - Analysis of what went wrong
794
+ - Explanations of your changes
795
+ - Meta-text like "Here's an improved version..." or "Based on feedback..."
796
+ - Recommendations or suggestions (those are already in the feedback)
797
+
798
+ Output the improved prompt directly and only the prompt."""
799
+
800
+ result = reflection_lm_client.generate(
801
+ system_prompt=optimization_system_prompt,
802
+ user_prompt=cleaned_prompt, # Use cleaned prompt (no base64, truncated)
803
+ image_base64=""
804
+ )
805
+
806
+ # Extract content from result
807
+ if isinstance(result, dict):
808
+ candidate = result.get("content", str(result))
809
+ source = result.get("source", "unknown")
810
+ self.logger.info(f"✅ Candidate from {source} (FULL TEXT):")
811
+ self.logger.info(f" '{candidate}'")
812
+ return candidate
813
+ else:
814
+ candidate = str(result)
815
+ self.logger.info(f"✅ Candidate generated (FULL TEXT):")
816
+ self.logger.info(f" '{candidate}'")
817
+ return candidate
818
+
819
+ except Exception as e:
820
+ self.logger.error(f"❌ Error in reflection_lm_callable: {e}")
821
+ import traceback
822
+ self.logger.error(traceback.format_exc())
823
+ # Fallback: return prompt as-is
824
+ return prompt
825
+
826
+ # Set up reflection context for LLEGO wrapper
827
+ if self.config.enable_gepa_reflection_with_llego and isinstance(reflection_lm_client, LLEGOEnhancedLLMClient):
828
+ # Store current prompt in reflection context for LLEGO operators
829
+ reflection_lm_client.set_reflection_context(
830
+ current_prompt=seed_candidate.get('system_prompt', ''),
831
+ feedback=None,
832
+ in_reflection=True
833
+ )
834
+
835
+ else:
836
+ # Standard GEPA reflection (no LLEGO)
837
+ adapter = self.adapter # Use the original adapter
838
+
839
+ # 🔥 CRITICAL: Always set _reflection_lm_client in adapter (even without LLEGO)
840
+ # This is required for propose_new_texts() to work
841
+ if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None:
842
+ adapter._reflection_lm_client = reflection_lm_client
843
+ self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)")
844
+
845
+ # Define standard reflection callable (no LLEGO enhancement)
846
+ def reflection_lm_callable(prompt: str) -> str:
847
+ """Standard callable wrapper for reflection model that GEPA expects"""
848
+ try:
849
+ # 🔥 CRITICAL: System prompt must instruct LLM to generate improved prompt, not feedback
850
+ optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization.
851
+
852
+ Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues.
853
+
854
+ Core Requirements:
855
+ 1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary)
856
+ 2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening)
857
+ 3. PRESERVE the core task domain and output format requirements
858
+ 4. INTEGRATE improvements from feedback naturally into the prompt structure
859
+ 5. MAINTAIN clarity, specificity, and actionability
860
+
861
+ Quality Standards:
862
+ - Be specific and concrete (avoid vague instructions)
863
+ - Use clear, imperative language for task instructions
864
+ - Include edge case handling if feedback identifies confusion
865
+ - Ensure the prompt is self-contained and unambiguous
866
+
867
+ DO NOT include:
868
+ - Analysis of what went wrong
869
+ - Explanations of your changes
870
+ - Meta-text like "Here's an improved version..." or "Based on feedback..."
871
+ - Recommendations or suggestions (those are already in the feedback)
872
+
873
+ Output the improved prompt directly and only the prompt."""
874
+
875
+ # For reflection, we only need text generation (no images)
876
+ result = reflection_lm_client.generate(
877
+ system_prompt=optimization_system_prompt,
878
+ user_prompt=prompt,
879
+ image_base64="" # No image for reflection
880
+ )
881
+
882
+ # Extract string content from the result dictionary
883
+ if isinstance(result, dict):
884
+ return result.get("content", str(result))
885
+ else:
886
+ return str(result)
887
+
888
+ except Exception as e:
889
+ self.logger.error(f"Reflection model error: {e}")
890
+ return prompt # Return original prompt on error
891
+ self.logger.info(
892
+ f"Starting GEPA optimization with {max_iterations} iterations, "
893
+ f"batch size {batch_size}, max metric calls: {max_metric_calls}"
894
+ )
895
+ self.logger.info(
896
+ f"GEPA parameters: candidate_selection_strategy=pareto, "
897
+ f"reflection_minibatch_size={batch_size}, "
898
+ f"skip_perfect_score=False, "
899
+ f"module_selector=round_robin"
900
+ )
901
+
902
+ # Prepare optimization parameters with ONLY valid GEPA parameters
903
+ # Note: 'adapter' variable is set above (either LLEGO-enhanced or standard)
904
+ # 🔥 REMOVED: Excessive diagnostic warnings - moved to DEBUG level
905
+ reflection_lm_passed = reflection_lm_callable if self.config.use_llego_operators else None
906
+ if reflection_lm_passed:
907
+ self.logger.debug(f"reflection_lm_callable passed to GEPA (may be ignored in adapter mode)")
908
+
909
+ # #region agent log
910
+ import json as _json_debug
911
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
912
+ with open(_debug_log_path, "a") as _f:
913
+ _f.write(_json_debug.dumps({"hypothesisId": "A", "location": "optimizer.py:gepa_params", "message": "GEPA params construction", "data": {"max_iterations_from_config": max_iterations, "max_metric_calls": max_metric_calls, "batch_size": batch_size}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
914
+ # #endregion
915
+
916
+ gepa_params = {
917
+ 'adapter': adapter, # Use the adapter created above (with or without LLEGO)
918
+ 'seed_candidate': seed_candidate,
919
+ 'trainset': trainset,
920
+ 'valset': valset,
921
+ 'max_metric_calls': max_metric_calls,
922
+ # NOTE: GEPA does NOT have num_iterations - it uses max_metric_calls to control iterations
923
+
924
+ # 🔥 CRITICAL: When using an adapter, GEPA expects:
925
+ # - adapter.make_reflective_dataset() to create feedback data
926
+ # - GEPA's internal proposer to generate candidates from that data
927
+ # - task_lm and reflection_lm must be None (GEPA will use model from adapter)
928
+ 'task_lm': None, # Don't pass - adapter handles this
929
+ 'reflection_lm': reflection_lm_passed, # Pass LLEGO-enhanced reflection (may be ignored!)
930
+
931
+ # Valid GEPA parameters based on actual library
932
+ 'candidate_selection_strategy': 'pareto', # Use Pareto selection
933
+ 'skip_perfect_score': False, # Don't skip perfect scores
934
+ 'reflection_minibatch_size': batch_size, # Use batch size for reflection
935
+ 'perfect_score': 1.0, # Perfect score threshold
936
+ 'module_selector': 'round_robin', # Cycle through components
937
+ 'display_progress_bar': self.config.verbose, # Show progress if verbose
938
+ 'raise_on_exception': True, # Raise exceptions for debugging
939
+ }
940
+
941
+ # 🔥 CRITICAL FIX: Filter kwargs to only include valid GEPA parameters
942
+ # GEPA does NOT accept num_iterations, max_iterations, or other non-GEPA params
943
+ VALID_GEPA_PARAMS = {
944
+ 'seed_candidate', 'trainset', 'valset', 'adapter', 'task_lm', 'reflection_lm',
945
+ 'candidate_selection_strategy', 'skip_perfect_score', 'batch_sampler',
946
+ 'reflection_minibatch_size', 'perfect_score', 'reflection_prompt_template',
947
+ 'module_selector', 'use_merge', 'max_merge_invocations', 'merge_val_overlap_floor',
948
+ 'max_metric_calls', 'stop_callbacks', 'logger', 'run_dir', 'use_wandb',
949
+ 'wandb_api_key', 'wandb_init_kwargs', 'use_mlflow', 'mlflow_tracking_uri',
950
+ 'mlflow_experiment_name', 'track_best_outputs', 'display_progress_bar',
951
+ 'use_cloudpickle', 'seed', 'raise_on_exception', 'val_evaluation_policy'
952
+ }
953
+
954
+ # Only add valid kwargs that aren't already in gepa_params
955
+ for key, value in kwargs.items():
956
+ if key in VALID_GEPA_PARAMS and key not in gepa_params:
957
+ gepa_params[key] = value
958
+ elif key not in VALID_GEPA_PARAMS:
959
+ self.logger.debug(f"⚠️ Filtering out invalid GEPA parameter: {key}")
960
+
961
+ # #region agent log
962
+ with open(_debug_log_path, "a") as _f:
963
+ _f.write(_json_debug.dumps({"hypothesisId": "A", "location": "optimizer.py:gepa_params_final", "message": "Final GEPA params keys", "data": {"params_keys": list(gepa_params.keys()), "max_metric_calls": gepa_params.get('max_metric_calls', 'NOT_PASSED')}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
964
+ # #endregion
965
+
966
+ # 🎯 NEW: Capture GEPA's internal logging for pareto front information
967
+ gepa_output = io.StringIO()
968
+
969
+ # Log iteration start
970
+ from ..utils.clean_logger import get_clean_logger
971
+ clean_log = get_clean_logger()
972
+ clean_log.log_iteration_start(1, seed_prompt=seed_candidate.get('system_prompt', ''))
973
+
974
+ # 🔥 CRITICAL: Pass valset size to adapter for better dataset type detection
975
+ if hasattr(adapter, '_valset_size'):
976
+ adapter._valset_size = len(valset)
977
+ self.logger.debug(f"✅ Set valset_size in adapter: {len(valset)} for Dpareto detection")
978
+
979
+ # 🔥 CRITICAL FIX: Store valset in adapter so we can evaluate generated candidates on it
980
+ # This ensures generated candidates are evaluated on Dpareto for Pareto selection
981
+ if hasattr(adapter, '_valset'):
982
+ adapter._valset = valset
983
+ self.logger.debug(f"✅ Stored valset in adapter ({len(valset)} samples) for Dpareto evaluation of generated candidates")
984
+ else:
985
+ # Add _valset attribute if it doesn't exist
986
+ adapter._valset = valset
987
+ self.logger.debug(f"✅ Added _valset attribute to adapter ({len(valset)} samples)")
988
+
989
+ # Run GEPA optimization (synchronous call wrapped in async)
990
+ result = await asyncio.get_event_loop().run_in_executor(
991
+ None,
992
+ lambda: self._run_gepa_with_logging(gepa_params, gepa_output)
993
+ )
994
+
995
+ # 🎯 NEW: Process and log pareto front information, extract iteration count
996
+ gepa_logs = gepa_output.getvalue()
997
+ actual_iterations = self._log_pareto_front_info(gepa_logs) # Get iteration count
998
+
999
+ return result, actual_iterations # Return both result and iteration count
1000
+ except Exception as e:
1001
+ # Try to extract partial results before failing
1002
+ self.logger.warning(f"GEPA optimization failed: {e}")
1003
+
1004
+ # Check if we have any cached results from the adapter
1005
+ best_candidate = adapter.get_best_candidate()
1006
+ best_score = adapter.get_best_score()
1007
+
1008
+ if best_candidate and best_score > 0:
1009
+ self.logger.info(f"🎯 Using cached best result with score: {best_score:.4f}")
1010
+
1011
+ # Create a mock GEPA result with the best candidate found
1012
+ return {
1013
+ 'best_candidate': best_candidate,
1014
+ 'best_score': best_score,
1015
+ 'partial_result': True,
1016
+ 'error': f'GEPA failed but returning best result found: {str(e)}'
1017
+ }
1018
+ else:
1019
+ # If no cached results, re-raise the error
1020
+ raise GepaOptimizerError(f"GEPA optimization failed: {str(e)}")
1021
+
1022
+ def _run_gepa_with_logging(self, gepa_params: Dict[str, Any], output_buffer: io.StringIO) -> Any:
1023
+ """Run GEPA optimization while capturing its output."""
1024
+ # Capture GEPA's print statements and logging
1025
+ with redirect_stdout(output_buffer), redirect_stderr(output_buffer):
1026
+ return gepa.optimize(**gepa_params)
1027
+
1028
+ def _log_pareto_front_info(self, gepa_logs: str) -> int: # Return int instead of None
1029
+ """Extract and log pareto front information from GEPA logs. Returns max iteration count."""
1030
+ lines = gepa_logs.split('\n')
1031
+ current_iteration = 0
1032
+ max_iteration = 0 # Track max iteration
1033
+
1034
+ for line in lines:
1035
+ # Look for iteration information
1036
+ if 'iteration' in line.lower():
1037
+ # Try to extract iteration number
1038
+ import re
1039
+ iteration_match = re.search(r'iteration\s+(\d+)', line.lower())
1040
+ if iteration_match:
1041
+ current_iteration = int(iteration_match.group(1))
1042
+ max_iteration = max(max_iteration, current_iteration) # Track max
1043
+ # Log iteration change
1044
+ from ..utils.clean_logger import get_clean_logger
1045
+ clean_log = get_clean_logger()
1046
+ if current_iteration > clean_log.current_iteration:
1047
+ clean_log.current_iteration = current_iteration
1048
+
1049
+ # Look for pareto front information
1050
+ if 'pareto front' in line.lower() or 'new program' in line.lower():
1051
+ self.logger.info(f"GEPA Pareto Update: {line.strip()}")
1052
+ elif 'iteration' in line.lower() and ('score' in line.lower() or 'program' in line.lower()):
1053
+ self.logger.debug(f"{line.strip()}")
1054
+ elif 'best' in line.lower() and 'score' in line.lower():
1055
+ self.logger.info(f"{line.strip()}")
1056
+
1057
+ # Look for evaluation information
1058
+ if 'evaluating' in line.lower() and 'candidate' in line.lower():
1059
+ self.logger.debug(f"{line.strip()}")
1060
+
1061
+ self.logger.info(f"GEPA Optimization Complete: {max_iteration} iterations")
1062
+
1063
+ # #region agent log
1064
+ import json as _json_debug
1065
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
1066
+ with open(_debug_log_path, "a") as _f:
1067
+ _f.write(_json_debug.dumps({"hypothesisId": "F", "location": "optimizer.py:gepa_complete", "message": "GEPA optimization complete - iteration count", "data": {"max_iteration_from_logs": max_iteration, "expected_iterations": self.config.max_iterations, "off_by_one": max_iteration != self.config.max_iterations, "gepa_logs_length": len(gepa_logs)}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
1068
+ # #endregion
1069
+
1070
+ return max_iteration # Return the max iteration count
1071
+
1072
+ def _extract_best_candidate(self, gepa_result: Any) -> Dict[str, str]:
1073
+ """
1074
+ Extract the best candidate from GEPA Pareto front (single source of truth).
1075
+
1076
+ GEPA Pareto front is the single source of truth because:
1077
+ - All candidates (GEPA reflection, LLEGO crossover, LLEGO mutation) are evaluated on Dpareto
1078
+ - All non-dominated candidates are added to GEPA Pareto front
1079
+ - Therefore, the best candidate MUST be in GEPA Pareto front
1080
+
1081
+ Args:
1082
+ gepa_result: Raw result from gepa.optimize() (used only as fallback edge case)
1083
+
1084
+ Returns:
1085
+ Best candidate dictionary with prompt components from GEPA Pareto front
1086
+ """
1087
+ try:
1088
+ self.logger.info(f"\n{'═'*80}")
1089
+ self.logger.info(f"🔍 EXTRACTING BEST CANDIDATE FROM GEPA PARETO FRONT")
1090
+ self.logger.info(f"{'═'*80}")
1091
+
1092
+ # ========================================================================
1093
+ # PRIMARY: Get best candidate from GEPA Pareto front (single source of truth)
1094
+ # ========================================================================
1095
+ from ..utils.pareto_logger import get_pareto_logger
1096
+ pareto_log = get_pareto_logger()
1097
+
1098
+ if pareto_log.pareto_front:
1099
+ try:
1100
+ # Get best candidate from GEPA Pareto front (highest score = best)
1101
+ gepa_pareto_best = max(pareto_log.pareto_front, key=lambda x: x['score'])
1102
+ gepa_pareto_fitness = gepa_pareto_best['score']
1103
+ gepa_pareto_prompt = gepa_pareto_best['prompt']
1104
+ gepa_pareto_type = gepa_pareto_best.get('type', 'unknown')
1105
+ gepa_pareto_notation = gepa_pareto_best.get('notation', 'S')
1106
+
1107
+ best_candidate = {
1108
+ 'system_prompt': gepa_pareto_prompt,
1109
+ 'fitness': gepa_pareto_fitness,
1110
+ 'source': 'gepa_pareto_front',
1111
+ 'candidate_type': gepa_pareto_type,
1112
+ 'notation': gepa_pareto_notation
1113
+ }
1114
+
1115
+ self.logger.info(f"✅ SELECTED: Best candidate from GEPA Pareto front")
1116
+ self.logger.info(f" Notation: {gepa_pareto_notation}")
1117
+ self.logger.info(f" Fitness: f({gepa_pareto_notation})={gepa_pareto_fitness:.4f}")
1118
+ self.logger.info(f" Type: {gepa_pareto_type}")
1119
+ self.logger.info(f" Prompt length: {len(gepa_pareto_prompt)} chars")
1120
+ self.logger.info(f" 💡 GEPA Pareto front is single source of truth (all candidates evaluated on Dpareto)")
1121
+
1122
+ return best_candidate
1123
+
1124
+ except Exception as e:
1125
+ self.logger.error(f"❌ Failed to extract from GEPA Pareto front: {e}")
1126
+ import traceback
1127
+ self.logger.error(traceback.format_exc())
1128
+
1129
+ # ========================================================================
1130
+ # EDGE CASE FALLBACK: Pareto front empty (shouldn't happen, but handle gracefully)
1131
+ # ========================================================================
1132
+ self.logger.warning(f"⚠️ GEPA Pareto front is empty - using gepa_result as fallback")
1133
+ self.logger.warning(f" This should not happen if all candidates are evaluated on Dpareto")
1134
+
1135
+ # Try to extract from gepa_result (last resort)
1136
+ if hasattr(gepa_result, 'best_candidate'):
1137
+ gepa_candidate = gepa_result.best_candidate
1138
+ gepa_prompt = gepa_candidate.get('system_prompt') if isinstance(gepa_candidate, dict) else str(gepa_candidate)
1139
+ gepa_fitness = getattr(gepa_result, 'best_score', None)
1140
+
1141
+ if gepa_prompt:
1142
+ self.logger.info(f"✅ Using gepa_result.best_candidate as fallback")
1143
+ return {
1144
+ 'system_prompt': gepa_prompt,
1145
+ 'fitness': float(gepa_fitness) if gepa_fitness is not None else None,
1146
+ 'source': 'gepa_result_fallback',
1147
+ 'candidate_type': 'unknown',
1148
+ 'notation': 'S'
1149
+ }
1150
+
1151
+ # Last resort: return empty prompt
1152
+ self.logger.error(f"❌ No candidates found anywhere - returning empty prompt")
1153
+ return {'system_prompt': ''}
1154
+
1155
+ except Exception as e:
1156
+ self.logger.error(f"❌ Error extracting best candidate: {e}")
1157
+ import traceback
1158
+ self.logger.error(traceback.format_exc())
1159
+ return {'system_prompt': ''}
1160
+
1161
+ def _evaluate_candidate_on_testset(
1162
+ self,
1163
+ candidate: Dict[str, str],
1164
+ testset: List[Dict]
1165
+ ) -> float:
1166
+ """
1167
+ Evaluate a candidate prompt on the held-out test set.
1168
+
1169
+ Args:
1170
+ candidate: Prompt candidate to evaluate
1171
+ testset: Test dataset (not used during optimization)
1172
+
1173
+ Returns:
1174
+ Average composite score on test set
1175
+
1176
+ Raises:
1177
+ TestSetEvaluationError: If evaluation fails
1178
+ """
1179
+ from ..utils.exceptions import TestSetEvaluationError
1180
+
1181
+ try:
1182
+ # Evaluate using the adapter (same as GEPA does internally)
1183
+ eval_result = self.adapter.evaluate(
1184
+ batch=testset,
1185
+ candidate=candidate,
1186
+ capture_traces=False # Don't need detailed traces for test
1187
+ )
1188
+
1189
+ if not eval_result.scores:
1190
+ raise TestSetEvaluationError("No scores returned from test evaluation")
1191
+
1192
+ # Calculate average score
1193
+ avg_score = sum(eval_result.scores) / len(eval_result.scores)
1194
+
1195
+ self.logger.debug(
1196
+ f"Test set evaluation: {len(eval_result.scores)} samples, "
1197
+ f"scores: {eval_result.scores}, avg: {avg_score:.4f}"
1198
+ )
1199
+
1200
+ return avg_score
1201
+
1202
+ except Exception as e:
1203
+ raise TestSetEvaluationError(f"Failed to evaluate on test set: {str(e)}")
1204
+
1205
+ def optimize_sync(self,
1206
+ model: str,
1207
+ seed_prompt: str,
1208
+ dataset: Any,
1209
+ reflection_lm: str,
1210
+ max_metric_calls: int = 150,
1211
+ **kwargs) -> OptimizedResult:
1212
+ """
1213
+ Synchronous version of the optimization method
1214
+
1215
+ Args:
1216
+ model: Target model to optimize for
1217
+ seed_prompt: Initial prompt to optimize
1218
+ dataset: Training data in any format
1219
+ reflection_lm: Model for reflection
1220
+ max_metric_calls: Budget for optimization attempts
1221
+ **kwargs: Additional optimization parameters
1222
+
1223
+ Returns:
1224
+ OptimizedResult: Optimization result
1225
+ """
1226
+ # Run the async method in a new event loop
1227
+ loop = asyncio.new_event_loop()
1228
+ asyncio.set_event_loop(loop)
1229
+
1230
+ try:
1231
+ result = loop.run_until_complete(
1232
+ self.train(model, seed_prompt, dataset, reflection_lm, max_metric_calls, **kwargs)
1233
+ )
1234
+ return result
1235
+ finally:
1236
+ loop.close()
1237
+
1238
+
1239
+ # Convenience function for quick optimization
1240
+ def optimize_prompt(
1241
+ model: Union[str, ModelConfig],
1242
+ seed_prompt: str,
1243
+ dataset: Any,
1244
+ reflection_model: Optional[Union[str, ModelConfig]] = None,
1245
+ **kwargs
1246
+ ) -> OptimizedResult:
1247
+ """
1248
+ Convenience function for quick prompt optimization without creating optimizer instance
1249
+
1250
+ Args:
1251
+ model: Target model configuration
1252
+ seed_prompt: Initial prompt to optimize
1253
+ dataset: Training data
1254
+ reflection_model: Model for reflection (optional)
1255
+ **kwargs: Additional optimization parameters
1256
+
1257
+ Returns:
1258
+ OptimizedResult: Optimization result
1259
+ """
1260
+ # Create default config if not provided
1261
+ if reflection_model is None:
1262
+ reflection_model = model
1263
+
1264
+ config = OptimizationConfig(
1265
+ model=model,
1266
+ reflection_model=reflection_model,
1267
+ max_iterations=kwargs.get('max_iterations', 10),
1268
+ max_metric_calls=kwargs.get('max_metric_calls', 50),
1269
+ batch_size=kwargs.get('batch_size', 4)
1270
+ )
1271
+
1272
+ optimizer = GepaOptimizer(config=config)
1273
+ return asyncio.run(optimizer.train(seed_prompt, dataset, **kwargs))
1274
+
1275
+
1276
+
1277
+
1278
+
1279
+
src/gepa_optimizer/core/result.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Result processing for GEPA Optimizer
3
+ Handles extraction and processing of GEPA optimization results
4
+ """
5
+
6
+ from typing import Any, Dict, Optional
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ResultProcessor:
12
+ """
13
+ Processes raw GEPA optimization results into clean, usable formats
14
+ """
15
+
16
+ @staticmethod
17
+ def extract_optimized_prompt(result: Any) -> str:
18
+ """
19
+ Extract the optimized prompt from GEPA result object
20
+
21
+ Args:
22
+ result: Raw GEPA optimization result
23
+
24
+ Returns:
25
+ str: The optimized prompt text
26
+ """
27
+ try:
28
+ # Try multiple possible result structures
29
+ if hasattr(result, 'best_candidate'):
30
+ candidate = result.best_candidate
31
+
32
+ if isinstance(candidate, dict):
33
+ # Try common prompt keys
34
+ for key in ['system_prompt', 'prompt', 'text']:
35
+ if key in candidate:
36
+ return str(candidate[key])
37
+
38
+ # If no standard key found, return string representation
39
+ return str(candidate)
40
+ else:
41
+ return str(candidate)
42
+
43
+ # Fallback - convert entire result to string
44
+ return str(result)
45
+
46
+ except Exception as e:
47
+ logger.warning(f"Failed to extract optimized prompt: {e}")
48
+ return "Optimization completed (prompt extraction failed)"
49
+
50
+ @staticmethod
51
+ def extract_metrics(result: Any) -> Dict[str, Any]:
52
+ """
53
+ Extract performance metrics from GEPA result
54
+
55
+ Args:
56
+ result: Raw GEPA optimization result
57
+
58
+ Returns:
59
+ Dict[str, Any]: Extracted metrics
60
+ """
61
+ metrics = {}
62
+
63
+ try:
64
+ # Extract common metrics
65
+ if hasattr(result, 'best_score'):
66
+ metrics['best_score'] = float(result.best_score)
67
+
68
+ if hasattr(result, 'baseline_score'):
69
+ metrics['baseline_score'] = float(result.baseline_score)
70
+
71
+ if hasattr(result, 'improvement'):
72
+ metrics['improvement'] = float(result.improvement)
73
+
74
+ if hasattr(result, 'iterations'):
75
+ metrics['iterations'] = int(result.iterations)
76
+
77
+ # Calculate improvement percentage if we have both scores
78
+ if 'best_score' in metrics and 'baseline_score' in metrics:
79
+ baseline = metrics['baseline_score']
80
+ if baseline > 0:
81
+ improvement_percent = ((metrics['best_score'] - baseline) / baseline) * 100
82
+ metrics['improvement_percent'] = round(improvement_percent, 2)
83
+
84
+ # Extract additional metadata
85
+ if hasattr(result, 'metadata'):
86
+ metrics['metadata'] = result.metadata
87
+
88
+ except Exception as e:
89
+ logger.warning(f"Failed to extract metrics: {e}")
90
+
91
+ return metrics
92
+
93
+ @staticmethod
94
+ def extract_reflection_history(result: Any) -> list:
95
+ """
96
+ Extract reflection/optimization history from GEPA result
97
+
98
+ Args:
99
+ result: Raw GEPA optimization result
100
+
101
+ Returns:
102
+ list: List of reflection iterations
103
+ """
104
+ history = []
105
+
106
+ try:
107
+ if hasattr(result, 'optimization_history'):
108
+ for i, iteration in enumerate(result.optimization_history):
109
+ history_item = {
110
+ 'iteration': i,
111
+ 'score': iteration.get('score', 0.0),
112
+ 'candidate': iteration.get('candidate', {}),
113
+ 'feedback': iteration.get('feedback', ''),
114
+ 'improvement': iteration.get('improvement', 0.0)
115
+ }
116
+ history.append(history_item)
117
+
118
+ except Exception as e:
119
+ logger.warning(f"Failed to extract reflection history: {e}")
120
+
121
+ return history
122
+
123
+ @staticmethod
124
+ def process_full_result(
125
+ result: Any,
126
+ original_prompt: str,
127
+ optimization_time: float,
128
+ actual_iterations: Optional[int] = None,
129
+ test_metrics: Optional[Dict[str, Any]] = None
130
+ ) -> Dict[str, Any]:
131
+ """
132
+ Process complete GEPA result into structured format.
133
+
134
+ Args:
135
+ result: Raw GEPA optimization result
136
+ original_prompt: Original seed prompt
137
+ optimization_time: Time taken for optimization
138
+ actual_iterations: Actual number of iterations from GEPA logs (optional)
139
+ test_metrics: Metrics from test set evaluation (optional)
140
+
141
+ Returns:
142
+ Dict[str, Any]: Complete processed result
143
+ """
144
+ # Extract metrics first
145
+ metrics = ResultProcessor.extract_metrics(result)
146
+
147
+ # Extract iterations from GEPA result
148
+ total_iterations = 0
149
+ try:
150
+ # First priority: use actual_iterations if provided (from logs)
151
+ if actual_iterations is not None:
152
+ total_iterations = actual_iterations
153
+ elif hasattr(result, 'iterations'):
154
+ total_iterations = int(result.iterations)
155
+ elif hasattr(result, 'num_iterations'):
156
+ total_iterations = int(result.num_iterations)
157
+ elif hasattr(result, 'optimization_history'):
158
+ total_iterations = len(result.optimization_history)
159
+ # Check if it's in metrics
160
+ elif 'iterations' in metrics:
161
+ total_iterations = metrics['iterations']
162
+ except Exception as e:
163
+ logger.warning(f"Failed to extract iterations: {e}")
164
+
165
+ # Merge test metrics into improvement_data
166
+ improvement_data = {}
167
+ if test_metrics:
168
+ improvement_data.update(test_metrics)
169
+
170
+ return {
171
+ 'original_prompt': original_prompt,
172
+ 'optimized_prompt': ResultProcessor.extract_optimized_prompt(result),
173
+ 'metrics': metrics,
174
+ 'improvement_data': improvement_data,
175
+ 'reflection_history': ResultProcessor.extract_reflection_history(result),
176
+ 'optimization_time': optimization_time,
177
+ 'total_iterations': total_iterations,
178
+ 'status': 'completed',
179
+ 'raw_result': result # Keep raw result for advanced users
180
+ }
src/gepa_optimizer/core/universal_adapter.py ADDED
The diff for this file is too large to render. See raw diff
 
src/gepa_optimizer/data/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data module for GEPA Optimizer
3
+ """
4
+
5
+ from .converters import UniversalConverter
6
+ from .loaders import DataLoader
7
+ from .validators import DataValidator
8
+ from .scroll_dataset_loader import ScrollDatasetLoader, load_scroll_dataset
9
+ from .validation_dataset_loader import ValidationDatasetLoader, load_validation_dataset, load_validation_split
10
+ from .index_caching_loader import IndexCachingDatasetLoader, load_index_caching_dataset, load_index_caching_split
11
+
12
+ __all__ = [
13
+ "UniversalConverter",
14
+ "DataLoader",
15
+ "DataValidator",
16
+ # Scroll dataset
17
+ "ScrollDatasetLoader",
18
+ "load_scroll_dataset",
19
+ # Validation dataset
20
+ "ValidationDatasetLoader",
21
+ "load_validation_dataset",
22
+ "load_validation_split",
23
+ # Index caching dataset
24
+ "IndexCachingDatasetLoader",
25
+ "load_index_caching_dataset",
26
+ "load_index_caching_split",
27
+ ]
src/gepa_optimizer/data/converters.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal converter for dataset to GEPA format with 3-way split (train/val/test)
3
+ """
4
+
5
+ import os
6
+ import json
7
+ from typing import Any, List, Tuple, Union, Dict, Optional
8
+ from pathlib import Path
9
+ import pandas as pd
10
+ import logging
11
+
12
+ from .loaders import DataLoader
13
+ from ..utils.exceptions import DatasetError
14
+ from ..models.config import DataSplitConfig
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class UniversalConverter:
19
+ """
20
+ Universal converter for datasets to GEPA format.
21
+
22
+ Handles 3-way splitting (train/val/test) with configurable ratios and
23
+ graceful handling of small datasets.
24
+ """
25
+
26
+ def __init__(self, data_split_config: Optional[DataSplitConfig] = None):
27
+ """
28
+ Initialize converter with optional split configuration.
29
+
30
+ Args:
31
+ data_split_config: Configuration for train/val/test splits.
32
+ If None, uses default 60/20/20 split.
33
+ """
34
+ self.supported_extensions = [
35
+ '.csv', '.json', '.jsonl', '.txt', '.md',
36
+ '.png', '.jpg', '.jpeg'
37
+ ]
38
+ self.loader = DataLoader()
39
+ self.data_split_config = data_split_config or DataSplitConfig()
40
+
41
+ def convert(
42
+ self,
43
+ dataset: Union[List[Any], str, Any, Dict[str, Any]],
44
+ split_config: Optional[DataSplitConfig] = None
45
+ ) -> Tuple[List[dict], List[dict], List[dict]]:
46
+ """
47
+ Convert any dataset to GEPA format with 3-way split (train/val/test).
48
+
49
+ Args:
50
+ dataset: Input dataset in any supported format
51
+ split_config: Optional split configuration (overrides instance config)
52
+
53
+ Returns:
54
+ Tuple of (trainset, valset, testset) where:
55
+ - trainset: Used for reflection/feedback (Dfeedback in GEPA paper)
56
+ - valset: Used for Pareto selection (Dpareto in GEPA paper)
57
+ - testset: Held-out for final evaluation (not passed to GEPA)
58
+
59
+ Raises:
60
+ DatasetError: If dataset cannot be converted or is too small
61
+ """
62
+ try:
63
+ # Use provided split config or instance default
64
+ config = split_config or self.data_split_config
65
+
66
+ # Handle UI tree dataset format
67
+ if isinstance(dataset, dict) and 'type' in dataset and dataset['type'] == 'ui_tree_dataset':
68
+ return self.convert_ui_tree_dataset(
69
+ dataset.get('json_dir', 'json_tree'),
70
+ dataset.get('screenshots_dir', 'screenshots'),
71
+ split_config=config
72
+ )
73
+ elif isinstance(dataset, str):
74
+ data = self._load_from_path(dataset)
75
+ elif hasattr(dataset, 'to_dict'): # pandas DataFrame
76
+ data = dataset.to_dict(orient='records')
77
+ elif isinstance(dataset, list):
78
+ data = dataset
79
+ else:
80
+ data = [dataset]
81
+
82
+ logger.info(f"Normalized data length: {len(data)}")
83
+ standardized = self._standardize(data)
84
+ train, val, test = self._split_three_way(standardized, config)
85
+ return train, val, test
86
+ except (FileNotFoundError, ValueError, TypeError) as e:
87
+ raise DatasetError(f"Failed to convert dataset: {str(e)}")
88
+
89
+ def _load_from_path(self, path: str) -> List[Any]:
90
+ """Load data from file path"""
91
+ p = Path(path)
92
+ if not p.exists():
93
+ raise FileNotFoundError(f"File not found: {path}")
94
+
95
+ ext = p.suffix.lower()
96
+ if ext in self.supported_extensions:
97
+ return [self.loader.load(p)]
98
+ else:
99
+ raise DatasetError(f"Unsupported file extension: {ext}")
100
+
101
+ def _standardize(self, data: List[Any]) -> List[dict]:
102
+ """Standardize data to input/output format
103
+
104
+ Handles both UI tree JSON format and simple text inputs.
105
+ UI tree format should have: {'screenshot': str, 'ui_tree': dict, 'expected_output': str}
106
+ Simple format can be: {'input': str, 'output': str} or {'question': str, 'answer': str} etc.
107
+ """
108
+ out = []
109
+ for item in data:
110
+ if not isinstance(item, dict):
111
+ item = {'input': str(item)}
112
+
113
+ # Handle UI tree JSON format
114
+ if 'ui_tree' in item and 'screenshot' in item:
115
+ ui_tree = item['ui_tree']
116
+ input_text = ui_tree.get('text', '')
117
+ output_text = item.get('expected_output', '')
118
+ image = item.get('screenshot', '')
119
+ out.append({'input': input_text, 'output': output_text, 'image': image})
120
+ # Handle simple text format
121
+ else:
122
+ inp = self._extract(item, ['input', 'question', 'text', 'prompt']) or ''
123
+ outp = self._extract(item, ['output', 'result', 'response', 'answer', 'expected_output']) or ''
124
+ image = self._extract(item, ['image', 'image_base64', 'screenshot']) or ''
125
+ out.append({'input': inp, 'output': outp, 'image': image})
126
+
127
+ return out
128
+
129
+ def _extract(self, d: dict, keys: List[str]) -> Union[str, None]:
130
+ """Extract value by trying multiple keys"""
131
+ for k in keys:
132
+ if k in d:
133
+ return d[k]
134
+ return None
135
+
136
+ def _split_three_way(
137
+ self,
138
+ data: List[dict],
139
+ config: DataSplitConfig
140
+ ) -> Tuple[List[dict], List[dict], List[dict]]:
141
+ """
142
+ Split data into train, validation, and test sets.
143
+
144
+ Args:
145
+ data: Standardized dataset
146
+ config: Split configuration with ratios and strategies
147
+
148
+ Returns:
149
+ Tuple of (train, val, test) datasets
150
+
151
+ Raises:
152
+ ValueError: If dataset is too small for configured splits
153
+ """
154
+ dataset_size = len(data)
155
+
156
+ # 🔥 NEW: Log adaptive strategy if being used
157
+ if config.small_dataset_strategy == 'adaptive':
158
+ train_ratio, val_ratio, test_ratio = config.get_adaptive_ratios(dataset_size)
159
+ logger.info(
160
+ f"📊 Adaptive dataset splitting (strategy: adaptive, size: {dataset_size}): "
161
+ f"ratios = {train_ratio*100:.0f}%/{val_ratio*100:.0f}%/{test_ratio*100:.0f}% "
162
+ f"(prioritizes validation for reliable candidate ranking)"
163
+ )
164
+
165
+ # Get split indices from config
166
+ try:
167
+ train_end, val_end, test_end, _ = config.get_split_indices(dataset_size)
168
+ except ValueError as e:
169
+ logger.error(f"Dataset split error: {e}")
170
+ raise DatasetError(str(e))
171
+
172
+ # Perform the split
173
+ train = data[:train_end]
174
+ val = data[train_end:val_end]
175
+ test = data[val_end:test_end]
176
+
177
+ # Log split information with strategy
178
+ strategy_note = ""
179
+ if config.small_dataset_strategy == 'adaptive':
180
+ strategy_note = " (adaptive)"
181
+ logger.info(
182
+ f"Dataset split{strategy_note}: {len(train)} train ({len(train)/dataset_size*100:.1f}%), "
183
+ f"{len(val)} val ({len(val)/dataset_size*100:.1f}%), "
184
+ f"{len(test)} test ({len(test)/dataset_size*100:.1f}%)"
185
+ )
186
+
187
+ # Validate splits are not empty
188
+ if len(train) == 0:
189
+ raise DatasetError("Training set is empty after split")
190
+ if len(val) == 0:
191
+ logger.warning("Validation set is empty - this may cause issues with Pareto selection")
192
+ val = [train[-1]] # Use last training sample as fallback
193
+ if len(test) == 0:
194
+ logger.warning("Test set is empty - final evaluation will not be performed")
195
+
196
+ return train, val, test
197
+
198
+ def _split(self, data: List[dict], ratio: float = 0.8) -> Tuple[List[dict], List[dict]]:
199
+ """
200
+ DEPRECATED: Legacy 2-way split for backwards compatibility.
201
+
202
+ Use _split_three_way() instead for production code.
203
+
204
+ Args:
205
+ data: Standardized dataset
206
+ ratio: Train ratio (0.0-1.0)
207
+
208
+ Returns:
209
+ Tuple of (train, val) datasets
210
+ """
211
+ import warnings
212
+ warnings.warn(
213
+ "_split() is deprecated. Use _split_three_way() for 3-way splitting.",
214
+ DeprecationWarning,
215
+ stacklevel=2
216
+ )
217
+
218
+ split = max(1, int(len(data) * ratio))
219
+ train = data[:split]
220
+ val = data[split:] or data[-1:] # Ensure val is not empty
221
+ return train, val
222
+
223
+ def convert_ui_tree_dataset(
224
+ self,
225
+ json_dir: str,
226
+ screenshots_dir: str,
227
+ split_config: Optional[DataSplitConfig] = None
228
+ ) -> Tuple[List[dict], List[dict], List[dict]]:
229
+ """
230
+ Convert UI tree dataset (JSON + screenshots) to GEPA format with 3-way split.
231
+
232
+ Args:
233
+ json_dir: Directory containing JSON files
234
+ screenshots_dir: Directory containing screenshot images
235
+ split_config: Optional split configuration (overrides instance config)
236
+
237
+ Returns:
238
+ Tuple of (train_data, val_data, test_data) in GEPA format
239
+
240
+ Raises:
241
+ DatasetError: If dataset cannot be loaded or is invalid
242
+ """
243
+ try:
244
+ # Load paired dataset
245
+ dataset = self.loader.load_ui_tree_dataset(json_dir, screenshots_dir)
246
+
247
+ if not dataset:
248
+ raise DatasetError("No valid image-JSON pairs found")
249
+
250
+ logger.info(f"Loaded {len(dataset)} UI tree samples")
251
+
252
+ # Use provided config or instance default
253
+ config = split_config or self.data_split_config
254
+
255
+ # Split into train/val/test
256
+ train, val, test = self._split_three_way(dataset, config)
257
+
258
+ logger.info(
259
+ f"Split UI tree dataset: {len(train)} train, "
260
+ f"{len(val)} validation, {len(test)} test"
261
+ )
262
+ return train, val, test
263
+
264
+ except Exception as e:
265
+ raise DatasetError(f"Failed to convert UI tree dataset: {str(e)}")
src/gepa_optimizer/data/index_caching_loader.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Index Caching Dataset Loader
3
+
4
+ Loads index caching dataset from JSON file (note2_debug.json format) and converts to GEPA-compatible format.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import base64
10
+ import logging
11
+ from typing import List, Dict, Any, Optional
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class IndexCachingDatasetLoader:
18
+ """
19
+ Loads index caching dataset from JSON file.
20
+
21
+ Expected JSON format:
22
+ [
23
+ {
24
+ "command": "Tap on first option from the suggestion",
25
+ "image": "element_images/QMxgc_14_0_tap_IkALe_element.png",
26
+ "xml": "xml/IkALe__debug.xml",
27
+ "expected": {
28
+ "is_index_based": true,
29
+ "index_value": 1,
30
+ "parent_element_id": "aaaabf",
31
+ "element_id_of_nth_child_of_parent": "aaaabg",
32
+ "selected_element_is_correct": true
33
+ }
34
+ },
35
+ ...
36
+ ]
37
+
38
+ Converts to GEPA format:
39
+ - input: command text (seed prompt will be provided in test script)
40
+ - output: JSON string with expected values
41
+ - image_base64: base64 encoded image (TOP LEVEL for UniversalConverter)
42
+ - input: Command + XML content (combined in user prompt)
43
+ - metadata: All original fields plus converted values
44
+ """
45
+
46
+ def __init__(self, json_path: Optional[str] = None, base_dir: Optional[str] = None):
47
+ """
48
+ Initialize index caching dataset loader.
49
+
50
+ Args:
51
+ json_path: Path to JSON file. Default: "./note2_debug.json" or from env var
52
+ base_dir: Base directory for resolving relative paths in JSON.
53
+ Default: Directory containing JSON file
54
+
55
+ Raises:
56
+ FileNotFoundError: If JSON file doesn't exist
57
+ json.JSONDecodeError: If JSON file is invalid
58
+ """
59
+ # Get JSON path from env or use default
60
+ if json_path is None:
61
+ json_path = os.getenv("INDEX_CACHING_DATASET_PATH", "./note2_debug.json")
62
+
63
+ self.json_path = Path(json_path).resolve()
64
+
65
+ if not self.json_path.exists():
66
+ raise FileNotFoundError(
67
+ f"Dataset file not found: {self.json_path}\n"
68
+ f"Make sure note2_debug.json exists in the project root."
69
+ )
70
+
71
+ # Base directory for resolving relative paths
72
+ if base_dir is None:
73
+ base_dir = self.json_path.parent
74
+ self.base_dir = Path(base_dir).resolve()
75
+
76
+ def load_dataset(self) -> List[Dict[str, Any]]:
77
+ """
78
+ Load dataset from JSON file and convert to GEPA format.
79
+
80
+ Returns:
81
+ List of dataset items in GEPA format:
82
+ [
83
+ {
84
+ "input": "Tap on first option from the suggestion", # Command only
85
+ "output": '{"is_index_based": true, "index_value": 1, ...}', # Expected JSON
86
+ "image_base64": "<base64_encoded_image>", # TOP LEVEL
87
+ "metadata": {
88
+ "command": "...",
89
+ "image_path": "...",
90
+ "xml_path": "...",
91
+ "expected": {...}
92
+ }
93
+ },
94
+ ...
95
+ ]
96
+
97
+ Raises:
98
+ FileNotFoundError: If image or XML file doesn't exist
99
+ json.JSONDecodeError: If JSON file is invalid
100
+ """
101
+ # Load JSON file
102
+ with open(self.json_path, "r", encoding="utf-8") as f:
103
+ dataset = json.load(f)
104
+
105
+ gepa_dataset = []
106
+
107
+ for idx, entry in enumerate(dataset):
108
+ command = entry.get("command", "")
109
+ image_path = entry.get("image", "")
110
+ xml_path = entry.get("xml", "")
111
+ expected = entry.get("expected", {})
112
+
113
+ # Resolve paths relative to base_dir
114
+ abs_image_path = (self.base_dir / image_path).resolve()
115
+ abs_xml_path = (self.base_dir / xml_path).resolve()
116
+
117
+ # Validate paths
118
+ if not abs_image_path.exists():
119
+ raise FileNotFoundError(
120
+ f"Image file not found: {abs_image_path}\n"
121
+ f"Entry {idx + 1}: {command}"
122
+ )
123
+
124
+ if not abs_xml_path.exists():
125
+ raise FileNotFoundError(
126
+ f"XML file not found: {abs_xml_path}\n"
127
+ f"Entry {idx + 1}: {command}"
128
+ )
129
+
130
+ # Load and encode image
131
+ with open(abs_image_path, "rb") as f:
132
+ image_data = f.read()
133
+ image_base64 = base64.b64encode(image_data).decode("utf-8")
134
+
135
+ # Load XML content
136
+ with open(abs_xml_path, "r", encoding="utf-8") as f:
137
+ xml_content = f.read()
138
+
139
+ # Convert expected to JSON string
140
+ expected_json = json.dumps(expected, ensure_ascii=False)
141
+
142
+ # Create user prompt with command + XML content
143
+ # The XML will be included in the user prompt text (as the agent does)
144
+ user_prompt = f"{command}\n\nXML Content:\n\n```xml\n{xml_content}\n```"
145
+
146
+ # For reflection, we don't need full XML - just the command is enough
147
+ # Reflection is about improving the prompt based on evaluation feedback,
148
+ # not analyzing specific XML structures
149
+ reflection_input = command # Just the command, no XML
150
+
151
+ # Create GEPA format item
152
+ gepa_item = {
153
+ "input": user_prompt, # Command + XML content (for evaluation)
154
+ "reflection_input": reflection_input, # Just command (for reflection)
155
+ "output": expected_json, # Expected output as JSON string
156
+ "image_base64": image_base64, # TOP LEVEL for UniversalConverter
157
+ "metadata": {
158
+ "command": command,
159
+ "image_path": str(image_path),
160
+ "xml_path": str(xml_path),
161
+ "abs_image_path": str(abs_image_path),
162
+ "abs_xml_path": str(abs_xml_path),
163
+ "xml_content": xml_content, # Store XML separately in metadata
164
+ "expected": expected,
165
+ "dataset_index": idx
166
+ }
167
+ }
168
+
169
+ gepa_dataset.append(gepa_item)
170
+
171
+ return gepa_dataset
172
+
173
+ def load_split(
174
+ self,
175
+ train_ratio: float = 0.6,
176
+ val_ratio: float = 0.4
177
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
178
+ """
179
+ Load dataset and split into train/val sets (no test set).
180
+
181
+ Args:
182
+ train_ratio: Ratio for training set (default: 0.6)
183
+ val_ratio: Ratio for validation set (default: 0.4)
184
+
185
+ Returns:
186
+ Tuple of (train_set, val_set)
187
+
188
+ Raises:
189
+ ValueError: If ratios don't sum to 1.0
190
+ """
191
+ if abs(train_ratio + val_ratio - 1.0) > 0.01:
192
+ raise ValueError(
193
+ f"Split ratios must sum to 1.0, got {train_ratio + val_ratio:.3f}"
194
+ )
195
+
196
+ dataset = self.load_dataset()
197
+ total = len(dataset)
198
+
199
+ train_end = int(total * train_ratio)
200
+
201
+ train_set = dataset[:train_end]
202
+ val_set = dataset[train_end:]
203
+
204
+ return train_set, val_set
205
+
206
+
207
+ def load_index_caching_dataset(
208
+ json_path: Optional[str] = None,
209
+ base_dir: Optional[str] = None
210
+ ) -> List[Dict[str, Any]]:
211
+ """
212
+ Convenience function to load index caching dataset.
213
+
214
+ Args:
215
+ json_path: Path to JSON file
216
+ base_dir: Base directory for resolving relative paths
217
+
218
+ Returns:
219
+ List of dataset items in GEPA format
220
+ """
221
+ loader = IndexCachingDatasetLoader(json_path=json_path, base_dir=base_dir)
222
+ return loader.load_dataset()
223
+
224
+
225
+ def load_index_caching_split(
226
+ json_path: Optional[str] = None,
227
+ base_dir: Optional[str] = None,
228
+ train_ratio: float = 0.6,
229
+ val_ratio: float = 0.4
230
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
231
+ """
232
+ Convenience function to load and split index caching dataset.
233
+
234
+ Args:
235
+ json_path: Path to JSON file
236
+ base_dir: Base directory for resolving relative paths
237
+ train_ratio: Ratio for training set
238
+ val_ratio: Ratio for validation set
239
+
240
+ Returns:
241
+ Tuple of (train_set, val_set) - no test set
242
+ """
243
+ loader = IndexCachingDatasetLoader(json_path=json_path, base_dir=base_dir)
244
+ return loader.load_split(train_ratio=train_ratio, val_ratio=val_ratio)
245
+
246
+
247
+ # Example usage
248
+ if __name__ == "__main__":
249
+ print("🚀 Testing Index Caching Dataset Loader...")
250
+
251
+ # Test loading
252
+ try:
253
+ loader = IndexCachingDatasetLoader(json_path="./note2_debug.json")
254
+ dataset = loader.load_dataset()
255
+
256
+ print(f"\n✅ Loaded {len(dataset)} items")
257
+
258
+ # Show sample
259
+ if dataset:
260
+ sample = dataset[0]
261
+ print(f"\n📝 Sample Item:")
262
+ print(f" Command: {sample['input']}")
263
+ print(f" Image path: {sample['metadata']['image_path']}")
264
+ print(f" XML path: {sample['metadata']['xml_path']}")
265
+ print(f" Expected: {sample['output'][:100]}...")
266
+ print(f" Image base64 length: {len(sample['image_base64'])}")
267
+ print(f" XML content length: {len(sample['metadata'].get('xml_content', ''))}")
268
+
269
+ # Test split
270
+ train, val = loader.load_split()
271
+ print(f"\n📊 Dataset Split:")
272
+ print(f" Training: {len(train)} samples")
273
+ print(f" Validation: {len(val)} samples")
274
+ print(f" Test: Not used (no test set)")
275
+
276
+ except Exception as e:
277
+ print(f"❌ Error: {e}")
278
+
src/gepa_optimizer/data/loaders.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for various file formats
3
+ """
4
+
5
+ import json
6
+ import base64
7
+ import pandas as pd
8
+ from typing import Any, Optional, Union, List , Dict
9
+ from pathlib import Path
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class DataLoader:
15
+ """
16
+ Utility class for loading data from various sources
17
+ """
18
+
19
+ def __init__(self):
20
+ self.supported_formats = [
21
+ '.csv', '.json', '.jsonl', '.txt', '.md', '.xlsx',
22
+ '.png', '.jpg', '.jpeg'
23
+ ]
24
+
25
+ def load(self, source: Union[str, Path], format_hint: Optional[str] = None) -> Optional[Any]:
26
+ """
27
+ Load data from any supported source
28
+
29
+ Args:
30
+ source: File path or data source
31
+ format_hint: Optional format hint to override auto-detection
32
+
33
+ Returns:
34
+ Loaded data or None if failed
35
+ """
36
+ try:
37
+ path = Path(source)
38
+
39
+ if not path.exists():
40
+ logger.error(f"File not found: {source}")
41
+ return None
42
+
43
+ # Use format hint or detect from extension
44
+ file_format = format_hint or path.suffix.lower()
45
+
46
+ if file_format == '.csv':
47
+ return self.load_csv(path)
48
+ elif file_format == '.json':
49
+ return self.load_json(path)
50
+ elif file_format == '.jsonl':
51
+ return self.load_jsonl(path)
52
+ elif file_format in ['.txt', '.md']:
53
+ return self.load_text(path)
54
+ elif file_format == '.xlsx':
55
+ return self.load_excel(path)
56
+ elif file_format in ['.png', '.jpg', '.jpeg']:
57
+ return self.load_image_base64(path)
58
+ else:
59
+ logger.warning(f"Unsupported format: {file_format}")
60
+ return None
61
+
62
+ except Exception as e:
63
+ logger.error(f"Failed to load data from {source}: {str(e)}")
64
+ return None
65
+
66
+ def load_csv(self, path: Union[str, Path]) -> Optional[pd.DataFrame]:
67
+ """Load CSV file as pandas DataFrame"""
68
+ try:
69
+ df = pd.read_csv(path)
70
+ logger.info(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
71
+ return df
72
+ except Exception as e:
73
+ logger.error(f"Failed to load CSV {path}: {str(e)}")
74
+ return None
75
+
76
+ def load_json(self, path: Union[str, Path]) -> Optional[Any]:
77
+ """Load JSON file"""
78
+ try:
79
+ with open(path, 'r', encoding='utf-8') as f:
80
+ data = json.load(f)
81
+
82
+ if isinstance(data, list):
83
+ logger.info(f"Loaded JSON with {len(data)} items")
84
+ else:
85
+ logger.info("Loaded JSON object")
86
+
87
+ return data
88
+ except Exception as e:
89
+ logger.error(f"Failed to load JSON {path}: {str(e)}")
90
+ return None
91
+
92
+ def load_jsonl(self, path: Union[str, Path]) -> Optional[List[Dict]]:
93
+ """Load JSONL (JSON Lines) file"""
94
+ try:
95
+ data = []
96
+ with open(path, 'r', encoding='utf-8') as f:
97
+ for line_num, line in enumerate(f, 1):
98
+ line = line.strip()
99
+ if line:
100
+ try:
101
+ data.append(json.loads(line))
102
+ except json.JSONDecodeError as e:
103
+ logger.warning(f"Invalid JSON on line {line_num}: {str(e)}")
104
+
105
+ logger.info(f"Loaded JSONL with {len(data)} items")
106
+ return data
107
+ except Exception as e:
108
+ logger.error(f"Failed to load JSONL {path}: {str(e)}")
109
+ return None
110
+
111
+ def load_text(self, path: Union[str, Path]) -> Optional[str]:
112
+ """Load plain text file"""
113
+ try:
114
+ with open(path, 'r', encoding='utf-8') as f:
115
+ content = f.read()
116
+
117
+ logger.info(f"Loaded text file with {len(content)} characters")
118
+ return content
119
+ except Exception as e:
120
+ logger.error(f"Failed to load text {path}: {str(e)}")
121
+ return None
122
+
123
+ def load_excel(self, path: Union[str, Path]) -> Optional[pd.DataFrame]:
124
+ """Load Excel file as pandas DataFrame"""
125
+ try:
126
+ df = pd.read_excel(path)
127
+ logger.info(f"Loaded Excel with {len(df)} rows and {len(df.columns)} columns")
128
+ return df
129
+ except Exception as e:
130
+ logger.error(f"Failed to load Excel {path}: {str(e)}")
131
+ return None
132
+
133
+ def load_image_base64(self, path: Union[str, Path]) -> Optional[str]:
134
+ """Load image file and encode as Base64 string"""
135
+ try:
136
+ with open(path, 'rb') as f:
137
+ encoded_string = base64.b64encode(f.read()).decode('utf-8')
138
+ logger.info(f"Loaded image {path} and encoded to Base64")
139
+ return encoded_string
140
+ except Exception as e:
141
+ logger.error(f"Failed to load image {path}: {str(e)}")
142
+ return None
143
+
144
+ def is_supported_format(self, file_path: Union[str, Path]) -> bool:
145
+ """Check if file format is supported"""
146
+ path = Path(file_path)
147
+ return path.suffix.lower() in self.supported_formats
148
+
149
+ def get_file_info(self, file_path: Union[str, Path]) -> Dict[str, Any]:
150
+ """Get information about a file"""
151
+ path = Path(file_path)
152
+
153
+ if not path.exists():
154
+ return {'exists': False}
155
+
156
+ return {
157
+ 'exists': True,
158
+ 'size': path.stat().st_size,
159
+ 'format': path.suffix.lower(),
160
+ 'supported': self.is_supported_format(path),
161
+ 'name': path.name,
162
+ 'stem': path.stem,
163
+ 'parent': str(path.parent)
164
+ }
165
+
166
+ def load_ui_tree_dataset(self, json_dir: str, screenshots_dir: str) -> List[Dict[str, Any]]:
167
+ """
168
+ Load UI tree dataset by pairing JSON files with corresponding screenshots
169
+
170
+ Args:
171
+ json_dir: Directory containing JSON files (e.g., "json_tree")
172
+ screenshots_dir: Directory containing screenshot images (e.g., "screenshots")
173
+
174
+ Returns:
175
+ List of dictionaries with 'input', 'output', and 'image' keys
176
+ """
177
+ json_path = Path(json_dir)
178
+ screenshots_path = Path(screenshots_dir)
179
+
180
+ if not json_path.exists():
181
+ raise FileNotFoundError(f"JSON directory not found: {json_dir}")
182
+ if not screenshots_path.exists():
183
+ raise FileNotFoundError(f"Screenshots directory not found: {screenshots_dir}")
184
+
185
+ dataset = []
186
+
187
+ # Get all JSON files
188
+ json_files = list(json_path.glob("*.json"))
189
+ logger.info(f"Found {len(json_files)} JSON files in {json_dir}")
190
+
191
+ for json_file in json_files:
192
+ # Extract filename without extension (e.g., "2" from "2.json")
193
+ file_stem = json_file.stem
194
+
195
+ # Look for corresponding image file
196
+ image_extensions = ['.jpg', '.jpeg', '.png']
197
+ image_file = None
198
+
199
+ for ext in image_extensions:
200
+ potential_image = screenshots_path / f"{file_stem}{ext}"
201
+ if potential_image.exists():
202
+ image_file = potential_image
203
+ break
204
+
205
+ if not image_file:
206
+ logger.warning(f"No corresponding image found for {json_file.name}")
207
+ continue
208
+
209
+ try:
210
+ # Load JSON content
211
+ json_data = self.load_json(json_file)
212
+ if not json_data:
213
+ logger.warning(f"Failed to load JSON: {json_file}")
214
+ continue
215
+
216
+ # Load image as base64
217
+ image_base64 = self.load_image_base64(image_file)
218
+ if not image_base64:
219
+ logger.warning(f"Failed to load image: {image_file}")
220
+ continue
221
+
222
+ # Create dataset entry
223
+ dataset_entry = {
224
+ 'input': 'Extract UI elements from this screenshot and provide the complete UI tree structure',
225
+ 'output': json.dumps(json_data, indent=2), # Convert JSON to string
226
+ 'image': image_base64
227
+ }
228
+
229
+ dataset.append(dataset_entry)
230
+ logger.debug(f"Loaded pair: {json_file.name} + {image_file.name}")
231
+
232
+ except Exception as e:
233
+ logger.error(f"Error loading {json_file.name}: {str(e)}")
234
+ continue
235
+
236
+ logger.info(f"Successfully loaded {len(dataset)} image-JSON pairs")
237
+ return dataset
src/gepa_optimizer/data/scroll_dataset_loader.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scroll Element Dataset Loader for Drizz Mobile App Testing
3
+
4
+ Loads screenshots with bounding boxes and commands to identify scroll elements.
5
+ Converts to GEPA-compatible format for prompt optimization.
6
+ """
7
+
8
+ import base64
9
+ import random
10
+ import logging
11
+ from typing import List, Dict, Any, Tuple, Optional
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ScrollDatasetLoader:
18
+ """
19
+ GENERIC dataset loader for image-based tasks.
20
+
21
+ This is a LIBRARY class - NO hardcoded assumptions about:
22
+ - What the task is (OCR, element detection, classification, etc.)
23
+ - Input format (questions, commands, descriptions, etc.)
24
+ - Output format (IDs, text, JSON, etc.)
25
+
26
+ Users define their dataset in the test script and pass it here.
27
+
28
+ Dataset format per item: (image_filename, input_text, expected_output)
29
+
30
+ Example usage (ANY task):
31
+ # Define YOUR dataset in YOUR test script
32
+ my_dataset = [
33
+ ("img1.png", "What is the main color?", "blue"),
34
+ ("img2.png", "Count the objects", "5"),
35
+ ("img3.png", "Describe the scene", "A cat on a sofa"),
36
+ ]
37
+
38
+ # Pass to loader
39
+ loader = ScrollDatasetLoader(
40
+ images_dir="images",
41
+ dataset_config=my_dataset
42
+ )
43
+ data = loader.load_dataset()
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ images_dir: str = "images",
49
+ dataset_config: Optional[List[Tuple[str, str, str]]] = None
50
+ ):
51
+ """
52
+ Initialize dataset loader.
53
+
54
+ Args:
55
+ images_dir: Directory containing images
56
+ dataset_config: List of (image_filename, input_text, expected_output) tuples.
57
+ REQUIRED - no hardcoded defaults to keep library generic.
58
+
59
+ Raises:
60
+ FileNotFoundError: If images_dir doesn't exist
61
+ ValueError: If dataset_config is None
62
+ """
63
+ self.images_dir = Path(images_dir)
64
+
65
+ if not self.images_dir.exists():
66
+ raise FileNotFoundError(f"Images directory not found: {images_dir}")
67
+
68
+ if dataset_config is None:
69
+ raise ValueError(
70
+ "dataset_config is required. This is a library class - define your "
71
+ "dataset in the test script:\n"
72
+ " dataset = [('img1.png', 'your input', 'expected output'), ...]\n"
73
+ " loader = ScrollDatasetLoader(images_dir='...', dataset_config=dataset)"
74
+ )
75
+
76
+ self.dataset_config = dataset_config
77
+
78
+ def load_dataset(self) -> List[Dict[str, Any]]:
79
+ """
80
+ Load complete dataset with images.
81
+
82
+ Phase 1: Includes element_id extraction from expected output.
83
+
84
+ Returns:
85
+ List of dataset items in GEPA format:
86
+ [
87
+ {
88
+ "input": "Command: Scroll down by 70%",
89
+ "output": "3",
90
+ "image_base64": "<base64_encoded_image>", # TOP LEVEL
91
+ "metadata": {
92
+ "image_path": "images/5.png",
93
+ "input_text": "Command: Scroll down by 70%",
94
+ "expected_output": "3",
95
+ "image_filename": "5.png",
96
+ "element_id": 3 # Extracted integer (None if extraction fails)
97
+ }
98
+ },
99
+ ...
100
+ ]
101
+ """
102
+ dataset = []
103
+
104
+ # Generic variable names - no assumptions about data type
105
+ for image_filename, input_text, expected_output in self.dataset_config:
106
+ image_path = self.images_dir / image_filename
107
+
108
+ # Validate image exists
109
+ if not image_path.exists():
110
+ logger.warning(f"Image not found: {image_path}")
111
+ continue
112
+
113
+ # Read and encode image
114
+ try:
115
+ image_base64 = self._encode_image(image_path)
116
+ except Exception as e:
117
+ logger.warning(f"Error encoding {image_filename}: {e}")
118
+ continue
119
+
120
+ # 🔥 Phase 1: Extract element_id from expected_output for robust evaluation
121
+ element_id = self._extract_element_id(expected_output)
122
+ if element_id is None:
123
+ logger.warning(f"Could not extract element_id from '{expected_output}' in {image_filename}")
124
+
125
+ # Create dataset item - COMPLETELY GENERIC
126
+ # NO assumptions about output format (element IDs, commands, etc.)
127
+ # Just: image + input text + expected output text
128
+ # Library doesn't know or care what the task is!
129
+ # IMPORTANT: Put image_base64 at TOP LEVEL for UniversalConverter to find it
130
+ dataset_item = {
131
+ "input": input_text, # Generic input text (ANY format)
132
+ "output": expected_output, # Generic expected output (ANY format, full reasoning)
133
+ "image_base64": image_base64, # TOP LEVEL for converter
134
+ "metadata": {
135
+ "image_path": str(image_path),
136
+ "input_text": input_text,
137
+ "expected_output": expected_output,
138
+ "image_filename": image_filename,
139
+ "element_id": element_id # NEW: Extracted element ID (int or None)
140
+ }
141
+ }
142
+
143
+ dataset.append(dataset_item)
144
+
145
+ if not dataset:
146
+ raise ValueError("No valid images found in dataset")
147
+
148
+ logger.info(f"Loaded {len(dataset)} scroll element detection samples")
149
+ return dataset
150
+
151
+ def _extract_element_id(self, expected_output: str) -> Optional[int]:
152
+ """
153
+ Extract element ID from expected output string.
154
+
155
+ Handles multiple formats:
156
+ - "Element: 4"
157
+ - "Element 4"
158
+ - "4" (standalone)
159
+ - "Element: 4, Description: ..." (full reasoning)
160
+
161
+ Args:
162
+ expected_output: Full expected output string with reasoning
163
+
164
+ Returns:
165
+ Element ID as integer, or None if not found
166
+ """
167
+ import re
168
+
169
+ if not expected_output:
170
+ return None
171
+
172
+ # Pattern 1: "Element: X" or "Element X" (case insensitive)
173
+ patterns = [
174
+ r'element[:\s]+(\d+)', # "Element: 4" or "Element 4"
175
+ r'\belement\s+(\d+)\b', # "element 4" (word boundary)
176
+ ]
177
+
178
+ for pattern in patterns:
179
+ match = re.search(pattern, expected_output, re.IGNORECASE)
180
+ if match:
181
+ try:
182
+ element_id = int(match.group(1))
183
+ # Validate range (reasonable UI element IDs)
184
+ if 1 <= element_id <= 100:
185
+ return element_id
186
+ except (ValueError, IndexError):
187
+ continue
188
+
189
+ # Pattern 2: First standalone number (if no "Element:" pattern found)
190
+ # Only use if it's a reasonable element ID (1-100)
191
+ number_match = re.search(r'\b(\d{1,3})\b', expected_output)
192
+ if number_match:
193
+ try:
194
+ element_id = int(number_match.group(1))
195
+ if 1 <= element_id <= 100: # Reasonable range for UI elements
196
+ return element_id
197
+ except ValueError:
198
+ pass
199
+
200
+ return None
201
+
202
+ def _encode_image(self, image_path: Path) -> str:
203
+ """
204
+ Encode image to base64 string.
205
+
206
+ Args:
207
+ image_path: Path to image file
208
+
209
+ Returns:
210
+ Base64 encoded image string
211
+ """
212
+ with open(image_path, "rb") as image_file:
213
+ encoded = base64.b64encode(image_file.read()).decode('utf-8')
214
+ return encoded
215
+
216
+ def split_dataset(
217
+ self,
218
+ dataset: List[Dict[str, Any]],
219
+ train_size: int = 4,
220
+ val_size: int = 1,
221
+ test_size: int = 1,
222
+ shuffle: bool = True,
223
+ seed: Optional[int] = None
224
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
225
+ """
226
+ Split dataset into train, validation, and test sets.
227
+
228
+ 🔥 NEW: Added shuffling support to ensure different image distribution
229
+ across splits, preventing hard images from always landing in validation set.
230
+
231
+ Args:
232
+ dataset: Complete dataset
233
+ train_size: Number of samples for training (default: 4)
234
+ val_size: Number of samples for validation (default: 1)
235
+ test_size: Number of samples for test (default: 1)
236
+ shuffle: Whether to shuffle dataset before splitting (default: True)
237
+ seed: Random seed for reproducible shuffling (default: None = random)
238
+
239
+ Returns:
240
+ Tuple of (train_set, val_set, test_set)
241
+ """
242
+ n = len(dataset)
243
+
244
+ # Validate split sizes
245
+ total_size = train_size + val_size + test_size
246
+ if total_size > n:
247
+ logger.warning(f"Requested split ({total_size}) exceeds dataset size ({n}). Adjusting split proportionally...")
248
+ ratio = n / total_size
249
+ train_size = int(train_size * ratio)
250
+ val_size = int(val_size * ratio)
251
+ test_size = n - train_size - val_size
252
+
253
+ # 🔥 CRITICAL: Shuffle dataset to ensure different image distribution
254
+ # This prevents the same hard images from always being in validation set
255
+ dataset_copy = dataset.copy() # Don't modify original
256
+ if shuffle:
257
+ if seed is not None:
258
+ random.seed(seed)
259
+ logger.debug(f"Shuffling dataset with seed={seed} for reproducible splits")
260
+ else:
261
+ logger.debug(f"Shuffling dataset randomly (no seed)")
262
+ random.shuffle(dataset_copy)
263
+ else:
264
+ logger.warning(f"Not shuffling dataset - using original order")
265
+
266
+ # Split shuffled dataset
267
+ train_set = dataset_copy[:train_size]
268
+ val_set = dataset_copy[train_size:train_size + val_size]
269
+ test_set = dataset_copy[train_size + val_size:train_size + val_size + test_size]
270
+
271
+ logger.info(f"Dataset split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test")
272
+
273
+ # Log which images are in each split for debugging
274
+ if shuffle:
275
+ train_images = [item['metadata'].get('image_filename', 'N/A') for item in train_set]
276
+ val_images = [item['metadata'].get('image_filename', 'N/A') for item in val_set]
277
+ test_images = [item['metadata'].get('image_filename', 'N/A') for item in test_set]
278
+ print(f" Train images: {train_images[:5]}{'...' if len(train_images) > 5 else ''}")
279
+ print(f" Val images: {val_images}")
280
+ print(f" Test images: {test_images[:5]}{'...' if len(test_images) > 5 else ''}")
281
+
282
+ return train_set, val_set, test_set
283
+
284
+
285
+ def load_scroll_dataset(
286
+ images_dir: str = "images",
287
+ dataset_config: List[Tuple[str, str, str]] = None,
288
+ split: bool = True
289
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
290
+ """
291
+ Convenience function to load image-based dataset (GENERIC).
292
+
293
+ Args:
294
+ images_dir: Directory containing images
295
+ dataset_config: List of (image_filename, input_text, expected_output) tuples
296
+ split: Whether to split into train/val/test
297
+
298
+ Returns:
299
+ If split=True: (train_set, val_set, test_set)
300
+ If split=False: (full_dataset, [], [])
301
+
302
+ Example (works for ANY task):
303
+ dataset_config = [
304
+ ("img1.png", "What color is the sky?", "blue"),
305
+ ("img2.png", "Count the dogs", "2"),
306
+ ]
307
+ train, val, test = load_scroll_dataset(
308
+ images_dir="images",
309
+ dataset_config=dataset_config
310
+ )
311
+ """
312
+ loader = ScrollDatasetLoader(images_dir, dataset_config=dataset_config)
313
+ dataset = loader.load_dataset()
314
+
315
+ if split:
316
+ return loader.split_dataset(dataset)
317
+ else:
318
+ return dataset, [], []
319
+
320
+
321
+ # Example usage (for testing the library loader itself)
322
+ if __name__ == "__main__":
323
+ print("🚀 Testing Scroll Dataset Loader...")
324
+ print("⚠️ NOTE: This is a library class. Define your dataset in your test script.")
325
+ print("\nExample:")
326
+ print(" dataset_config = [")
327
+ print(" ('image1.png', 'Scroll down by 50%', '3'),")
328
+ print(" ('image2.png', 'Swipe left', '4'),")
329
+ print(" ]")
330
+ print(" train, val, test = load_scroll_dataset(")
331
+ print(" images_dir='images',")
332
+ print(" dataset_config=dataset_config")
333
+ print(" )")
334
+
src/gepa_optimizer/data/validation_dataset_loader.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validation Dataset Loader for UI Validation Use Case
3
+
4
+ Loads validation datapoints from SQLite database and converts to GEPA-compatible format.
5
+ Supports filtering by data_type (trainset/valset/testset) and confirmed status.
6
+ """
7
+
8
+ import os
9
+ import sqlite3
10
+ import base64
11
+ import logging
12
+ from typing import List, Dict, Any, Optional, Literal
13
+ from pathlib import Path
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ValidationDatasetLoader:
19
+ """
20
+ Loads validation dataset from SQLite database.
21
+
22
+ Database schema:
23
+ - validation_data: id, image_id, command, result (0/1), reasoning, data_type, confirmed, created_at
24
+ - images: image_id, mime, bytes (BLOB), created_at
25
+
26
+ Converts to GEPA format:
27
+ - input: command text (seed prompt will be provided in test script)
28
+ - output: "true" or "false" (converted from 0/1)
29
+ - image_base64: base64 encoded image (TOP LEVEL for UniversalConverter)
30
+ - metadata: All original fields plus converted values
31
+
32
+ Note: The seed prompt is NOT stored in database - it will be provided in the test script.
33
+ The input field contains just the command, and the image is at top level.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ db_path: Optional[str] = None,
39
+ confirmed_only: bool = True
40
+ ):
41
+ """
42
+ Initialize validation dataset loader.
43
+
44
+ Args:
45
+ db_path: Path to SQLite database file.
46
+ Default: "./validation_data.db" or from VD_DB_PATH env var
47
+ confirmed_only: If True, only load datapoints where confirmed=1.
48
+ Default: True (only manually reviewed data)
49
+
50
+ Raises:
51
+ FileNotFoundError: If database file doesn't exist
52
+ sqlite3.Error: If database connection fails
53
+ """
54
+ # Get database path from env or use default
55
+ if db_path is None:
56
+ db_path = os.getenv("VD_DB_PATH", "./validation_data.db")
57
+
58
+ self.db_path = Path(db_path).resolve()
59
+
60
+ if not self.db_path.exists():
61
+ raise FileNotFoundError(
62
+ f"Database file not found: {self.db_path}\n"
63
+ f"Make sure validation_data_ui_server_async.py has been run at least once to create the database."
64
+ )
65
+
66
+ self.confirmed_only = confirmed_only
67
+
68
+ def load_dataset(
69
+ self,
70
+ data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
71
+ confirmed_only: Optional[bool] = None
72
+ ) -> List[Dict[str, Any]]:
73
+ """
74
+ Load dataset from database and convert to GEPA format.
75
+
76
+ Args:
77
+ data_type: Filter by data_type. If None, loads all types.
78
+ Options: "trainset", "valset", "testset"
79
+ confirmed_only: Override instance default. If True, only load confirmed datapoints.
80
+ If None, uses instance default (self.confirmed_only)
81
+
82
+ Returns:
83
+ List of dataset items in GEPA format:
84
+ [
85
+ {
86
+ "input": "Validate Submit button is visible", # Command only (seed prompt in test script)
87
+ "output": "true", # or "false" (converted from 0/1)
88
+ "image_base64": "<base64_encoded_image>", # TOP LEVEL (image + command together)
89
+ "metadata": {
90
+ "id": 1,
91
+ "image_id": "abc123...",
92
+ "command": "Validate Submit button is visible",
93
+ "result": True, # Boolean
94
+ "result_int": 1, # Original 0/1
95
+ "reasoning": "Detailed explanation...",
96
+ "data_type": "trainset",
97
+ "confirmed": True,
98
+ "created_at": "2024-01-01 12:00:00"
99
+ }
100
+ },
101
+ ...
102
+ ]
103
+
104
+ Note: Seed prompt is provided separately in test script, not in database.
105
+
106
+ Raises:
107
+ sqlite3.Error: If database query fails
108
+ ValueError: If no datapoints found matching criteria
109
+ """
110
+ # Use provided confirmed_only or instance default
111
+ use_confirmed = confirmed_only if confirmed_only is not None else self.confirmed_only
112
+
113
+ conn = sqlite3.connect(str(self.db_path))
114
+ conn.row_factory = sqlite3.Row # Access columns by name
115
+ dataset = []
116
+
117
+ try:
118
+ # Build query with filters
119
+ query = """
120
+ SELECT
121
+ v.id,
122
+ v.image_id,
123
+ v.command,
124
+ v.result,
125
+ v.reasoning,
126
+ v.data_type,
127
+ v.confirmed,
128
+ v.created_at,
129
+ i.mime,
130
+ i.bytes
131
+ FROM validation_data v
132
+ INNER JOIN images i ON v.image_id = i.image_id
133
+ WHERE 1=1
134
+ """
135
+ params = []
136
+
137
+ # Add filters
138
+ if use_confirmed:
139
+ query += " AND v.confirmed = 1"
140
+
141
+ if data_type:
142
+ query += " AND v.data_type = ?"
143
+ params.append(data_type)
144
+
145
+ query += " ORDER BY v.id ASC"
146
+
147
+ # Execute query
148
+ cursor = conn.execute(query, params)
149
+ rows = cursor.fetchall()
150
+
151
+ if not rows:
152
+ filter_msg = []
153
+ if use_confirmed:
154
+ filter_msg.append("confirmed=1")
155
+ if data_type:
156
+ filter_msg.append(f"data_type='{data_type}'")
157
+
158
+ filter_str = " with filters: " + ", ".join(filter_msg) if filter_msg else ""
159
+ raise ValueError(
160
+ f"No datapoints found{filter_str} in database: {self.db_path}\n"
161
+ f"Make sure you have generated and saved datapoints using the validation UI."
162
+ )
163
+
164
+ # Convert rows to GEPA format
165
+ for row in rows:
166
+ # Convert 0/1 to "true"/"false" string for GEPA
167
+ result_str = "true" if row["result"] == 1 else "false"
168
+
169
+ # Encode image bytes to base64
170
+ image_base64 = base64.b64encode(row["bytes"]).decode("utf-8")
171
+
172
+ # Create GEPA format item
173
+ # Input: command (seed prompt will be provided in test script)
174
+ # Image: separate at top level (image_base64)
175
+ # Output: "true" or "false" (converted from 0/1)
176
+ dataset_item = {
177
+ "input": row["command"], # Just the command - seed prompt will be in test script
178
+ "output": result_str, # "true" or "false" (string)
179
+ "image_base64": image_base64, # TOP LEVEL for UniversalConverter (image + command together)
180
+ "metadata": {
181
+ "id": row["id"],
182
+ "image_id": row["image_id"],
183
+ "command": row["command"], # Keep original for reference
184
+ "result": bool(row["result"]), # Boolean for reference
185
+ "result_int": row["result"], # Original 0/1 for reference
186
+ "reasoning": row["reasoning"],
187
+ "data_type": row["data_type"],
188
+ "confirmed": bool(row["confirmed"]),
189
+ "created_at": row["created_at"],
190
+ "mime": row["mime"],
191
+ }
192
+ }
193
+
194
+ dataset.append(dataset_item)
195
+
196
+ # Log summary
197
+ data_type_str = f" ({data_type})" if data_type else ""
198
+ confirmed_str = " (confirmed only)" if use_confirmed else " (all)"
199
+ logger.info(f"Loaded {len(dataset)} validation datapoints{data_type_str}{confirmed_str}")
200
+
201
+ return dataset
202
+
203
+ finally:
204
+ conn.close()
205
+
206
+ def load_split_dataset(
207
+ self,
208
+ confirmed_only: Optional[bool] = None
209
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
210
+ """
211
+ Load dataset split by data_type (trainset/valset/testset).
212
+
213
+ Convenience method that loads all three splits at once.
214
+
215
+ Args:
216
+ confirmed_only: Override instance default. If True, only load confirmed datapoints.
217
+
218
+ Returns:
219
+ Tuple of (train_set, val_set, test_set) in GEPA format
220
+
221
+ Example:
222
+ loader = ValidationDatasetLoader(db_path="./validation_data.db")
223
+ train, val, test = loader.load_split_dataset()
224
+ """
225
+ train_set = self.load_dataset(data_type="trainset", confirmed_only=confirmed_only)
226
+ val_set = self.load_dataset(data_type="valset", confirmed_only=confirmed_only)
227
+ test_set = self.load_dataset(data_type="testset", confirmed_only=confirmed_only)
228
+
229
+ logger.info(f"Dataset Split Summary: Training={len(train_set)}, Validation={len(val_set)}, Test={len(test_set)}, Total={len(train_set) + len(val_set) + len(test_set)}")
230
+
231
+ return train_set, val_set, test_set
232
+
233
+ def get_dataset_stats(self) -> Dict[str, Any]:
234
+ """
235
+ Get statistics about the dataset in the database.
236
+
237
+ Returns:
238
+ Dictionary with dataset statistics:
239
+ {
240
+ "total": 100,
241
+ "confirmed": 95,
242
+ "unconfirmed": 5,
243
+ "by_data_type": {
244
+ "trainset": 70,
245
+ "valset": 15,
246
+ "testset": 15
247
+ },
248
+ "by_result": {
249
+ "true": 50,
250
+ "false": 50
251
+ }
252
+ }
253
+ """
254
+ conn = sqlite3.connect(str(self.db_path))
255
+ conn.row_factory = sqlite3.Row
256
+
257
+ try:
258
+ stats = {}
259
+
260
+ # Total counts
261
+ total = conn.execute("SELECT COUNT(*) FROM validation_data").fetchone()[0]
262
+ confirmed = conn.execute("SELECT COUNT(*) FROM validation_data WHERE confirmed = 1").fetchone()[0]
263
+ stats["total"] = total
264
+ stats["confirmed"] = confirmed
265
+ stats["unconfirmed"] = total - confirmed
266
+
267
+ # By data_type
268
+ data_type_rows = conn.execute("""
269
+ SELECT data_type, COUNT(*) as count
270
+ FROM validation_data
271
+ GROUP BY data_type
272
+ """).fetchall()
273
+ stats["by_data_type"] = {row["data_type"]: row["count"] for row in data_type_rows}
274
+
275
+ # By result (true/false)
276
+ result_rows = conn.execute("""
277
+ SELECT result, COUNT(*) as count
278
+ FROM validation_data
279
+ GROUP BY result
280
+ """).fetchall()
281
+ stats["by_result"] = {
282
+ "true": sum(row["count"] for row in result_rows if row["result"] == 1),
283
+ "false": sum(row["count"] for row in result_rows if row["result"] == 0)
284
+ }
285
+
286
+ return stats
287
+
288
+ finally:
289
+ conn.close()
290
+
291
+
292
+ def load_validation_dataset(
293
+ db_path: Optional[str] = None,
294
+ data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
295
+ confirmed_only: bool = True
296
+ ) -> List[Dict[str, Any]]:
297
+ """
298
+ Convenience function to load validation dataset.
299
+
300
+ Args:
301
+ db_path: Path to SQLite database file. Default: "./validation_data.db"
302
+ data_type: Filter by data_type. If None, loads all types.
303
+ confirmed_only: If True, only load confirmed datapoints.
304
+
305
+ Returns:
306
+ List of dataset items in GEPA format
307
+
308
+ Example:
309
+ # Load all confirmed training data
310
+ train_data = load_validation_dataset(data_type="trainset", confirmed_only=True)
311
+
312
+ # Load all confirmed data
313
+ all_data = load_validation_dataset(confirmed_only=True)
314
+ """
315
+ loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
316
+ return loader.load_dataset(data_type=data_type, confirmed_only=confirmed_only)
317
+
318
+
319
+ def load_validation_split(
320
+ db_path: Optional[str] = None,
321
+ confirmed_only: bool = True
322
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
323
+ """
324
+ Convenience function to load validation dataset split by data_type.
325
+
326
+ Args:
327
+ db_path: Path to SQLite database file. Default: "./validation_data.db"
328
+ confirmed_only: If True, only load confirmed datapoints.
329
+
330
+ Returns:
331
+ Tuple of (train_set, val_set, test_set) in GEPA format
332
+
333
+ Example:
334
+ train, val, test = load_validation_split(confirmed_only=True)
335
+ """
336
+ loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
337
+ return loader.load_split_dataset(confirmed_only=confirmed_only)
338
+
339
+
340
+ # Example usage and testing
341
+ if __name__ == "__main__":
342
+ print("🚀 Testing Validation Dataset Loader...")
343
+
344
+ try:
345
+ loader = ValidationDatasetLoader()
346
+
347
+ # Get stats
348
+ print("\n📊 Dataset Statistics:")
349
+ stats = loader.get_dataset_stats()
350
+ print(f" Total: {stats['total']}")
351
+ print(f" Confirmed: {stats['confirmed']}")
352
+ print(f" Unconfirmed: {stats['unconfirmed']}")
353
+ print(f" By data_type: {stats['by_data_type']}")
354
+ print(f" By result: {stats['by_result']}")
355
+
356
+ # Load split dataset
357
+ print("\n📦 Loading split dataset...")
358
+ train, val, test = loader.load_split_dataset()
359
+
360
+ # Show sample
361
+ if train:
362
+ sample = train[0]
363
+ print(f"\n📝 Sample Training Item:")
364
+ print(f" Input: {sample['input']}")
365
+ print(f" Output: {sample['output']}")
366
+ print(f" Image ID: {sample['metadata']['image_id'][:8]}...")
367
+ print(f" Data Type: {sample['metadata']['data_type']}")
368
+ print(f" Result: {sample['metadata']['result']} (int: {sample['metadata']['result_int']})")
369
+
370
+ except FileNotFoundError as e:
371
+ print(f"❌ {e}")
372
+ print("\n💡 Make sure validation_data_ui_server_async.py has been run to create the database.")
373
+ except ValueError as e:
374
+ print(f"❌ {e}")
375
+ print("\n💡 Generate and save some datapoints using the validation UI first.")
376
+
src/gepa_optimizer/data/validators.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data validation utilities for GEPA optimizer
3
+ """
4
+
5
+ from typing import List, Dict, Any, Optional, Tuple
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class DataValidator:
11
+ """
12
+ Validates datasets for completeness and GEPA compatibility
13
+ """
14
+
15
+ def __init__(self):
16
+ self.required_fields = ['input', 'output']
17
+ self.optional_fields = ['metadata', 'id', 'tags']
18
+
19
+ def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Tuple[bool, List[str]]:
20
+ """
21
+ Validate entire dataset
22
+
23
+ Args:
24
+ dataset: List of data items to validate
25
+
26
+ Returns:
27
+ Tuple[bool, List[str]]: (is_valid, list_of_errors)
28
+ """
29
+ errors = []
30
+
31
+ # Basic dataset checks
32
+ if not dataset:
33
+ errors.append("Dataset is empty")
34
+ return False, errors
35
+
36
+ if not isinstance(dataset, list):
37
+ errors.append("Dataset must be a list")
38
+ return False, errors
39
+
40
+ # Validate each item
41
+ for idx, item in enumerate(dataset):
42
+ item_errors = self.validate_item(item, idx)
43
+ errors.extend(item_errors)
44
+
45
+ # Check for minimum dataset size
46
+ if len(dataset) < 2:
47
+ errors.append("Dataset should have at least 2 items for proper train/val split")
48
+
49
+ # Log validation results
50
+ if errors:
51
+ logger.warning(f"Dataset validation failed with {len(errors)} errors")
52
+ else:
53
+ logger.info(f"Dataset validation passed for {len(dataset)} items")
54
+
55
+ return len(errors) == 0, errors
56
+
57
+ def validate_item(self, item: Dict[str, Any], index: Optional[int] = None) -> List[str]:
58
+ """
59
+ Validate a single dataset item
60
+
61
+ Args:
62
+ item: Single data item to validate
63
+ index: Optional item index for error reporting
64
+
65
+ Returns:
66
+ List[str]: List of validation errors
67
+ """
68
+ errors = []
69
+ item_ref = f"item {index}" if index is not None else "item"
70
+
71
+ # Check if item is a dictionary
72
+ if not isinstance(item, dict):
73
+ errors.append(f"{item_ref}: Must be a dictionary")
74
+ return errors
75
+
76
+ # Check for required fields
77
+ if 'input' not in item:
78
+ errors.append(f"{item_ref}: Missing required 'input' field")
79
+ elif not isinstance(item['input'], str):
80
+ errors.append(f"{item_ref}: 'input' field must be a string")
81
+ elif not item['input'].strip():
82
+ errors.append(f"{item_ref}: 'input' field cannot be empty")
83
+
84
+ # Check output field (can be empty but should exist for supervised learning)
85
+ if 'output' in item:
86
+ if not isinstance(item['output'], str):
87
+ errors.append(f"{item_ref}: 'output' field must be a string")
88
+
89
+ # Validate metadata if present
90
+ if 'metadata' in item and not isinstance(item['metadata'], dict):
91
+ errors.append(f"{item_ref}: 'metadata' field must be a dictionary")
92
+
93
+ return errors
94
+
95
+ def validate_gepa_format(self, gepa_data: List[Dict[str, Any]]) -> Tuple[bool, List[str]]:
96
+ """
97
+ Validate data in GEPA format
98
+
99
+ Args:
100
+ gepa_data: Data in GEPA format
101
+
102
+ Returns:
103
+ Tuple[bool, List[str]]: (is_valid, list_of_errors)
104
+ """
105
+ errors = []
106
+
107
+ if not gepa_data:
108
+ errors.append("GEPA dataset is empty")
109
+ return False, errors
110
+
111
+ for idx, item in enumerate(gepa_data):
112
+ if 'input' not in item:
113
+ errors.append(f"GEPA item {idx}: Missing 'input' field")
114
+
115
+ if 'expected_output' not in item:
116
+ errors.append(f"GEPA item {idx}: Missing 'expected_output' field")
117
+
118
+ if 'metadata' not in item:
119
+ errors.append(f"GEPA item {idx}: Missing 'metadata' field")
120
+ elif not isinstance(item['metadata'], dict):
121
+ errors.append(f"GEPA item {idx}: 'metadata' must be a dictionary")
122
+
123
+ return len(errors) == 0, errors
124
+
125
+ def validate_split(self, trainset: List[Dict], valset: List[Dict]) -> Tuple[bool, List[str]]:
126
+ """
127
+ Validate train/validation split
128
+
129
+ Args:
130
+ trainset: Training data
131
+ valset: Validation data
132
+
133
+ Returns:
134
+ Tuple[bool, List[str]]: (is_valid, list_of_errors)
135
+ """
136
+ errors = []
137
+
138
+ if not trainset:
139
+ errors.append("Training set is empty")
140
+
141
+ if not valset:
142
+ errors.append("Validation set is empty")
143
+
144
+ # Check proportions
145
+ total_size = len(trainset) + len(valset)
146
+ if total_size > 0:
147
+ train_ratio = len(trainset) / total_size
148
+ if train_ratio < 0.5:
149
+ errors.append(f"Training set too small: {train_ratio:.2%} of total data")
150
+ elif train_ratio > 0.95:
151
+ errors.append(f"Validation set too small: {1-train_ratio:.2%} of total data")
152
+
153
+ return len(errors) == 0, errors
154
+
155
+ def get_dataset_stats(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
156
+ """
157
+ Get statistics about the dataset
158
+
159
+ Args:
160
+ dataset: Dataset to analyze
161
+
162
+ Returns:
163
+ Dict[str, Any]: Dataset statistics
164
+ """
165
+ if not dataset:
166
+ return {'total_items': 0, 'valid': False}
167
+
168
+ stats = {
169
+ 'total_items': len(dataset),
170
+ 'has_output': sum(1 for item in dataset if item.get('output')),
171
+ 'avg_input_length': 0,
172
+ 'avg_output_length': 0,
173
+ 'empty_inputs': 0,
174
+ 'empty_outputs': 0
175
+ }
176
+
177
+ input_lengths = []
178
+ output_lengths = []
179
+
180
+ for item in dataset:
181
+ if isinstance(item, dict):
182
+ input_text = item.get('input', '')
183
+ output_text = item.get('output', '')
184
+
185
+ if isinstance(input_text, str):
186
+ input_lengths.append(len(input_text))
187
+ if not input_text.strip():
188
+ stats['empty_inputs'] += 1
189
+
190
+ if isinstance(output_text, str):
191
+ output_lengths.append(len(output_text))
192
+ if not output_text.strip():
193
+ stats['empty_outputs'] += 1
194
+
195
+ if input_lengths:
196
+ stats['avg_input_length'] = sum(input_lengths) / len(input_lengths)
197
+
198
+ if output_lengths:
199
+ stats['avg_output_length'] = sum(output_lengths) / len(output_lengths)
200
+
201
+ # Determine if dataset looks valid
202
+ stats['valid'] = (
203
+ stats['total_items'] > 0 and
204
+ stats['empty_inputs'] < stats['total_items'] * 0.5 # Less than 50% empty inputs
205
+ )
206
+
207
+ return stats
src/gepa_optimizer/evaluation/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation module for GEPA Optimizer
3
+
4
+ Includes:
5
+ - UniversalSemanticEvaluator: Works for ANY task (recommended for general use)
6
+ - BaseEvaluator: Abstract base class for custom evaluators
7
+ - Task-specific evaluators for specialized use cases
8
+ """
9
+
10
+ from .base_evaluator import BaseEvaluator
11
+ from .universal_evaluator import UniversalSemanticEvaluator, create_universal_evaluator
12
+ from .ui_evaluator import UITreeEvaluator
13
+ from .scroll_evaluator import ScrollElementEvaluator
14
+ from .validation_evaluator import ValidationEvaluator
15
+ from .index_caching_evaluator import IndexCachingEvaluator
16
+
17
+ __all__ = [
18
+ # Universal (recommended)
19
+ "UniversalSemanticEvaluator",
20
+ "create_universal_evaluator",
21
+ # Base class
22
+ "BaseEvaluator",
23
+ # Task-specific
24
+ "UITreeEvaluator",
25
+ "ScrollElementEvaluator",
26
+ "ValidationEvaluator",
27
+ "IndexCachingEvaluator",
28
+ ]
src/gepa_optimizer/evaluation/base_evaluator.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base evaluator class for all evaluation strategies.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, Optional
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class BaseEvaluator(ABC):
12
+ """
13
+ Abstract base class for all evaluation strategies.
14
+
15
+ This enforces a consistent interface while allowing complete customization
16
+ of evaluation logic for any use case.
17
+ """
18
+
19
+ def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
20
+ """
21
+ Initialize evaluator with optional metric weights.
22
+
23
+ Args:
24
+ metric_weights: Optional weights for different metrics.
25
+ If None, subclasses should provide defaults.
26
+ """
27
+ self.metric_weights = metric_weights or {}
28
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
29
+
30
+ @abstractmethod
31
+ def evaluate(self, predicted: Any, expected: Any) -> Dict[str, float]:
32
+ """
33
+ Evaluate predicted output against expected output.
34
+
35
+ Args:
36
+ predicted: The model's predicted output
37
+ expected: The ground truth expected output
38
+
39
+ Returns:
40
+ Dictionary with metric names as keys and scores as values.
41
+ Must include 'composite_score' key for GEPA integration.
42
+ """
43
+ pass
44
+
45
+ def validate_weights(self) -> bool:
46
+ """Validate that metric weights sum to approximately 1.0"""
47
+ if not self.metric_weights:
48
+ return True
49
+
50
+ total = sum(self.metric_weights.values())
51
+ return abs(total - 1.0) < 0.01 # Allow small floating point errors
src/gepa_optimizer/evaluation/index_caching_evaluator.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Index Caching Evaluator for Index-Based Element Selection Use Case
3
+
4
+ Evaluates predicted index caching results against expected results.
5
+ Compares all 5 fields with equal weight:
6
+ - is_index_based
7
+ - index_value
8
+ - parent_element_id
9
+ - element_id_of_nth_child_of_parent
10
+ - selected_element_is_correct
11
+ """
12
+
13
+ from typing import Dict, Any, Optional
14
+ import json
15
+ import re
16
+ import logging
17
+
18
+ from .base_evaluator import BaseEvaluator
19
+
20
+
21
+ class IndexCachingEvaluator(BaseEvaluator):
22
+ """
23
+ Evaluator for index caching use case.
24
+
25
+ Features:
26
+ - Compares all 5 fields with equal weight (20% each)
27
+ - Parses JSON from LLM response
28
+ - Handles null values correctly
29
+ - Returns detailed field-by-field comparison
30
+ """
31
+
32
+ def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
33
+ """
34
+ Initialize index caching evaluator.
35
+
36
+ Args:
37
+ metric_weights: Weights for evaluation metrics
38
+ Default: Equal weight for all 5 fields (0.2 each)
39
+ """
40
+ # Each field gets 20% weight (5 fields * 0.2 = 1.0)
41
+ default_weights = {
42
+ "is_index_based_match": 0.2,
43
+ "index_value_match": 0.2,
44
+ "parent_element_id_match": 0.2,
45
+ "element_id_of_nth_child_match": 0.2,
46
+ "selected_element_correct_match": 0.2,
47
+ }
48
+
49
+ weights = metric_weights or default_weights
50
+ super().__init__(metric_weights=weights)
51
+
52
+ def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
53
+ """
54
+ Evaluate predicted index caching result against expected result.
55
+
56
+ Args:
57
+ predicted: LLM's output (JSON string with all 5 fields)
58
+ expected: Expected output (JSON string or dict with all 5 fields)
59
+
60
+ Returns:
61
+ Dictionary with evaluation metrics:
62
+ {
63
+ "is_index_based_match": 1.0 or 0.0,
64
+ "index_value_match": 1.0 or 0.0,
65
+ "parent_element_id_match": 1.0 or 0.0,
66
+ "element_id_of_nth_child_match": 1.0 or 0.0,
67
+ "selected_element_correct_match": 1.0 or 0.0,
68
+ "composite_score": 0.0 to 1.0,
69
+ "predicted_output": str,
70
+ "expected_output": str,
71
+ "field_scores": {...},
72
+ "evaluation_reason": str
73
+ }
74
+ """
75
+ if not predicted or not expected:
76
+ return {
77
+ "is_index_based_match": 0.0,
78
+ "index_value_match": 0.0,
79
+ "parent_element_id_match": 0.0,
80
+ "element_id_of_nth_child_match": 0.0,
81
+ "selected_element_correct_match": 0.0,
82
+ "composite_score": 0.0,
83
+ "predicted_output": str(predicted).strip() if predicted else "",
84
+ "expected_output": str(expected).strip() if expected else "",
85
+ "field_scores": {},
86
+ "evaluation_reason": "❌ Empty or missing input/output"
87
+ }
88
+
89
+ # Parse expected (could be JSON string or dict)
90
+ try:
91
+ if isinstance(expected, str):
92
+ expected_dict = json.loads(expected)
93
+ else:
94
+ expected_dict = expected
95
+ except (json.JSONDecodeError, TypeError):
96
+ # If expected is already a dict from dataset
97
+ expected_dict = expected if isinstance(expected, dict) else {}
98
+
99
+ # Parse predicted (must be JSON string)
100
+ try:
101
+ predicted_dict = self._parse_json_response(predicted)
102
+ except Exception as e:
103
+ # Log the actual response for debugging
104
+ response_preview = predicted[:200] if predicted else "(empty)"
105
+ self.logger.warning(f"Failed to parse predicted JSON: {e}")
106
+ self.logger.warning(f"Response preview: {response_preview}...")
107
+ predicted_dict = {}
108
+
109
+ # NOTE: "notes" field is present in the output but is NOT used for scoring or reflection
110
+ # It's kept for reference but ignored in evaluation
111
+
112
+ # Compare each field (only the 5 core fields, ignoring "notes")
113
+ field_scores = {}
114
+ field_reasons = []
115
+
116
+ # 1. is_index_based (boolean)
117
+ pred_is_index = predicted_dict.get("is_index_based")
118
+ exp_is_index = expected_dict.get("is_index_based")
119
+ is_index_match = (pred_is_index == exp_is_index) if (pred_is_index is not None and exp_is_index is not None) else False
120
+ field_scores["is_index_based"] = 1.0 if is_index_match else 0.0
121
+ field_reasons.append(f"is_index_based: {pred_is_index} vs {exp_is_index} → {'✅' if is_index_match else '❌'}")
122
+
123
+ # 2. index_value (int or null)
124
+ pred_index_val = predicted_dict.get("index_value")
125
+ exp_index_val = expected_dict.get("index_value")
126
+ # Handle null/None comparison
127
+ index_val_match = (pred_index_val == exp_index_val) or (pred_index_val is None and exp_index_val is None)
128
+ field_scores["index_value"] = 1.0 if index_val_match else 0.0
129
+ field_reasons.append(f"index_value: {pred_index_val} vs {exp_index_val} → {'✅' if index_val_match else '❌'}")
130
+
131
+ # 3. parent_element_id (string or null)
132
+ pred_parent = predicted_dict.get("parent_element_id")
133
+ exp_parent = expected_dict.get("parent_element_id")
134
+ # Handle null/None comparison
135
+ parent_match = (pred_parent == exp_parent) or (pred_parent is None and exp_parent is None)
136
+ field_scores["parent_element_id"] = 1.0 if parent_match else 0.0
137
+ field_reasons.append(f"parent_element_id: {pred_parent} vs {exp_parent} → {'✅' if parent_match else '❌'}")
138
+
139
+ # 4. element_id_of_nth_child_of_parent (string or null)
140
+ pred_element = predicted_dict.get("element_id_of_nth_child_of_parent")
141
+ exp_element = expected_dict.get("element_id_of_nth_child_of_parent")
142
+ # Handle null/None comparison
143
+ element_match = (pred_element == exp_element) or (pred_element is None and exp_element is None)
144
+ field_scores["element_id_of_nth_child_of_parent"] = 1.0 if element_match else 0.0
145
+ field_reasons.append(f"element_id_of_nth_child: {pred_element} vs {exp_element} → {'✅' if element_match else '❌'}")
146
+
147
+ # 5. selected_element_is_correct (boolean)
148
+ pred_selected = predicted_dict.get("selected_element_is_correct")
149
+ exp_selected = expected_dict.get("selected_element_is_correct")
150
+ selected_match = (pred_selected == exp_selected) if (pred_selected is not None and exp_selected is not None) else False
151
+ field_scores["selected_element_is_correct"] = 1.0 if selected_match else 0.0
152
+ field_reasons.append(f"selected_element_is_correct: {pred_selected} vs {exp_selected} → {'✅' if selected_match else '❌'}")
153
+
154
+ # Calculate composite score (weighted average)
155
+ composite_score = (
156
+ field_scores["is_index_based"] * 0.2 +
157
+ field_scores["index_value"] * 0.2 +
158
+ field_scores["parent_element_id"] * 0.2 +
159
+ field_scores["element_id_of_nth_child_of_parent"] * 0.2 +
160
+ field_scores["selected_element_is_correct"] * 0.2
161
+ )
162
+
163
+ # Build evaluation reason
164
+ all_match = composite_score == 1.0
165
+ reason = "✅ All fields match!" if all_match else f"❌ Partial match ({composite_score:.1%})"
166
+ reason += "\n" + "\n".join(f" {r}" for r in field_reasons)
167
+
168
+ # Log evaluation details
169
+ self.logger.info(f"\n{'─'*70}")
170
+ self.logger.info(f"📊 INDEX CACHING EVALUATION")
171
+ self.logger.info(f"{'─'*70}")
172
+ self.logger.info(f" 🎯 COMPOSITE SCORE: {composite_score:.2f} ({composite_score:.1%})")
173
+ for field, score in field_scores.items():
174
+ status = "✅" if score == 1.0 else "❌"
175
+ self.logger.info(f" {status} {field}: {score:.0f}")
176
+ self.logger.info(f"{'─'*70}\n")
177
+
178
+ return {
179
+ "is_index_based_match": field_scores["is_index_based"],
180
+ "index_value_match": field_scores["index_value"],
181
+ "parent_element_id_match": field_scores["parent_element_id"],
182
+ "element_id_of_nth_child_match": field_scores["element_id_of_nth_child_of_parent"],
183
+ "selected_element_correct_match": field_scores["selected_element_is_correct"],
184
+ "composite_score": composite_score,
185
+ "predicted_output": predicted,
186
+ "expected_output": json.dumps(expected_dict) if isinstance(expected_dict, dict) else str(expected),
187
+ "predicted_dict": predicted_dict,
188
+ "expected_dict": expected_dict,
189
+ "field_scores": field_scores,
190
+ "evaluation_reason": reason
191
+ }
192
+
193
+ def _parse_json_response(self, response: str) -> Dict[str, Any]:
194
+ """
195
+ Parse JSON from LLM response, handling markdown code blocks and various formats.
196
+
197
+ Args:
198
+ response: LLM response string (may contain markdown)
199
+
200
+ Returns:
201
+ Parsed JSON dictionary (empty dict if parsing fails)
202
+ """
203
+ if not response or not isinstance(response, str):
204
+ return {}
205
+
206
+ response = response.strip()
207
+
208
+ # If response is empty, return empty dict
209
+ if not response:
210
+ return {}
211
+
212
+ # Strategy 1: Try to extract JSON from markdown code block
213
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response, re.DOTALL)
214
+ if json_match:
215
+ try:
216
+ json_str = json_match.group(1).strip()
217
+ return json.loads(json_str)
218
+ except json.JSONDecodeError:
219
+ pass
220
+
221
+ # Strategy 2: Find JSON object in response (handle nested braces)
222
+ json_start = response.find('{')
223
+ if json_start != -1:
224
+ # Find matching closing brace
225
+ brace_count = 0
226
+ json_end = json_start
227
+ for i in range(json_start, len(response)):
228
+ if response[i] == '{':
229
+ brace_count += 1
230
+ elif response[i] == '}':
231
+ brace_count -= 1
232
+ if brace_count == 0:
233
+ json_end = i + 1
234
+ break
235
+
236
+ if brace_count == 0:
237
+ json_str = response[json_start:json_end]
238
+ try:
239
+ return json.loads(json_str)
240
+ except json.JSONDecodeError:
241
+ pass
242
+
243
+ # Strategy 3: Try to find any JSON-like structure (more lenient)
244
+ # Look for patterns like {"key": "value"} even if not perfectly formatted
245
+ json_pattern = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
246
+ if json_pattern:
247
+ try:
248
+ return json.loads(json_pattern.group(0))
249
+ except json.JSONDecodeError:
250
+ pass
251
+
252
+ # Strategy 4: Try parsing entire response as JSON
253
+ try:
254
+ return json.loads(response)
255
+ except json.JSONDecodeError:
256
+ pass
257
+
258
+ # If all strategies fail, return empty dict
259
+ self.logger.debug(f"Could not parse JSON from response: {response[:100]}...")
260
+ return {}
261
+
262
+ def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
263
+ """
264
+ Get summary statistics for a batch of evaluations.
265
+
266
+ Args:
267
+ results: List of evaluation result dictionaries
268
+
269
+ Returns:
270
+ Summary statistics including accuracy per field and overall
271
+ """
272
+ if not results:
273
+ return {
274
+ "total_samples": 0,
275
+ "overall_accuracy": 0.0,
276
+ "field_accuracies": {},
277
+ "perfect_matches": 0
278
+ }
279
+
280
+ total = len(results)
281
+ perfect_matches = sum(1 for r in results if r.get("composite_score", 0.0) == 1.0)
282
+ overall_accuracy = perfect_matches / total if total > 0 else 0.0
283
+
284
+ # Calculate accuracy per field
285
+ field_accuracies = {
286
+ "is_index_based": sum(1 for r in results if r.get("is_index_based_match", 0.0) == 1.0) / total,
287
+ "index_value": sum(1 for r in results if r.get("index_value_match", 0.0) == 1.0) / total,
288
+ "parent_element_id": sum(1 for r in results if r.get("parent_element_id_match", 0.0) == 1.0) / total,
289
+ "element_id_of_nth_child": sum(1 for r in results if r.get("element_id_of_nth_child_match", 0.0) == 1.0) / total,
290
+ "selected_element_is_correct": sum(1 for r in results if r.get("selected_element_correct_match", 0.0) == 1.0) / total,
291
+ }
292
+
293
+ return {
294
+ "total_samples": total,
295
+ "overall_accuracy": overall_accuracy,
296
+ "field_accuracies": field_accuracies,
297
+ "perfect_matches": perfect_matches,
298
+ "partial_matches": total - perfect_matches
299
+ }
300
+
301
+
302
+ # Example usage and testing
303
+ if __name__ == "__main__":
304
+ print("🚀 Testing Index Caching Evaluator...")
305
+
306
+ evaluator = IndexCachingEvaluator()
307
+
308
+ # Test cases
309
+ test_cases = [
310
+ # (predicted, expected, should_be_perfect)
311
+ (
312
+ '{"is_index_based": true, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": true}',
313
+ {"is_index_based": True, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": True},
314
+ True
315
+ ),
316
+ (
317
+ '{"is_index_based": false, "index_value": null, "parent_element_id": null, "element_id_of_nth_child_of_parent": null, "selected_element_is_correct": true}',
318
+ {"is_index_based": False, "index_value": None, "parent_element_id": None, "element_id_of_nth_child_of_parent": None, "selected_element_is_correct": True},
319
+ True
320
+ ),
321
+ (
322
+ '{"is_index_based": true, "index_value": 3, "parent_element_id": null, "element_id_of_nth_child_of_parent": "aaaaaw", "selected_element_is_correct": true}',
323
+ {"is_index_based": True, "index_value": 3, "parent_element_id": None, "element_id_of_nth_child_of_parent": "aaaaaw", "selected_element_is_correct": True},
324
+ True
325
+ ),
326
+ (
327
+ '{"is_index_based": true, "index_value": 2, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": true}',
328
+ {"is_index_based": True, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": True},
329
+ False # index_value mismatch
330
+ ),
331
+ ]
332
+
333
+ print("\n📝 Running test cases:")
334
+ print("-" * 80)
335
+
336
+ results = []
337
+ for predicted, expected, should_be_perfect in test_cases:
338
+ result = evaluator.evaluate(predicted, expected)
339
+ is_perfect = result["composite_score"] == 1.0
340
+
341
+ status = "✅" if is_perfect == should_be_perfect else "❌"
342
+ print(f"{status} Test: Perfect match = {is_perfect} (expected {should_be_perfect})")
343
+ print(f" Score: {result['composite_score']:.2f}")
344
+ print()
345
+
346
+ results.append(result)
347
+
348
+ # Summary
349
+ print("\n📊 Summary:")
350
+ summary = evaluator.get_evaluation_summary(results)
351
+ print(f" Total: {summary['total_samples']}")
352
+ print(f" Perfect matches: {summary['perfect_matches']}")
353
+ print(f" Overall accuracy: {summary['overall_accuracy']:.1%}")
354
+ print(f" Field accuracies:")
355
+ for field, acc in summary['field_accuracies'].items():
356
+ print(f" {field}: {acc:.1%}")
357
+
src/gepa_optimizer/evaluation/scroll_evaluator.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GENERIC String Match Evaluator
3
+
4
+ Compares predicted output against expected output (simple string comparison).
5
+ NO assumptions about what the output represents (IDs, text, JSON, etc.).
6
+
7
+ Let GEPA discover the correct output format through evolution and feedback!
8
+ """
9
+
10
+ from typing import Dict, Any
11
+
12
+ try:
13
+ from .base_evaluator import BaseEvaluator
14
+ except ImportError:
15
+ # For standalone testing
16
+ import sys
17
+ from pathlib import Path
18
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
19
+ from gepa_optimizer.evaluation.base_evaluator import BaseEvaluator
20
+
21
+
22
+ class ScrollElementEvaluator(BaseEvaluator):
23
+ """
24
+ GENERIC evaluator - just compares strings!
25
+
26
+ NO assumptions about:
27
+ - Output format (element IDs, text, JSON, etc.)
28
+ - Output structure
29
+ - What the task is
30
+
31
+ GEPA will learn the correct format through feedback and evolution.
32
+ """
33
+
34
+ def __init__(self, metric_weights: Dict[str, float] = None):
35
+ """
36
+ Initialize evaluator.
37
+
38
+ Args:
39
+ metric_weights: Weights for evaluation metrics
40
+ Default: {"output_match": 1.0}
41
+ """
42
+ default_weights = {
43
+ "output_match": 1.0 # Simple string comparison
44
+ }
45
+
46
+ weights = metric_weights or default_weights
47
+ super().__init__(metric_weights=weights)
48
+
49
+ def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
50
+ """
51
+ Binary evaluation with element ID extraction.
52
+
53
+ Phase 1 Implementation:
54
+ - Extracts element IDs using regex patterns (flexible format support)
55
+ - Uses INTEGER comparison for robustness (prevents "4" vs "14" bugs)
56
+ - Binary scoring: correct element = 1.0, wrong/missing = 0.0
57
+
58
+ Scoring Strategy:
59
+ 1. Extract element ID from both predicted and expected outputs
60
+ 2. Compare using integer arithmetic (not string comparison)
61
+ 3. Return 1.0 if match, 0.0 otherwise (no partial credit)
62
+
63
+ Args:
64
+ predicted: LLM's output (may include verbose explanation)
65
+ expected: Expected output (may include verbose explanation)
66
+
67
+ Returns:
68
+ Dictionary with evaluation metrics and extracted element IDs
69
+ """
70
+ import re
71
+
72
+ if not predicted or not expected:
73
+ return {
74
+ "content_match": 0.0,
75
+ "output_match": 0.0,
76
+ "composite_score": 0.0,
77
+ "predicted_output": str(predicted).strip() if predicted else "",
78
+ "expected_output": str(expected).strip() if expected else "",
79
+ "predicted_element": "None",
80
+ "expected_element": "None",
81
+ "evaluation_reason": "❌ Empty or missing input/output"
82
+ }
83
+
84
+ predicted_str = str(predicted).strip()
85
+ expected_str = str(expected).strip()
86
+
87
+ # 1. Extract element numbers using MULTIPLE strategies (flexible!)
88
+ # Strategy A: "Element: X" or "Element X" (explicit format)
89
+ element_pattern_a = r'element[:\s]+(\d+)'
90
+
91
+ # Strategy B: "element X" or "Element X" anywhere in text
92
+ element_pattern_b = r'\belement\s+(\d+)\b'
93
+
94
+ # Strategy C: Just find ANY number if other strategies fail (last resort)
95
+ number_pattern = r'\b(\d+)\b'
96
+
97
+ # Try to extract from predicted
98
+ pred_match = re.search(element_pattern_a, predicted_str, re.IGNORECASE)
99
+ if not pred_match:
100
+ pred_match = re.search(element_pattern_b, predicted_str, re.IGNORECASE)
101
+ if not pred_match:
102
+ # Last resort: find first number in the text
103
+ pred_match = re.search(number_pattern, predicted_str)
104
+
105
+ # Try to extract from expected
106
+ exp_match = re.search(element_pattern_a, expected_str, re.IGNORECASE)
107
+ if not exp_match:
108
+ exp_match = re.search(element_pattern_b, expected_str, re.IGNORECASE)
109
+ if not exp_match:
110
+ exp_match = re.search(number_pattern, expected_str)
111
+
112
+ # 2. Check if we found element numbers in both
113
+ if not exp_match:
114
+ # Expected doesn't have element pattern - fallback to exact match
115
+ content_score = 1.0 if predicted_str.lower() == expected_str.lower() else 0.0
116
+ elif not pred_match:
117
+ # Predicted doesn't have element number - WRONG
118
+ content_score = 0.0
119
+ else:
120
+ # Both have element pattern - compare using INTEGER comparison
121
+ pred_element = pred_match.group(1)
122
+ exp_element = exp_match.group(1)
123
+
124
+ # 🔥 Phase 1: Use INTEGER comparison for robustness
125
+ # This prevents bugs like "4" != "14" string comparison issues
126
+ try:
127
+ pred_num = int(pred_element)
128
+ exp_num = int(exp_element)
129
+
130
+ # Integer comparison (more robust than string)
131
+ content_score = 1.0 if pred_num == exp_num else 0.0
132
+
133
+ # Log comparison for debugging
134
+ if pred_num != exp_num:
135
+ import logging
136
+ logger = logging.getLogger(__name__)
137
+ logger.debug(f"Element mismatch: predicted={pred_num}, expected={exp_num}")
138
+
139
+ except (ValueError, TypeError) as e:
140
+ # Fallback to string comparison if conversion fails
141
+ import logging
142
+ logger = logging.getLogger(__name__)
143
+ logger.warning(f"Could not convert elements to integers: {e}, using string comparison")
144
+ content_score = 1.0 if pred_element == exp_element else 0.0
145
+
146
+ # 3. Binary score and reason
147
+ if content_score == 1.0:
148
+ composite_score = 1.0
149
+ reason = "✅ Correct! Element number matches"
150
+ else:
151
+ composite_score = 0.0
152
+ if pred_match and exp_match:
153
+ reason = "❌ Wrong element number (predicted different element)"
154
+ else:
155
+ reason = "❌ Missing or invalid element number"
156
+
157
+ pred_element = pred_match.group(1) if pred_match else "None"
158
+ exp_element = exp_match.group(1) if exp_match else "None"
159
+
160
+ # Detailed logging for transparency
161
+ import logging
162
+ logger = logging.getLogger(__name__)
163
+ logger.info(f"\n{'─'*70}")
164
+ logger.info(f"📊 EVALUATION DETAILS")
165
+ logger.info(f"{'─'*70}")
166
+ logger.info(f" Expected: '{expected_str}' (Element: {exp_element})")
167
+ logger.info(f" Predicted: '{predicted_str}' (Element: {pred_element})")
168
+ logger.info(f" {'─'*66}")
169
+ logger.info(f" 🎯 SCORE: {composite_score:.2f} - {reason}")
170
+ logger.info(f"{'─'*70}\n")
171
+
172
+ return {
173
+ "content_match": content_score,
174
+ "output_match": composite_score, # This is what GEPA uses
175
+ "composite_score": composite_score,
176
+ "predicted_output": predicted_str,
177
+ "expected_output": expected_str,
178
+ "predicted_element": pred_element,
179
+ "expected_element": exp_element,
180
+ "evaluation_reason": reason
181
+ }
182
+
183
+ def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
184
+ """
185
+ Get summary statistics for a batch of evaluations.
186
+
187
+ Args:
188
+ results: List of evaluation result dictionaries
189
+
190
+ Returns:
191
+ Summary statistics
192
+ """
193
+ if not results:
194
+ return {
195
+ "total_samples": 0,
196
+ "accuracy": 0.0,
197
+ "correct_predictions": 0
198
+ }
199
+
200
+ total = len(results)
201
+ correct = sum(1 for r in results if r.get("output_match", 0.0) == 1.0)
202
+ accuracy = correct / total if total > 0 else 0.0
203
+
204
+ return {
205
+ "total_samples": total,
206
+ "accuracy": accuracy,
207
+ "correct_predictions": correct,
208
+ "incorrect_predictions": total - correct
209
+ }
210
+
211
+
212
+ # Example usage and testing
213
+ if __name__ == "__main__":
214
+ print("🚀 Testing Scroll Element Evaluator...")
215
+
216
+ evaluator = ScrollElementEvaluator()
217
+
218
+ # Test cases
219
+ test_cases = [
220
+ ("4", "4", True),
221
+ ("Element: 4", "4", True),
222
+ ("Element 4", "4", True),
223
+ ("The element to interact with is 4", "4", True),
224
+ ("Element ID: 4", "4", True),
225
+ ("Click on element 4 to scroll", "4", True),
226
+ ("5", "4", False),
227
+ ("Element: 5", "4", False),
228
+ ("No element found", "4", False),
229
+ ("", "4", False),
230
+ ]
231
+
232
+ print("\n📝 Running test cases:")
233
+ print("-" * 80)
234
+
235
+ results = []
236
+ for predicted, expected, should_match in test_cases:
237
+ result = evaluator.evaluate(predicted, expected)
238
+ match = result["composite_score"] == 1.0
239
+
240
+ status = "✅" if match == should_match else "❌"
241
+ print(f"{status} Predicted: '{predicted}' | Expected: '{expected}' | Match: {match}")
242
+
243
+ results.append(result)
244
+
245
+ # Summary
246
+ print("\n📊 Summary:")
247
+ summary = evaluator.get_evaluation_summary(results)
248
+ print(f" Total: {summary['total_samples']}")
249
+ print(f" Correct: {summary['correct_predictions']}")
250
+ print(f" Accuracy: {summary['accuracy']:.1%}")
251
+
src/gepa_optimizer/evaluation/ui_evaluator.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UI Tree Evaluator for GEPA Optimizer
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import difflib
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from .base_evaluator import BaseEvaluator
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class UITreeEvaluator(BaseEvaluator):
15
+ """
16
+ Comprehensive evaluator for UI tree extraction quality.
17
+ """
18
+
19
+ def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
20
+ """
21
+ Initializes the UITreeEvaluator with configurable metric weights.
22
+
23
+ Args:
24
+ metric_weights: A dictionary of weights for different metrics.
25
+ If None, default weights will be used.
26
+ """
27
+ # Set default weights for UI tree evaluation
28
+ default_weights = {
29
+ "element_completeness": 0.3, # How many elements are captured
30
+ "element_type_accuracy": 0.25, # Correct element types (Button, Text, etc.)
31
+ "text_content_accuracy": 0.2, # Text content matches
32
+ "hierarchy_accuracy": 0.15, # Parent-child relationships
33
+ "style_accuracy": 0.1, # Style properties captured
34
+ }
35
+
36
+ # Use provided weights or defaults
37
+ weights = metric_weights or default_weights
38
+
39
+ # Initialize parent class
40
+ super().__init__(metric_weights=weights)
41
+
42
+ # Normalize weights
43
+ self._normalize_weights()
44
+
45
+ def _normalize_weights(self):
46
+ """Normalize weights to sum to 1.0"""
47
+ total_weight = sum(self.metric_weights.values())
48
+ if total_weight > 0:
49
+ self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()}
50
+ else:
51
+ self.logger.warning("Total metric weight is zero. Scores will be zero.")
52
+
53
+ def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]:
54
+ """
55
+ Generates a weighted composite score from individual metrics.
56
+
57
+ Args:
58
+ predicted_json: The JSON generated by the LLM.
59
+ expected_json: The ground truth JSON.
60
+
61
+ Returns:
62
+ A dictionary of individual metric scores and the composite score.
63
+ """
64
+ scores = {
65
+ "element_completeness": self.calculate_element_completeness(predicted_json, expected_json),
66
+ "element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json),
67
+ "text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json),
68
+ "hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json),
69
+ "style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json),
70
+ }
71
+
72
+ composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores)
73
+ scores["composite_score"] = composite_score
74
+
75
+ # Add detailed logging for debugging
76
+ logger.debug(f"Evaluation scores: {scores}")
77
+ logger.debug(f"Composite score: {composite_score:.4f}")
78
+
79
+ # Add small improvement bonus for better prompts (encourage GEPA to accept improvements)
80
+ # This helps GEPA recognize even tiny improvements
81
+ if composite_score > 0.05: # If we have any meaningful content
82
+ composite_score = min(composite_score + 0.001, 1.0) # Small bonus to encourage acceptance
83
+
84
+ return scores
85
+
86
+ def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float:
87
+ """
88
+ Calculates how many UI elements are captured in the predicted JSON.
89
+ This is the most important metric for UI tree extraction.
90
+ """
91
+ def _count_elements(node):
92
+ """Count total elements in the tree"""
93
+ if not isinstance(node, dict):
94
+ return 0
95
+ count = 1 # Count current node
96
+ for child in node.get("children", []):
97
+ count += _count_elements(child)
98
+ return count
99
+
100
+ try:
101
+ predicted_count = _count_elements(predicted)
102
+ expected_count = _count_elements(expected)
103
+
104
+ if expected_count == 0:
105
+ return 1.0 if predicted_count == 0 else 0.0
106
+
107
+ # Score based on how many elements are captured
108
+ completeness_ratio = predicted_count / expected_count
109
+
110
+ # Give bonus for capturing more elements (up to 1.0)
111
+ # Penalize heavily for missing elements
112
+ if completeness_ratio >= 1.0:
113
+ return 1.0 # Perfect or better
114
+ elif completeness_ratio >= 0.8:
115
+ return completeness_ratio # Good coverage
116
+ elif completeness_ratio >= 0.5:
117
+ return completeness_ratio * 0.8 # Moderate coverage with penalty
118
+ else:
119
+ return completeness_ratio * 0.5 # Poor coverage with heavy penalty
120
+
121
+ except Exception as e:
122
+ logger.warning(f"Error calculating element completeness: {e}")
123
+ return 0.0
124
+
125
+ def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float:
126
+ """
127
+ Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes.
128
+ Focuses on common UI element types like Button, Text, Image, etc.
129
+ """
130
+ def _get_all_types(node):
131
+ if not isinstance(node, dict):
132
+ return []
133
+ types = [node.get("type")]
134
+ for child in node.get("children", []):
135
+ types.extend(_get_all_types(child))
136
+ return [t for t in types if t is not None]
137
+
138
+ try:
139
+ predicted_types = _get_all_types(predicted)
140
+ expected_types = _get_all_types(expected)
141
+
142
+ if not expected_types:
143
+ return 1.0 if not predicted_types else 0.5
144
+
145
+ if not predicted_types:
146
+ return 0.0
147
+
148
+ # Count matching types with frequency consideration
149
+ expected_type_counts = {}
150
+ for t in expected_types:
151
+ expected_type_counts[t] = expected_type_counts.get(t, 0) + 1
152
+
153
+ predicted_type_counts = {}
154
+ for t in predicted_types:
155
+ predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1
156
+
157
+ # Calculate accuracy based on type matches
158
+ total_matches = 0
159
+ for type_name, expected_count in expected_type_counts.items():
160
+ predicted_count = predicted_type_counts.get(type_name, 0)
161
+ # Count matches up to the expected count
162
+ total_matches += min(predicted_count, expected_count)
163
+
164
+ return total_matches / len(expected_types) if expected_types else 0.0
165
+
166
+ except Exception as e:
167
+ logger.warning(f"Error calculating element type accuracy: {e}")
168
+ return 0.0
169
+
170
+ def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float:
171
+ """
172
+ Calculates hierarchy accuracy by comparing parent-child relationships.
173
+ """
174
+ def _get_hierarchy_structure(node, parent_type="ROOT"):
175
+ """Extract hierarchy structure as (parent_type, child_type) pairs"""
176
+ if not isinstance(node, dict):
177
+ return []
178
+
179
+ current_type = node.get("type", "unknown")
180
+ hierarchy = [(parent_type, current_type)]
181
+
182
+ for child in node.get("children", []):
183
+ hierarchy.extend(_get_hierarchy_structure(child, current_type))
184
+
185
+ return hierarchy
186
+
187
+ try:
188
+ predicted_hierarchy = _get_hierarchy_structure(predicted)
189
+ expected_hierarchy = _get_hierarchy_structure(expected)
190
+
191
+ if not expected_hierarchy:
192
+ return 1.0 if not predicted_hierarchy else 0.5
193
+
194
+ if not predicted_hierarchy:
195
+ return 0.0
196
+
197
+ # Count matching hierarchy relationships
198
+ expected_hierarchy_set = set(expected_hierarchy)
199
+ predicted_hierarchy_set = set(predicted_hierarchy)
200
+
201
+ matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set))
202
+ total_expected = len(expected_hierarchy_set)
203
+
204
+ return matches / total_expected if total_expected > 0 else 0.0
205
+
206
+ except Exception as e:
207
+ logger.warning(f"Error calculating hierarchy accuracy: {e}")
208
+ return 0.0
209
+
210
+ def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float:
211
+ """
212
+ Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes.
213
+ """
214
+ def _get_all_texts(node):
215
+ if not isinstance(node, dict):
216
+ return []
217
+ texts = [node.get("text")]
218
+ for child in node.get("children", []):
219
+ texts.extend(_get_all_texts(child))
220
+ return [t for t in texts if t is not None and str(t).strip()]
221
+
222
+ try:
223
+ predicted_texts = _get_all_texts(predicted)
224
+ expected_texts = _get_all_texts(expected)
225
+
226
+ if not expected_texts:
227
+ return 1.0 if not predicted_texts else 0.5 # Partial credit if predicted has texts but expected doesn't
228
+
229
+ if not predicted_texts:
230
+ return 0.0 # No predicted texts, so no match
231
+
232
+ total_similarity = 0.0
233
+ for p_text in predicted_texts:
234
+ best_similarity = 0.0
235
+ for e_text in expected_texts:
236
+ similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio()
237
+ best_similarity = max(best_similarity, similarity)
238
+ total_similarity += best_similarity
239
+
240
+ # Average similarity over all predicted texts
241
+ if not predicted_texts and not expected_texts:
242
+ return 1.0
243
+ elif not predicted_texts:
244
+ return 0.0
245
+ else:
246
+ return total_similarity / len(predicted_texts)
247
+ except Exception as e:
248
+ logger.warning(f"Error calculating text content accuracy: {e}")
249
+ return 0.0
250
+
251
+ def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float:
252
+ """
253
+ Calculates style accuracy by comparing style properties.
254
+ """
255
+ def _get_all_styles(node):
256
+ """Extract all style properties from the tree"""
257
+ if not isinstance(node, dict):
258
+ return []
259
+
260
+ styles = []
261
+ if "style" in node and isinstance(node["style"], dict):
262
+ styles.append(node["style"])
263
+
264
+ for child in node.get("children", []):
265
+ styles.extend(_get_all_styles(child))
266
+
267
+ return styles
268
+
269
+ try:
270
+ predicted_styles = _get_all_styles(predicted)
271
+ expected_styles = _get_all_styles(expected)
272
+
273
+ if not expected_styles:
274
+ return 1.0 if not predicted_styles else 0.5
275
+
276
+ if not predicted_styles:
277
+ return 0.0
278
+
279
+ # Calculate style property overlap
280
+ total_style_properties = 0
281
+ matching_properties = 0
282
+
283
+ for exp_style in expected_styles:
284
+ for prop_name, prop_value in exp_style.items():
285
+ total_style_properties += 1
286
+
287
+ # Find matching property in predicted styles
288
+ for pred_style in predicted_styles:
289
+ if prop_name in pred_style and pred_style[prop_name] == prop_value:
290
+ matching_properties += 1
291
+ break
292
+
293
+ return matching_properties / total_style_properties if total_style_properties > 0 else 0.0
294
+
295
+ except Exception as e:
296
+ logger.warning(f"Error calculating style accuracy: {e}")
297
+ return 0.0
src/gepa_optimizer/evaluation/universal_evaluator.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal Semantic Evaluator for ANY prompt optimization use case.
3
+
4
+ This evaluator uses LLM-powered semantic analysis to compare predicted vs expected outputs,
5
+ enabling prompt optimization for ANY task without requiring custom evaluator code.
6
+
7
+ Key Features:
8
+ - Semantic understanding (not just string matching)
9
+ - Works with text, JSON, numbers, structured outputs
10
+ - Provides rich feedback for GEPA reflection
11
+ - No task-specific assumptions
12
+ """
13
+
14
+ import json
15
+ import re
16
+ import logging
17
+ from typing import Dict, Any, Optional, List
18
+ from difflib import SequenceMatcher
19
+
20
+ from .base_evaluator import BaseEvaluator
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class UniversalSemanticEvaluator(BaseEvaluator):
26
+ """
27
+ Universal evaluator using LLM for semantic comparison.
28
+
29
+ Works for ANY task without hardcoded assumptions:
30
+ - Text outputs: "The answer is 42" vs "42"
31
+ - JSON outputs: {"count": 23} vs {"count": 22}
32
+ - Structured data: Lists, nested objects
33
+ - Multi-modal: Image descriptions, analysis results
34
+
35
+ Evaluation Strategy:
36
+ 1. Quick checks (exact match, empty handling)
37
+ 2. Structural comparison (for JSON/structured data)
38
+ 3. LLM semantic analysis (for meaning understanding)
39
+ 4. Combine into composite score with rich feedback
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ llm_client=None,
45
+ use_llm_analysis: bool = True,
46
+ semantic_weight: float = 0.6,
47
+ structural_weight: float = 0.25,
48
+ exact_match_bonus: float = 0.15,
49
+ metric_weights: Optional[Dict[str, float]] = None
50
+ ):
51
+ """
52
+ Initialize Universal Semantic Evaluator.
53
+
54
+ Args:
55
+ llm_client: LLM client for semantic analysis (optional, falls back to heuristics)
56
+ use_llm_analysis: Whether to use LLM for semantic comparison
57
+ semantic_weight: Weight for semantic similarity (0.0-1.0)
58
+ structural_weight: Weight for structural similarity (0.0-1.0)
59
+ exact_match_bonus: Bonus weight for exact matches (0.0-1.0)
60
+ metric_weights: Optional custom weights (overrides above)
61
+ """
62
+ default_weights = metric_weights or {
63
+ "semantic_similarity": semantic_weight,
64
+ "structural_similarity": structural_weight,
65
+ "exact_match": exact_match_bonus
66
+ }
67
+ super().__init__(metric_weights=default_weights)
68
+
69
+ self.llm_client = llm_client
70
+ self.use_llm_analysis = use_llm_analysis and llm_client is not None
71
+
72
+ # Cache for LLM analysis to reduce API calls
73
+ self._analysis_cache: Dict[str, Dict] = {}
74
+
75
+ logger.info(f"🎯 Universal Semantic Evaluator initialized")
76
+ logger.info(f" LLM analysis: {'enabled' if self.use_llm_analysis else 'disabled (using heuristics)'}")
77
+ logger.info(f" Weights: semantic={semantic_weight}, structural={structural_weight}, exact={exact_match_bonus}")
78
+
79
+ def evaluate(self, predicted: Any, expected: Any) -> Dict[str, float]:
80
+ """
81
+ Evaluate predicted output against expected output using semantic understanding.
82
+
83
+ Args:
84
+ predicted: The model's predicted output (string, dict, or any serializable type)
85
+ expected: The ground truth expected output
86
+
87
+ Returns:
88
+ Dictionary with metrics including 'composite_score' (required for GEPA)
89
+ """
90
+ # Convert to strings for comparison
91
+ predicted_str = self._to_string(predicted)
92
+ expected_str = self._to_string(expected)
93
+
94
+ # Initialize result
95
+ result = {
96
+ "composite_score": 0.0,
97
+ "exact_match": 0.0,
98
+ "semantic_similarity": 0.0,
99
+ "structural_similarity": 0.0,
100
+ "predicted_output": predicted_str[:500], # Truncate for logging
101
+ "expected_output": expected_str[:500],
102
+ "analysis": {},
103
+ "improvement_feedback": ""
104
+ }
105
+
106
+ # Handle empty/missing outputs
107
+ if not predicted_str or not predicted_str.strip():
108
+ result["improvement_feedback"] = "❌ Output is EMPTY. The prompt must instruct the model to produce output."
109
+ result["analysis"] = {"status": "empty_predicted"}
110
+ return result
111
+
112
+ if not expected_str or not expected_str.strip():
113
+ result["improvement_feedback"] = "⚠️ Expected output is empty - cannot evaluate."
114
+ result["analysis"] = {"status": "empty_expected"}
115
+ result["composite_score"] = 0.5 # Neutral score
116
+ return result
117
+
118
+ # ─────────────────────────────────────────────────────
119
+ # STEP 1: Exact Match Check (Fast Path)
120
+ # ─────────────────────────────────────────────────────
121
+ normalized_pred = self._normalize(predicted_str)
122
+ normalized_exp = self._normalize(expected_str)
123
+
124
+ if normalized_pred == normalized_exp:
125
+ result["exact_match"] = 1.0
126
+ result["semantic_similarity"] = 1.0
127
+ result["structural_similarity"] = 1.0
128
+ result["composite_score"] = 1.0
129
+ result["improvement_feedback"] = "✅ Perfect match! Output exactly matches expected."
130
+ result["analysis"] = {"status": "exact_match"}
131
+ return result
132
+
133
+ # ─────────────────────────────────────────────────────
134
+ # STEP 1.5: FORMAT MISMATCH DETECTION (CRITICAL FIX)
135
+ # ─────────────────────────────────────────────────────
136
+ # 🔥 CRITICAL: Detect when expected is JSON but predicted is narrative text
137
+ # This causes catastrophically low scores and needs explicit handling
138
+ expected_is_json = self._try_parse_json(expected_str) is not None
139
+ predicted_is_json = self._try_parse_json(predicted_str) is not None
140
+
141
+ format_mismatch = expected_is_json and not predicted_is_json
142
+ if format_mismatch:
143
+ # Expected JSON but got narrative - this is a CRITICAL format error
144
+ # Give partial credit for semantic content but penalize heavily for format
145
+ result["analysis"]["format_mismatch"] = True
146
+ result["improvement_feedback"] = (
147
+ "❌ FORMAT ERROR: Expected JSON output but received narrative text. "
148
+ "The prompt MUST enforce JSON output format. "
149
+ "Add explicit instructions like: 'Output ONLY valid JSON, no explanations.' "
150
+ "Consider adding: 'Do NOT write prose or explanations.'"
151
+ )
152
+ # Still evaluate semantic content but cap the score
153
+ # This gives feedback for improving the prompt
154
+ logger.warning(f"⚠️ Format mismatch: expected JSON ({len(expected_str)} chars), got narrative ({len(predicted_str)} chars)")
155
+
156
+ # ─────────────────────────────────────────────────────
157
+ # STEP 2: Structural Comparison (for JSON/structured data)
158
+ # ─────────────────────────────────────────────────────
159
+ structural_result = self._compare_structure(predicted_str, expected_str)
160
+ result["structural_similarity"] = structural_result["score"]
161
+ result["analysis"]["structural"] = structural_result.get("details", {})
162
+
163
+ # ─────────────────────────────────────────────────────
164
+ # STEP 3: Semantic Analysis
165
+ # ─────────────────────────────────────────────────────
166
+ if self.use_llm_analysis:
167
+ semantic_result = self._llm_semantic_analysis(predicted_str, expected_str)
168
+ else:
169
+ semantic_result = self._heuristic_semantic_analysis(predicted_str, expected_str)
170
+
171
+ result["semantic_similarity"] = semantic_result["score"]
172
+ result["analysis"]["semantic"] = semantic_result.get("details", {})
173
+ result["improvement_feedback"] = semantic_result.get("feedback", "")
174
+
175
+ # ─────────────────────────────────────────────────────
176
+ # STEP 4: Compute Composite Score
177
+ # ─────────────────────────────────────────────────────
178
+ weights = self.metric_weights
179
+ composite = (
180
+ result["semantic_similarity"] * weights.get("semantic_similarity", 0.6) +
181
+ result["structural_similarity"] * weights.get("structural_similarity", 0.25) +
182
+ result["exact_match"] * weights.get("exact_match", 0.15)
183
+ )
184
+
185
+ # 🔥 CRITICAL FIX: Apply format mismatch penalty
186
+ # If expected JSON but got narrative, cap the score to encourage format compliance
187
+ if result.get("analysis", {}).get("format_mismatch"):
188
+ # Cap at 0.3 to indicate "partial semantic match but wrong format"
189
+ # This ensures format-correct outputs always score higher
190
+ composite = min(composite, 0.30)
191
+ logger.debug(f"📊 Format mismatch penalty applied: score capped at {composite:.3f}")
192
+
193
+ result["composite_score"] = min(max(composite, 0.0), 1.0)
194
+
195
+ # Add score breakdown to feedback
196
+ if not result["improvement_feedback"]:
197
+ result["improvement_feedback"] = self._generate_default_feedback(result)
198
+
199
+ # Log evaluation
200
+ logger.debug(f"📊 Evaluation: composite={result['composite_score']:.3f}, "
201
+ f"semantic={result['semantic_similarity']:.3f}, "
202
+ f"structural={result['structural_similarity']:.3f}")
203
+
204
+ # #region agent log
205
+ try:
206
+ import json as _json_debug
207
+ import time as _time_debug
208
+ import os as _os_debug
209
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
210
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
211
+ with open(_debug_log_path, "a") as _f:
212
+ _f.write(_json_debug.dumps({"hypothesisId": "G", "location": "universal_evaluator.py:final_score", "message": "Final evaluation score breakdown", "data": {"composite": result["composite_score"], "semantic": result["semantic_similarity"], "structural": result["structural_similarity"], "exact_match": result["exact_match"], "format_mismatch": result.get("analysis", {}).get("format_mismatch", False), "predicted_preview": predicted_str[:150] if predicted_str else "EMPTY", "expected_preview": expected_str[:150] if expected_str else "EMPTY"}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
213
+ except Exception as _e:
214
+ pass # Silent fail for instrumentation
215
+ # #endregion
216
+
217
+ return result
218
+
219
+ def _to_string(self, value: Any) -> str:
220
+ """Convert any value to string for comparison."""
221
+ if value is None:
222
+ return ""
223
+ if isinstance(value, str):
224
+ return value.strip()
225
+ if isinstance(value, dict):
226
+ try:
227
+ return json.dumps(value, sort_keys=True, indent=2)
228
+ except (TypeError, ValueError):
229
+ return str(value)
230
+ if isinstance(value, (list, tuple)):
231
+ try:
232
+ return json.dumps(list(value), sort_keys=True)
233
+ except (TypeError, ValueError):
234
+ return str(value)
235
+ return str(value).strip()
236
+
237
+ def _normalize(self, text: str) -> str:
238
+ """Normalize text for comparison (lowercase, whitespace)."""
239
+ # Lowercase and normalize whitespace
240
+ normalized = ' '.join(text.lower().split())
241
+ # Remove common punctuation that doesn't affect meaning
242
+ normalized = re.sub(r'[.,;:!?\'"]+$', '', normalized)
243
+ return normalized
244
+
245
+ def _compare_structure(self, predicted: str, expected: str) -> Dict[str, Any]:
246
+ """
247
+ Compare structural similarity (especially for JSON/structured outputs).
248
+
249
+ Returns:
250
+ Dict with 'score' (0.0-1.0) and 'details'
251
+ """
252
+ result = {"score": 0.0, "details": {}}
253
+
254
+ # Try to parse as JSON
255
+ pred_json = self._try_parse_json(predicted)
256
+ exp_json = self._try_parse_json(expected)
257
+
258
+ if pred_json is not None and exp_json is not None:
259
+ # Both are valid JSON - do structural comparison
260
+ return self._compare_json_structures(pred_json, exp_json)
261
+
262
+ # Fallback: Compare as text structure
263
+ return self._compare_text_structure(predicted, expected)
264
+
265
+ def _try_parse_json(self, text: str) -> Optional[Any]:
266
+ """
267
+ Try to parse text as JSON with robust extraction.
268
+
269
+ 🔥 FIX: LLMs often wrap JSON in markdown code blocks or add extra text.
270
+ This method now handles multiple formats:
271
+ - Direct JSON
272
+ - ```json ... ``` blocks
273
+ - ``` ... ``` blocks (no language tag)
274
+ - JSON embedded in prose
275
+ - Escaped newlines and quotes
276
+ """
277
+ if not text or not isinstance(text, str):
278
+ return None
279
+
280
+ # 🔥 PREPROCESSING: Clean common LLM output issues
281
+ cleaned = text.strip()
282
+
283
+ # Remove BOM and other invisible characters
284
+ cleaned = cleaned.lstrip('\ufeff\u200b\u200c\u200d')
285
+
286
+ # Strategy 1: Try direct parse (cleanest case)
287
+ try:
288
+ return json.loads(cleaned)
289
+ except json.JSONDecodeError:
290
+ pass
291
+
292
+ # Strategy 2: Extract JSON from markdown code block (```json ... ```)
293
+ # More permissive regex that handles optional language tags
294
+ json_match = re.search(r'```(?:json|JSON)?\s*([\{|\[].*?[\}|\]])\s*```', cleaned, re.DOTALL)
295
+ if json_match:
296
+ try:
297
+ return json.loads(json_match.group(1))
298
+ except json.JSONDecodeError:
299
+ pass
300
+
301
+ # Strategy 3: Find JSON using balanced brace matching (handles nested objects)
302
+ def extract_balanced_json(s: str, start_char: str, end_char: str) -> Optional[str]:
303
+ """Extract JSON with balanced braces/brackets."""
304
+ count = 0
305
+ start_idx = -1
306
+ for i, char in enumerate(s):
307
+ if char == start_char:
308
+ if count == 0:
309
+ start_idx = i
310
+ count += 1
311
+ elif char == end_char:
312
+ count -= 1
313
+ if count == 0 and start_idx >= 0:
314
+ return s[start_idx:i+1]
315
+ return None
316
+
317
+ # Try to find JSON object
318
+ json_obj = extract_balanced_json(cleaned, '{', '}')
319
+ if json_obj:
320
+ try:
321
+ return json.loads(json_obj)
322
+ except json.JSONDecodeError:
323
+ # Try to repair common issues
324
+ repaired = self._repair_json(json_obj)
325
+ try:
326
+ return json.loads(repaired)
327
+ except json.JSONDecodeError:
328
+ pass
329
+
330
+ # Try to find JSON array
331
+ json_arr = extract_balanced_json(cleaned, '[', ']')
332
+ if json_arr:
333
+ try:
334
+ return json.loads(json_arr)
335
+ except json.JSONDecodeError:
336
+ repaired = self._repair_json(json_arr)
337
+ try:
338
+ return json.loads(repaired)
339
+ except json.JSONDecodeError:
340
+ pass
341
+
342
+ return None
343
+
344
+ def _repair_json(self, json_str: str) -> str:
345
+ """
346
+ Attempt to repair common JSON issues from LLM output.
347
+
348
+ Fixes:
349
+ - Trailing commas before } or ]
350
+ - Single quotes instead of double quotes
351
+ - Unquoted keys
352
+ - Comments (// and /* */)
353
+ """
354
+ repaired = json_str
355
+
356
+ # Remove trailing commas
357
+ repaired = re.sub(r',\s*}', '}', repaired)
358
+ repaired = re.sub(r',\s*]', ']', repaired)
359
+
360
+ # Remove single-line comments
361
+ repaired = re.sub(r'//[^\n]*', '', repaired)
362
+
363
+ # Remove multi-line comments
364
+ repaired = re.sub(r'/\*.*?\*/', '', repaired, flags=re.DOTALL)
365
+
366
+ # Replace single quotes with double quotes (but be careful with apostrophes)
367
+ # Only replace when it looks like a JSON delimiter
368
+ def replace_single_quotes(match):
369
+ content = match.group(0)
370
+ # Skip if it looks like an apostrophe in a word
371
+ if re.match(r"'\w+'\s*:", content) or re.match(r":\s*'[^']*'", content):
372
+ return content.replace("'", '"')
373
+ return content
374
+
375
+ # Basic single quote replacement for keys
376
+ repaired = re.sub(r"'([^']+)'\s*:", r'"\1":', repaired)
377
+
378
+ return repaired
379
+
380
+ def _compare_json_structures(self, pred: Any, exp: Any) -> Dict[str, Any]:
381
+ """Compare two JSON structures."""
382
+ result = {"score": 0.0, "details": {"type": "json", "matches": [], "mismatches": []}}
383
+
384
+ if type(pred) != type(exp):
385
+ result["details"]["mismatches"].append(f"Type mismatch: predicted={type(pred).__name__}, expected={type(exp).__name__}")
386
+ result["score"] = 0.2 # Some credit for being JSON
387
+ return result
388
+
389
+ if isinstance(pred, dict) and isinstance(exp, dict):
390
+ return self._compare_dicts(pred, exp)
391
+ elif isinstance(pred, list) and isinstance(exp, list):
392
+ return self._compare_lists(pred, exp)
393
+ else:
394
+ # Primitive types
395
+ if pred == exp:
396
+ result["score"] = 1.0
397
+ result["details"]["matches"].append(f"Values match: {pred}")
398
+ else:
399
+ result["score"] = self._value_similarity(pred, exp)
400
+ result["details"]["mismatches"].append(f"Value mismatch: predicted={pred}, expected={exp}")
401
+ return result
402
+
403
+ def _compare_dicts(self, pred: dict, exp: dict) -> Dict[str, Any]:
404
+ """
405
+ Compare two dictionaries with CASE-INSENSITIVE key matching.
406
+
407
+ 🔥 FIX: LLMs often produce keys like 'Category' when expected is 'category'.
408
+ This method now normalizes keys before comparison for fair scoring.
409
+ """
410
+ result = {"score": 0.0, "details": {"type": "dict", "matches": [], "mismatches": [], "missing_keys": [], "extra_keys": []}}
411
+
412
+ # 🔥 NORMALIZE: Convert all keys to lowercase for comparison
413
+ # Also handle common variations like underscores vs camelCase
414
+ def normalize_key(key: str) -> str:
415
+ """Normalize key: lowercase, underscores to nothing, strip spaces."""
416
+ return re.sub(r'[_\s-]', '', str(key).lower())
417
+
418
+ # Build normalized key mappings
419
+ pred_normalized = {normalize_key(k): (k, v) for k, v in pred.items()}
420
+ exp_normalized = {normalize_key(k): (k, v) for k, v in exp.items()}
421
+
422
+ pred_norm_keys = set(pred_normalized.keys())
423
+ exp_norm_keys = set(exp_normalized.keys())
424
+
425
+ # Check for missing/extra keys (using normalized comparison)
426
+ missing_norm = exp_norm_keys - pred_norm_keys
427
+ extra_norm = pred_norm_keys - exp_norm_keys
428
+ common_norm = pred_norm_keys & exp_norm_keys
429
+
430
+ # Convert back to original key names for reporting
431
+ missing = [exp_normalized[k][0] for k in missing_norm]
432
+ extra = [pred_normalized[k][0] for k in extra_norm]
433
+
434
+ result["details"]["missing_keys"] = missing
435
+ result["details"]["extra_keys"] = extra
436
+
437
+ if not exp_norm_keys:
438
+ result["score"] = 1.0 if not pred_norm_keys else 0.5
439
+ return result
440
+
441
+ # Score based on key overlap (normalized)
442
+ key_score = len(common_norm) / len(exp_norm_keys) if exp_norm_keys else 1.0
443
+
444
+ # Score based on value matches
445
+ value_scores = []
446
+ for norm_key in common_norm:
447
+ pred_orig_key, pred_val = pred_normalized[norm_key]
448
+ exp_orig_key, exp_val = exp_normalized[norm_key]
449
+
450
+ if pred_val == exp_val:
451
+ value_scores.append(1.0)
452
+ result["details"]["matches"].append(f"{exp_orig_key}: {exp_val}")
453
+ else:
454
+ sim = self._value_similarity(pred_val, exp_val)
455
+ value_scores.append(sim)
456
+ if sim < 0.8:
457
+ result["details"]["mismatches"].append(f"{exp_orig_key}: predicted={pred_val}, expected={exp_val}")
458
+
459
+ value_score = sum(value_scores) / len(value_scores) if value_scores else 0.0
460
+
461
+ # Combine scores
462
+ result["score"] = 0.3 * key_score + 0.7 * value_score
463
+
464
+ # Penalty for missing keys (reduced from 0.1 to 0.05 per key)
465
+ if missing:
466
+ result["score"] *= (1 - 0.05 * len(missing))
467
+
468
+ result["score"] = max(0.0, min(1.0, result["score"]))
469
+ return result
470
+
471
+ def _compare_lists(self, pred: list, exp: list) -> Dict[str, Any]:
472
+ """Compare two lists."""
473
+ result = {"score": 0.0, "details": {"type": "list", "length_match": False, "item_matches": 0}}
474
+
475
+ if not exp:
476
+ result["score"] = 1.0 if not pred else 0.5
477
+ return result
478
+
479
+ result["details"]["length_match"] = len(pred) == len(exp)
480
+
481
+ # Compare items (order-sensitive)
482
+ matches = 0
483
+ for i, exp_item in enumerate(exp):
484
+ if i < len(pred):
485
+ if pred[i] == exp_item:
486
+ matches += 1
487
+ else:
488
+ # Check if item exists elsewhere
489
+ if exp_item in pred:
490
+ matches += 0.5 # Partial credit for wrong position
491
+
492
+ result["details"]["item_matches"] = matches
493
+ result["score"] = matches / len(exp)
494
+
495
+ # Penalty for length mismatch
496
+ if len(pred) != len(exp):
497
+ len_ratio = min(len(pred), len(exp)) / max(len(pred), len(exp))
498
+ result["score"] *= (0.7 + 0.3 * len_ratio)
499
+
500
+ return result
501
+
502
+ def _value_similarity(self, pred: Any, exp: Any) -> float:
503
+ """
504
+ Calculate similarity between two values.
505
+
506
+ 🔥 ENHANCED: Now handles:
507
+ - Case-insensitive string comparison
508
+ - Semantic similarity for common variations
509
+ - Underscore/space/dash normalization
510
+ - Numeric comparison with tolerance
511
+ """
512
+ # Same value (exact match)
513
+ if pred == exp:
514
+ return 1.0
515
+
516
+ # Numeric comparison
517
+ try:
518
+ pred_num = float(pred)
519
+ exp_num = float(exp)
520
+ if exp_num == 0:
521
+ return 1.0 if pred_num == 0 else 0.0
522
+ # Relative error with tolerance
523
+ error = abs(pred_num - exp_num) / abs(exp_num)
524
+ return max(0.0, 1.0 - error)
525
+ except (ValueError, TypeError):
526
+ pass
527
+
528
+ # String comparison with normalization
529
+ pred_str = str(pred).strip()
530
+ exp_str = str(exp).strip()
531
+
532
+ # Case-insensitive exact match
533
+ if pred_str.lower() == exp_str.lower():
534
+ return 0.98 # Slight penalty for case mismatch
535
+
536
+ # Normalize strings (remove underscores, spaces, dashes for comparison)
537
+ def normalize_str(s: str) -> str:
538
+ return re.sub(r'[_\s\-]+', '', s.lower())
539
+
540
+ pred_norm = normalize_str(pred_str)
541
+ exp_norm = normalize_str(exp_str)
542
+
543
+ if pred_norm == exp_norm:
544
+ return 0.95 # Good match despite formatting differences
545
+
546
+ # Check if one contains the other (partial match)
547
+ if pred_norm in exp_norm or exp_norm in pred_norm:
548
+ ratio = min(len(pred_norm), len(exp_norm)) / max(len(pred_norm), len(exp_norm))
549
+ return 0.7 + (0.2 * ratio) # 0.7-0.9 for partial matches
550
+
551
+ # 🔥 SEMANTIC SIMILARITY: Check for common equivalent terms
552
+ semantic_equivalents = {
553
+ # Priority levels
554
+ 'low': ['low', 'minor', 'trivial', 'p3', 'p4'],
555
+ 'medium': ['medium', 'normal', 'moderate', 'p2'],
556
+ 'high': ['high', 'important', 'major', 'p1', 'critical', 'urgent'],
557
+ # Boolean variations
558
+ 'true': ['true', 'yes', '1', 'on', 'enabled'],
559
+ 'false': ['false', 'no', '0', 'off', 'disabled'],
560
+ # Status variations
561
+ 'success': ['success', 'succeeded', 'completed', 'done', 'passed'],
562
+ 'failure': ['failure', 'failed', 'error', 'crashed'],
563
+ 'pending': ['pending', 'waiting', 'queued', 'in_progress', 'processing'],
564
+ }
565
+
566
+ for canonical, equivalents in semantic_equivalents.items():
567
+ pred_match = any(eq in pred_norm for eq in equivalents)
568
+ exp_match = any(eq in exp_norm for eq in equivalents)
569
+ if pred_match and exp_match:
570
+ return 0.85 # Semantic match
571
+
572
+ # Sequence matching (character-level similarity)
573
+ ratio = SequenceMatcher(None, pred_str.lower(), exp_str.lower()).ratio()
574
+
575
+ # 🔥 WORD-LEVEL SIMILARITY: Check word overlap
576
+ pred_words = set(re.findall(r'\w+', pred_str.lower()))
577
+ exp_words = set(re.findall(r'\w+', exp_str.lower()))
578
+
579
+ if pred_words and exp_words:
580
+ word_overlap = len(pred_words & exp_words) / max(len(pred_words), len(exp_words))
581
+ # Combine character and word similarity
582
+ return max(ratio, word_overlap * 0.9)
583
+
584
+ def _compare_text_structure(self, predicted: str, expected: str) -> Dict[str, Any]:
585
+ """Compare text structure when not JSON."""
586
+ result = {"score": 0.0, "details": {"type": "text"}}
587
+
588
+ # Word overlap
589
+ pred_words = set(predicted.lower().split())
590
+ exp_words = set(expected.lower().split())
591
+
592
+ if not exp_words:
593
+ result["score"] = 1.0 if not pred_words else 0.5
594
+ return result
595
+
596
+ overlap = len(pred_words & exp_words)
597
+ result["details"]["word_overlap"] = overlap
598
+ result["details"]["expected_words"] = len(exp_words)
599
+
600
+ # Jaccard similarity
601
+ union = len(pred_words | exp_words)
602
+ result["score"] = overlap / union if union > 0 else 0.0
603
+
604
+ return result
605
+
606
+ def _llm_semantic_analysis(self, predicted: str, expected: str) -> Dict[str, Any]:
607
+ """
608
+ Use LLM for semantic analysis of predicted vs expected.
609
+
610
+ Uses XML-delimited prompt structure to prevent context bleeding
611
+ and Multi-Dimensional Scoring (Semantics vs. Syntax).
612
+
613
+ Returns:
614
+ Dict with 'score' (0.0-1.0), 'details', and 'feedback'
615
+ """
616
+ # Check cache
617
+ cache_key = f"{hash(predicted)}:{hash(expected)}"
618
+ if cache_key in self._analysis_cache:
619
+ return self._analysis_cache[cache_key]
620
+
621
+ result = {"score": 0.0, "details": {}, "feedback": ""}
622
+
623
+ try:
624
+ # Truncate for token limits but preserve enough context
625
+ expected_truncated = expected[:10000]
626
+ predicted_truncated = predicted[:10000]
627
+
628
+ # OPTIMIZED: Penalty-based scoring with self-verification
629
+ # Starts at 1.0 and deducts for failures - more consistent than subjective scoring
630
+ analysis_prompt = f"""<system_role>
631
+ You are a **Semantic Logic Engine** tasked with grading AI performance.
632
+ You must compare a [PREDICTED] output against a [EXPECTED] truth.
633
+ </system_role>
634
+
635
+ <input_data>
636
+ <expected_output>
637
+ {expected_truncated}
638
+ </expected_output>
639
+
640
+ <predicted_output>
641
+ {predicted_truncated}
642
+ </predicted_output>
643
+ </input_data>
644
+
645
+ <scoring_algorithm>
646
+ Calculate the score based on these STRICT rules. Start with 1.0 and deduct penalties.
647
+
648
+ 1. **Information Completeness (Max -0.5)**:
649
+ - If key facts/fields are missing, deduct proportional to importance.
650
+ - If a nested JSON field is missing, deduct 0.1 per field.
651
+
652
+ 2. **Accuracy & Hallucination (Max -1.0)**:
653
+ - If factual numbers/IDs are wrong: Score = 0 immediately.
654
+ - If the model invents information NOT in the input: Deduct 0.3.
655
+
656
+ 3. **Format Compliance (Max -0.3)**:
657
+ - If JSON is requested but Markdown is returned: Deduct 0.3.
658
+ - If keys are lowercase instead of snake_case: Deduct 0.1.
659
+
660
+ 4. **Semantic Equivalence (No Penalty)**:
661
+ - Synonyms are ACCEPTED (e.g., "Purchase" == "Buy").
662
+ - Formatting differences (whitespace) are IGNORED.
663
+ </scoring_algorithm>
664
+
665
+ <self_verification>
666
+ Before finalizing the score, ask: "If I used the predicted output in code expecting the original output, would the code crash?"
667
+ - If YES (Crash) -> Score must be < 0.5.
668
+ - If NO (Safe) -> Score can be high.
669
+ </self_verification>
670
+
671
+ <output_schema>
672
+ Return JSON ONLY:
673
+ {{
674
+ "semantic_similarity": 0.0-1.0,
675
+ "structural_similarity": 0.0-1.0,
676
+ "verdict": "PERFECT" | "ACCEPTABLE" | "FORMAT_ERROR" | "DATA_CORRUPTION",
677
+ "critical_failures": ["List specific failures that caused score < 1.0"],
678
+ "penalty_breakdown": {{"completeness": -0.0, "accuracy": -0.0, "format": -0.0}},
679
+ "fix_directive": "Imperative command to fix the prompt"
680
+ }}
681
+ </output_schema>
682
+ """
683
+
684
+ response = self.llm_client.generate(
685
+ system_prompt="You are a Semantic Logic Engine. Calculate scores using penalty-based deduction from 1.0. Respond only with valid JSON.",
686
+ user_prompt=analysis_prompt,
687
+ image_base64=""
688
+ )
689
+
690
+ content = response.get("content", str(response)) if isinstance(response, dict) else str(response)
691
+
692
+ # Parse JSON response
693
+ analysis = self._extract_json_from_response(content)
694
+
695
+ if analysis:
696
+ # Extract semantic similarity (primary score)
697
+ semantic_sim = float(analysis.get("semantic_similarity", 0.5))
698
+ structural_sim = float(analysis.get("structural_similarity", semantic_sim))
699
+
700
+ # Compute weighted score based on verdict (updated for new schema)
701
+ verdict = analysis.get("verdict", "ACCEPTABLE")
702
+ verdict_multiplier = {
703
+ "PERFECT": 1.0,
704
+ "ACCEPTABLE": 0.85,
705
+ "FORMAT_ERROR": 0.6, # New: was WRONG_FORMAT
706
+ "DATA_CORRUPTION": 0.1, # New: replaces WRONG_CONTENT + HALLUCINATION
707
+ # Legacy support
708
+ "WRONG_FORMAT": 0.6,
709
+ "WRONG_CONTENT": 0.3,
710
+ "HALLUCINATION": 0.1
711
+ }.get(verdict, 0.5)
712
+
713
+ # Final score: weighted combination
714
+ result["score"] = min(1.0, semantic_sim * 0.6 + structural_sim * 0.3 + verdict_multiplier * 0.1)
715
+
716
+ # Extract penalty breakdown if available
717
+ penalty_breakdown = analysis.get("penalty_breakdown", {})
718
+ critical_failures = analysis.get("critical_failures", [])
719
+
720
+ result["details"] = {
721
+ "verdict": verdict,
722
+ "semantic_similarity": semantic_sim,
723
+ "structural_similarity": structural_sim,
724
+ "critical_failures": critical_failures,
725
+ "penalty_breakdown": penalty_breakdown,
726
+ # Legacy field support
727
+ "key_matches": analysis.get("key_matches", []),
728
+ "key_differences": analysis.get("key_differences", critical_failures),
729
+ "value_errors": analysis.get("value_errors", {}),
730
+ "reasoning": analysis.get("reasoning", "")
731
+ }
732
+ result["feedback"] = analysis.get("fix_directive", "")
733
+ else:
734
+ # Fallback if JSON parsing fails
735
+ result = self._heuristic_semantic_analysis(predicted, expected)
736
+
737
+ # Cache result
738
+ self._analysis_cache[cache_key] = result
739
+
740
+ except Exception as e:
741
+ logger.warning(f"LLM semantic analysis failed: {e}, falling back to heuristics")
742
+ result = self._heuristic_semantic_analysis(predicted, expected)
743
+
744
+ return result
745
+
746
+ def _extract_json_from_response(self, content: str) -> Optional[Dict]:
747
+ """Extract JSON from LLM response."""
748
+ # Try to find JSON in response
749
+ json_match = re.search(r'\{[\s\S]*\}', content)
750
+ if json_match:
751
+ try:
752
+ return json.loads(json_match.group(0))
753
+ except json.JSONDecodeError:
754
+ pass
755
+ return None
756
+
757
+ def _heuristic_semantic_analysis(self, predicted: str, expected: str) -> Dict[str, Any]:
758
+ """
759
+ Heuristic-based semantic analysis when LLM is not available.
760
+
761
+ Uses multiple signals:
762
+ - Word overlap (Jaccard)
763
+ - Sequence matching (SequenceMatcher)
764
+ - Number extraction and comparison
765
+ - Key phrase matching
766
+ """
767
+ result = {"score": 0.0, "details": {}, "feedback": ""}
768
+
769
+ pred_lower = predicted.lower()
770
+ exp_lower = expected.lower()
771
+
772
+ # 1. Sequence similarity
773
+ seq_sim = SequenceMatcher(None, pred_lower, exp_lower).ratio()
774
+
775
+ # 2. Word overlap (Jaccard)
776
+ pred_words = set(pred_lower.split())
777
+ exp_words = set(exp_lower.split())
778
+ jaccard = len(pred_words & exp_words) / len(pred_words | exp_words) if (pred_words | exp_words) else 0.0
779
+
780
+ # 3. Number comparison
781
+ pred_nums = re.findall(r'-?\d+\.?\d*', predicted)
782
+ exp_nums = re.findall(r'-?\d+\.?\d*', expected)
783
+
784
+ num_score = 1.0
785
+ num_errors = []
786
+ if exp_nums:
787
+ matches = 0
788
+ for exp_num in exp_nums:
789
+ if exp_num in pred_nums:
790
+ matches += 1
791
+ else:
792
+ # Check for close matches
793
+ try:
794
+ exp_val = float(exp_num)
795
+ for pred_num in pred_nums:
796
+ pred_val = float(pred_num)
797
+ if abs(pred_val - exp_val) <= 1: # Off by 1
798
+ matches += 0.9
799
+ num_errors.append(f"Number close: expected {exp_num}, got {pred_num}")
800
+ break
801
+ else:
802
+ num_errors.append(f"Number missing: expected {exp_num}")
803
+ except ValueError:
804
+ pass
805
+ num_score = matches / len(exp_nums) if exp_nums else 1.0
806
+
807
+ # 4. Key entity extraction (simple approach)
808
+ # Look for capitalized words, quoted strings, etc.
809
+ pred_entities = set(re.findall(r'\b[A-Z][a-z]+\b', predicted))
810
+ exp_entities = set(re.findall(r'\b[A-Z][a-z]+\b', expected))
811
+ entity_overlap = len(pred_entities & exp_entities) / len(exp_entities) if exp_entities else 1.0
812
+
813
+ # Combine scores
814
+ result["score"] = (
815
+ 0.3 * seq_sim +
816
+ 0.25 * jaccard +
817
+ 0.25 * num_score +
818
+ 0.2 * entity_overlap
819
+ )
820
+
821
+ result["details"] = {
822
+ "sequence_similarity": seq_sim,
823
+ "word_overlap": jaccard,
824
+ "number_accuracy": num_score,
825
+ "entity_overlap": entity_overlap,
826
+ "number_errors": num_errors
827
+ }
828
+
829
+ # Generate feedback
830
+ feedback_parts = []
831
+ if jaccard < 0.5:
832
+ feedback_parts.append("Low word overlap - output may be missing key terms.")
833
+ if num_errors:
834
+ feedback_parts.append(f"Number issues: {'; '.join(num_errors[:3])}")
835
+ if entity_overlap < 0.5 and exp_entities:
836
+ missing = exp_entities - pred_entities
837
+ feedback_parts.append(f"Missing entities: {', '.join(list(missing)[:3])}")
838
+
839
+ if feedback_parts:
840
+ result["feedback"] = " | ".join(feedback_parts)
841
+ else:
842
+ result["feedback"] = "Output is semantically similar but not exact match."
843
+
844
+ return result
845
+
846
+ def _generate_default_feedback(self, result: Dict) -> str:
847
+ """Generate default feedback based on scores."""
848
+ score = result["composite_score"]
849
+ semantic = result["semantic_similarity"]
850
+ structural = result["structural_similarity"]
851
+
852
+ if score >= 0.9:
853
+ return "✅ Excellent match! Minor differences only."
854
+ elif score >= 0.7:
855
+ return f"⚠️ Good match (semantic={semantic:.0%}, structural={structural:.0%}). Some differences to address."
856
+ elif score >= 0.5:
857
+ return f"⚠️ Partial match (semantic={semantic:.0%}, structural={structural:.0%}). Significant differences found."
858
+ else:
859
+ return f"❌ Poor match (semantic={semantic:.0%}, structural={structural:.0%}). Major issues to fix."
860
+
861
+ def get_evaluation_summary(self, results: List[Dict]) -> Dict[str, Any]:
862
+ """
863
+ Get summary statistics for a batch of evaluations.
864
+
865
+ Args:
866
+ results: List of evaluation result dictionaries
867
+
868
+ Returns:
869
+ Summary statistics
870
+ """
871
+ if not results:
872
+ return {
873
+ "total_samples": 0,
874
+ "accuracy": 0.0,
875
+ "avg_semantic_similarity": 0.0,
876
+ "avg_structural_similarity": 0.0
877
+ }
878
+
879
+ total = len(results)
880
+ scores = [r.get("composite_score", 0.0) for r in results]
881
+ semantic_scores = [r.get("semantic_similarity", 0.0) for r in results]
882
+ structural_scores = [r.get("structural_similarity", 0.0) for r in results]
883
+
884
+ return {
885
+ "total_samples": total,
886
+ "accuracy": sum(1 for s in scores if s >= 0.8) / total,
887
+ "avg_composite_score": sum(scores) / total,
888
+ "avg_semantic_similarity": sum(semantic_scores) / total,
889
+ "avg_structural_similarity": sum(structural_scores) / total,
890
+ "min_score": min(scores),
891
+ "max_score": max(scores)
892
+ }
893
+
894
+
895
+ # Convenience function to create evaluator
896
+ def create_universal_evaluator(llm_client=None) -> UniversalSemanticEvaluator:
897
+ """
898
+ Create a Universal Semantic Evaluator.
899
+
900
+ Args:
901
+ llm_client: Optional LLM client for semantic analysis.
902
+ If not provided, uses heuristic-based analysis.
903
+
904
+ Returns:
905
+ Configured UniversalSemanticEvaluator instance
906
+ """
907
+ return UniversalSemanticEvaluator(
908
+ llm_client=llm_client,
909
+ use_llm_analysis=llm_client is not None
910
+ )
911
+
src/gepa_optimizer/evaluation/validation_evaluator.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validation Evaluator for UI Validation Use Case
3
+
4
+ Evaluates predicted validation results (true/false) against expected results.
5
+ Extracts reasoning from both predicted and expected outputs for LLM-as-judge feedback.
6
+ """
7
+
8
+ from typing import Dict, Any, Optional
9
+ import re
10
+ import logging
11
+
12
+ try:
13
+ from .base_evaluator import BaseEvaluator
14
+ except ImportError:
15
+ # For standalone testing
16
+ import sys
17
+ from pathlib import Path
18
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
19
+ from gepa_optimizer.evaluation.base_evaluator import BaseEvaluator
20
+
21
+
22
+ class ValidationEvaluator(BaseEvaluator):
23
+ """
24
+ Evaluator for validation use case (true/false results).
25
+
26
+ Features:
27
+ - Normalizes boolean formats ("true"/"True"/"1" → True, "false"/"False"/"0" → False)
28
+ - Extracts reasoning from both predicted and expected outputs (REQUIRED for LLM-as-judge)
29
+ - Binary scoring: correct boolean = 1.0, wrong = 0.0
30
+ - Returns reasoning in evaluation results for LLM-as-judge feedback
31
+ """
32
+
33
+ def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
34
+ """
35
+ Initialize validation evaluator.
36
+
37
+ Args:
38
+ metric_weights: Weights for evaluation metrics
39
+ Default: {"output_match": 1.0}
40
+ """
41
+ default_weights = {
42
+ "output_match": 1.0 # Binary boolean comparison
43
+ }
44
+
45
+ weights = metric_weights or default_weights
46
+ super().__init__(metric_weights=weights)
47
+
48
+ def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
49
+ """
50
+ Evaluate predicted validation result against expected result.
51
+
52
+ Scoring Strategy:
53
+ 1. Normalize both predicted and expected to boolean
54
+ 2. Compare booleans (exact match required)
55
+ 3. Extract reasoning from both (for LLM-as-judge)
56
+ 4. Return 1.0 if match, 0.0 otherwise (binary scoring)
57
+
58
+ Args:
59
+ predicted: LLM's output (may include "true"/"false" + reasoning)
60
+ expected: Expected output (should be "true" or "false", may include reasoning)
61
+
62
+ Returns:
63
+ Dictionary with evaluation metrics, extracted booleans, and reasoning:
64
+ {
65
+ "output_match": 1.0 or 0.0,
66
+ "composite_score": 1.0 or 0.0,
67
+ "predicted_output": str,
68
+ "expected_output": str,
69
+ "predicted_boolean": True/False,
70
+ "expected_boolean": True/False,
71
+ "predicted_reasoning": str, # REQUIRED for LLM-as-judge
72
+ "expected_reasoning": str, # REQUIRED for LLM-as-judge
73
+ "evaluation_reason": str
74
+ }
75
+ """
76
+ if not predicted or not expected:
77
+ return {
78
+ "output_match": 0.0,
79
+ "composite_score": 0.0,
80
+ "predicted_output": str(predicted).strip() if predicted else "",
81
+ "expected_output": str(expected).strip() if expected else "",
82
+ "predicted_boolean": None,
83
+ "expected_boolean": None,
84
+ "predicted_reasoning": "",
85
+ "expected_reasoning": "",
86
+ "evaluation_reason": "❌ Empty or missing input/output"
87
+ }
88
+
89
+ predicted_str = str(predicted).strip()
90
+ expected_str = str(expected).strip()
91
+
92
+ # 1. Extract boolean from predicted output
93
+ pred_bool = self._normalize_to_bool(predicted_str)
94
+ pred_reasoning = self._extract_reasoning(predicted_str)
95
+
96
+ # 2. Extract boolean from expected output
97
+ exp_bool = self._normalize_to_bool(expected_str)
98
+ exp_reasoning = self._extract_reasoning(expected_str)
99
+
100
+ # 🔥 NEW: Detect output structure for both expected and predicted
101
+ expected_structure = self._detect_output_structure(expected_str)
102
+ predicted_structure = self._detect_output_structure(predicted_str)
103
+
104
+ # Compare structures
105
+ structure_match = (expected_structure['format'] == predicted_structure['format'])
106
+
107
+ # 3. Compare booleans (binary scoring)
108
+ if pred_bool is None or exp_bool is None:
109
+ # Could not extract boolean from one or both
110
+ score = 0.0
111
+ reason = "❌ Could not extract boolean value"
112
+ if pred_bool is None:
113
+ reason += " from predicted output"
114
+ if exp_bool is None:
115
+ reason += " from expected output"
116
+ else:
117
+ # Both booleans extracted successfully - compare
118
+ score = 1.0 if pred_bool == exp_bool else 0.0
119
+ if score == 1.0:
120
+ reason = f"✅ Correct! Result matches (both are {exp_bool})"
121
+ # 🔥 NEW: Add note if structure doesn't match
122
+ if not structure_match:
123
+ reason += f" (but format differs: expected {expected_structure['format']}, got {predicted_structure['format']})"
124
+ else:
125
+ reason = f"❌ Wrong result (predicted: {pred_bool}, expected: {exp_bool})"
126
+
127
+ # 4. Log evaluation details
128
+ self.logger.info(f"\n{'─'*70}")
129
+ self.logger.info(f"📊 VALIDATION EVALUATION")
130
+ self.logger.info(f"{'─'*70}")
131
+ self.logger.info(f" Expected: '{expected_str[:100]}...' → {exp_bool}")
132
+ self.logger.info(f" Predicted: '{predicted_str[:100]}...' → {pred_bool}")
133
+ self.logger.info(f" {'─'*66}")
134
+ self.logger.info(f" 🎯 SCORE: {score:.2f} - {reason}")
135
+ if pred_reasoning:
136
+ self.logger.info(f" 📝 Predicted Reasoning: {pred_reasoning[:150]}...")
137
+ if exp_reasoning:
138
+ self.logger.info(f" 📝 Expected Reasoning: {exp_reasoning[:150]}...")
139
+ # 🔥 NEW: Log structure comparison
140
+ self.logger.info(f" 📐 Expected Format: {expected_structure['format']} (reasoning: {expected_structure['reasoning_quality']})")
141
+ self.logger.info(f" 📐 Predicted Format: {predicted_structure['format']} (reasoning: {predicted_structure['reasoning_quality']})")
142
+ if not structure_match:
143
+ self.logger.warning(f" ⚠️ OUTPUT STRUCTURE MISMATCH!")
144
+ self.logger.info(f"{'─'*70}\n")
145
+
146
+ return {
147
+ "output_match": score,
148
+ "composite_score": score, # This is what GEPA uses
149
+ "predicted_output": predicted_str,
150
+ "expected_output": expected_str,
151
+ "predicted_boolean": pred_bool,
152
+ "expected_boolean": exp_bool,
153
+ "predicted_reasoning": pred_reasoning, # REQUIRED for LLM-as-judge
154
+ "expected_reasoning": exp_reasoning, # REQUIRED for LLM-as-judge
155
+ "evaluation_reason": reason,
156
+ # 🔥 NEW: Structure metadata for LLM-as-judge
157
+ "expected_structure": expected_structure,
158
+ "predicted_structure": predicted_structure,
159
+ "output_structure_match": structure_match,
160
+ "expected_has_reasoning": expected_structure['has_reasoning'],
161
+ "predicted_has_reasoning": predicted_structure['has_reasoning'],
162
+ "reasoning_quality_gap": expected_structure['reasoning_quality'] + " → " + predicted_structure['reasoning_quality']
163
+ }
164
+
165
+ def _normalize_to_bool(self, value: str) -> Optional[bool]:
166
+ """
167
+ Normalize various formats to boolean.
168
+
169
+ Handles:
170
+ - "true", "True", "TRUE" → True
171
+ - "false", "False", "FALSE" → False
172
+ - "1", "0" → True, False
173
+ - "yes", "no" → True, False
174
+ - "correct", "incorrect" → True, False
175
+ - JSON: {"result": true} → True
176
+ - Text with boolean: "The result is true because..." → True
177
+
178
+ Args:
179
+ value: String that may contain a boolean value
180
+
181
+ Returns:
182
+ Boolean value or None if cannot be determined
183
+ """
184
+ if not value:
185
+ return None
186
+
187
+ value_lower = value.lower().strip()
188
+
189
+ # Direct boolean strings
190
+ if value_lower in ("true", "1", "yes", "correct", "valid", "pass"):
191
+ return True
192
+ if value_lower in ("false", "0", "no", "incorrect", "invalid", "fail"):
193
+ return False
194
+
195
+ # JSON format: {"action": "TRUE"} or {"action": "FALSE"} or {"action": "LOADING"}
196
+ # This handles the production prompt's JSON output format
197
+ # Match both quoted and unquoted values, case-insensitive
198
+ action_match = re.search(r'["\']?action["\']?\s*:\s*["\']?(true|false|loading)["\']?', value_lower)
199
+ if action_match:
200
+ action_value = action_match.group(1).lower()
201
+ if action_value == "true":
202
+ return True
203
+ elif action_value == "false":
204
+ return False
205
+ elif action_value == "loading":
206
+ # Treat LOADING as False for validation purposes (screen not ready)
207
+ return False
208
+
209
+ # Also try to parse full JSON structure if present (more robust)
210
+ try:
211
+ import json
212
+ # Try to find and parse JSON object
213
+ json_start = value.find('{')
214
+ if json_start != -1:
215
+ # Try to extract JSON from the response
216
+ for end_idx in range(len(value), json_start, -1):
217
+ try:
218
+ json_str = value[json_start:end_idx]
219
+ data = json.loads(json_str)
220
+ # Check for "action" field (production prompt format)
221
+ if "action" in data:
222
+ action_val = str(data["action"]).upper()
223
+ if action_val == "TRUE":
224
+ return True
225
+ elif action_val == "FALSE":
226
+ return False
227
+ elif action_val == "LOADING":
228
+ return False # Treat as False
229
+ # Check for "result" field (alternative format)
230
+ if "result" in data:
231
+ result_val = data["result"]
232
+ if isinstance(result_val, bool):
233
+ return result_val
234
+ elif isinstance(result_val, str):
235
+ return result_val.lower() in ("true", "1", "yes")
236
+ except (json.JSONDecodeError, KeyError, ValueError):
237
+ continue
238
+ except Exception:
239
+ pass # Fall through to other extraction methods
240
+
241
+ # JSON format: {"result": true} or {"result": false}
242
+ json_match = re.search(r'["\']?result["\']?\s*:\s*(true|false)', value_lower)
243
+ if json_match:
244
+ return json_match.group(1) == "true"
245
+
246
+ # Pattern: "result is true" or "result: true"
247
+ pattern_match = re.search(r'result[:\s]+(true|false)', value_lower)
248
+ if pattern_match:
249
+ return pattern_match.group(1) == "true"
250
+
251
+ # Pattern: "is true" or "is false" (standalone)
252
+ is_match = re.search(r'\b(is|are)\s+(true|false)\b', value_lower)
253
+ if is_match:
254
+ return is_match.group(2) == "true"
255
+
256
+ # Pattern: "true" or "false" as standalone word (not in other words)
257
+ standalone_match = re.search(r'\b(true|false)\b', value_lower)
258
+ if standalone_match:
259
+ return standalone_match.group(1) == "true"
260
+
261
+ # Last resort: check if "true" appears before "false" in text
262
+ true_pos = value_lower.find("true")
263
+ false_pos = value_lower.find("false")
264
+
265
+ if true_pos != -1 and false_pos != -1:
266
+ # Both found - use the one that appears first
267
+ return true_pos < false_pos
268
+ elif true_pos != -1:
269
+ return True
270
+ elif false_pos != -1:
271
+ return False
272
+
273
+ # Cannot determine
274
+ return None
275
+
276
+ def _detect_output_structure(self, output: str) -> Dict[str, Any]:
277
+ """
278
+ Dynamically detect the structure/components of the output.
279
+
280
+ This detects:
281
+ - Boolean result presence
282
+ - Reasoning/explanation presence and quality
283
+ - Output format (boolean only, boolean+reasoning, etc.)
284
+
285
+ Args:
286
+ output: Output string to analyze
287
+
288
+ Returns:
289
+ Dictionary with structure information:
290
+ {
291
+ "has_boolean": bool,
292
+ "has_reasoning": bool,
293
+ "reasoning_length": int,
294
+ "reasoning_quality": str, # "missing", "minimal", "adequate", "detailed"
295
+ "format": str # "boolean_only", "boolean_with_reasoning", "unknown"
296
+ }
297
+ """
298
+ if not output:
299
+ return {
300
+ "has_boolean": False,
301
+ "has_reasoning": False,
302
+ "reasoning_length": 0,
303
+ "reasoning_quality": "missing",
304
+ "format": "empty"
305
+ }
306
+
307
+ output_clean = output.strip()
308
+
309
+ # Detect boolean
310
+ has_boolean = self._normalize_to_bool(output_clean) is not None
311
+
312
+ # Extract reasoning
313
+ reasoning = self._extract_reasoning(output_clean)
314
+ has_reasoning = len(reasoning) > 15 # Minimum 15 chars to count as reasoning
315
+ reasoning_length = len(reasoning)
316
+
317
+ # Classify reasoning quality
318
+ if reasoning_length == 0:
319
+ reasoning_quality = "missing"
320
+ elif reasoning_length < 30:
321
+ reasoning_quality = "minimal" # Just a few words
322
+ elif reasoning_length < 100:
323
+ reasoning_quality = "adequate" # Brief explanation
324
+ else:
325
+ reasoning_quality = "detailed" # Full explanation
326
+
327
+ # Determine format
328
+ if has_boolean and has_reasoning:
329
+ output_format = "boolean_with_reasoning"
330
+ elif has_boolean and not has_reasoning:
331
+ output_format = "boolean_only"
332
+ elif not has_boolean and has_reasoning:
333
+ output_format = "reasoning_only"
334
+ else:
335
+ output_format = "unknown"
336
+
337
+ return {
338
+ "has_boolean": has_boolean,
339
+ "has_reasoning": has_reasoning,
340
+ "reasoning_length": reasoning_length,
341
+ "reasoning_quality": reasoning_quality,
342
+ "format": output_format
343
+ }
344
+
345
+ def _extract_reasoning(self, output: str) -> str:
346
+ """
347
+ Extract reasoning/explanation from output string.
348
+
349
+ This is REQUIRED for LLM-as-judge feedback. The reasoning helps
350
+ the judge understand why the result was true/false and compare
351
+ predicted vs expected reasoning.
352
+
353
+ Args:
354
+ output: Full output string that may contain reasoning
355
+
356
+ Returns:
357
+ Extracted reasoning text, or empty string if not found
358
+ """
359
+ if not output:
360
+ return ""
361
+
362
+ # Patterns to find reasoning sections
363
+ reasoning_patterns = [
364
+ r'[Rr]eason[:\s]+(.*?)(?:\n\n|\Z)', # "Reason: ..."
365
+ r'[Ee]xplanation[:\s]+(.*?)(?:\n\n|\Z)', # "Explanation: ..."
366
+ r'[Bb]ecause[:\s]+(.*?)(?:\n\n|\Z)', # "Because: ..."
367
+ r'[Ww]hy[:\s]+(.*?)(?:\n\n|\Z)', # "Why: ..."
368
+ r'[Dd]etails[:\s]+(.*?)(?:\n\n|\Z)', # "Details: ..."
369
+ ]
370
+
371
+ # Try each pattern
372
+ for pattern in reasoning_patterns:
373
+ match = re.search(pattern, output, re.DOTALL | re.IGNORECASE)
374
+ if match:
375
+ reasoning = match.group(1).strip()
376
+ if len(reasoning) > 20: # Only return if substantial
377
+ return reasoning
378
+
379
+ # If no explicit reasoning section, check if output has substantial text
380
+ # after boolean (likely contains reasoning)
381
+ bool_match = re.search(r'\b(true|false)\b', output.lower())
382
+ if bool_match:
383
+ # Get text after the boolean
384
+ bool_pos = bool_match.end()
385
+ remaining = output[bool_pos:].strip()
386
+
387
+ # If remaining text is substantial (more than just punctuation), use it
388
+ if len(remaining) > 30:
389
+ # Clean up common prefixes
390
+ remaining = re.sub(r'^[:\s.,;!?-]+', '', remaining)
391
+ if remaining:
392
+ return remaining
393
+
394
+ # If output is long and doesn't start with boolean, might be all reasoning
395
+ if len(output) > 100 and not re.match(r'^\s*(true|false)\s*$', output, re.IGNORECASE):
396
+ # Return first 500 chars as reasoning
397
+ return output[:500].strip()
398
+
399
+ # No reasoning found
400
+ return ""
401
+
402
+ def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
403
+ """
404
+ Get summary statistics for a batch of evaluations.
405
+
406
+ Args:
407
+ results: List of evaluation result dictionaries
408
+
409
+ Returns:
410
+ Summary statistics including accuracy, true/false distribution
411
+ """
412
+ if not results:
413
+ return {
414
+ "total_samples": 0,
415
+ "accuracy": 0.0,
416
+ "correct_predictions": 0,
417
+ "incorrect_predictions": 0,
418
+ "true_predictions": 0,
419
+ "false_predictions": 0
420
+ }
421
+
422
+ total = len(results)
423
+ correct = sum(1 for r in results if r.get("output_match", 0.0) == 1.0)
424
+ accuracy = correct / total if total > 0 else 0.0
425
+
426
+ # Count true/false predictions
427
+ true_preds = sum(1 for r in results if r.get("predicted_boolean") is True)
428
+ false_preds = sum(1 for r in results if r.get("predicted_boolean") is False)
429
+
430
+ return {
431
+ "total_samples": total,
432
+ "accuracy": accuracy,
433
+ "correct_predictions": correct,
434
+ "incorrect_predictions": total - correct,
435
+ "true_predictions": true_preds,
436
+ "false_predictions": false_preds
437
+ }
438
+
439
+
440
+ # Example usage and testing
441
+ if __name__ == "__main__":
442
+ print("🚀 Testing Validation Evaluator...")
443
+
444
+ evaluator = ValidationEvaluator()
445
+
446
+ # Test cases
447
+ test_cases = [
448
+ # (predicted, expected, should_match)
449
+ ("true", "true", True),
450
+ ("false", "false", True),
451
+ ("True", "true", True),
452
+ ("FALSE", "false", True),
453
+ ("1", "true", True),
454
+ ("0", "false", True),
455
+ ("true", "false", False),
456
+ ("false", "true", False),
457
+ ("The result is true because the button is visible", "true", True),
458
+ ("The result is false because the element is not found", "false", True),
459
+ ('{"result": true, "reasoning": "Button is visible"}', "true", True),
460
+ ("Result: true\n\nReasoning: The submit button is clearly visible at the bottom of the screen.", "true", True),
461
+ ("", "true", False),
462
+ ("invalid", "true", False),
463
+ ]
464
+
465
+ print("\n📝 Running test cases:")
466
+ print("-" * 80)
467
+
468
+ results = []
469
+ for predicted, expected, should_match in test_cases:
470
+ result = evaluator.evaluate(predicted, expected)
471
+ match = result["composite_score"] == 1.0
472
+
473
+ status = "✅" if match == should_match else "❌"
474
+ pred_bool = result.get("predicted_boolean", "?")
475
+ exp_bool = result.get("expected_boolean", "?")
476
+ pred_reason = result.get("predicted_reasoning", "")[:50]
477
+
478
+ print(f"{status} Predicted: '{predicted[:40]}...' → {pred_bool}")
479
+ print(f" Expected: '{expected}' → {exp_bool}")
480
+ print(f" Match: {match} (should be {should_match})")
481
+ if pred_reason:
482
+ print(f" Reasoning: {pred_reason}...")
483
+ print()
484
+
485
+ results.append(result)
486
+
487
+ # Summary
488
+ print("\n📊 Summary:")
489
+ summary = evaluator.get_evaluation_summary(results)
490
+ print(f" Total: {summary['total_samples']}")
491
+ print(f" Correct: {summary['correct_predictions']}")
492
+ print(f" Accuracy: {summary['accuracy']:.1%}")
493
+ print(f" True predictions: {summary['true_predictions']}")
494
+ print(f" False predictions: {summary['false_predictions']}")
495
+
src/gepa_optimizer/infrastructure/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Infrastructure module for cross-cutting concerns.
3
+
4
+ This module contains infrastructure components that are used across
5
+ the entire application, including logging, metrics, and configuration.
6
+ """
7
+
8
+ from .logging import get_logger, configure_logging, LogContext
9
+
10
+ __all__ = [
11
+ "get_logger",
12
+ "configure_logging",
13
+ "LogContext",
14
+ ]
15
+
src/gepa_optimizer/infrastructure/logging/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Centralized Logging Infrastructure for GEPA Optimizer.
3
+
4
+ This module provides a unified logging system with:
5
+ - Structured logging with context
6
+ - Consistent formatting across all modules
7
+ - Log level configuration
8
+ - Operation tracking with timing
9
+ - Contextual logging for debugging
10
+
11
+ Usage:
12
+ from gepa_optimizer.infrastructure.logging import get_logger, LogContext
13
+
14
+ logger = get_logger(__name__)
15
+ logger.info("Starting optimization", extra={"iteration": 1})
16
+
17
+ with LogContext(logger, "evaluation", sample_id=123):
18
+ logger.info("Evaluating sample")
19
+ """
20
+
21
+ from .logger import (
22
+ get_logger,
23
+ configure_logging,
24
+ LogLevel,
25
+ GEPA_LOGGER_NAME,
26
+ )
27
+ from .context import LogContext, log_operation
28
+ from .formatters import GepaFormatter, JsonFormatter
29
+
30
+ __all__ = [
31
+ # Core logging
32
+ "get_logger",
33
+ "configure_logging",
34
+ "LogLevel",
35
+ "GEPA_LOGGER_NAME",
36
+ # Context management
37
+ "LogContext",
38
+ "log_operation",
39
+ # Formatters
40
+ "GepaFormatter",
41
+ "JsonFormatter",
42
+ ]
43
+
src/gepa_optimizer/infrastructure/logging/context.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging Context Management.
3
+
4
+ Provides context managers and decorators for:
5
+ - Operation tracking with timing
6
+ - Contextual logging with nested contexts
7
+ - Automatic exception logging
8
+ """
9
+
10
+ import logging
11
+ import time
12
+ import functools
13
+ from contextlib import contextmanager
14
+ from typing import Any, Callable, Dict, Optional, TypeVar, ParamSpec
15
+
16
+ P = ParamSpec('P')
17
+ R = TypeVar('R')
18
+
19
+
20
+ class LogContext:
21
+ """
22
+ Context manager for logging operations with timing and context.
23
+
24
+ Features:
25
+ - Automatic start/end logging
26
+ - Timing measurement
27
+ - Exception capture
28
+ - Nested context support
29
+
30
+ Example:
31
+ logger = get_logger(__name__)
32
+
33
+ with LogContext(logger, "optimization", iteration=5):
34
+ # ... optimization code ...
35
+ logger.info("Processing sample") # Inherits context
36
+
37
+ # Output:
38
+ # INFO | Starting optimization | iteration=5
39
+ # INFO | Processing sample | iteration=5
40
+ # INFO | Completed optimization | iteration=5 duration_ms=1234
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ logger: logging.Logger,
46
+ operation: str,
47
+ log_start: bool = True,
48
+ log_end: bool = True,
49
+ log_level: int = logging.INFO,
50
+ **context_fields: Any
51
+ ):
52
+ """
53
+ Initialize log context.
54
+
55
+ Args:
56
+ logger: Logger instance to use
57
+ operation: Name of the operation being performed
58
+ log_start: Whether to log when entering context
59
+ log_end: Whether to log when exiting context
60
+ log_level: Log level for start/end messages
61
+ **context_fields: Additional fields to include in all logs
62
+ """
63
+ self.logger = logger
64
+ self.operation = operation
65
+ self.log_start = log_start
66
+ self.log_end = log_end
67
+ self.log_level = log_level
68
+ self.context_fields = context_fields
69
+ self.start_time: Optional[float] = None
70
+ self.exception: Optional[Exception] = None
71
+
72
+ def __enter__(self) -> "LogContext":
73
+ """Enter the context, logging start if configured."""
74
+ self.start_time = time.perf_counter()
75
+
76
+ if self.log_start:
77
+ self.logger.log(
78
+ self.log_level,
79
+ f"Starting {self.operation}",
80
+ extra=self.context_fields
81
+ )
82
+
83
+ return self
84
+
85
+ def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
86
+ """Exit the context, logging completion or error."""
87
+ duration_ms = (time.perf_counter() - self.start_time) * 1000
88
+
89
+ extra = {
90
+ **self.context_fields,
91
+ "duration_ms": round(duration_ms, 2)
92
+ }
93
+
94
+ if exc_type is not None:
95
+ # Log exception
96
+ self.exception = exc_val
97
+ self.logger.error(
98
+ f"Failed {self.operation}: {exc_type.__name__}: {exc_val}",
99
+ extra=extra,
100
+ exc_info=True
101
+ )
102
+ # Don't suppress the exception
103
+ return False
104
+
105
+ if self.log_end:
106
+ self.logger.log(
107
+ self.log_level,
108
+ f"Completed {self.operation}",
109
+ extra=extra
110
+ )
111
+
112
+ return False
113
+
114
+ def log(self, level: int, message: str, **extra_fields: Any) -> None:
115
+ """Log a message within this context, inheriting context fields."""
116
+ self.logger.log(
117
+ level,
118
+ message,
119
+ extra={**self.context_fields, **extra_fields}
120
+ )
121
+
122
+ def info(self, message: str, **extra_fields: Any) -> None:
123
+ """Log info message within context."""
124
+ self.log(logging.INFO, message, **extra_fields)
125
+
126
+ def debug(self, message: str, **extra_fields: Any) -> None:
127
+ """Log debug message within context."""
128
+ self.log(logging.DEBUG, message, **extra_fields)
129
+
130
+ def warning(self, message: str, **extra_fields: Any) -> None:
131
+ """Log warning message within context."""
132
+ self.log(logging.WARNING, message, **extra_fields)
133
+
134
+ def error(self, message: str, **extra_fields: Any) -> None:
135
+ """Log error message within context."""
136
+ self.log(logging.ERROR, message, **extra_fields)
137
+
138
+
139
+ def log_operation(
140
+ logger: Optional[logging.Logger] = None,
141
+ operation: Optional[str] = None,
142
+ log_args: bool = False,
143
+ log_result: bool = False,
144
+ log_level: int = logging.INFO,
145
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
146
+ """
147
+ Decorator for logging function execution.
148
+
149
+ Automatically logs:
150
+ - Function entry (with arguments if configured)
151
+ - Function exit (with result if configured)
152
+ - Execution duration
153
+ - Exceptions
154
+
155
+ Args:
156
+ logger: Logger to use (defaults to logger named after module)
157
+ operation: Operation name (defaults to function name)
158
+ log_args: Whether to log function arguments
159
+ log_result: Whether to log function result
160
+ log_level: Log level for messages
161
+
162
+ Example:
163
+ @log_operation(log_args=True)
164
+ def process_batch(batch_id: int, items: List[str]) -> int:
165
+ return len(items)
166
+
167
+ # Output:
168
+ # INFO | Starting process_batch | batch_id=123 items=['a', 'b']
169
+ # INFO | Completed process_batch | duration_ms=45.2 result=2
170
+ """
171
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
172
+ nonlocal logger, operation
173
+
174
+ if logger is None:
175
+ logger = logging.getLogger(func.__module__)
176
+ if operation is None:
177
+ operation = func.__name__
178
+
179
+ @functools.wraps(func)
180
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
181
+ start_time = time.perf_counter()
182
+
183
+ # Build context fields
184
+ extra: Dict[str, Any] = {}
185
+ if log_args:
186
+ # Include positional args (skip self for methods)
187
+ arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
188
+ for i, (name, value) in enumerate(zip(arg_names, args)):
189
+ if name != 'self':
190
+ extra[name] = _safe_repr(value)
191
+ # Include keyword args
192
+ for key, value in kwargs.items():
193
+ extra[key] = _safe_repr(value)
194
+
195
+ logger.log(log_level, f"Starting {operation}", extra=extra)
196
+
197
+ try:
198
+ result = func(*args, **kwargs)
199
+
200
+ duration_ms = (time.perf_counter() - start_time) * 1000
201
+ result_extra: Dict[str, Any] = {"duration_ms": round(duration_ms, 2)}
202
+
203
+ if log_result:
204
+ result_extra["result"] = _safe_repr(result)
205
+
206
+ logger.log(log_level, f"Completed {operation}", extra=result_extra)
207
+
208
+ return result
209
+
210
+ except Exception as e:
211
+ duration_ms = (time.perf_counter() - start_time) * 1000
212
+ logger.error(
213
+ f"Failed {operation}: {type(e).__name__}: {e}",
214
+ extra={"duration_ms": round(duration_ms, 2)},
215
+ exc_info=True
216
+ )
217
+ raise
218
+
219
+ return wrapper
220
+
221
+ return decorator
222
+
223
+
224
+ @contextmanager
225
+ def timed_block(logger: logging.Logger, description: str, log_level: int = logging.DEBUG):
226
+ """
227
+ Simple context manager for timing a block of code.
228
+
229
+ Less verbose than LogContext, suitable for quick timing measurements.
230
+
231
+ Example:
232
+ with timed_block(logger, "data processing"):
233
+ process_data()
234
+ # Output: DEBUG | data processing completed in 123.45ms
235
+ """
236
+ start = time.perf_counter()
237
+ try:
238
+ yield
239
+ finally:
240
+ duration_ms = (time.perf_counter() - start) * 1000
241
+ logger.log(log_level, f"{description} completed in {duration_ms:.2f}ms")
242
+
243
+
244
+ def _safe_repr(value: Any, max_length: int = 100) -> str:
245
+ """
246
+ Create a safe string representation of a value for logging.
247
+
248
+ Truncates long strings and handles non-serializable objects.
249
+ """
250
+ try:
251
+ repr_str = repr(value)
252
+ if len(repr_str) > max_length:
253
+ return repr_str[:max_length] + "..."
254
+ return repr_str
255
+ except Exception:
256
+ return f"<{type(value).__name__}>"
257
+
src/gepa_optimizer/infrastructure/logging/formatters.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Log Formatters for GEPA Optimizer.
3
+
4
+ Provides formatters for:
5
+ - Console output with colors and emoji
6
+ - JSON structured logging for production
7
+ - Plain text for file logging
8
+ """
9
+
10
+ import json
11
+ import logging
12
+ from datetime import datetime
13
+ from typing import Any, Dict, Optional
14
+
15
+
16
+ # ANSI color codes for terminal output
17
+ class Colors:
18
+ """ANSI color codes for terminal coloring."""
19
+ RESET = "\033[0m"
20
+ BOLD = "\033[1m"
21
+ DIM = "\033[2m"
22
+
23
+ # Log level colors
24
+ DEBUG = "\033[36m" # Cyan
25
+ INFO = "\033[32m" # Green
26
+ WARNING = "\033[33m" # Yellow
27
+ ERROR = "\033[31m" # Red
28
+ CRITICAL = "\033[35m" # Magenta
29
+
30
+ # Semantic colors
31
+ TIMESTAMP = "\033[90m" # Gray
32
+ MODULE = "\033[34m" # Blue
33
+ MESSAGE = "\033[0m" # Default
34
+
35
+
36
+ # Emoji prefixes for visual log scanning
37
+ LEVEL_EMOJI = {
38
+ logging.DEBUG: "🔍",
39
+ logging.INFO: "ℹ️ ",
40
+ logging.WARNING: "⚠️ ",
41
+ logging.ERROR: "❌",
42
+ logging.CRITICAL: "🚨",
43
+ }
44
+
45
+ # Level colors mapping
46
+ LEVEL_COLORS = {
47
+ logging.DEBUG: Colors.DEBUG,
48
+ logging.INFO: Colors.INFO,
49
+ logging.WARNING: Colors.WARNING,
50
+ logging.ERROR: Colors.ERROR,
51
+ logging.CRITICAL: Colors.CRITICAL,
52
+ }
53
+
54
+
55
+ class GepaFormatter(logging.Formatter):
56
+ """
57
+ Custom formatter for GEPA Optimizer logs.
58
+
59
+ Features:
60
+ - Optional color output for console
61
+ - Optional emoji prefixes for visual scanning
62
+ - Structured extra fields support
63
+ - Clean, readable format
64
+
65
+ Example output:
66
+ 2024-01-15 10:30:45 | INFO | ℹ️ gepa_optimizer.core.optimizer | Starting optimization iteration=5
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ fmt: Optional[str] = None,
72
+ datefmt: Optional[str] = None,
73
+ use_colors: bool = True,
74
+ include_emoji: bool = True,
75
+ ):
76
+ """
77
+ Initialize the formatter.
78
+
79
+ Args:
80
+ fmt: Format string (uses default if not provided)
81
+ datefmt: Date format string
82
+ use_colors: Whether to use ANSI colors
83
+ include_emoji: Whether to include emoji prefixes
84
+ """
85
+ super().__init__(fmt=fmt, datefmt=datefmt)
86
+ self.use_colors = use_colors
87
+ self.include_emoji = include_emoji
88
+
89
+ def format(self, record: logging.LogRecord) -> str:
90
+ """Format a log record with colors and emoji."""
91
+ # Store original values
92
+ original_msg = record.msg
93
+ original_levelname = record.levelname
94
+
95
+ try:
96
+ # Add emoji prefix if enabled
97
+ if self.include_emoji:
98
+ emoji = LEVEL_EMOJI.get(record.levelno, "")
99
+ record.levelname = f"{emoji} {record.levelname}"
100
+
101
+ # Add colors if enabled
102
+ if self.use_colors:
103
+ color = LEVEL_COLORS.get(record.levelno, Colors.RESET)
104
+ record.levelname = f"{color}{record.levelname}{Colors.RESET}"
105
+ record.name = f"{Colors.MODULE}{record.name}{Colors.RESET}"
106
+
107
+ # Format extra fields if present
108
+ extra_str = self._format_extra(record)
109
+ if extra_str:
110
+ record.msg = f"{record.msg} | {extra_str}"
111
+
112
+ # Call parent formatter
113
+ formatted = super().format(record)
114
+
115
+ return formatted
116
+
117
+ finally:
118
+ # Restore original values
119
+ record.msg = original_msg
120
+ record.levelname = original_levelname
121
+
122
+ def _format_extra(self, record: logging.LogRecord) -> str:
123
+ """
124
+ Format extra fields from the log record.
125
+
126
+ Extra fields are passed via the 'extra' parameter to logging calls:
127
+ logger.info("Message", extra={"key": "value"})
128
+ """
129
+ # Standard LogRecord attributes to exclude
130
+ standard_attrs = {
131
+ 'name', 'msg', 'args', 'created', 'filename', 'funcName',
132
+ 'levelname', 'levelno', 'lineno', 'module', 'msecs',
133
+ 'pathname', 'process', 'processName', 'relativeCreated',
134
+ 'stack_info', 'exc_info', 'exc_text', 'thread', 'threadName',
135
+ 'taskName', 'message'
136
+ }
137
+
138
+ # Collect extra fields
139
+ extra_fields = {
140
+ k: v for k, v in record.__dict__.items()
141
+ if k not in standard_attrs and not k.startswith('_')
142
+ }
143
+
144
+ if not extra_fields:
145
+ return ""
146
+
147
+ # Format as key=value pairs
148
+ parts = []
149
+ for key, value in extra_fields.items():
150
+ if isinstance(value, str):
151
+ parts.append(f"{key}={value}")
152
+ elif isinstance(value, (int, float)):
153
+ parts.append(f"{key}={value}")
154
+ elif isinstance(value, bool):
155
+ parts.append(f"{key}={str(value).lower()}")
156
+ else:
157
+ parts.append(f"{key}={repr(value)}")
158
+
159
+ return " ".join(parts)
160
+
161
+
162
+ class JsonFormatter(logging.Formatter):
163
+ """
164
+ JSON formatter for structured logging.
165
+
166
+ Outputs each log record as a single JSON line, suitable for:
167
+ - Log aggregation systems (ELK, Splunk)
168
+ - Cloud logging (CloudWatch, Stackdriver)
169
+ - Log parsing and analysis
170
+
171
+ Example output:
172
+ {"timestamp": "2024-01-15T10:30:45.123Z", "level": "INFO", "logger": "gepa_optimizer.core", "message": "Starting optimization", "iteration": 5}
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ include_timestamp: bool = True,
178
+ include_location: bool = False,
179
+ ):
180
+ """
181
+ Initialize JSON formatter.
182
+
183
+ Args:
184
+ include_timestamp: Include ISO timestamp
185
+ include_location: Include file/line information
186
+ """
187
+ super().__init__()
188
+ self.include_timestamp = include_timestamp
189
+ self.include_location = include_location
190
+
191
+ def format(self, record: logging.LogRecord) -> str:
192
+ """Format record as JSON string."""
193
+ log_dict: Dict[str, Any] = {}
194
+
195
+ # Timestamp
196
+ if self.include_timestamp:
197
+ log_dict["timestamp"] = datetime.utcfromtimestamp(
198
+ record.created
199
+ ).isoformat() + "Z"
200
+
201
+ # Core fields
202
+ log_dict["level"] = record.levelname
203
+ log_dict["logger"] = record.name
204
+ log_dict["message"] = record.getMessage()
205
+
206
+ # Location info
207
+ if self.include_location:
208
+ log_dict["file"] = record.filename
209
+ log_dict["line"] = record.lineno
210
+ log_dict["function"] = record.funcName
211
+
212
+ # Exception info
213
+ if record.exc_info:
214
+ log_dict["exception"] = self.formatException(record.exc_info)
215
+
216
+ # Extra fields
217
+ standard_attrs = {
218
+ 'name', 'msg', 'args', 'created', 'filename', 'funcName',
219
+ 'levelname', 'levelno', 'lineno', 'module', 'msecs',
220
+ 'pathname', 'process', 'processName', 'relativeCreated',
221
+ 'stack_info', 'exc_info', 'exc_text', 'thread', 'threadName',
222
+ 'taskName', 'message'
223
+ }
224
+
225
+ for key, value in record.__dict__.items():
226
+ if key not in standard_attrs and not key.startswith('_'):
227
+ try:
228
+ # Ensure value is JSON serializable
229
+ json.dumps(value)
230
+ log_dict[key] = value
231
+ except (TypeError, ValueError):
232
+ log_dict[key] = str(value)
233
+
234
+ return json.dumps(log_dict, default=str)
235
+
236
+
237
+ class CompactFormatter(logging.Formatter):
238
+ """
239
+ Compact formatter for minimal log output.
240
+
241
+ Useful for:
242
+ - CI/CD pipelines
243
+ - Reduced log verbosity
244
+ - Quick debugging
245
+
246
+ Example output:
247
+ 10:30:45 INFO optimizer: Starting optimization
248
+ """
249
+
250
+ def format(self, record: logging.LogRecord) -> str:
251
+ """Format record in compact form."""
252
+ # Short timestamp (time only)
253
+ time_str = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
254
+
255
+ # Short module name (last part only)
256
+ short_name = record.name.split(".")[-1]
257
+
258
+ return f"{time_str} {record.levelname:5s} {short_name}: {record.getMessage()}"
259
+
src/gepa_optimizer/infrastructure/logging/logger.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Logger Factory and Configuration.
3
+
4
+ This module provides the centralized logger factory that should be used
5
+ across all GEPA Optimizer modules. It ensures consistent logging behavior
6
+ and formatting throughout the application.
7
+
8
+ Design Principles:
9
+ - Single source of truth for logger configuration
10
+ - Lazy initialization (loggers created on first use)
11
+ - Thread-safe logger access
12
+ - Configurable log levels per module
13
+ """
14
+
15
+ import logging
16
+ import sys
17
+ from enum import Enum
18
+ from typing import Optional, Dict, Any
19
+ from functools import lru_cache
20
+
21
+ from .formatters import GepaFormatter
22
+
23
+ # Root logger name for GEPA Optimizer
24
+ GEPA_LOGGER_NAME = "gepa_optimizer"
25
+
26
+ # Default log format
27
+ DEFAULT_FORMAT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
28
+ DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
29
+
30
+
31
+ class LogLevel(str, Enum):
32
+ """Supported log levels with string representation."""
33
+ DEBUG = "DEBUG"
34
+ INFO = "INFO"
35
+ WARNING = "WARNING"
36
+ ERROR = "ERROR"
37
+ CRITICAL = "CRITICAL"
38
+
39
+ @classmethod
40
+ def from_string(cls, level: str) -> "LogLevel":
41
+ """Convert string to LogLevel enum."""
42
+ try:
43
+ return cls(level.upper())
44
+ except ValueError:
45
+ return cls.INFO
46
+
47
+
48
+ class LoggerConfig:
49
+ """
50
+ Configuration class for GEPA logging.
51
+
52
+ This class holds all logging configuration and can be modified
53
+ before calling configure_logging() to customize behavior.
54
+ """
55
+
56
+ # Default configuration
57
+ level: LogLevel = LogLevel.INFO
58
+ format: str = DEFAULT_FORMAT
59
+ date_format: str = DEFAULT_DATE_FORMAT
60
+
61
+ # Module-specific log levels (for fine-grained control)
62
+ module_levels: Dict[str, LogLevel] = {}
63
+
64
+ # Output configuration
65
+ log_to_console: bool = True
66
+ log_to_file: Optional[str] = None
67
+
68
+ # Formatting options
69
+ use_colors: bool = True
70
+ include_emoji: bool = True # For visual clarity in development
71
+
72
+ @classmethod
73
+ def reset(cls) -> None:
74
+ """Reset configuration to defaults."""
75
+ cls.level = LogLevel.INFO
76
+ cls.format = DEFAULT_FORMAT
77
+ cls.date_format = DEFAULT_DATE_FORMAT
78
+ cls.module_levels = {}
79
+ cls.log_to_console = True
80
+ cls.log_to_file = None
81
+ cls.use_colors = True
82
+ cls.include_emoji = True
83
+
84
+
85
+ # Global flag to track if logging is configured
86
+ _logging_configured = False
87
+
88
+
89
+ def configure_logging(
90
+ level: Optional[str] = None,
91
+ log_file: Optional[str] = None,
92
+ use_colors: bool = True,
93
+ include_emoji: bool = True,
94
+ format_string: Optional[str] = None,
95
+ module_levels: Optional[Dict[str, str]] = None,
96
+ ) -> None:
97
+ """
98
+ Configure the GEPA logging system.
99
+
100
+ This should be called once at application startup. Subsequent calls
101
+ will update the configuration.
102
+
103
+ Args:
104
+ level: Global log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
105
+ log_file: Optional path to log file
106
+ use_colors: Whether to use colored output in console
107
+ include_emoji: Whether to include emoji prefixes for visual clarity
108
+ format_string: Custom format string (optional)
109
+ module_levels: Dict mapping module names to their specific log levels
110
+
111
+ Example:
112
+ configure_logging(
113
+ level="DEBUG",
114
+ log_file="optimization.log",
115
+ module_levels={
116
+ "gepa_optimizer.core.optimizer": "INFO",
117
+ "gepa_optimizer.llms": "DEBUG"
118
+ }
119
+ )
120
+ """
121
+ global _logging_configured
122
+
123
+ # Update configuration
124
+ if level:
125
+ LoggerConfig.level = LogLevel.from_string(level)
126
+ if log_file:
127
+ LoggerConfig.log_to_file = log_file
128
+ LoggerConfig.use_colors = use_colors
129
+ LoggerConfig.include_emoji = include_emoji
130
+ if format_string:
131
+ LoggerConfig.format = format_string
132
+ if module_levels:
133
+ LoggerConfig.module_levels = {
134
+ k: LogLevel.from_string(v) for k, v in module_levels.items()
135
+ }
136
+
137
+ # Get or create root GEPA logger
138
+ root_logger = logging.getLogger(GEPA_LOGGER_NAME)
139
+ root_logger.setLevel(getattr(logging, LoggerConfig.level.value))
140
+
141
+ # Remove existing handlers to avoid duplicates
142
+ root_logger.handlers.clear()
143
+
144
+ # Console handler
145
+ if LoggerConfig.log_to_console:
146
+ console_handler = logging.StreamHandler(sys.stdout)
147
+ console_handler.setLevel(getattr(logging, LoggerConfig.level.value))
148
+
149
+ # Use custom formatter
150
+ formatter = GepaFormatter(
151
+ fmt=LoggerConfig.format,
152
+ datefmt=LoggerConfig.date_format,
153
+ use_colors=use_colors,
154
+ include_emoji=include_emoji,
155
+ )
156
+ console_handler.setFormatter(formatter)
157
+ root_logger.addHandler(console_handler)
158
+
159
+ # File handler (if configured)
160
+ if LoggerConfig.log_to_file:
161
+ file_handler = logging.FileHandler(LoggerConfig.log_to_file)
162
+ file_handler.setLevel(getattr(logging, LoggerConfig.level.value))
163
+
164
+ # File logs don't use colors
165
+ file_formatter = GepaFormatter(
166
+ fmt=LoggerConfig.format,
167
+ datefmt=LoggerConfig.date_format,
168
+ use_colors=False,
169
+ include_emoji=False,
170
+ )
171
+ file_handler.setFormatter(file_formatter)
172
+ root_logger.addHandler(file_handler)
173
+
174
+ # Apply module-specific levels
175
+ for module_name, module_level in LoggerConfig.module_levels.items():
176
+ module_logger = logging.getLogger(module_name)
177
+ module_logger.setLevel(getattr(logging, module_level.value))
178
+
179
+ _logging_configured = True
180
+
181
+ # Log that configuration is complete
182
+ root_logger.debug(
183
+ f"Logging configured: level={LoggerConfig.level.value}, "
184
+ f"file={LoggerConfig.log_to_file}"
185
+ )
186
+
187
+
188
+ @lru_cache(maxsize=128)
189
+ def get_logger(name: str) -> logging.Logger:
190
+ """
191
+ Get a logger instance for the given module name.
192
+
193
+ This is the primary factory function for obtaining loggers.
194
+ All GEPA modules should use this instead of logging.getLogger().
195
+
196
+ Args:
197
+ name: Module name (typically __name__)
198
+
199
+ Returns:
200
+ Configured Logger instance
201
+
202
+ Example:
203
+ from gepa_optimizer.infrastructure.logging import get_logger
204
+
205
+ logger = get_logger(__name__)
206
+ logger.info("Starting process")
207
+ logger.error("Failed to connect", exc_info=True)
208
+ """
209
+ global _logging_configured
210
+
211
+ # Auto-configure with defaults if not yet configured
212
+ if not _logging_configured:
213
+ configure_logging()
214
+
215
+ # Ensure name is under GEPA namespace for consistent handling
216
+ if not name.startswith(GEPA_LOGGER_NAME) and name != GEPA_LOGGER_NAME:
217
+ # External module - still use our formatting
218
+ pass
219
+
220
+ logger = logging.getLogger(name)
221
+
222
+ # Apply module-specific level if configured
223
+ if name in LoggerConfig.module_levels:
224
+ logger.setLevel(getattr(logging, LoggerConfig.module_levels[name].value))
225
+
226
+ return logger
227
+
228
+
229
+ def set_log_level(level: str, module: Optional[str] = None) -> None:
230
+ """
231
+ Dynamically change log level at runtime.
232
+
233
+ Args:
234
+ level: New log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
235
+ module: Optional module name. If None, changes global level.
236
+
237
+ Example:
238
+ # Enable debug for specific module
239
+ set_log_level("DEBUG", "gepa_optimizer.core.optimizer")
240
+
241
+ # Change global level
242
+ set_log_level("WARNING")
243
+ """
244
+ log_level = LogLevel.from_string(level)
245
+
246
+ if module:
247
+ # Set level for specific module
248
+ logger = logging.getLogger(module)
249
+ logger.setLevel(getattr(logging, log_level.value))
250
+ LoggerConfig.module_levels[module] = log_level
251
+ else:
252
+ # Set global level
253
+ LoggerConfig.level = log_level
254
+ root_logger = logging.getLogger(GEPA_LOGGER_NAME)
255
+ root_logger.setLevel(getattr(logging, log_level.value))
256
+
257
+ # Update all handlers
258
+ for handler in root_logger.handlers:
259
+ handler.setLevel(getattr(logging, log_level.value))
260
+
src/gepa_optimizer/llms/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM module for GEPA Optimizer
3
+ """
4
+
5
+ from .base_llm import BaseLLMClient
6
+ from .vision_llm import VisionLLMClient
7
+ from .batch_llm import BatchLLMClient
8
+ from .llego_enhanced_llm import LLEGOEnhancedLLMClient
9
+
10
+ __all__ = ["BaseLLMClient", "VisionLLMClient", "BatchLLMClient", "LLEGOEnhancedLLMClient"]
src/gepa_optimizer/llms/base_llm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base LLM client class for all LLM providers.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, Optional, Union
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class BaseLLMClient(ABC):
12
+ """
13
+ Abstract base class for all LLM clients.
14
+
15
+ Provides a consistent interface for different LLM providers and models.
16
+ """
17
+
18
+ def __init__(self, provider: str, model_name: str, **kwargs):
19
+ """
20
+ Initialize LLM client.
21
+
22
+ Args:
23
+ provider: LLM provider (e.g., 'openai', 'anthropic')
24
+ model_name: Specific model name
25
+ **kwargs: Additional provider-specific parameters
26
+ """
27
+ self.provider = provider
28
+ self.model_name = model_name
29
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
30
+
31
+ # Store additional configuration
32
+ self.config = kwargs
33
+
34
+ @abstractmethod
35
+ def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> Dict[str, Any]:
36
+ """
37
+ Generate response from LLM.
38
+
39
+ Args:
40
+ system_prompt: System-level instructions
41
+ user_prompt: User's input prompt
42
+ **kwargs: Additional generation parameters (e.g., image_base64)
43
+
44
+ Returns:
45
+ Dictionary with 'content' key containing the generated response
46
+ and additional metadata
47
+ """
48
+ pass
49
+
50
+ def get_model_info(self) -> Dict[str, str]:
51
+ """Get model information for logging and debugging"""
52
+ return {
53
+ 'provider': self.provider,
54
+ 'model_name': self.model_name,
55
+ 'class': self.__class__.__name__
56
+ }
src/gepa_optimizer/llms/batch_llm.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch LLM Client for cost-effective processing using Gemini Batch API.
3
+
4
+ This client provides 50% cost savings by using Google's Gemini Batch API
5
+ instead of real-time API calls. Ideal for large-scale prompt optimization
6
+ where latency is acceptable.
7
+
8
+ Features:
9
+ - 50% cost reduction compared to standard API
10
+ - Automatic batching and job management
11
+ - Built-in retry and polling logic
12
+ - Thread-safe operation
13
+ - Comprehensive error handling
14
+
15
+ Author: GEPA Optimizer Team
16
+ """
17
+
18
+ import os
19
+ import json
20
+ import time
21
+ import logging
22
+ import tempfile
23
+ import io
24
+ from pathlib import Path
25
+ from typing import Dict, List, Any, Optional, Tuple
26
+ from .base_llm import BaseLLMClient
27
+
28
+ try:
29
+ from PIL import Image
30
+ PIL_AVAILABLE = True
31
+ except ImportError:
32
+ PIL_AVAILABLE = False
33
+ Image = None
34
+
35
+ try:
36
+ from google import genai
37
+ from google.genai import types
38
+ GENAI_AVAILABLE = True
39
+ except ImportError:
40
+ GENAI_AVAILABLE = False
41
+ genai = None
42
+ types = None
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class BatchLLMClient(BaseLLMClient):
48
+ """
49
+ Batch LLM client that uses Gemini Batch API for cost-effective processing.
50
+
51
+ This client processes multiple requests together in batch jobs, providing:
52
+ - 50% cost savings vs standard API
53
+ - No rate limit impact
54
+ - Automatic job management and polling
55
+
56
+ Usage:
57
+ >>> from gepa_optimizer.llms import BatchLLMClient
58
+ >>>
59
+ >>> client = BatchLLMClient(
60
+ ... provider="google",
61
+ ... model_name="gemini-2.5-flash",
62
+ ... api_key="your-key",
63
+ ... batch_size=20,
64
+ ... polling_interval=30
65
+ ... )
66
+ >>>
67
+ >>> # Use just like VisionLLMClient - adapter handles the rest!
68
+ >>> result = client.generate(
69
+ ... system_prompt="You are a helpful assistant",
70
+ ... user_prompt="Analyze this image",
71
+ ... image_base64="..."
72
+ ... )
73
+
74
+ Performance Note:
75
+ Batch processing adds latency (30s+ polling time) but reduces costs by 50%.
76
+ Choose this mode for large-scale optimization where cost > speed.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ provider: str,
82
+ model_name: str,
83
+ api_key: Optional[str] = None,
84
+ batch_size: int = 20,
85
+ polling_interval: int = 30,
86
+ max_polling_time: int = 3600,
87
+ temp_dir: str = ".gepa_batch_temp",
88
+ **kwargs
89
+ ):
90
+ """
91
+ Initialize Batch LLM Client.
92
+
93
+ Args:
94
+ provider: Must be "google" or "gemini"
95
+ model_name: Gemini model (e.g., "gemini-2.5-flash", "gemini-1.5-flash")
96
+ api_key: Google API key (defaults to GEMINI_API_KEY env var)
97
+ batch_size: Number of samples to process per batch job (1-100)
98
+ polling_interval: Seconds between job status checks (default: 30)
99
+ max_polling_time: Maximum seconds to wait for job completion (default: 3600)
100
+ temp_dir: Directory for temporary files (default: ".gepa_batch_temp")
101
+ **kwargs: Additional parameters
102
+
103
+ Raises:
104
+ ValueError: If provider is not Google/Gemini
105
+ ImportError: If google-genai is not installed
106
+ """
107
+ super().__init__(provider=provider, model_name=model_name, **kwargs)
108
+
109
+ # Validate provider
110
+ if provider.lower() not in ["google", "gemini"]:
111
+ raise ValueError(
112
+ f"BatchLLMClient only supports Google/Gemini provider. Got: {provider}"
113
+ )
114
+
115
+ # Check dependencies
116
+ if not GENAI_AVAILABLE:
117
+ raise ImportError(
118
+ "google-genai not installed. Install with: pip install google-genai"
119
+ )
120
+
121
+ # Configuration
122
+ self.batch_size = batch_size
123
+ self.polling_interval = polling_interval
124
+ self.max_polling_time = max_polling_time
125
+ self.temp_dir = Path(temp_dir)
126
+ self.temp_dir.mkdir(exist_ok=True)
127
+
128
+ # Initialize Gemini client
129
+ from ..utils.api_keys import APIKeyManager
130
+ self.api_key = api_key or APIKeyManager().get_api_key("google")
131
+
132
+ if not self.api_key:
133
+ raise ValueError(
134
+ "Google API key required. Provide via api_key parameter or "
135
+ "set GEMINI_API_KEY environment variable."
136
+ )
137
+
138
+ self.client = genai.Client(api_key=self.api_key)
139
+
140
+ logger.info(
141
+ f"✓ BatchLLMClient initialized: {model_name} "
142
+ f"(batch_size={batch_size}, polling={polling_interval}s)"
143
+ )
144
+
145
+ def generate(
146
+ self,
147
+ system_prompt: str,
148
+ user_prompt: str,
149
+ image_base64: Optional[str] = None,
150
+ **kwargs
151
+ ) -> Dict[str, Any]:
152
+ """
153
+ Generate response using batch API.
154
+
155
+ Note: This method is primarily for compatibility. For batch optimization,
156
+ the adapter will call generate_batch() directly with multiple requests.
157
+
158
+ Args:
159
+ system_prompt: System-level instructions
160
+ user_prompt: User's input prompt
161
+ image_base64: Optional base64 encoded image
162
+ **kwargs: Additional generation parameters
163
+
164
+ Returns:
165
+ Dict with 'content' key containing generated text
166
+ """
167
+ # Single request - process as a batch of 1
168
+ requests = [{
169
+ 'system_prompt': system_prompt,
170
+ 'user_prompt': user_prompt,
171
+ 'image_base64': image_base64
172
+ }]
173
+
174
+ results = self.generate_batch(requests)
175
+ return results[0] if results else {"content": "", "error": "No results"}
176
+
177
+ def generate_batch(
178
+ self,
179
+ requests: List[Dict[str, Any]],
180
+ timeout_override: Optional[int] = None
181
+ ) -> List[Dict[str, Any]]:
182
+ """
183
+ Process multiple requests in a single batch job.
184
+
185
+ This is the main method called by UniversalGepaAdapter during GEPA optimization.
186
+
187
+ Args:
188
+ requests: List of request dicts with keys:
189
+ - system_prompt: System instructions
190
+ - user_prompt: User input
191
+ - image_base64: Optional base64 image
192
+ timeout_override: Override max_polling_time for this batch
193
+
194
+ Returns:
195
+ List of response dicts with 'content' key
196
+
197
+ Raises:
198
+ RuntimeError: If batch job fails
199
+ TimeoutError: If polling exceeds timeout
200
+ """
201
+ logger.info(f"📦 Processing batch of {len(requests)} requests via Gemini Batch API...")
202
+
203
+ start_time = time.time()
204
+
205
+ try:
206
+ # Step 1: Upload images if needed
207
+ file_uris, mime_types = self._upload_images_for_batch(requests)
208
+
209
+ # Step 2: Create JSONL file
210
+ jsonl_path = self._create_batch_jsonl(requests, file_uris, mime_types)
211
+
212
+ # Step 3: Submit batch job
213
+ batch_job_name = self._submit_batch_job(jsonl_path)
214
+
215
+ # Step 4: Wait for completion
216
+ timeout = timeout_override or self.max_polling_time
217
+ self._wait_for_batch_completion(batch_job_name, timeout)
218
+
219
+ # Step 5: Retrieve results
220
+ results = self._retrieve_batch_results(batch_job_name)
221
+
222
+ # Cleanup
223
+ jsonl_path.unlink(missing_ok=True)
224
+
225
+ elapsed_time = time.time() - start_time
226
+ logger.info(
227
+ f"✓ Batch processing complete: {len(results)} results in {elapsed_time:.1f}s "
228
+ f"(~{elapsed_time/len(results):.1f}s per request)"
229
+ )
230
+
231
+ return results
232
+
233
+ except Exception as e:
234
+ elapsed_time = time.time() - start_time
235
+ logger.error(f"❌ Batch processing failed after {elapsed_time:.1f}s: {e}")
236
+ raise
237
+
238
+ def _upload_images_for_batch(self, requests: List[Dict]) -> Tuple[List[Optional[str]], List[Optional[str]]]:
239
+ """
240
+ Upload images to Gemini and return file URIs and MIME types.
241
+
242
+ Args:
243
+ requests: List of request dicts
244
+
245
+ Returns:
246
+ Tuple of (file_uris, mime_types) - both are lists with None for requests without images
247
+ """
248
+ file_uris = []
249
+ mime_types = []
250
+ images_to_upload = sum(1 for r in requests if r.get('image_base64'))
251
+
252
+ if images_to_upload > 0:
253
+ logger.info(f" ⬆️ Uploading {images_to_upload} images to Gemini...")
254
+
255
+ for i, request in enumerate(requests):
256
+ image_base64 = request.get('image_base64')
257
+
258
+ if not image_base64:
259
+ file_uris.append(None)
260
+ mime_types.append(None)
261
+ continue
262
+
263
+ try:
264
+ # Decode image data
265
+ import base64
266
+ image_data = base64.b64decode(image_base64)
267
+
268
+ # Detect image format using Pillow
269
+ image_format = None
270
+ if PIL_AVAILABLE:
271
+ try:
272
+ img = Image.open(io.BytesIO(image_data))
273
+ image_format = img.format.lower() if img.format else None
274
+ except Exception as e:
275
+ logger.warning(f" ⚠️ Could not detect image format: {e}")
276
+
277
+ # Map format to extension and MIME type
278
+ format_map = {
279
+ 'jpeg': ('.jpg', 'image/jpeg'),
280
+ 'jpg': ('.jpg', 'image/jpeg'),
281
+ 'png': ('.png', 'image/png'),
282
+ 'gif': ('.gif', 'image/gif'),
283
+ 'webp': ('.webp', 'image/webp'),
284
+ 'bmp': ('.bmp', 'image/bmp'),
285
+ 'tiff': ('.tiff', 'image/tiff'),
286
+ 'tif': ('.tiff', 'image/tiff'),
287
+ }
288
+
289
+ # Get extension and MIME type (default to PNG if unknown)
290
+ ext, mime_type = format_map.get(image_format, ('.png', 'image/png'))
291
+
292
+ if image_format and image_format not in format_map:
293
+ logger.warning(f" ⚠️ Unknown image format '{image_format}' for image {i}, defaulting to PNG")
294
+ elif not image_format:
295
+ logger.debug(f" ℹ️ Could not detect format for image {i}, using PNG")
296
+
297
+ # Save to temp file with correct extension
298
+ temp_file = tempfile.NamedTemporaryFile(
299
+ delete=False,
300
+ suffix=ext,
301
+ dir=self.temp_dir
302
+ )
303
+ temp_file.write(image_data)
304
+ temp_file.close()
305
+
306
+ # Upload to Gemini with correct MIME type
307
+ uploaded_file = self.client.files.upload(
308
+ file=temp_file.name,
309
+ config=types.UploadFileConfig(
310
+ display_name=f"batch_image_{i}_{int(time.time())}{ext}",
311
+ mime_type=mime_type
312
+ )
313
+ )
314
+
315
+ logger.debug(f" ✓ Uploaded image {i} as {mime_type}")
316
+
317
+ # Wait for file to be active
318
+ self._wait_for_file_active(uploaded_file)
319
+ file_uris.append(uploaded_file.uri)
320
+ mime_types.append(mime_type)
321
+
322
+ # Cleanup temp file
323
+ Path(temp_file.name).unlink()
324
+
325
+ except Exception as e:
326
+ logger.error(f" ✗ Failed to upload image {i}: {e}")
327
+ file_uris.append(None)
328
+ mime_types.append(None)
329
+
330
+ if images_to_upload > 0:
331
+ successful = sum(1 for uri in file_uris if uri is not None)
332
+ logger.info(f" ✓ Uploaded {successful}/{images_to_upload} images successfully")
333
+
334
+ return file_uris, mime_types
335
+
336
+ def _create_batch_jsonl(
337
+ self,
338
+ requests: List[Dict],
339
+ file_uris: List[Optional[str]],
340
+ mime_types: List[Optional[str]]
341
+ ) -> Path:
342
+ """
343
+ Create JSONL file for batch job.
344
+
345
+ Args:
346
+ requests: List of request dicts
347
+ file_uris: List of uploaded file URIs
348
+ mime_types: List of MIME types for uploaded files
349
+
350
+ Returns:
351
+ Path to created JSONL file
352
+ """
353
+ timestamp = int(time.time())
354
+ jsonl_path = self.temp_dir / f"batch_{timestamp}.jsonl"
355
+
356
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
357
+ for i, (request, file_uri, mime_type) in enumerate(zip(requests, file_uris, mime_types)):
358
+ # Combine system and user prompts
359
+ system_prompt = request.get('system_prompt', '')
360
+ user_prompt = request.get('user_prompt', '')
361
+ full_prompt = f"{system_prompt}\n\n{user_prompt}".strip()
362
+
363
+ # Build request parts
364
+ parts = [{"text": full_prompt}]
365
+
366
+ if file_uri:
367
+ parts.append({
368
+ "file_data": {
369
+ "file_uri": file_uri,
370
+ "mime_type": mime_type or "image/png" # Use actual MIME type
371
+ }
372
+ })
373
+
374
+ # Gemini Batch API format according to official docs
375
+ # Reference: https://ai.google.dev/gemini-api/docs/batch-inference
376
+ # NOTE: The "request" wrapper is REQUIRED for Gemini 2.5 batch API
377
+ batch_request = {
378
+ "custom_id": f"request-{i}",
379
+ "request": {
380
+ "contents": [{
381
+ "role": "user",
382
+ "parts": parts
383
+ }]
384
+ }
385
+ }
386
+
387
+ f.write(json.dumps(batch_request, ensure_ascii=False) + '\n')
388
+
389
+ logger.info(f" 📝 Created JSONL file: {jsonl_path.name} ({len(requests)} requests)")
390
+ return jsonl_path
391
+
392
+ def _submit_batch_job(self, jsonl_path: Path) -> str:
393
+ """
394
+ Submit batch job to Gemini.
395
+
396
+ Args:
397
+ jsonl_path: Path to JSONL file
398
+
399
+ Returns:
400
+ Batch job name
401
+ """
402
+ # Upload JSONL file
403
+ # Try multiple methods as the google-genai SDK can be finicky
404
+ try:
405
+ logger.info(f" 📤 Uploading JSONL file: {jsonl_path.name}")
406
+
407
+ # Read and validate file content
408
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
409
+ content = f.read()
410
+ line_count = len(content.strip().split('\n'))
411
+ logger.debug(f" 📄 JSONL: {len(content)} bytes, {line_count} lines")
412
+
413
+ # Validate JSONL format
414
+ for line_num, line in enumerate(content.strip().split('\n'), 1):
415
+ try:
416
+ json.loads(line)
417
+ except json.JSONDecodeError as e:
418
+ logger.error(f" ❌ Invalid JSON at line {line_num}: {e}")
419
+ logger.error(f" Content: {line[:100]}...")
420
+ raise ValueError(f"Invalid JSONL format at line {line_num}") from e
421
+
422
+ # Method 1: Try uploading with Path object
423
+ logger.info(f" 🔄 Upload method 1: Using Path object...")
424
+ try:
425
+ jsonl_file = self.client.files.upload(
426
+ file=jsonl_path,
427
+ config=types.UploadFileConfig(
428
+ display_name=f'gepa-batch-{int(time.time())}',
429
+ mime_type='application/json' # Try application/json instead of application/jsonl
430
+ )
431
+ )
432
+ logger.info(f" ✓ JSONL file uploaded: {jsonl_file.name}")
433
+
434
+ except Exception as e1:
435
+ logger.warning(f" ⚠️ Method 1 failed: {e1}")
436
+ logger.info(f" 🔄 Upload method 2: Using string path...")
437
+
438
+ # Method 2: Fallback to string path
439
+ try:
440
+ jsonl_file = self.client.files.upload(
441
+ file=str(jsonl_path.absolute()),
442
+ config=types.UploadFileConfig(
443
+ display_name=f'gepa-batch-{int(time.time())}',
444
+ mime_type='application/json'
445
+ )
446
+ )
447
+ logger.info(f" ✓ JSONL file uploaded (method 2): {jsonl_file.name}")
448
+ except Exception as e2:
449
+ logger.error(f" ❌ Method 2 also failed: {e2}")
450
+ raise e2
451
+
452
+ except KeyError as e:
453
+ logger.error(f"❌ KeyError during JSONL upload: {e}")
454
+ logger.error(f" This suggests the Gemini API response format changed")
455
+ logger.error(f" Try updating google-genai: pip install --upgrade google-genai")
456
+ raise RuntimeError(f"Gemini Batch API response format error: {e}") from e
457
+ except Exception as e:
458
+ logger.error(f"❌ Failed to upload JSONL file: {e}")
459
+ logger.error(f" File path: {jsonl_path}")
460
+ logger.error(f" File exists: {jsonl_path.exists()}")
461
+ logger.error(f" File size: {jsonl_path.stat().st_size if jsonl_path.exists() else 'N/A'} bytes")
462
+ raise RuntimeError(f"Gemini Batch API file upload failed: {e}") from e
463
+
464
+ # Wait for JSONL to be active
465
+ try:
466
+ logger.info(f" ⏳ Waiting for JSONL file to be processed...")
467
+ self._wait_for_file_active(jsonl_file)
468
+ except Exception as e:
469
+ logger.error(f"❌ JSONL file processing failed: {e}")
470
+ raise
471
+
472
+ # Create batch job
473
+ try:
474
+ logger.info(f" 🚀 Creating batch job...")
475
+ batch_job = self.client.batches.create(
476
+ model=self.model_name,
477
+ src=jsonl_file.name,
478
+ config={'display_name': f'gepa-opt-{int(time.time())}'}
479
+ )
480
+
481
+ logger.info(f" ✓ Batch job submitted: {batch_job.name}")
482
+ return batch_job.name
483
+
484
+ except Exception as e:
485
+ logger.error(f"❌ Failed to create batch job: {e}")
486
+ raise RuntimeError(f"Batch job creation failed: {e}") from e
487
+
488
+ def _wait_for_batch_completion(self, job_name: str, timeout: int):
489
+ """
490
+ Poll batch job until completion.
491
+
492
+ Args:
493
+ job_name: Batch job name
494
+ timeout: Maximum seconds to wait
495
+
496
+ Raises:
497
+ TimeoutError: If polling exceeds timeout
498
+ RuntimeError: If batch job fails
499
+ """
500
+ logger.info(f" ⏳ Polling for completion (checking every {self.polling_interval}s)...")
501
+
502
+ start_time = time.time()
503
+ poll_count = 0
504
+
505
+ while True:
506
+ elapsed = time.time() - start_time
507
+
508
+ if elapsed > timeout:
509
+ raise TimeoutError(
510
+ f"Batch job timeout after {elapsed:.0f}s "
511
+ f"(max: {timeout}s)"
512
+ )
513
+
514
+ try:
515
+ batch_job = self.client.batches.get(name=job_name)
516
+ state = batch_job.state.name
517
+
518
+ # Success states
519
+ if state in ['JOB_STATE_SUCCEEDED', 'SUCCEEDED']:
520
+ logger.info(f" ✓ Batch job completed in {elapsed:.0f}s")
521
+ return
522
+
523
+ # Failure states
524
+ if state in ['JOB_STATE_FAILED', 'FAILED']:
525
+ raise RuntimeError(f"Batch job failed with state: {state}")
526
+
527
+ if state in ['JOB_STATE_CANCELLED', 'CANCELLED']:
528
+ raise RuntimeError(f"Batch job was cancelled: {state}")
529
+
530
+ # Still processing
531
+ poll_count += 1
532
+ if poll_count % 5 == 0: # Log every 5 polls
533
+ logger.info(f" ... still processing ({elapsed:.0f}s elapsed, state: {state})")
534
+
535
+ time.sleep(self.polling_interval)
536
+
537
+ except (TimeoutError, RuntimeError):
538
+ raise
539
+ except Exception as e:
540
+ logger.warning(f" ⚠️ Error checking job status: {e}, retrying...")
541
+ time.sleep(5)
542
+
543
+ def _retrieve_batch_results(self, job_name: str) -> List[Dict[str, Any]]:
544
+ """
545
+ Retrieve and parse batch results.
546
+
547
+ Args:
548
+ job_name: Batch job name
549
+
550
+ Returns:
551
+ List of result dicts
552
+ """
553
+ batch_job = self.client.batches.get(name=job_name)
554
+
555
+ # Check for inline responses (preferred)
556
+ if hasattr(batch_job.dest, 'inlined_responses') and batch_job.dest.inlined_responses:
557
+ logger.info(f" 📥 Processing inline responses...")
558
+ return self._parse_inline_results(batch_job.dest.inlined_responses)
559
+
560
+ # Download results file (fallback)
561
+ if hasattr(batch_job.dest, 'file_name') and batch_job.dest.file_name:
562
+ logger.info(f" 📥 Downloading results file: {batch_job.dest.file_name}")
563
+ file_data = self.client.files.download(file=batch_job.dest.file_name)
564
+ return self._parse_file_results(file_data)
565
+
566
+ raise RuntimeError("No results available from batch job")
567
+
568
+ def _parse_inline_results(self, inline_responses) -> List[Dict[str, Any]]:
569
+ """Parse inline batch results."""
570
+ results = []
571
+
572
+ for response_obj in inline_responses:
573
+ if hasattr(response_obj, 'response') and response_obj.response:
574
+ text = self._extract_text_from_response(response_obj.response)
575
+ results.append({
576
+ "content": text,
577
+ "role": "assistant",
578
+ "model": self.model_name,
579
+ "provider": "google"
580
+ })
581
+ else:
582
+ error_msg = str(getattr(response_obj, 'error', 'Unknown error'))
583
+ logger.warning(f" ⚠️ Response error: {error_msg}")
584
+ results.append({
585
+ "content": "",
586
+ "error": error_msg
587
+ })
588
+
589
+ return results
590
+
591
+ def _parse_file_results(self, file_data) -> List[Dict[str, Any]]:
592
+ """Parse JSONL results file."""
593
+ if isinstance(file_data, bytes):
594
+ jsonl_content = file_data.decode('utf-8')
595
+ else:
596
+ jsonl_content = file_data
597
+
598
+ results = []
599
+
600
+ for line_num, line in enumerate(jsonl_content.strip().split('\n'), 1):
601
+ if not line.strip():
602
+ continue
603
+
604
+ try:
605
+ result = json.loads(line)
606
+
607
+ if 'response' in result:
608
+ text = self._extract_text_from_dict(result['response'])
609
+ results.append({
610
+ "content": text,
611
+ "role": "assistant",
612
+ "model": self.model_name,
613
+ "provider": "google"
614
+ })
615
+ else:
616
+ error_msg = result.get('error', 'Unknown error')
617
+ logger.warning(f" ⚠️ Line {line_num} error: {error_msg}")
618
+ results.append({
619
+ "content": "",
620
+ "error": error_msg
621
+ })
622
+
623
+ except json.JSONDecodeError as e:
624
+ logger.error(f" ✗ Line {line_num}: JSON decode error: {e}")
625
+ results.append({"content": "", "error": f"JSON decode error: {e}"})
626
+
627
+ return results
628
+
629
+ def _extract_text_from_response(self, response_obj) -> str:
630
+ """Extract text from response object."""
631
+ try:
632
+ # Direct text attribute
633
+ if hasattr(response_obj, 'text'):
634
+ return response_obj.text
635
+
636
+ # Navigate through candidates
637
+ if hasattr(response_obj, 'candidates') and response_obj.candidates:
638
+ candidate = response_obj.candidates[0]
639
+ if hasattr(candidate, 'content'):
640
+ content = candidate.content
641
+ if hasattr(content, 'parts') and content.parts:
642
+ part = content.parts[0]
643
+ if hasattr(part, 'text'):
644
+ return part.text
645
+
646
+ # Fallback to string representation
647
+ return str(response_obj)
648
+
649
+ except Exception as e:
650
+ logger.error(f"Error extracting text from response: {e}")
651
+ return ""
652
+
653
+ def _extract_text_from_dict(self, response_dict: Dict) -> str:
654
+ """Extract text from response dictionary."""
655
+ try:
656
+ # Direct text key
657
+ if 'text' in response_dict:
658
+ return response_dict['text']
659
+
660
+ # Navigate through candidates
661
+ if 'candidates' in response_dict and response_dict['candidates']:
662
+ candidate = response_dict['candidates'][0]
663
+ if 'content' in candidate and 'parts' in candidate['content']:
664
+ parts = candidate['content']['parts']
665
+ if parts and 'text' in parts[0]:
666
+ return parts[0]['text']
667
+
668
+ # Fallback to JSON string
669
+ return json.dumps(response_dict)
670
+
671
+ except Exception as e:
672
+ logger.error(f"Error extracting text from dict: {e}")
673
+ return ""
674
+
675
+ def _wait_for_file_active(self, uploaded_file, timeout: int = 60):
676
+ """
677
+ Wait for uploaded file to become active.
678
+
679
+ Args:
680
+ uploaded_file: Uploaded file object
681
+ timeout: Maximum seconds to wait
682
+
683
+ Raises:
684
+ TimeoutError: If file processing exceeds timeout
685
+ RuntimeError: If file processing fails
686
+ """
687
+ start_time = time.time()
688
+
689
+ while uploaded_file.state.name == "PROCESSING":
690
+ if time.time() - start_time > timeout:
691
+ raise TimeoutError(f"File processing timeout: {uploaded_file.name}")
692
+
693
+ time.sleep(1)
694
+ uploaded_file = self.client.files.get(name=uploaded_file.name)
695
+
696
+ if uploaded_file.state.name != "ACTIVE":
697
+ raise RuntimeError(
698
+ f"File processing failed: {uploaded_file.name} "
699
+ f"(state: {uploaded_file.state.name})"
700
+ )
701
+
702
+ def get_model_info(self) -> Dict[str, str]:
703
+ """Get model information for logging and debugging."""
704
+ return {
705
+ 'provider': self.provider,
706
+ 'model_name': self.model_name,
707
+ 'class': self.__class__.__name__,
708
+ 'mode': 'batch',
709
+ 'batch_size': str(self.batch_size),
710
+ 'polling_interval': f'{self.polling_interval}s'
711
+ }
712
+
src/gepa_optimizer/llms/llego_enhanced_llm.py ADDED
@@ -0,0 +1,1625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLEGO-Enhanced LLM Client Wrapper
3
+
4
+ This wrapper intercepts LLM calls and uses LLEGO genetic operators
5
+ when generating new prompt candidates during GEPA's reflection phase.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import Optional, Dict, Any, Callable, List
11
+ from .base_llm import BaseLLMClient
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Fallback system prompt for sequential generation (when JSON parsing fails)
16
+ # Uses Linear Command structure for reliability when complex JSON generation fails
17
+ _FALLBACK_SYSTEM_PROMPT = """You are a Prompt Optimization Engine operating in **SAFE MODE**.
18
+
19
+ <task>
20
+ Rewrite the prompt based on the feedback provided below.
21
+ </task>
22
+
23
+ <output_rules>
24
+ 1. Output **ONLY** the new prompt text.
25
+ 2. No JSON. No Explanations. No "Here is the prompt".
26
+ 3. The prompt must be fully functional and self-contained.
27
+ 4. START directly with the prompt content (e.g., "You are a..." or task instructions).
28
+ 5. Preserve the core task/domain - only improve HOW it's described.
29
+ </output_rules>
30
+
31
+ <quality_standards>
32
+ - Be specific and concrete (no vague instructions)
33
+ - Use clear, imperative language
34
+ - Include edge case handling if feedback identifies confusion
35
+ - Ensure the prompt is self-contained and unambiguous
36
+ - Add explicit constraints for format/output if needed
37
+ </quality_standards>
38
+
39
+ <forbidden_outputs>
40
+ - Analysis of what went wrong
41
+ - Explanations of your changes
42
+ - Meta-text like "Here's an improved version..."
43
+ - Anything other than the raw prompt text
44
+ </forbidden_outputs>
45
+
46
+ Start of New Prompt:"""
47
+
48
+
49
+ class LLEGOEnhancedLLMClient(BaseLLMClient):
50
+ """
51
+ Wrapper around BaseLLMClient that uses LLEGO for candidate generation.
52
+
53
+ This wrapper detects when GEPA is asking for new prompt candidates
54
+ and routes those requests through LLEGO's genetic operators instead
55
+ of standard LLM generation.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ base_llm: BaseLLMClient,
61
+ llego_layer,
62
+ config=None,
63
+ verbose: bool = True
64
+ ):
65
+ """
66
+ Initialize LLEGO-enhanced LLM client.
67
+
68
+ Args:
69
+ base_llm: The underlying LLM client (VisionLLMClient, etc.)
70
+ llego_layer: LLEGOIntegrationLayer instance
71
+ config: Optional OptimizationConfig for hybrid mode settings
72
+ verbose: Whether to log LLEGO operations
73
+ """
74
+ self.base_llm = base_llm
75
+ self.llego = llego_layer
76
+ self.config = config
77
+ self.verbose = verbose
78
+
79
+ # Get log level from config (default to INFO)
80
+ self.log_level = getattr(config, 'log_level', 'INFO') if config else 'INFO'
81
+
82
+ # Track context for detecting reflection calls
83
+ self.reflection_context = {
84
+ 'current_prompt': None,
85
+ 'feedback': None,
86
+ 'in_reflection': False
87
+ }
88
+
89
+ # Queue for hybrid mode candidates (GEPA will call generate() multiple times)
90
+ self._candidate_queue = []
91
+ self._hybrid_generation_complete = False
92
+
93
+ # 🔥 CRITICAL: Queue for adapter-generated candidates (from make_reflective_dataset)
94
+ # When adapter generates candidates at adapter level, they're stored here
95
+ # GEPA will call generate() for proposals, and we'll return these candidates
96
+ self._adapter_generated_candidates = []
97
+
98
+
99
+ # 🔥 FORMAT AWARENESS: Store format info from adapter for use in candidate generation
100
+ self._detected_format = None # Will be set by adapter after format detection
101
+
102
+ # FIX #5: Circuit breaker for LLEGO failures
103
+ self._llego_failures = 0
104
+ self._llego_disabled = False
105
+ self._llego_failure_threshold = 3 # Disable after 3 consecutive failures
106
+
107
+ logger.info("🧬 LLEGO-Enhanced LLM Client initialized")
108
+ logger.info(f" Base LLM: {base_llm.__class__.__name__}")
109
+ logger.info(f" LLEGO enabled: {llego_layer is not None}")
110
+ if config and hasattr(config, 'enable_gepa_reflection_with_llego'):
111
+ logger.info(f" Hybrid mode: {config.enable_gepa_reflection_with_llego}")
112
+ logger.debug(f" Log level: {self.log_level}")
113
+
114
+ def _should_log_debug(self) -> bool:
115
+ """
116
+ Check if DEBUG logging is enabled.
117
+
118
+ Returns:
119
+ True if DEBUG level logging is enabled, False otherwise
120
+ """
121
+ return self.log_level == "DEBUG" or (
122
+ hasattr(logging, 'getLogger') and
123
+ logging.getLogger().isEnabledFor(logging.DEBUG)
124
+ )
125
+
126
+ def _extract_clean_prompt_from_reflection(self, reflection_output: str) -> str:
127
+ """
128
+ 🛡️ DEFENSIVE FALLBACK: Extract clean prompt if LLM adds analysis despite system prompt instructions.
129
+
130
+ NOTE: The system prompt now explicitly instructs the LLM to output ONLY the prompt text.
131
+ However, this extraction logic serves as a safety net in case the LLM still adds:
132
+ "Based on the performance analysis...
133
+ ### Recommendations...
134
+ ### Revised Prompt Example:
135
+ [THE ACTUAL PROMPT HERE]
136
+ ### Conclusion..."
137
+
138
+ This is now a defensive measure, not the primary mechanism.
139
+
140
+ Args:
141
+ reflection_output: Full reflection output (should be clean prompt, but may contain analysis)
142
+
143
+ Returns:
144
+ str: Clean, extracted prompt (or original if extraction fails or not needed)
145
+ """
146
+ if not reflection_output or not isinstance(reflection_output, str):
147
+ return reflection_output
148
+
149
+ # Pattern 1: Look for "Revised Prompt Example:" or "### Revised Prompt Example:"
150
+ patterns = [
151
+ r'(?:###\s*)?Revised\s+Prompt\s+(?:Example|:)?\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
152
+ r'(?:###\s*)?Revised\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
153
+ r'(?:###\s*)?Optimized\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
154
+ r'(?:###\s*)?New\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
155
+ r'(?:Here\s+is|Here\'s)\s+a?\s*refined?\s+(?:version\s+of\s+)?(?:the\s+)?prompt\s*[:\n](.*?)(?:\n###|\n##|\n---|\Z)',
156
+ ]
157
+
158
+ for pattern in patterns:
159
+ match = re.search(pattern, reflection_output, re.IGNORECASE | re.DOTALL)
160
+ if match:
161
+ extracted = match.group(1).strip()
162
+ # Clean up common artifacts
163
+ extracted = re.sub(r'^```(?:plaintext|markdown|text)?\s*\n', '', extracted, flags=re.MULTILINE)
164
+ extracted = re.sub(r'\n```\s*$', '', extracted, flags=re.MULTILINE)
165
+ extracted = extracted.strip()
166
+
167
+ if len(extracted) > 50: # Reasonable minimum length for a prompt
168
+ logger.debug(f"✅ Extracted clean prompt using pattern: {pattern[:50]}...")
169
+ logger.debug(f" Original length: {len(reflection_output)} chars")
170
+ logger.debug(f" Extracted length: {len(extracted)} chars")
171
+ return extracted
172
+
173
+ # Pattern 2: If output starts with a quote or prompt-like structure
174
+ # Look for text that starts with "You are..." and is substantial
175
+ if 'You are' in reflection_output:
176
+ # Find the longest continuous block that starts with "You are"
177
+ prompt_match = re.search(r'(You are[^#]*?)(?:\n###|\n##|###|##|Conclusion|\Z)',
178
+ reflection_output, re.IGNORECASE | re.DOTALL)
179
+ if prompt_match:
180
+ extracted = prompt_match.group(1).strip()
181
+ if len(extracted) > 50:
182
+ logger.debug(f"✅ Extracted prompt starting with 'You are...'")
183
+ return extracted
184
+
185
+ # Pattern 3: If the reflection output is actually just a clean prompt (no analysis)
186
+ # Check if it's relatively short and doesn't contain analysis keywords
187
+ analysis_keywords = ['recommendation', 'suggestion', 'improvement', 'conclusion',
188
+ 'optimization', 'analysis', 'feedback']
189
+ if (len(reflection_output) < 2000 and
190
+ not any(keyword in reflection_output.lower() for keyword in analysis_keywords)):
191
+ # Likely a clean prompt, return as-is
192
+ logger.debug(f"✅ Reflection output appears to be a clean prompt (no analysis detected)")
193
+ return reflection_output.strip()
194
+
195
+ # Fallback: Try to extract ANY valid prompt-like text
196
+ # Look for text that might be a prompt even if not perfectly formatted
197
+ if 'You are' in reflection_output:
198
+ # Try to find a substantial block starting with "You are"
199
+ potential_prompt = re.search(
200
+ r'(You are(?:[^\.]|\.(?!\s*(?:Here|This|These|The above)))*?)(?:\n\n|\n###|Conclusion|\Z)',
201
+ reflection_output,
202
+ re.IGNORECASE | re.DOTALL
203
+ )
204
+ if potential_prompt and len(potential_prompt.group(1)) > 100:
205
+ extracted = potential_prompt.group(1).strip()
206
+ logger.warning(f"⚠️ Could not extract clean prompt using standard patterns")
207
+ logger.warning(f" Falling back to 'You are...' block (length: {len(extracted)} chars)")
208
+ logger.warning(f" This may still contain some analysis text")
209
+ return extracted
210
+
211
+ # Final fallback: If still nothing, return original but log strongly
212
+ logger.warning(f"⚠️ Could not extract clean prompt from reflection output")
213
+ logger.warning(f" Output length: {len(reflection_output)} chars")
214
+ logger.warning(f" Output preview: {reflection_output[:200]}...")
215
+ logger.warning(f" ⚠️ WARNING: Returning original output (may contain analysis text or be invalid)")
216
+ logger.warning(f" This candidate may perform poorly - consider improving extraction logic")
217
+ return reflection_output.strip()
218
+
219
+ def _parse_json_variations(self, response_text: str, num_expected: int) -> List[str]:
220
+ """
221
+ 🔥 OPTIMIZED: Parse N prompt variations from JSON format response.
222
+
223
+ Uses robust JSON parsing with multiple fallback strategies.
224
+
225
+ Handles common LLM output issues:
226
+ - Markdown code blocks (```json ... ```)
227
+ - Extra text before/after JSON
228
+ - Trailing commas
229
+ - Comments in JSON
230
+ - Newlines in strings
231
+ """
232
+ import json
233
+ import re
234
+
235
+ if not response_text or not isinstance(response_text, str):
236
+ raise ValueError("Empty or invalid response text")
237
+
238
+ # 🔥 PREPROCESSING: Clean LLM output
239
+ cleaned = response_text.strip()
240
+
241
+ # Remove BOM and invisible chars
242
+ cleaned = cleaned.lstrip('\ufeff\u200b\u200c\u200d')
243
+
244
+ # Strategy 0: Handle Python dict syntax (single quotes -> double quotes)
245
+ # LLMs sometimes return Python dict syntax {'key': 'value'} instead of JSON {"key": "value"}
246
+ if "'variations'" in cleaned or (cleaned.startswith("{'") or cleaned.startswith("{'variations'")):
247
+ try:
248
+ import ast
249
+ # Try to parse as Python literal (handles single quotes, True/False, None)
250
+ python_dict = ast.literal_eval(cleaned)
251
+ if isinstance(python_dict, dict) and 'variations' in python_dict:
252
+ # Convert to JSON-compatible format
253
+ json_str = json.dumps(python_dict)
254
+ data = json.loads(json_str)
255
+ if 'variations' in data:
256
+ # #region agent log
257
+ import json as _json_debug
258
+ import time as _time_debug
259
+ import os as _os_debug
260
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
261
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
262
+ with open(_debug_log_path, "a") as _f:
263
+ _f.write(_json_debug.dumps({"hypothesisId": "JSON_FIX", "location": "llego_enhanced_llm.py:python_dict_parse", "message": "Successfully parsed Python dict syntax", "data": {"num_expected": num_expected, "parsed_variations": len(data.get('variations', []))}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
264
+ # #endregion
265
+ return self._extract_variations_from_json(data, num_expected)
266
+ except (ValueError, SyntaxError, TypeError) as e:
267
+ # If ast.literal_eval fails, try string replacement as fallback
268
+ try:
269
+ # Simple conversion: replace single quotes with double quotes (with escaping)
270
+ # This is a heuristic and may not work for all cases
271
+ converted = cleaned.replace("'", '"')
272
+ data = json.loads(converted)
273
+ if 'variations' in data:
274
+ # #region agent log
275
+ import json as _json_debug
276
+ import time as _time_debug
277
+ import os as _os_debug
278
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
279
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
280
+ with open(_debug_log_path, "a") as _f:
281
+ _f.write(_json_debug.dumps({"hypothesisId": "JSON_FIX", "location": "llego_enhanced_llm.py:python_dict_string_replace", "message": "Parsed Python dict via string replacement", "data": {"num_expected": num_expected, "parsed_variations": len(data.get('variations', []))}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
282
+ # #endregion
283
+ return self._extract_variations_from_json(data, num_expected)
284
+ except json.JSONDecodeError:
285
+ pass
286
+
287
+ # Strategy 1: Direct JSON parse (cleanest case)
288
+ try:
289
+ data = json.loads(cleaned)
290
+ if 'variations' in data:
291
+ return self._extract_variations_from_json(data, num_expected)
292
+ except json.JSONDecodeError:
293
+ pass
294
+
295
+ # Strategy 2: Extract from markdown code block
296
+ # More permissive regex that handles various formats
297
+ code_block_patterns = [
298
+ r'```(?:json|JSON)?\s*(\{[\s\S]*?\})\s*```', # Standard markdown
299
+ r'```\s*(\{[\s\S]*"variations"[\s\S]*\})\s*```', # With "variations" keyword
300
+ ]
301
+
302
+ for pattern in code_block_patterns:
303
+ json_match = re.search(pattern, cleaned)
304
+ if json_match:
305
+ json_str = json_match.group(1)
306
+ try:
307
+ data = json.loads(json_str)
308
+ if 'variations' in data:
309
+ return self._extract_variations_from_json(data, num_expected)
310
+ except json.JSONDecodeError:
311
+ # Try repair
312
+ repaired = self._repair_json_string(json_str)
313
+ try:
314
+ data = json.loads(repaired)
315
+ if 'variations' in data:
316
+ return self._extract_variations_from_json(data, num_expected)
317
+ except json.JSONDecodeError:
318
+ pass
319
+
320
+ # Strategy 3: Balanced brace extraction (handles nested objects)
321
+ json_str = self._extract_balanced_json(cleaned)
322
+ if json_str:
323
+ try:
324
+ data = json.loads(json_str)
325
+ if 'variations' in data:
326
+ return self._extract_variations_from_json(data, num_expected)
327
+ except json.JSONDecodeError:
328
+ repaired = self._repair_json_string(json_str)
329
+ try:
330
+ data = json.loads(repaired)
331
+ if 'variations' in data:
332
+ return self._extract_variations_from_json(data, num_expected)
333
+ except json.JSONDecodeError:
334
+ pass
335
+
336
+ # Strategy 4: Find JSON object with "variations" keyword
337
+ # Use greedy matching to get the full object
338
+ json_match = re.search(r'(\{[\s\S]*"variations"[\s\S]*\})', cleaned)
339
+ if json_match:
340
+ json_str = json_match.group(1)
341
+ # Find the balanced JSON within
342
+ balanced = self._extract_balanced_json(json_str)
343
+ if balanced:
344
+ try:
345
+ data = json.loads(balanced)
346
+ if 'variations' in data:
347
+ return self._extract_variations_from_json(data, num_expected)
348
+ except json.JSONDecodeError:
349
+ repaired = self._repair_json_string(balanced)
350
+ try:
351
+ data = json.loads(repaired)
352
+ if 'variations' in data:
353
+ return self._extract_variations_from_json(data, num_expected)
354
+ except json.JSONDecodeError:
355
+ pass
356
+
357
+ # Strategy 5: Fallback to numbered sections
358
+ logger.warning(f"JSON parsing failed, trying numbered section fallback...")
359
+ try:
360
+ return self._parse_numbered_section_variations(response_text, num_expected)
361
+ except ValueError:
362
+ pass
363
+
364
+ # #region agent log
365
+ import json as _json_debug
366
+ import time as _time_debug
367
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
368
+ with open(_debug_log_path, "a") as _f:
369
+ _f.write(_json_debug.dumps({"hypothesisId": "D", "location": "llego_enhanced_llm.py:json_parse_fail", "message": "JSON parsing failed completely", "data": {"num_expected": num_expected, "response_preview": response_text[:500] if response_text else "EMPTY", "response_length": len(response_text) if response_text else 0}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
370
+ # #endregion
371
+
372
+ raise ValueError(f"Could not parse {num_expected} variations from response")
373
+
374
+ def _extract_balanced_json(self, text: str) -> Optional[str]:
375
+ """Extract JSON with balanced braces."""
376
+ brace_count = 0
377
+ start_idx = -1
378
+ in_string = False
379
+ escape_next = False
380
+
381
+ for i, char in enumerate(text):
382
+ # Handle string escaping
383
+ if escape_next:
384
+ escape_next = False
385
+ continue
386
+ if char == '\\' and in_string:
387
+ escape_next = True
388
+ continue
389
+ if char == '"' and not escape_next:
390
+ in_string = not in_string
391
+ continue
392
+
393
+ # Skip characters inside strings
394
+ if in_string:
395
+ continue
396
+
397
+ if char == '{':
398
+ if brace_count == 0:
399
+ start_idx = i
400
+ brace_count += 1
401
+ elif char == '}':
402
+ brace_count -= 1
403
+ if brace_count == 0 and start_idx >= 0:
404
+ return text[start_idx:i+1]
405
+
406
+ return None
407
+
408
+ def _repair_json_string(self, json_str: str) -> str:
409
+ """
410
+ Repair common JSON issues from LLM output.
411
+
412
+ Fixes:
413
+ - Trailing commas
414
+ - Comments
415
+ - Unescaped newlines in strings
416
+ """
417
+ repaired = json_str
418
+
419
+ # Remove trailing commas before } or ]
420
+ repaired = re.sub(r',\s*}', '}', repaired)
421
+ repaired = re.sub(r',\s*]', ']', repaired)
422
+
423
+ # Remove single-line comments
424
+ repaired = re.sub(r'//[^\n]*\n', '\n', repaired)
425
+
426
+ # Remove multi-line comments
427
+ repaired = re.sub(r'/\*[\s\S]*?\*/', '', repaired)
428
+
429
+ return repaired
430
+
431
+ def _extract_variations_from_json(self, data: Dict[str, Any], num_expected: int) -> List[str]:
432
+ """Extract and validate variations from parsed JSON data."""
433
+
434
+ if not isinstance(data, dict):
435
+ raise ValueError("JSON data is not a dictionary")
436
+
437
+ variations_list = data.get('variations', [])
438
+ if not isinstance(variations_list, list):
439
+ raise ValueError("'variations' field is not a list")
440
+
441
+ # Extract and sort by index
442
+ variations_with_index = []
443
+ for var in variations_list:
444
+ if not isinstance(var, dict):
445
+ continue
446
+ index = var.get('index', 0)
447
+ prompt = var.get('prompt', '')
448
+ if prompt and isinstance(prompt, str):
449
+ variations_with_index.append((index, prompt.strip()))
450
+
451
+ variations_with_index.sort(key=lambda x: x[0])
452
+ variations = [v[1] for v in variations_with_index]
453
+
454
+ # Validate count
455
+ if len(variations) < num_expected:
456
+ logger.warning(f"Only {len(variations)} valid variations found, expected {num_expected}")
457
+ while len(variations) < num_expected:
458
+ variations.append(variations[-1] if variations else "")
459
+
460
+ variations = variations[:num_expected]
461
+
462
+ if not all(v for v in variations):
463
+ raise ValueError(f"Some variations are empty after parsing")
464
+
465
+ return variations
466
+
467
+ def _parse_numbered_section_variations(self, response_text: str, num_expected: int) -> List[str]:
468
+ """Fallback parser: Extract variations from numbered sections."""
469
+ import re
470
+
471
+ variations = []
472
+
473
+ pattern1 = r'---\s*VARIATION\s+(\d+)\s*---\s*\n(.*?)(?=\n---\s*VARIATION|\Z)'
474
+ matches1 = re.findall(pattern1, response_text, re.DOTALL | re.IGNORECASE)
475
+
476
+ pattern2 = r'Variation\s+(\d+)\s*:?\s*\n(.*?)(?=\nVariation\s+\d+|$)'
477
+ matches2 = re.findall(pattern2, response_text, re.DOTALL | re.IGNORECASE)
478
+
479
+ pattern3 = r'(\d+)\.\s*\n(.*?)(?=\n\d+\.|$)'
480
+ matches3 = re.findall(pattern3, response_text, re.DOTALL)
481
+
482
+ matches = matches1 if len(matches1) >= num_expected else (matches2 if len(matches2) >= num_expected else matches3)
483
+
484
+ if len(matches) >= num_expected:
485
+ matches.sort(key=lambda x: int(x[0]))
486
+ variations = [match[1].strip() for match in matches[:num_expected]]
487
+
488
+ if len(variations) != num_expected:
489
+ raise ValueError(f"Numbered section parsing found {len(variations)} variations, expected {num_expected}")
490
+
491
+ return variations
492
+
493
+ def _is_valid_prompt(self, prompt: str) -> bool:
494
+ """
495
+ Validate that extracted text is actually a valid system prompt.
496
+
497
+ Uses minimal, conservative filtering: only rejects OBVIOUSLY wrong text.
498
+ Let evaluation decide on quality - false negatives (rejecting good prompts)
499
+ are worse than false positives (accepting bad prompts).
500
+
501
+ Args:
502
+ prompt: Extracted text to validate
503
+
504
+ Returns:
505
+ True if appears to be a valid prompt, False if obviously wrong
506
+ """
507
+ if not prompt or not prompt.strip():
508
+ return False
509
+
510
+ prompt_lower = prompt.lower().strip()
511
+
512
+ # STRONG indicators of analysis text (high confidence rejection)
513
+ # These are phrases that almost never appear in actual prompts
514
+ strong_analysis_patterns = [
515
+ 'in conclusion',
516
+ 'to summarize',
517
+ 'based on the analysis',
518
+ 'the analysis shows',
519
+ 'here are some suggestions',
520
+ 'it seems you\'re looking for',
521
+ ]
522
+
523
+ # Check first 200 characters for strong patterns
524
+ first_200 = prompt_lower[:200]
525
+ for pattern in strong_analysis_patterns:
526
+ if pattern in first_200:
527
+ if self._should_log_debug():
528
+ logger.debug(f"Rejected prompt: contains analysis pattern '{pattern}'")
529
+ return False
530
+
531
+ # POSITIVE indicators of valid prompt (high confidence acceptance)
532
+ # These are common prompt starters
533
+ valid_starters = [
534
+ 'you are',
535
+ 'you\'re',
536
+ 'your task',
537
+ 'your role',
538
+ 'analyze',
539
+ 'identify',
540
+ 'select',
541
+ 'determine',
542
+ 'given',
543
+ 'when',
544
+ ]
545
+
546
+ # If starts with valid prompt pattern, accept immediately
547
+ first_100 = prompt_lower[:100]
548
+ if any(first_100.startswith(starter) for starter in valid_starters):
549
+ return True
550
+
551
+ # DEFAULT: Accept everything else and let evaluation decide
552
+ # This is conservative - we'd rather evaluate a bad prompt than reject a good one
553
+ return True
554
+
555
+ def set_reflection_context(
556
+ self,
557
+ current_prompt: Optional[str] = None,
558
+ feedback: Optional[Any] = None,
559
+ in_reflection: bool = False
560
+ ):
561
+ """
562
+ Set context for the next generate() call.
563
+
564
+ Args:
565
+ current_prompt: The prompt being reflected upon
566
+ feedback: Evaluation feedback
567
+ in_reflection: Whether we're in reflection mode
568
+ """
569
+ self.reflection_context = {
570
+ 'current_prompt': current_prompt,
571
+ 'feedback': feedback,
572
+ 'in_reflection': in_reflection
573
+ }
574
+
575
+ # Reset candidate queue when entering new reflection phase
576
+ if in_reflection:
577
+ self._candidate_queue = []
578
+ self._hybrid_generation_complete = False
579
+ if self._should_log_debug():
580
+ logger.debug("🔄 Entering LLEGO reflection mode (queue reset)")
581
+ else:
582
+ logger.info("🔄 Entering LLEGO reflection mode")
583
+
584
+ def generate(
585
+ self,
586
+ system_prompt: str = "",
587
+ user_prompt: str = "",
588
+ image_base64: str = "",
589
+ **kwargs
590
+ ) -> Dict[str, Any]:
591
+ """
592
+ Generate response, using LLEGO for reflection calls.
593
+
594
+ 🔥 CRITICAL: This method intercepts ALL LLM calls. For candidate generation,
595
+ it checks if we have pre-generated candidates from hybrid mode and returns those.
596
+
597
+ Args:
598
+ system_prompt: System prompt
599
+ user_prompt: User prompt
600
+ image_base64: Base64-encoded image (if any)
601
+ **kwargs: Additional arguments
602
+
603
+ Returns:
604
+ Dict with 'content' key containing the generated text
605
+ """
606
+ # 🔍 DEBUG: Log generate calls (full details at DEBUG level)
607
+ if self._should_log_debug():
608
+ logger.debug(f"🔍 LLEGO Wrapper: generate() called")
609
+ logger.debug(f" system_prompt: '{system_prompt[:100]}...' (truncated)")
610
+ logger.debug(f" user_prompt length: {len(user_prompt)} chars")
611
+ logger.debug(f" in_reflection: {self.reflection_context['in_reflection']}")
612
+ logger.debug(f" has_image: {bool(image_base64)}")
613
+
614
+ # #region agent log
615
+ try:
616
+ import json as _json_debug
617
+ import time as _time_debug
618
+ import os as _os_debug
619
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
620
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
621
+ with open(_debug_log_path, "a") as _f:
622
+ _f.write(_json_debug.dumps({"hypothesisId": "INTERCEPTION", "location": "llego_enhanced_llm.py:generate", "message": "Generate called", "data": {"system_prompt_len": len(system_prompt), "user_prompt_len": len(user_prompt), "has_image": bool(image_base64), "has_candidates": len(getattr(self, '_adapter_generated_candidates', [])), "in_reflection": self.reflection_context.get('in_reflection', False)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
623
+ except Exception:
624
+ pass
625
+ # #endregion
626
+
627
+ # 🔥 CRITICAL: Check if we have pre-generated candidates from adapter-level generation
628
+ # This happens when GEPA calls adapter.llm_client to generate candidates
629
+ # We intercept and return our pre-generated candidates instead
630
+ # 🔥 NEW: Select BEST candidate instead of FIFO
631
+ # 🔥 FIX: DON'T intercept evaluation calls (those have images!)
632
+ # Only intercept proposal calls (no images, just asking for new candidate)
633
+ # 🔥 FIX 2: DON'T intercept TEST EVALUATION calls!
634
+ # Test evaluation has no images but uses the OPTIMIZED prompt to execute tasks
635
+ # We detect test evaluation by checking if this is a TASK EXECUTION call (not reflection)
636
+ is_task_execution = (
637
+ # Task execution prompts contain task instructions, not optimization requests
638
+ not any(kw in system_prompt.lower() for kw in ['evolutionary', 'mutation', 'variation', 'optimize', 'improve prompt', 'rewrite', 'generate variations']) and
639
+ # Short prompts are usually task prompts, not optimization prompts
640
+ len(system_prompt) < 1000 and
641
+ # User prompt is the actual task input (short), not feedback (long)
642
+ len(user_prompt) < 2000
643
+ )
644
+
645
+ # Log task execution detection for debugging
646
+ if is_task_execution and hasattr(self, '_adapter_generated_candidates') and self._adapter_generated_candidates:
647
+ logger.info(f"🔒 NOT intercepting: Task execution detected (not optimization)")
648
+ logger.debug(f" system_prompt_len={len(system_prompt)}, user_prompt_len={len(user_prompt)}")
649
+
650
+ if hasattr(self, '_adapter_generated_candidates') and self._adapter_generated_candidates and not image_base64 and not is_task_execution:
651
+ # 🔥 BEST-CANDIDATE SELECTION: Find candidate with highest Dpareto score
652
+ # This ensures we use the best candidate for the current iteration
653
+ best_candidate = None
654
+ best_score = -float('inf')
655
+ best_idx = -1
656
+
657
+ # Check if candidates have scores stored
658
+ for idx, cand in enumerate(self._adapter_generated_candidates):
659
+ if isinstance(cand, dict):
660
+ # Try to get score from candidate dict
661
+ score = cand.get('score', -float('inf'))
662
+
663
+ # If score not in dict, try to get from Pareto logger
664
+ if score == -float('inf'):
665
+ from ..utils.pareto_logger import get_pareto_logger
666
+ pareto_log = get_pareto_logger()
667
+
668
+ # Look up score in Pareto front or evaluated candidates
669
+ cand_prompt = cand.get('prompt', '')
670
+ if cand_prompt:
671
+ normalized = cand_prompt.strip().strip('"\'')
672
+ # Check in Pareto front
673
+ for front_cand in pareto_log.pareto_front:
674
+ if front_cand.get('prompt', '').strip().strip('"\'') == normalized:
675
+ score = front_cand.get('score', -float('inf'))
676
+ break
677
+
678
+ # If not in front, check evaluated candidates
679
+ if score == -float('inf'):
680
+ for eval_cand in pareto_log.candidates_evaluated:
681
+ if eval_cand.get('prompt', '').strip().strip('"\'') == normalized:
682
+ score = eval_cand.get('score', -float('inf'))
683
+ break
684
+
685
+ if score > best_score:
686
+ best_score = score
687
+ best_candidate = cand
688
+ best_idx = idx
689
+
690
+ # If no scores found, fall back to FIFO (first candidate)
691
+ if best_candidate is None and self._adapter_generated_candidates:
692
+ best_candidate = self._adapter_generated_candidates[0]
693
+ best_idx = 0
694
+ logger.info(f"⚠️ No scores found for candidates - using FIFO selection")
695
+
696
+ # Remove selected candidate from queue
697
+ if best_idx >= 0:
698
+ self._adapter_generated_candidates.pop(best_idx)
699
+
700
+ # Important event - keep at INFO
701
+ if best_score > -float('inf'):
702
+ logger.info(f"🎯 INTERCEPTING GEPA PROPOSAL CALL - Returning BEST candidate (score: {best_score:.4f})!")
703
+ logger.info(f"🎯 Remaining candidates: {len(self._adapter_generated_candidates)}")
704
+ else:
705
+ logger.info(f"🎯 INTERCEPTING GEPA PROPOSAL CALL - Returning pre-generated candidate!")
706
+ logger.info(f"🎯 Remaining candidates: {len(self._adapter_generated_candidates)}")
707
+
708
+ if isinstance(best_candidate, dict) and 'prompt' in best_candidate:
709
+ prompt = best_candidate['prompt']
710
+
711
+ # Detailed logging only in DEBUG mode
712
+ if self._should_log_debug():
713
+ logger.debug(f"✅ Pre-generated candidate details:")
714
+ logger.debug(f"{'▓'*80}")
715
+ logger.debug(f"{prompt}")
716
+ logger.debug(f"{'▓'*80}")
717
+ else:
718
+ source = best_candidate.get('source', 'unknown')
719
+ score_info = f" (score: {best_score:.4f})" if best_score > -float('inf') else ""
720
+ logger.info(f"✅ Candidate length: {len(prompt)} chars, Source: {source}{score_info}")
721
+
722
+ return {'content': prompt, 'source': best_candidate.get('source', 'adapter_generated')}
723
+ elif isinstance(best_candidate, str):
724
+ if self._should_log_debug():
725
+ logger.debug(f"✅ Pre-generated candidate (string format):")
726
+ logger.debug(f"{'▓'*80}")
727
+ logger.debug(f"{best_candidate}")
728
+ logger.debug(f"{'▓'*80}")
729
+ else:
730
+ logger.info(f"✅ Candidate length: {len(best_candidate)} chars")
731
+ return {'content': best_candidate, 'source': 'adapter_generated'}
732
+
733
+ # 🔥 ENHANCED CALL TYPE DETECTION
734
+ # We need to distinguish between 4 types of calls:
735
+ # 1. Evaluation calls: Image + task command → identify element (pass through)
736
+ # 2. Judge calls: Image + "prompt engineer" → analyze feedback (pass through)
737
+ # 3. Proposal calls: No image + feedback → generate candidate (intercept)
738
+ # 4. JSON batch calls: JSON generation request (pass through)
739
+
740
+ # FIX: DON'T intercept JSON batch generation calls
741
+ is_json_batch_request = (
742
+ '"variations"' in system_prompt or
743
+ 'MUST BE VALID JSON' in system_prompt or
744
+ 'Output ONLY the JSON object' in system_prompt or
745
+ '```json' in system_prompt.lower()
746
+ )
747
+
748
+ # FIX: DON'T intercept LLM-as-Judge calls (they analyze feedback with images)
749
+ is_judge_call = (
750
+ 'prompt engineer' in system_prompt.lower() or
751
+ 'analyzing mobile ui automation' in system_prompt.lower() or
752
+ 'expert prompt engineer' in system_prompt.lower() or
753
+ ('analyze' in system_prompt.lower() and 'screenshot with numbered bounding boxes' in system_prompt.lower() and image_base64)
754
+ )
755
+
756
+ # Check if this is a reflection call (GEPA asking for new candidate)
757
+ is_reflection_call = (
758
+ self.reflection_context['in_reflection'] or
759
+ self._detect_reflection_call(system_prompt, user_prompt)
760
+ )
761
+
762
+ # Proposal calls are reflection calls WITHOUT images and NOT judge/JSON calls
763
+ # These are the calls we want to intercept with LLEGO
764
+ is_proposal_call = (
765
+ not is_json_batch_request and # Not a JSON generation request
766
+ not is_judge_call and # Not an LLM-as-Judge analysis
767
+ not image_base64 and # No image = not an evaluation/judge call
768
+ (
769
+ is_reflection_call or
770
+ 'improve' in system_prompt.lower() or
771
+ 'optimize' in system_prompt.lower() or
772
+ 'suggest' in system_prompt.lower() or
773
+ 'feedback' in system_prompt.lower() or
774
+ 'reflection' in system_prompt.lower()
775
+ ) and
776
+ len(user_prompt) > 100 # Proposal calls have substantial feedback
777
+ )
778
+
779
+ # Detailed call detection logging only in DEBUG mode
780
+ if self._should_log_debug():
781
+ logger.debug(f" is_json_batch_request: {is_json_batch_request}")
782
+ logger.debug(f" is_judge_call: {is_judge_call}")
783
+ logger.debug(f" is_reflection_call: {is_reflection_call}")
784
+ logger.debug(f" is_proposal_call: {is_proposal_call}")
785
+ logger.debug(f" has_image: {bool(image_base64)}")
786
+ logger.debug(f" has_llego: {self.llego is not None}")
787
+
788
+ # Only intercept proposal calls (not judge, not evaluation, not JSON)
789
+ if is_proposal_call and self.llego:
790
+ # FIX #5: Check if LLEGO is disabled due to repeated failures
791
+ if self._llego_disabled:
792
+ logger.warning("⚠️ LLEGO is disabled (circuit breaker), using base LLM")
793
+ return self.base_llm.generate(
794
+ system_prompt=system_prompt,
795
+ user_prompt=user_prompt,
796
+ image_base64=image_base64,
797
+ **kwargs
798
+ )
799
+
800
+ # Important event - keep at INFO
801
+ logger.info("🔥 INTERCEPTING REFLECTION/PROPOSAL CALL FOR CANDIDATE GENERATION")
802
+ return self._llego_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
803
+ else:
804
+ # Standard LLM call (for evaluation, not reflection)
805
+ if self._should_log_debug():
806
+ logger.debug(" → Standard LLM call (evaluation, not reflection)")
807
+ return self.base_llm.generate(
808
+ system_prompt=system_prompt,
809
+ user_prompt=user_prompt,
810
+ image_base64=image_base64,
811
+ **kwargs
812
+ )
813
+
814
+ def _clean_reflection_feedback(self, feedback_text: str, max_length: int = 50000) -> str:
815
+ """
816
+ Clean reflection feedback by removing base64 images and truncating.
817
+
818
+ 🔥 CRITICAL: GEPA's feedback can include massive base64 images (7MB+).
819
+ This function removes them and keeps feedback concise.
820
+
821
+ Args:
822
+ feedback_text: Original feedback (may contain base64)
823
+ max_length: Maximum length after cleaning (default: 50K chars)
824
+
825
+ Returns:
826
+ Cleaned feedback without base64, within size limits
827
+ """
828
+ if not feedback_text:
829
+ return feedback_text
830
+
831
+ # Step 1: Remove very long base64-like sequences (50K+ chars of alphanumeric)
832
+ base64_pattern = r'[A-Za-z0-9+/=]{5000,}'
833
+ cleaned = re.sub(base64_pattern, '[IMAGE_DATA_REMOVED]', feedback_text)
834
+
835
+ # Step 2: Remove explicit image_base64 references and their values
836
+ cleaned = re.sub(r'image_base64["\']?\s*[:=]\s*["\']?[A-Za-z0-9+/=]+["\']?',
837
+ 'image_base64: [REMOVED]', cleaned, flags=re.IGNORECASE)
838
+
839
+ # Step 3: Remove detailed_scores sections that might contain base64
840
+ cleaned = re.sub(r'##\s+detailed_scores[^\n]*\n[^#]*(?:image_base64|base64)[^\n]*(?:\n[^#]*)*',
841
+ '## detailed_scores: [REMOVED_FOR_BREVITY]', cleaned, flags=re.IGNORECASE | re.MULTILINE)
842
+
843
+ # Step 4: Remove any remaining very long strings (likely base64)
844
+ cleaned = re.sub(r'"[A-Za-z0-9+/=]{10000,}"', '[LARGE_DATA_STRING_REMOVED]', cleaned)
845
+
846
+ # Step 5: Truncate if still too long (keep beginning which has most important info)
847
+ if len(cleaned) > max_length:
848
+ truncated_size = len(cleaned) - max_length
849
+ cleaned = cleaned[:max_length] + f"\n\n[TRUNCATED {truncated_size} characters - keeping essential feedback only]"
850
+ logger.warning(f"⚠️ Reflection feedback truncated: {len(feedback_text)} → {len(cleaned)} chars")
851
+
852
+ return cleaned
853
+
854
+ def _detect_reflection_call(self, system_prompt: str, user_prompt: str) -> bool:
855
+ """
856
+ Heuristic to detect if this is a reflection call from GEPA.
857
+
858
+ GEPA's reflection calls typically contain feedback/error analysis.
859
+ """
860
+ reflection_keywords = [
861
+ 'improve', 'feedback', 'error', 'failure', 'reflection',
862
+ 'better prompt', 'modify', 'enhance', 'optimize'
863
+ ]
864
+
865
+ combined = (system_prompt + " " + user_prompt).lower()
866
+ return any(keyword in combined for keyword in reflection_keywords)
867
+
868
+ def _llego_generate(
869
+ self,
870
+ system_prompt: str,
871
+ user_prompt: str,
872
+ image_base64: str = "",
873
+ **kwargs
874
+ ) -> Dict[str, Any]:
875
+ """
876
+ Use LLEGO (or Hybrid mode) to generate new prompt candidates.
877
+
878
+ Args:
879
+ system_prompt: System prompt
880
+ user_prompt: User prompt (contains reflection feedback)
881
+ image_base64: Image data (for reflection, always empty)
882
+ **kwargs: Additional arguments (may contain image_base64, will be removed)
883
+
884
+ Returns:
885
+ Dict with 'content' key containing a new prompt candidate
886
+ """
887
+ try:
888
+ # 🔥 CRITICAL: Remove image_base64 from kwargs to avoid duplicate argument error
889
+ kwargs.pop('image_base64', None) # Remove if present to avoid conflict
890
+
891
+ # 🔥 HYBRID MODE: Generate from BOTH GEPA reflection AND LLEGO
892
+ if (self.config and
893
+ hasattr(self.config, 'enable_gepa_reflection_with_llego') and
894
+ self.config.enable_gepa_reflection_with_llego):
895
+
896
+ return self._hybrid_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
897
+
898
+ # STANDARD LLEGO MODE (LLEGO only)
899
+ return self._llego_only_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
900
+
901
+ except Exception as e:
902
+ # FIX #5: Circuit breaker - track failures and disable LLEGO if needed
903
+ self._llego_failures += 1
904
+
905
+ logger.error(f"❌ LLEGO generation failed ({self._llego_failures}/{self._llego_failure_threshold}): {e}")
906
+ logger.error("⚠️ Falling back to base LLM")
907
+
908
+ if self._llego_failures >= self._llego_failure_threshold:
909
+ self._llego_disabled = True
910
+ logger.error(f"🚫 LLEGO DISABLED - {self._llego_failures} consecutive failures detected")
911
+ logger.error(" All future requests will use base LLM only")
912
+
913
+ import traceback
914
+ logger.debug(traceback.format_exc())
915
+
916
+ # Fallback to base LLM - ensure image_base64 is not in kwargs
917
+ kwargs.pop('image_base64', None)
918
+ return self.base_llm.generate(
919
+ system_prompt=system_prompt,
920
+ user_prompt=user_prompt,
921
+ image_base64=image_base64,
922
+ **kwargs
923
+ )
924
+
925
+ def _hybrid_generate(
926
+ self,
927
+ system_prompt: str,
928
+ user_prompt: str,
929
+ image_base64: str = "",
930
+ **kwargs
931
+ ) -> Dict[str, Any]:
932
+ """
933
+ 🔥 HYBRID MODE: Generate candidates from BOTH GEPA reflection AND LLEGO operators.
934
+
935
+ Smart Compensation Strategy:
936
+ - When crossover can't run (< 2 parents), compensates with extra GEPA reflection
937
+ - GEPA is smarter than mutation (uses semantic understanding of feedback)
938
+ - Crossover only runs when we have 2+ scored parents to combine
939
+
940
+ GEPA will call generate() multiple times. On first call, we generate all candidates
941
+ and queue them. Subsequent calls return from the queue.
942
+ """
943
+ # If we already generated candidates, return next from queue
944
+ if self._hybrid_generation_complete and self._candidate_queue:
945
+ candidate = self._candidate_queue.pop(0)
946
+ source = candidate.get('source', 'unknown')
947
+ logger.info(f"📦 Returning queued candidate (source: {source}, {len(self._candidate_queue)} remaining)")
948
+ return {'content': candidate['prompt'], 'source': source}
949
+
950
+ # First call: Generate ALL candidates
951
+ from ..utils.clean_logger import get_clean_logger
952
+ clean_log = get_clean_logger()
953
+
954
+ all_candidates = []
955
+
956
+ # ─────────────────────────────────────────────────────
957
+ # PHASE 0: Check if crossover will be possible
958
+ # ─────────────────────────────────────────────────────
959
+ from ..utils.pareto_logger import get_pareto_logger
960
+ pareto_log = get_pareto_logger()
961
+ gepa_pareto_front = pareto_log.pareto_front
962
+
963
+ # Determine if we need to compensate for crossover
964
+ crossover_possible = len(gepa_pareto_front) >= 2
965
+ n_crossover_config = self.config.n_crossover if hasattr(self.config, 'n_crossover') else 2
966
+ crossover_compensation = 0 if crossover_possible else n_crossover_config
967
+
968
+ if not crossover_possible:
969
+ logger.info(f"⚠️ Crossover NOT possible (have {len(gepa_pareto_front)} parents, need 2+)")
970
+ logger.info(f" → Smart compensation: +{crossover_compensation} extra GEPA reflection candidates")
971
+
972
+ # ─────────────────────────────────────────────────────
973
+ # PHASE 1: GEPA REFLECTION (Semantic Understanding)
974
+ # More GEPA = better, it understands WHY things fail
975
+ # ─────────────────────────────────────────────────────
976
+ base_gepa_count = self.config.num_gepa_reflection_candidates if hasattr(self.config, 'num_gepa_reflection_candidates') else 3
977
+
978
+ # 🔥 SMART COMPENSATION: More GEPA when crossover can't run
979
+ num_gepa = base_gepa_count + crossover_compensation
980
+
981
+ logger.info("─" * 80)
982
+ logger.info("PHASE 1: GEPA REFLECTION (Semantic Understanding)")
983
+ if crossover_compensation > 0:
984
+ logger.info(f"Generating {num_gepa} candidates ({base_gepa_count} base + {crossover_compensation} compensation for skipped crossover)")
985
+ else:
986
+ logger.info(f"Generating {num_gepa} candidates")
987
+ logger.info("─" * 80)
988
+
989
+ # 🔥 OPTIMIZED: Single call with JSON format for multiple variations
990
+ try:
991
+ # Clean user_prompt before sending to LLM
992
+ cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
993
+
994
+ # Build diversity requirements based on num_gepa
995
+ diversity_requirements = self._build_diversity_requirements(num_gepa)
996
+
997
+ # 🔥 FORMAT AWARENESS: Get format constraint if available
998
+ format_constraint = ""
999
+ if self._detected_format and self._detected_format.get('format_constraint'):
1000
+ format_constraint = self._detected_format['format_constraint']
1001
+ logger.info(f"📐 Injecting format constraint into candidate generation")
1002
+ # #region agent log
1003
+ import json as _json_debug
1004
+ import time as _time_debug
1005
+ import os as _os_debug
1006
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
1007
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
1008
+ with open(_debug_log_path, "a") as _f:
1009
+ _f.write(_json_debug.dumps({"hypothesisId": "FORMAT_CONSTRAINT", "location": "llego_enhanced_llm.py:format_injection", "message": "Format constraint injected", "data": {"format_type": self._detected_format.get('format_type', 'unknown'), "constraint_length": len(format_constraint), "avg_length": self._detected_format.get('avg_length', 0)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
1010
+ # #endregion
1011
+ else:
1012
+ format_constraint = "No specific format detected - ensure output is CONCISE and matches expected examples."
1013
+ # #region agent log
1014
+ import json as _json_debug
1015
+ import time as _time_debug
1016
+ import os as _os_debug
1017
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
1018
+ _os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
1019
+ with open(_debug_log_path, "a") as _f:
1020
+ _f.write(_json_debug.dumps({"hypothesisId": "FORMAT_CONSTRAINT", "location": "llego_enhanced_llm.py:format_injection", "message": "No format constraint available", "data": {"has_detected_format": bool(self._detected_format)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
1021
+ # #endregion
1022
+
1023
+ # 🔥 EVOLUTIONARY PROMPT ENGINEER: Forces radically different mutations
1024
+ # Each variation MUST use a distinct genetic strategy to maximize search space
1025
+ optimization_system_prompt = f"""<system_core>
1026
+ You are an **Evolutionary Prompt Engineer**. Your task is to mutate a [FAILING_PROMPT] into a high-performance instruction set using genetic strategies.
1027
+ You must generate {num_gepa} radically different prompt variations based on the [FAILURE_FEEDBACK].
1028
+ </system_core>
1029
+
1030
+ <input_data>
1031
+ <failure_feedback_log>
1032
+ {cleaned_user_prompt}
1033
+ </failure_feedback_log>
1034
+ </input_data>
1035
+
1036
+ <mutation_strategies>
1037
+ You MUST use a different strategy for each variation. Assign strategies in order:
1038
+
1039
+ 1. **STRATEGY A: The Strict Auditor (Constraints)**
1040
+ - Focus: Add "Negative Constraints" (e.g., "Do NOT...", "NEVER...", "FORBIDDEN:").
1041
+ - Use strict XML tagging for the output schema.
1042
+ - Goal: Fix hallucinations and formatting errors.
1043
+
1044
+ 2. **STRATEGY B: The Reasoning Expert (Chain of Thought)**
1045
+ - Focus: Add a "Reasoning Steps" section.
1046
+ - Instruct the model to "Think step-by-step" before generating the final output.
1047
+ - Goal: Fix logic errors and complex multi-step reasoning failures.
1048
+
1049
+ 3. **STRATEGY C: The Few-Shot Teacher (Examples)**
1050
+ - Focus: Generate a *synthetic* example of Input -> Correct Output within the prompt.
1051
+ - Goal: Fix understanding of abstract concepts or strict schema requirements.
1052
+
1053
+ 4. **STRATEGY D: The Role-Player (Persona)**
1054
+ - Focus: Change the persona to a hyper-specific expert (e.g., "Senior Data Engineer at Fortune 500" vs "Coder").
1055
+ - Add domain-specific vocabulary and expertise markers.
1056
+ - Goal: Fix domain-specific terminology errors.
1057
+
1058
+ 5. **STRATEGY E: The Structure Architect (Format)**
1059
+ - Focus: Add explicit output schema with field-by-field instructions.
1060
+ - Use markdown or XML headers to organize the prompt.
1061
+ - Goal: Fix output structure and field naming errors.
1062
+ </mutation_strategies>
1063
+
1064
+ <output_constraints>
1065
+ 1. **Self-Contained**: Each variation must be the FULL prompt text (100-500 words), ready to run.
1066
+ 2. **No Meta-Talk**: Do not explain your strategy inside the prompt. Just output the optimized prompt.
1067
+ 3. **Preserve Core Task**: Keep the original task/domain - only improve HOW it's described.
1068
+ 4. **JSON Output**: Follow the schema below exactly.
1069
+ 5. **ENFORCE OUTPUT FORMAT**: The generated prompt MUST instruct the model to output in the EXACT format shown in examples.
1070
+ </output_constraints>
1071
+
1072
+ <critical_output_format_requirement>
1073
+ 🚨 THE GENERATED PROMPTS MUST INCLUDE EXPLICIT OUTPUT FORMAT INSTRUCTIONS!
1074
+ Common failure: The model generates explanations/prose instead of the required concise format.
1075
+
1076
+ {format_constraint}
1077
+
1078
+ Your generated prompts MUST include:
1079
+ - Explicit instruction to output ONLY in the required format
1080
+ - "Do NOT explain", "No reasoning", "Output ONLY [format]" constraints
1081
+ - Length constraint to prevent verbose responses
1082
+ </critical_output_format_requirement>
1083
+
1084
+ <response_format>
1085
+ You MUST output ONLY valid JSON. No comments, no explanations, no markdown code blocks.
1086
+
1087
+ Generate exactly {num_gepa} variations in this exact format:
1088
+
1089
+ {{
1090
+ "variations": [
1091
+ {{
1092
+ "index": 1,
1093
+ "strategy": "Strict Auditor",
1094
+ "prompt": "[FULL PROMPT TEXT - Complete, self-contained, ready to use]"
1095
+ }},
1096
+ {{
1097
+ "index": 2,
1098
+ "strategy": "Reasoning Expert",
1099
+ "prompt": "[FULL PROMPT TEXT - Complete, self-contained, ready to use]"
1100
+ }}
1101
+ ]
1102
+ }}
1103
+
1104
+ CRITICAL RULES:
1105
+ 1. Output ONLY the JSON object - no text before or after
1106
+ 2. Do NOT use markdown code blocks (no ```json)
1107
+ 3. Do NOT include comments (no // or /* */)
1108
+ 4. Ensure all strings are properly escaped
1109
+ 5. Generate exactly {num_gepa} variations
1110
+ 6. Each variation must have: index (number), strategy (string), prompt (string)
1111
+ </response_format>
1112
+ """
1113
+
1114
+ # Standard GEPA reflection call
1115
+ call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
1116
+ result = self.base_llm.generate(
1117
+ system_prompt=optimization_system_prompt,
1118
+ user_prompt=cleaned_user_prompt,
1119
+ image_base64=image_base64,
1120
+ **call_kwargs
1121
+ )
1122
+
1123
+ if isinstance(result, dict):
1124
+ response_text = result.get("content", str(result))
1125
+ else:
1126
+ response_text = str(result)
1127
+
1128
+ # Parse JSON variations
1129
+ gepa_variations = self._parse_json_variations(response_text, num_gepa)
1130
+
1131
+ # Add all variations to candidates
1132
+ for idx, variation_prompt in enumerate(gepa_variations, 1):
1133
+ # 🛡️ DEFENSIVE FALLBACK: Extract clean prompt if LLM adds analysis
1134
+ gepa_candidate = self._extract_clean_prompt_from_reflection(variation_prompt)
1135
+
1136
+ # Validate extracted prompt before adding
1137
+ if not self._is_valid_prompt(gepa_candidate):
1138
+ logger.warning(f" ⚠️ Variation {idx} appears invalid, skipping")
1139
+ continue
1140
+
1141
+ # 🔍 DIAGNOSTIC: Log candidate length to help diagnose scoring issues
1142
+ if self._should_log_debug():
1143
+ logger.debug(f" Candidate {idx} length: {len(gepa_candidate)} chars")
1144
+ logger.debug(f" Candidate {idx} preview: {gepa_candidate[:100]}...")
1145
+
1146
+ all_candidates.append({
1147
+ 'prompt': gepa_candidate,
1148
+ 'source': 'gepa_reflection',
1149
+ 'index': idx
1150
+ })
1151
+
1152
+ clean_log.log_gepa_reflection_candidate(idx, gepa_candidate)
1153
+
1154
+ gepa_count = len(all_candidates)
1155
+ logger.info(f"✅ GEPA Reflection: {gepa_count} candidates generated in single optimized call")
1156
+
1157
+ except Exception as e:
1158
+ logger.error(f"❌ Error generating GEPA reflection candidates: {e}")
1159
+ logger.warning(f" Falling back to sequential generation...")
1160
+ import traceback
1161
+ logger.debug(traceback.format_exc())
1162
+
1163
+ # Fallback: Sequential generation (when JSON parsing fails)
1164
+ gepa_count = self._fallback_sequential_gepa_generation(
1165
+ num_gepa, user_prompt, image_base64, kwargs, all_candidates, clean_log
1166
+ )
1167
+
1168
+ if gepa_count > 0:
1169
+ logger.info(f"GEPA Reflection Complete: {gepa_count} candidates")
1170
+
1171
+ # ─────────────────────────────────────────────────────
1172
+ # PHASE 2: LLEGO GENETIC OPERATORS
1173
+ # ─────────────────────────────────────────────────────
1174
+ logger.info("─" * 80)
1175
+ logger.info("PHASE 2: LLEGO GENETIC OPERATORS")
1176
+ logger.info("─" * 80)
1177
+
1178
+ # Extract current prompt from context
1179
+ current_prompt = self.reflection_context.get('current_prompt', '')
1180
+ if not current_prompt:
1181
+ current_prompt = self._extract_prompt_from_feedback(user_prompt)
1182
+
1183
+ if not current_prompt and self.llego.population:
1184
+ current_prompt = self.llego.population[0].prompt
1185
+ logger.info(f" Using population prompt (length: {len(current_prompt)})")
1186
+
1187
+ # Convert GEPA Pareto front to PromptCandidate format (already fetched in Phase 0)
1188
+ pareto_candidates = self.llego._convert_gepa_pareto_to_candidates(gepa_pareto_front)
1189
+ pareto_front = pareto_candidates
1190
+
1191
+ logger.info(f" Pareto front: {len(pareto_front)} candidates with scores")
1192
+ for idx, p in enumerate(pareto_front, 1):
1193
+ notation = p.metadata.get('notation', 'S') if p.metadata else 'S'
1194
+ logger.info(f" {notation}: fitness={p.fitness:.3f}")
1195
+
1196
+ # Create LLM callable for LLEGO genetic operations (crossover/mutation)
1197
+ call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
1198
+
1199
+ # LLEGO genetic prompt with SAFETY LOCKS to prevent task drift
1200
+ # Directed mutations ensure prompts improve without losing core functionality
1201
+ genetic_operator_system_prompt = """<system_role>
1202
+ You are a **Prompt Mutation Engine**. Your input is a [PARENT_PROMPT]. Your output is a [MUTATED_CHILD].
1203
+ </system_role>
1204
+
1205
+ <mutation_directives>
1206
+ Apply ONE of the following micro-mutations to improve the prompt:
1207
+
1208
+ 1. **COMPRESS**: Remove fluff words ("please", "ensure to", "kindly"). Make it telegraphic and efficient.
1209
+ 2. **INTENSIFY**: Capitalize key constraints (e.g., "must return JSON" -> "**MUST** return **VALID JSON**").
1210
+ 3. **STRUCTURIZE**: Add markdown headers or XML tags to organize a messy prompt.
1211
+ 4. **CLARIFY**: Expand vague nouns (e.g., "code" -> "production-ready Python code with type hints").
1212
+ 5. **CONSTRAIN**: Add negative constraints ("Do NOT include explanations", "NEVER output markdown").
1213
+ </mutation_directives>
1214
+
1215
+ <safety_locks>
1216
+ 1. **IMMUTABLE CORE**: You MUST NOT change the core task (e.g., do not change "Extract JSON" to "Write a Summary").
1217
+ 2. **NO EXPLANATION**: Output ONLY the new prompt string. No meta-commentary.
1218
+ 3. **VALIDITY**: The output must remain a functional system prompt.
1219
+ 4. **LENGTH LIMIT**: Keep mutations within 20% of original length (no excessive expansion).
1220
+ </safety_locks>"""
1221
+
1222
+ def llm_callable(genetic_prompt: str) -> str:
1223
+ result = self.base_llm.generate(
1224
+ system_prompt=genetic_operator_system_prompt,
1225
+ user_prompt=genetic_prompt,
1226
+ image_base64="",
1227
+ **call_kwargs
1228
+ )
1229
+ if isinstance(result, dict):
1230
+ return result.get('content', str(result))
1231
+ return str(result)
1232
+
1233
+ # Generate LLEGO offspring (crossover will be skipped if < 2 parents)
1234
+ llego_prompts = self.llego.evolve_generation(
1235
+ llm=llm_callable,
1236
+ pareto_front=pareto_front
1237
+ )
1238
+
1239
+ # Track actual crossover count from LLEGO (it tracks internally now)
1240
+ actual_crossover = getattr(self.llego, '_actual_crossover_count', 0)
1241
+ crossover_skipped = getattr(self.llego, '_crossover_skipped', False)
1242
+
1243
+ crossover_idx = 1
1244
+ mutation_idx = 1
1245
+
1246
+ for i, prompt in enumerate(llego_prompts):
1247
+ if i < actual_crossover:
1248
+ source = 'llego_crossover'
1249
+ clean_log.log_llego_crossover_candidate(crossover_idx, prompt)
1250
+ crossover_idx += 1
1251
+ else:
1252
+ source = 'llego_mutation'
1253
+ clean_log.log_llego_mutation_candidate(mutation_idx, prompt)
1254
+ mutation_idx += 1
1255
+
1256
+ all_candidates.append({
1257
+ 'prompt': prompt,
1258
+ 'source': source,
1259
+ 'index': i + 1
1260
+ })
1261
+
1262
+ mutation_count = len(llego_prompts) - actual_crossover
1263
+ logger.info(f"🧬 LLEGO: {actual_crossover} crossover + {mutation_count} mutation = {len(llego_prompts)} candidates")
1264
+ if crossover_skipped:
1265
+ logger.info(f" (Crossover was skipped - compensated with extra GEPA reflection)")
1266
+
1267
+ # ─────────────────────────────────────────────────────
1268
+ # SUMMARY
1269
+ # ─────────────────────────────────────────────────────
1270
+ total_gepa = len([c for c in all_candidates if c.get('source') == 'gepa_reflection'])
1271
+ total_crossover = len([c for c in all_candidates if c.get('source') == 'llego_crossover'])
1272
+ total_mutation = len([c for c in all_candidates if c.get('source') == 'llego_mutation'])
1273
+
1274
+ logger.info("─" * 80)
1275
+ logger.info("CANDIDATE GENERATION SUMMARY")
1276
+ logger.info("─" * 80)
1277
+ logger.info(f" GEPA Reflection: {total_gepa} candidates (semantic understanding)")
1278
+ logger.info(f" LLEGO Crossover: {total_crossover} candidates (combine best)")
1279
+ logger.info(f" LLEGO Mutation: {total_mutation} candidates (exploration)")
1280
+ logger.info(f" TOTAL: {len(all_candidates)} candidates")
1281
+ if crossover_skipped:
1282
+ logger.info(f" 📝 Note: Crossover skipped (waiting for 2+ scored parents)")
1283
+ logger.info("─" * 80)
1284
+
1285
+ clean_log.log_candidate_generation_summary()
1286
+
1287
+ # Store in queue (skip first one - return it now)
1288
+ self._candidate_queue = all_candidates[1:] if len(all_candidates) > 1 else []
1289
+ self._hybrid_generation_complete = True
1290
+
1291
+ # Return first candidate
1292
+ if all_candidates:
1293
+ first = all_candidates[0]
1294
+ logger.info(f"📤 Returning FIRST candidate (source: {first['source']})")
1295
+ return {'content': first['prompt'], 'source': first['source']}
1296
+ else:
1297
+ logger.error("❌ No candidates generated!")
1298
+ return {'content': '', 'source': 'error'}
1299
+
1300
+ def _llego_only_generate(
1301
+ self,
1302
+ system_prompt: str,
1303
+ user_prompt: str,
1304
+ image_base64: str = "",
1305
+ **kwargs
1306
+ ) -> Dict[str, Any]:
1307
+ """
1308
+ STANDARD LLEGO MODE: Generate candidates using only LLEGO operators.
1309
+ """
1310
+ # 🔥 CRITICAL: Remove image_base64 from kwargs to avoid duplicate argument error
1311
+ kwargs.pop('image_base64', None)
1312
+
1313
+ # 🔥 FIX: Clean user_prompt if it contains feedback (might have base64)
1314
+ cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
1315
+
1316
+ # Extract current prompt from context or user_prompt
1317
+ current_prompt = self.reflection_context.get('current_prompt', '')
1318
+
1319
+ if not current_prompt:
1320
+ # Try to extract from cleaned user_prompt
1321
+ current_prompt = self._extract_prompt_from_feedback(cleaned_user_prompt)
1322
+
1323
+ logger.info(f"🧬 LLEGO: Evolving prompt...")
1324
+ if self._should_log_debug():
1325
+ logger.debug(f" Current prompt: '{current_prompt[:100]}...' (length: {len(current_prompt)} chars)")
1326
+ else:
1327
+ logger.info(f" Prompt length: {len(current_prompt)} chars")
1328
+
1329
+ # 🔥 FIX 2: Get Pareto front from GEPA (not LLEGO population)
1330
+ # This ensures LLEGO operators use true non-dominated solutions
1331
+ from ..utils.pareto_logger import get_pareto_logger
1332
+ pareto_log = get_pareto_logger()
1333
+ gepa_pareto_front = pareto_log.pareto_front
1334
+
1335
+ # Convert GEPA Pareto front to PromptCandidate format
1336
+ pareto_candidates = self.llego._convert_gepa_pareto_to_candidates(gepa_pareto_front)
1337
+ pareto_front = pareto_candidates
1338
+
1339
+ logger.info(f" Using GEPA Pareto front (size: {len(gepa_pareto_front)})")
1340
+ logger.info(f" Converted to {len(pareto_front)} PromptCandidate objects")
1341
+
1342
+ # Create LLM callable for LLEGO genetic operations
1343
+ # Uses Genetic Mutation Engine prompt for micro-mutations
1344
+ call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
1345
+
1346
+ genetic_system_prompt = """You are a **Genetic Mutation Engine** for Text Prompts.
1347
+
1348
+ <task>
1349
+ Apply a specific micro-mutation to the provided prompt to increase its clarity, strictness, or effectiveness.
1350
+ </task>
1351
+
1352
+ <mutation_types>
1353
+ 1. **Compress**: Shorten verbose instructions without losing meaning.
1354
+ 2. **Expand**: Add detail to vague nouns (e.g., "code" -> "production-ready Python 3.10 code").
1355
+ 3. **Emphasize**: Highlight CRITICAL constraints using caps, bold, or explicit markers.
1356
+ 4. **Constrain**: Add explicit boundaries (what NOT to do, format rules, length limits).
1357
+ 5. **Exemplify**: Add a brief example if the task is ambiguous.
1358
+ </mutation_types>
1359
+
1360
+ <output_rules>
1361
+ 1. Output ONLY the mutated prompt text.
1362
+ 2. Do NOT change the core intent or task domain.
1363
+ 3. Do NOT add explanations or meta-commentary.
1364
+ 4. Apply ONE primary mutation type while preserving all existing strengths.
1365
+ </output_rules>"""
1366
+
1367
+ def llm_callable(prompt: str) -> str:
1368
+ # Clean prompt before sending (might contain base64 if from feedback)
1369
+ cleaned_prompt = self._clean_reflection_feedback(prompt)
1370
+ result = self.base_llm.generate(
1371
+ system_prompt=genetic_system_prompt,
1372
+ user_prompt=cleaned_prompt,
1373
+ image_base64="", # Always empty for LLEGO genetic operations
1374
+ **call_kwargs
1375
+ )
1376
+ if isinstance(result, dict):
1377
+ return result.get('content', str(result))
1378
+ return str(result)
1379
+
1380
+ # Generate offspring using LLEGO
1381
+ new_prompts = self.llego.evolve_generation(
1382
+ llm=llm_callable,
1383
+ pareto_front=pareto_front
1384
+ )
1385
+
1386
+ if new_prompts:
1387
+ new_prompt = new_prompts[0]
1388
+ logger.info(f"✅ LLEGO generated new candidate (length: {len(new_prompt)} chars)")
1389
+
1390
+ if self._should_log_debug():
1391
+ logger.debug(f" Full prompt:")
1392
+ logger.debug(f" '{new_prompt}'")
1393
+
1394
+ return {
1395
+ 'content': new_prompt,
1396
+ 'source': 'llego',
1397
+ 'num_candidates': len(new_prompts)
1398
+ }
1399
+ else:
1400
+ logger.warning("⚠️ LLEGO returned no candidates, falling back to base LLM")
1401
+ return self.base_llm.generate(
1402
+ system_prompt=system_prompt,
1403
+ user_prompt=user_prompt,
1404
+ image_base64="",
1405
+ **kwargs
1406
+ )
1407
+
1408
+ def _build_diversity_requirements(self, num_gepa: int) -> str:
1409
+ """
1410
+ Build diversity requirements using research-backed Prompt Design Patterns.
1411
+
1412
+ These are proven strategies from prompt engineering literature:
1413
+ - Chain-of-Thought (CoT)
1414
+ - Few-Shot Learning
1415
+ - Negative Constraints
1416
+ - Persona Pattern
1417
+
1418
+ Args:
1419
+ num_gepa: Number of GEPA variations to generate
1420
+
1421
+ Returns:
1422
+ String with diversity requirements for the optimization prompt
1423
+ """
1424
+ # Research-backed Prompt Design Patterns that solve specific classes of problems
1425
+ strategies = [
1426
+ """
1427
+ <variation_1>
1428
+ **STRATEGY: COGNITIVE DECOMPOSITION (Chain-of-Thought)**
1429
+ - **Goal**: Fixes logic/reasoning errors.
1430
+ - **Action**: Add a thinking process section that forces step-by-step reasoning.
1431
+ - **Implementation**: Include instructions like "First analyze..., then identify..., finally conclude..."
1432
+ - **Pattern**: Force the model to "Plan before executing".
1433
+ </variation_1>
1434
+ """,
1435
+
1436
+ """
1437
+ <variation_2>
1438
+ **STRATEGY: FEW-SHOT SIMULATION (In-Context Learning)**
1439
+ - **Goal**: Fixes formatting/syntax errors and output structure issues.
1440
+ - **Action**: Invent 1-2 realistic "Input -> Output" examples that mirror the expected format.
1441
+ - **Implementation**: Add "Example: Given [input], respond with: [expected output format]"
1442
+ - **Pattern**: Show, don't just tell. Demonstrate the gold standard.
1443
+ </variation_2>
1444
+ """,
1445
+
1446
+ """
1447
+ <variation_3>
1448
+ **STRATEGY: SEMANTIC CONSTRAINING (Negative Constraints)**
1449
+ - **Goal**: Fixes hallucinations, verbosity, and off-topic responses.
1450
+ - **Action**: Add explicit forbidden actions and boundaries.
1451
+ - **Implementation**: Include "Do NOT explain your reasoning", "Do NOT add preambles", "Do NOT include information not asked for"
1452
+ - **Pattern**: Define the walls, not just the path.
1453
+ </variation_3>
1454
+ """,
1455
+
1456
+ """
1457
+ <variation_4>
1458
+ **STRATEGY: PERSONA & ROLE HARDENING**
1459
+ - **Goal**: Fixes tone, domain knowledge gaps, and inconsistent behavior.
1460
+ - **Action**: Define a hyper-specific expert role with clear responsibilities.
1461
+ - **Implementation**: Instead of "You are a helpful assistant", use "You are a Senior Data Analyst with 10 years of experience in [domain]"
1462
+ - **Pattern**: Adopt the mental model and rigorous standards of a real expert.
1463
+ </variation_4>
1464
+ """,
1465
+
1466
+ """
1467
+ <variation_5>
1468
+ **STRATEGY: OUTPUT SCHEMA ENFORCEMENT**
1469
+ - **Goal**: Fixes structural and format compliance issues.
1470
+ - **Action**: Define an explicit output schema with field names and types.
1471
+ - **Implementation**: Include "Your response MUST follow this exact format: {field1: type, field2: type}"
1472
+ - **Pattern**: Leave no ambiguity about what the output should look like.
1473
+ </variation_5>
1474
+ """,
1475
+
1476
+ """
1477
+ <variation_6>
1478
+ **STRATEGY: SELF-VERIFICATION LOOP**
1479
+ - **Goal**: Fixes errors that could be caught by double-checking.
1480
+ - **Action**: Add instructions for the model to verify its own output.
1481
+ - **Implementation**: Include "Before responding, verify: 1) Does this match the required format? 2) Did I include all requested information?"
1482
+ - **Pattern**: Build in quality control before submission.
1483
+ </variation_6>
1484
+ """,
1485
+
1486
+ """
1487
+ <variation_7>
1488
+ **STRATEGY: TASK DECOMPOSITION**
1489
+ - **Goal**: Fixes complex tasks that overwhelm the model.
1490
+ - **Action**: Break the task into numbered sub-tasks.
1491
+ - **Implementation**: "Step 1: [subtask]. Step 2: [subtask]. Step 3: Combine results."
1492
+ - **Pattern**: Divide and conquer complexity.
1493
+ </variation_7>
1494
+ """
1495
+ ]
1496
+
1497
+ # Select strategies based on num_gepa
1498
+ selected = strategies[:min(num_gepa, len(strategies))]
1499
+
1500
+ requirements = "<required_strategies>\n"
1501
+ requirements += "Each variation MUST use a DIFFERENT strategy from the list below:\n"
1502
+ requirements += "\n".join(selected)
1503
+ requirements += "\n</required_strategies>"
1504
+
1505
+ requirements += """
1506
+
1507
+ <strategy_application_rules>
1508
+ 1. Each variation must apply its assigned strategy comprehensively.
1509
+ 2. Each variation must ALSO address ALL issues mentioned in the feedback.
1510
+ 3. The strategies are not mutually exclusive - but the PRIMARY focus of each variation should be its assigned strategy.
1511
+ 4. Do not just add a single line - transform the prompt structure according to the strategy.
1512
+ </strategy_application_rules>
1513
+ """
1514
+
1515
+ return requirements
1516
+
1517
+ def _fallback_sequential_gepa_generation(
1518
+ self,
1519
+ num_gepa: int,
1520
+ user_prompt: str,
1521
+ image_base64: str,
1522
+ kwargs: dict,
1523
+ all_candidates: list,
1524
+ clean_log
1525
+ ) -> int:
1526
+ """
1527
+ Fallback to sequential generation when JSON parsing fails.
1528
+
1529
+ Args:
1530
+ num_gepa: Number of candidates to generate
1531
+ user_prompt: The feedback/context
1532
+ image_base64: Image data (if any)
1533
+ kwargs: Additional kwargs
1534
+ all_candidates: List to append candidates to
1535
+ clean_log: Logger for clean output
1536
+
1537
+ Returns:
1538
+ Number of candidates generated
1539
+ """
1540
+ generated_count = 0
1541
+
1542
+ for i in range(num_gepa):
1543
+ logger.debug(f"Generating Reflection Candidate #{i+1}/{num_gepa} (fallback mode)...")
1544
+ try:
1545
+ cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
1546
+
1547
+ # Use research-backed strategy for each variation
1548
+ strategy_prompts = [
1549
+ "<optimization_rule>\nApply CHAIN-OF-THOUGHT: Add step-by-step reasoning instructions. Force the model to 'think before answering'.\n</optimization_rule>",
1550
+ "<optimization_rule>\nApply FEW-SHOT LEARNING: Add 1-2 concrete input/output examples within the prompt. Show, don't just tell.\n</optimization_rule>",
1551
+ "<optimization_rule>\nApply NEGATIVE CONSTRAINTS: Add explicit 'Do NOT' rules. Define what the model must avoid.\n</optimization_rule>",
1552
+ "<optimization_rule>\nApply PERSONA HARDENING: Define a specific expert role with clear responsibilities and standards.\n</optimization_rule>",
1553
+ "<optimization_rule>\nApply OUTPUT SCHEMA: Define the exact output format with field names and types. Leave no ambiguity.\n</optimization_rule>",
1554
+ ]
1555
+
1556
+ strategy = strategy_prompts[i % len(strategy_prompts)]
1557
+
1558
+ fallback_prompt = f"""You are a Prompt Optimization Engine in **SAFE MODE**.
1559
+
1560
+ {strategy}
1561
+
1562
+ {_FALLBACK_SYSTEM_PROMPT}"""
1563
+
1564
+ call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
1565
+ result = self.base_llm.generate(
1566
+ system_prompt=fallback_prompt,
1567
+ user_prompt=cleaned_user_prompt,
1568
+ image_base64=image_base64,
1569
+ **call_kwargs
1570
+ )
1571
+
1572
+ if isinstance(result, dict):
1573
+ gepa_candidate_raw = result.get("content", str(result))
1574
+ else:
1575
+ gepa_candidate_raw = str(result)
1576
+
1577
+ gepa_candidate = self._extract_clean_prompt_from_reflection(gepa_candidate_raw)
1578
+
1579
+ if not self._is_valid_prompt(gepa_candidate):
1580
+ logger.warning(f" ⚠️ Fallback candidate #{i+1} appears invalid, skipping")
1581
+ continue
1582
+
1583
+ all_candidates.append({
1584
+ 'prompt': gepa_candidate,
1585
+ 'source': 'gepa_reflection',
1586
+ 'index': i + 1
1587
+ })
1588
+
1589
+ clean_log.log_gepa_reflection_candidate(i + 1, gepa_candidate)
1590
+ generated_count += 1
1591
+
1592
+ except Exception as fallback_error:
1593
+ logger.error(f"❌ Error in fallback generation #{i+1}: {fallback_error}")
1594
+
1595
+ return generated_count
1596
+
1597
+ def _extract_prompt_from_feedback(self, user_prompt: str) -> str:
1598
+ """
1599
+ Try to extract the current prompt from GEPA's reflection feedback.
1600
+
1601
+ Args:
1602
+ user_prompt: The feedback text from GEPA
1603
+
1604
+ Returns:
1605
+ Extracted prompt or empty string
1606
+ """
1607
+ # Look for common patterns in GEPA's feedback
1608
+ if "current prompt:" in user_prompt.lower():
1609
+ lines = user_prompt.split('\n')
1610
+ for i, line in enumerate(lines):
1611
+ if "current prompt:" in line.lower():
1612
+ # Return the next line(s) as the prompt
1613
+ return '\n'.join(lines[i+1:i+10])
1614
+
1615
+ return ""
1616
+
1617
+ # Forward other methods to base LLM
1618
+ def get_model_info(self) -> str:
1619
+ """Get model information."""
1620
+ return f"LLEGO({self.base_llm.get_model_info()})"
1621
+
1622
+ def __getattr__(self, name):
1623
+ """Forward unknown attributes to base LLM."""
1624
+ return getattr(self.base_llm, name)
1625
+
src/gepa_optimizer/llms/vision_llm.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision LLM Client for GEPA Optimizer
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import time
8
+ from enum import Enum
9
+ import requests
10
+ from typing import Dict, Optional, Any, TYPE_CHECKING, Union
11
+
12
+ # Assuming APIKeyManager is available from utils
13
+ from ..utils.api_keys import APIKeyManager
14
+
15
+ # Import ModelConfig only for type checking to avoid circular imports
16
+ if TYPE_CHECKING:
17
+ from ..models.config import ModelConfig
18
+
19
+ from .base_llm import BaseLLMClient
20
+
21
+ class ProviderType(str, Enum):
22
+ OPENAI = "openai"
23
+ ANTHROPIC = "anthropic"
24
+ HUGGINGFACE = "huggingface"
25
+ VLLM = "vllm"
26
+ GOOGLE = "google"
27
+ GEMINI = "gemini"
28
+
29
+ class ErrorType(str, Enum):
30
+ API_ERROR = "api_error"
31
+ VALIDATION_ERROR = "validation_error"
32
+ NETWORK_ERROR = "network_error"
33
+ RATE_LIMIT = "rate_limit"
34
+ TIMEOUT = "timeout"
35
+
36
+ class GepaLLMError(Exception):
37
+ """Base exception for GEPA LLM related errors"""
38
+ def __init__(self, message: str, error_type: ErrorType, status_code: Optional[int] = None):
39
+ self.message = message
40
+ self.error_type = error_type
41
+ self.status_code = status_code
42
+ super().__init__(self.message)
43
+
44
+ def __str__(self):
45
+ if self.status_code:
46
+ return f"{self.error_type.value} (HTTP {self.status_code}): {self.message}"
47
+ return f"{self.error_type.value}: {self.message}"
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
52
+
53
+ class VisionLLMClient(BaseLLMClient):
54
+ """
55
+ A client for interacting with multi-modal Vision LLMs (e.g., OpenAI GPT-4 Vision).
56
+
57
+ Example:
58
+ ```python
59
+ # Basic usage
60
+ client = VisionLLMClient(
61
+ provider="openai",
62
+ model_name="gpt-4-vision-preview",
63
+ temperature=0.7,
64
+ max_tokens=2048
65
+ )
66
+
67
+ # With custom configuration
68
+ config = ModelConfig(
69
+ provider="openai",
70
+ model_name="gpt-4-vision-preview",
71
+ temperature=0.5,
72
+ max_tokens=1024
73
+ )
74
+ client = VisionLLMClient.from_config(config)
75
+ ```
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ provider: Union[str, ProviderType],
81
+ model_name: str,
82
+ api_key: Optional[str] = None,
83
+ base_url: Optional[str] = None,
84
+ temperature: float = 0.7,
85
+ max_tokens: int = 2048,
86
+ top_p: float = 1.0,
87
+ frequency_penalty: float = 0.0,
88
+ presence_penalty: float = 0.0,
89
+ timeout: int = 120, # Increase to 2 minutes for large prompts
90
+ max_retries: int = 3
91
+ ):
92
+ """
93
+ Initializes the VisionLLMClient with model configuration.
94
+
95
+ Args:
96
+ provider: The provider of the model (e.g., 'openai', 'anthropic')
97
+ model_name: The name of the multi-modal LLM model to use (e.g., "gpt-4-vision-preview").
98
+ api_key: Optional API key. If not provided, it will be fetched from APIKeyManager.
99
+ base_url: Optional base URL for the API endpoint.
100
+ temperature: Controls randomness in the response generation.
101
+ max_tokens: Maximum number of tokens to generate.
102
+ top_p: Controls diversity via nucleus sampling.
103
+ frequency_penalty: Penalizes repeated tokens.
104
+ presence_penalty: Penalizes new tokens based on their presence in the text so far.
105
+ """
106
+ # Initialize parent class
107
+ super().__init__(provider=str(provider), model_name=model_name, **{
108
+ 'api_key': api_key,
109
+ 'base_url': base_url,
110
+ 'temperature': temperature,
111
+ 'max_tokens': max_tokens,
112
+ 'top_p': top_p,
113
+ 'frequency_penalty': frequency_penalty,
114
+ 'presence_penalty': presence_penalty,
115
+ 'timeout': timeout,
116
+ 'max_retries': max_retries
117
+ })
118
+
119
+ # Initialize the actual client
120
+ self._initialize_client(provider, model_name, api_key, base_url, temperature,
121
+ max_tokens, top_p, frequency_penalty, presence_penalty,
122
+ timeout, max_retries)
123
+
124
+ def _initialize_client(self, provider, model_name, api_key, base_url, temperature,
125
+ max_tokens, top_p, frequency_penalty, presence_penalty,
126
+ timeout, max_retries):
127
+ """Initialize the actual client (existing logic)"""
128
+ # Input validation
129
+ try:
130
+ self.provider = ProviderType(provider.lower())
131
+ except ValueError:
132
+ raise GepaLLMError(
133
+ f"Unsupported provider: {provider}. "
134
+ f"Supported providers: {[p.value for p in ProviderType]}",
135
+ ErrorType.VALIDATION_ERROR
136
+ )
137
+
138
+ if not model_name:
139
+ raise GepaLLMError("model_name cannot be empty", ErrorType.VALIDATION_ERROR)
140
+
141
+ if not isinstance(temperature, (int, float)) or not 0 <= temperature <= 2:
142
+ raise GepaLLMError(
143
+ f"temperature must be between 0 and 2, got {temperature}",
144
+ ErrorType.VALIDATION_ERROR
145
+ )
146
+
147
+ if not isinstance(max_tokens, int) or max_tokens <= 0:
148
+ raise GepaLLMError(
149
+ f"max_tokens must be a positive integer, got {max_tokens}",
150
+ ErrorType.VALIDATION_ERROR
151
+ )
152
+
153
+ # Initialize API key
154
+ try:
155
+ self.api_key = api_key or APIKeyManager().get_api_key(self.provider.value)
156
+ if not self.api_key:
157
+ raise GepaLLMError(
158
+ f"No API key found for provider: {self.provider}",
159
+ ErrorType.VALIDATION_ERROR
160
+ )
161
+ except Exception as e:
162
+ raise GepaLLMError(
163
+ f"Failed to initialize API key: {str(e)}",
164
+ ErrorType.API_ERROR
165
+ ) from e
166
+
167
+ self.model_name = model_name
168
+ self.base_url = base_url or OPENAI_API_URL
169
+ self.temperature = temperature
170
+ self.max_tokens = max_tokens
171
+ self.top_p = top_p
172
+ self.frequency_penalty = frequency_penalty
173
+ self.presence_penalty = presence_penalty
174
+ self.timeout = timeout
175
+ self.max_retries = max_retries
176
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
177
+
178
+ # Configure session with retry
179
+ self.session = requests.Session()
180
+ retry_strategy = requests.adapters.Retry(
181
+ total=max_retries,
182
+ backoff_factor=1,
183
+ status_forcelist=[429, 500, 502, 503, 504],
184
+ allowed_methods=["POST"]
185
+ )
186
+ adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy)
187
+ self.session.mount("https://", adapter)
188
+ self.session.mount("http://", adapter)
189
+
190
+ # No hardcoded model restrictions - user can specify any model name
191
+ # The API provider will validate if the model exists and supports vision
192
+
193
+ def _get_api_key(self) -> Optional[str]:
194
+ """Get API key based on provider"""
195
+ if self.provider == 'openai':
196
+ return APIKeyManager().get_api_key('openai')
197
+ elif self.provider == 'anthropic':
198
+ return APIKeyManager().get_api_key('anthropic')
199
+ elif self.provider in ['google', 'gemini']:
200
+ return APIKeyManager().get_api_key('google')
201
+ # Add other providers as needed
202
+ return None
203
+
204
+ @classmethod
205
+ def from_config(cls, config: 'ModelConfig') -> 'VisionLLMClient':
206
+ """Create a VisionLLMClient from a ModelConfig object.
207
+
208
+ Args:
209
+ config: ModelConfig instance with provider and model settings
210
+
211
+ Returns:
212
+ Configured VisionLLMClient instance
213
+
214
+ Example:
215
+ ```python
216
+ config = ModelConfig(
217
+ provider="openai",
218
+ model_name="gpt-4-vision-preview",
219
+ temperature=0.7
220
+ )
221
+ client = VisionLLMClient.from_config(config)
222
+ ```
223
+ """
224
+ return cls(
225
+ provider=config.provider,
226
+ model_name=config.model_name,
227
+ api_key=config.api_key,
228
+ base_url=config.base_url,
229
+ temperature=config.temperature,
230
+ max_tokens=config.max_tokens,
231
+ top_p=config.top_p,
232
+ frequency_penalty=config.frequency_penalty,
233
+ presence_penalty=config.presence_penalty
234
+ )
235
+
236
+ @classmethod
237
+ def from_model_string(cls, model_string: str, **kwargs) -> 'VisionLLMClient':
238
+ """Create a VisionLLMClient from a model string like "provider/model-name".
239
+
240
+ Args:
241
+ model_string: Model identifier in format "provider/model-name" or just "model-name"
242
+ Examples: "google/gemini-2.0-flash", "openai/gpt-4o", "gemini-1.5-pro"
243
+ **kwargs: Additional configuration options (temperature, max_tokens, etc.)
244
+
245
+ Returns:
246
+ Configured VisionLLMClient instance
247
+
248
+ Example:
249
+ ```python
250
+ # With provider
251
+ client = VisionLLMClient.from_model_string("google/gemini-2.0-flash")
252
+
253
+ # Without provider (defaults to openai)
254
+ client = VisionLLMClient.from_model_string("gpt-4o")
255
+
256
+ # With additional options
257
+ client = VisionLLMClient.from_model_string(
258
+ "google/gemini-2.0-flash",
259
+ temperature=0.5,
260
+ max_tokens=4096
261
+ )
262
+ ```
263
+ """
264
+ import os
265
+
266
+ # Parse "provider/model-name" format
267
+ if "/" in model_string:
268
+ provider, model_name = model_string.split("/", 1)
269
+ else:
270
+ # Default to openai if no provider specified
271
+ provider = "openai"
272
+ model_name = model_string
273
+
274
+ # Normalize provider names
275
+ provider = provider.lower()
276
+ if provider == "gemini":
277
+ provider = "google"
278
+
279
+ # Get API key from environment if not provided
280
+ api_key = kwargs.pop('api_key', None)
281
+ if not api_key:
282
+ env_var_map = {
283
+ "openai": "OPENAI_API_KEY",
284
+ "anthropic": "ANTHROPIC_API_KEY",
285
+ "google": "GOOGLE_API_KEY",
286
+ }
287
+ env_var = env_var_map.get(provider, f"{provider.upper()}_API_KEY")
288
+ api_key = os.getenv(env_var)
289
+
290
+ return cls(
291
+ provider=provider,
292
+ model_name=model_name,
293
+ api_key=api_key,
294
+ **kwargs
295
+ )
296
+
297
+ def generate(
298
+ self,
299
+ system_prompt: str,
300
+ user_prompt: str,
301
+ image_base64: Optional[str] = None,
302
+ **generation_kwargs
303
+ ) -> Dict[str, Any]:
304
+ """
305
+ Generates a response from the Vision LLM.
306
+
307
+ Args:
308
+ system_prompt: The system-level instructions for the LLM.
309
+ user_prompt: The user's query or task.
310
+ image_base64: Optional Base64 encoded image string.
311
+ **generation_kwargs: Additional model-specific generation parameters
312
+
313
+ Returns:
314
+ A dictionary containing the generated response and metadata.
315
+
316
+ Raises:
317
+ GepaLLMError: If there's an error during generation
318
+
319
+ Example:
320
+ ```python
321
+ response = client.generate(
322
+ system_prompt="You are a helpful assistant.",
323
+ user_prompt="What's in this image?",
324
+ image_base64="base64_encoded_image"
325
+ )
326
+ ```
327
+ """
328
+ if not system_prompt or not user_prompt:
329
+ raise GepaLLMError(
330
+ "system_prompt and user_prompt are required",
331
+ ErrorType.VALIDATION_ERROR
332
+ )
333
+
334
+ try:
335
+ if self.provider == ProviderType.OPENAI:
336
+ return self._generate_openai(system_prompt, user_prompt, image_base64, **generation_kwargs)
337
+ elif self.provider in [ProviderType.GOOGLE, ProviderType.GEMINI]:
338
+ return self._generate_google(system_prompt, user_prompt, image_base64, **generation_kwargs)
339
+ else:
340
+ raise GepaLLMError(
341
+ f"Provider {self.provider} is not yet supported",
342
+ ErrorType.VALIDATION_ERROR
343
+ )
344
+ except requests.exceptions.RequestException as e:
345
+ self.logger.error(f"Network error during generation: {str(e)}")
346
+ raise GepaLLMError(
347
+ f"Network error: {str(e)}",
348
+ ErrorType.NETWORK_ERROR,
349
+ getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
350
+ ) from e
351
+ except GepaLLMError:
352
+ raise
353
+ except Exception as e:
354
+ self.logger.error(f"Unexpected error during generation: {str(e)}")
355
+ raise GepaLLMError(
356
+ f"Generation failed: {str(e)}",
357
+ ErrorType.API_ERROR
358
+ ) from e
359
+
360
+ def _generate_openai(
361
+ self,
362
+ system_prompt: str,
363
+ user_prompt: str,
364
+ image_base64: Optional[str] = None,
365
+ **generation_kwargs
366
+ ) -> Dict[str, Any]:
367
+ """
368
+ Generate response using OpenAI's API with configured parameters.
369
+
370
+ Args:
371
+ system_prompt: System instructions for the model
372
+ user_prompt: User's input prompt
373
+ image_base64: Optional base64 encoded image
374
+
375
+ Returns:
376
+ Dictionary containing the API response
377
+
378
+ Raises:
379
+ GepaDependencyError: If API call fails
380
+ """
381
+ headers = {
382
+ "Content-Type": "application/json",
383
+ "Authorization": f"Bearer {self.api_key}",
384
+ "User-Agent": "GepaOptimizer/1.0 (Python)"
385
+ }
386
+
387
+ messages = [
388
+ {"role": "system", "content": system_prompt},
389
+ {
390
+ "role": "user",
391
+ "content": [
392
+ {"type": "text", "text": user_prompt}
393
+ ]
394
+ }
395
+ ]
396
+
397
+ if image_base64:
398
+ # #region agent log
399
+ import json as _json_debug
400
+ import time as _time_debug
401
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
402
+ try:
403
+ with open(_debug_log_path, "a") as _f:
404
+ _f.write(_json_debug.dumps({
405
+ "id": f"log_{int(_time_debug.time() * 1000)}",
406
+ "timestamp": int(_time_debug.time() * 1000),
407
+ "location": "vision_llm.py:_generate_openai",
408
+ "message": "Image base64 BEFORE processing",
409
+ "data": {
410
+ "image_base64_length": len(image_base64) if image_base64 else 0,
411
+ "has_data_uri_prefix": image_base64.startswith("data:image") if image_base64 else False,
412
+ "prefix": image_base64[:50] if image_base64 and len(image_base64) > 50 else image_base64,
413
+ "is_none": image_base64 is None,
414
+ "is_empty": image_base64 == "" if image_base64 else True
415
+ },
416
+ "sessionId": "debug-session",
417
+ "runId": "run1",
418
+ "hypothesisId": "A,C,D"
419
+ }) + "\n")
420
+ except Exception:
421
+ pass
422
+ # #endregion
423
+
424
+ # Detect and extract image format
425
+ detected_format = "jpeg" # Default fallback
426
+ clean_base64 = image_base64
427
+
428
+ # Extract format from data URI prefix if present
429
+ if image_base64.startswith("data:image"):
430
+ # Parse format from prefix: data:image/png;base64,...
431
+ if "," in image_base64:
432
+ prefix_part = image_base64.split(",", 1)[0]
433
+ clean_base64 = image_base64.split(",", 1)[1]
434
+ # Extract format from "data:image/PNG;base64" or "data:image/png"
435
+ if "/" in prefix_part and ";" in prefix_part:
436
+ detected_format = prefix_part.split("/")[1].split(";")[0].lower()
437
+ elif "/" in prefix_part:
438
+ detected_format = prefix_part.split("/")[1].lower()
439
+ else:
440
+ # Fallback: try to extract format
441
+ if "/" in image_base64:
442
+ detected_format = image_base64.split("/")[1].split(";")[0].lower() if ";" in image_base64 else "jpeg"
443
+ clean_base64 = image_base64.replace("data:image/", "").replace(";base64", "")
444
+
445
+ # If no format detected from prefix, try to detect from image data
446
+ if detected_format == "jpeg" or not detected_format:
447
+ try:
448
+ import base64 as b64
449
+ from PIL import Image
450
+ import io
451
+ image_data = b64.b64decode(clean_base64)
452
+ img = Image.open(io.BytesIO(image_data))
453
+ if img.format:
454
+ detected_format = img.format.lower()
455
+ # Normalize format names
456
+ if detected_format in ["jpg", "jpeg"]:
457
+ detected_format = "jpeg"
458
+ except Exception:
459
+ # If detection fails, keep default
460
+ pass
461
+
462
+ # Normalize format for data URI (OpenAI accepts: jpeg, png, gif, webp)
463
+ format_map = {
464
+ "jpg": "jpeg",
465
+ "jpeg": "jpeg",
466
+ "png": "png",
467
+ "gif": "gif",
468
+ "webp": "webp",
469
+ "bmp": "png", # Convert BMP to PNG (OpenAI doesn't support BMP)
470
+ "tiff": "png", # Convert TIFF to PNG
471
+ "tif": "png"
472
+ }
473
+ final_format = format_map.get(detected_format, "jpeg")
474
+
475
+ final_url = f"data:image/{final_format};base64,{clean_base64}"
476
+
477
+ # #region agent log
478
+ try:
479
+ with open(_debug_log_path, "a") as _f:
480
+ _f.write(_json_debug.dumps({
481
+ "id": f"log_{int(_time_debug.time() * 1000)}",
482
+ "timestamp": int(_time_debug.time() * 1000),
483
+ "location": "vision_llm.py:_generate_openai",
484
+ "message": "Image URL AFTER processing",
485
+ "data": {
486
+ "detected_format": detected_format,
487
+ "final_format": final_format,
488
+ "clean_base64_length": len(clean_base64),
489
+ "final_url_length": len(final_url),
490
+ "final_url_prefix": final_url[:60]
491
+ },
492
+ "sessionId": "debug-session",
493
+ "runId": "run1",
494
+ "hypothesisId": "A,B"
495
+ }) + "\n")
496
+ except Exception:
497
+ pass
498
+ # #endregion
499
+
500
+ messages[1]["content"].append({
501
+ "type": "image_url",
502
+ "image_url": {
503
+ "url": final_url
504
+ }
505
+ })
506
+
507
+ payload = {
508
+ "model": self.model_name,
509
+ "messages": messages,
510
+ # "temperature": self.temperature,
511
+ # "max_tokens": self.max_tokens,
512
+ "top_p": self.top_p,
513
+ "frequency_penalty": self.frequency_penalty,
514
+ "presence_penalty": self.presence_penalty
515
+ }
516
+
517
+ self.logger.debug(f"Sending request to {self.base_url} with model {self.model_name}")
518
+
519
+ try:
520
+ self.logger.debug(f"Sending request to {self.model_name}")
521
+
522
+ # Make the API request with retry
523
+ response = self.session.post(
524
+ self.base_url,
525
+ headers=headers,
526
+ json=payload,
527
+ timeout=300
528
+ )
529
+
530
+ # Handle rate limiting
531
+ if response.status_code == 429:
532
+ retry_after = int(response.headers.get('Retry-After', 5))
533
+ self.logger.warning(f"Rate limited. Retrying after {retry_after} seconds...")
534
+ time.sleep(retry_after)
535
+ return self._generate_openai(system_prompt, user_prompt, image_base64, **generation_kwargs)
536
+
537
+ response.raise_for_status()
538
+
539
+ result = response.json()
540
+ self.logger.debug(f"Received response from {self.model_name}")
541
+
542
+ # Extract and validate the response
543
+ try:
544
+ message = result["choices"][0]["message"]
545
+ llm_response_content = message["content"]
546
+
547
+ # Log token usage if available
548
+ if "usage" in result:
549
+ usage = result["usage"]
550
+ self.logger.info(
551
+ f"Tokens used - Prompt: {usage.get('prompt_tokens', 'N/A')}, "
552
+ f"Completion: {usage.get('completion_tokens', 'N/A')}, "
553
+ f"Total: {usage.get('total_tokens', 'N/A')}"
554
+ )
555
+
556
+ # Try to parse JSON if the response looks like JSON
557
+ if isinstance(llm_response_content, str) and (
558
+ llm_response_content.startswith('{') or
559
+ llm_response_content.startswith('[')
560
+ ):
561
+ try:
562
+ return json.loads(llm_response_content)
563
+ except json.JSONDecodeError:
564
+ pass
565
+
566
+ # Default response format
567
+ return {
568
+ "content": llm_response_content,
569
+ "role": message.get("role", "assistant"),
570
+ "model": self.model_name,
571
+ "provider": self.provider.value
572
+ }
573
+
574
+ except (KeyError, IndexError) as e:
575
+ self.logger.error(f"Unexpected response format: {result}")
576
+ raise GepaLLMError(
577
+ f"Unexpected response format from {self.provider} API",
578
+ ErrorType.API_ERROR,
579
+ response.status_code
580
+ ) from e
581
+
582
+ except requests.exceptions.HTTPError as e:
583
+ status_code = e.response.status_code if hasattr(e, 'response') else None
584
+ error_msg = f"HTTP error {status_code} from {self.provider} API"
585
+
586
+ try:
587
+ error_data = e.response.json()
588
+ error_msg = error_data.get('error', {}).get('message', error_msg)
589
+ except Exception:
590
+ error_msg = str(e)
591
+
592
+ self.logger.error(f"{error_msg}: {error_data if 'error_data' in locals() else str(e)}")
593
+ raise GepaLLMError(
594
+ error_msg,
595
+ ErrorType.RATE_LIMIT if status_code == 429 else ErrorType.API_ERROR,
596
+ status_code
597
+ ) from e
598
+
599
+ except requests.exceptions.Timeout:
600
+ self.logger.error(f"Request to {self.provider} API timed out after {self.timeout} seconds")
601
+ raise GepaLLMError(
602
+ f"Request timed out after {self.timeout} seconds",
603
+ ErrorType.TIMEOUT
604
+ )
605
+
606
+ except requests.exceptions.RequestException as e:
607
+ self.logger.error(f"Network error: {str(e)}")
608
+ raise GepaLLMError(
609
+ f"Network error: {str(e)}",
610
+ ErrorType.NETWORK_ERROR
611
+ ) from e
612
+
613
+ except Exception as e:
614
+ self.logger.error(f"Unexpected error: {str(e)}", exc_info=True)
615
+ raise GepaLLMError(
616
+ f"Unexpected error: {str(e)}",
617
+ ErrorType.API_ERROR
618
+ ) from e
619
+
620
+ def _generate_google(
621
+ self,
622
+ system_prompt: str,
623
+ user_prompt: str,
624
+ image_base64: Optional[str] = None,
625
+ **generation_kwargs
626
+ ) -> Dict[str, Any]:
627
+ """
628
+ Generate response using Google Gemini API with configured parameters.
629
+
630
+ Args:
631
+ system_prompt: System instructions for the model
632
+ user_prompt: User's input prompt
633
+ image_base64: Optional base64 encoded image
634
+
635
+ Returns:
636
+ Dictionary containing the API response
637
+
638
+ Raises:
639
+ GepaLLMError: If API call fails
640
+ """
641
+ try:
642
+ import google.generativeai as genai
643
+ import base64
644
+ from PIL import Image
645
+ import io
646
+ except ImportError as e:
647
+ raise GepaLLMError(
648
+ f"Required dependencies for Google Gemini not installed: {str(e)}. "
649
+ f"Please install: pip install google-generativeai Pillow",
650
+ ErrorType.VALIDATION_ERROR
651
+ ) from e
652
+
653
+ # Configure Gemini
654
+ genai.configure(api_key=self.api_key)
655
+
656
+ # Use the model name directly as specified by the user
657
+ # No hardcoded mappings or restrictions - fully configurable
658
+ # The Gemini API will validate if the model exists
659
+ gemini_model_name = self.model_name
660
+
661
+ try:
662
+ model = genai.GenerativeModel(gemini_model_name)
663
+ except Exception as e:
664
+ raise GepaLLMError(
665
+ f"Failed to initialize Gemini model {gemini_model_name}: {str(e)}",
666
+ ErrorType.API_ERROR
667
+ ) from e
668
+
669
+ # Prepare content
670
+ content_parts = []
671
+
672
+ # Add system prompt and user prompt
673
+ full_prompt = f"{system_prompt}\n\n{user_prompt}"
674
+ content_parts.append(full_prompt)
675
+
676
+ # Add image if provided
677
+ if image_base64:
678
+ # #region agent log
679
+ import json as _json_debug
680
+ import time as _time_debug
681
+ _debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
682
+ try:
683
+ with open(_debug_log_path, "a") as _f:
684
+ _f.write(_json_debug.dumps({
685
+ "id": f"log_{int(_time_debug.time() * 1000)}",
686
+ "timestamp": int(_time_debug.time() * 1000),
687
+ "location": "vision_llm.py:_generate_google",
688
+ "message": "Image base64 BEFORE processing (Google)",
689
+ "data": {
690
+ "image_base64_length": len(image_base64) if image_base64 else 0,
691
+ "has_data_uri_prefix": image_base64.startswith("data:image") if image_base64 else False,
692
+ "prefix": image_base64[:50] if image_base64 and len(image_base64) > 50 else image_base64,
693
+ "is_none": image_base64 is None,
694
+ "is_empty": image_base64 == "" if image_base64 else True
695
+ },
696
+ "sessionId": "debug-session",
697
+ "runId": "run1",
698
+ "hypothesisId": "A,C,D"
699
+ }) + "\n")
700
+ except Exception:
701
+ pass
702
+ # #endregion
703
+
704
+ try:
705
+ # Strip data URI prefix if present (hypothesis A fix)
706
+ clean_base64 = image_base64
707
+ if image_base64.startswith("data:image"):
708
+ # Extract just the base64 part after the comma
709
+ if "," in image_base64:
710
+ clean_base64 = image_base64.split(",", 1)[1]
711
+ else:
712
+ clean_base64 = image_base64.replace("data:image/", "").replace(";base64", "")
713
+
714
+ # Decode base64 image
715
+ image_data = base64.b64decode(clean_base64)
716
+ image = Image.open(io.BytesIO(image_data))
717
+ content_parts.append(image)
718
+ self.logger.debug(f"Added image to Gemini request")
719
+ except Exception as e:
720
+ self.logger.warning(f"Failed to process image for Gemini: {str(e)}")
721
+ # Continue without image rather than failing
722
+
723
+ self.logger.debug(f"Sending request to Gemini model {gemini_model_name}")
724
+
725
+ try:
726
+ # Generate response with retry logic
727
+ max_retries = 3
728
+ for attempt in range(max_retries):
729
+ try:
730
+ # Configure generation parameters
731
+ generation_config = genai.types.GenerationConfig(
732
+ temperature=self.temperature,
733
+ max_output_tokens=self.max_tokens,
734
+ top_p=self.top_p,
735
+ )
736
+
737
+ response = model.generate_content(
738
+ content_parts,
739
+ generation_config=generation_config
740
+ )
741
+
742
+ # Check if response was blocked
743
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
744
+ raise GepaLLMError(
745
+ f"Gemini blocked the prompt: {response.prompt_feedback.block_reason}",
746
+ ErrorType.VALIDATION_ERROR
747
+ )
748
+
749
+ # Check if response was blocked
750
+ if not response.text:
751
+ if response.candidates and response.candidates[0].finish_reason:
752
+ finish_reason = response.candidates[0].finish_reason
753
+ if finish_reason == genai.types.FinishReason.SAFETY:
754
+ raise GepaLLMError(
755
+ "Gemini response blocked due to safety concerns",
756
+ ErrorType.VALIDATION_ERROR
757
+ )
758
+ elif finish_reason == genai.types.FinishReason.RECITATION:
759
+ raise GepaLLMError(
760
+ "Gemini response blocked due to recitation concerns",
761
+ ErrorType.VALIDATION_ERROR
762
+ )
763
+ raise GepaLLMError(
764
+ "Gemini returned empty response",
765
+ ErrorType.API_ERROR
766
+ )
767
+
768
+ self.logger.debug(f"Received response from Gemini model {gemini_model_name}")
769
+
770
+ # Log usage information if available
771
+ if hasattr(response, 'usage_metadata') and response.usage_metadata:
772
+ usage = response.usage_metadata
773
+ self.logger.info(
774
+ f"Tokens used - Prompt: {usage.prompt_token_count}, "
775
+ f"Completion: {usage.candidates_token_count}, "
776
+ f"Total: {usage.total_token_count}"
777
+ )
778
+
779
+ # Try to parse JSON if the response looks like JSON
780
+ response_text = response.text
781
+ if isinstance(response_text, str) and (
782
+ response_text.startswith('{') or
783
+ response_text.startswith('[')
784
+ ):
785
+ try:
786
+ return json.loads(response_text)
787
+ except json.JSONDecodeError:
788
+ pass
789
+
790
+ # Default response format
791
+ return {
792
+ "content": response_text,
793
+ "role": "assistant",
794
+ "model": gemini_model_name,
795
+ "provider": "google"
796
+ }
797
+
798
+ except Exception as e:
799
+ if attempt < max_retries - 1:
800
+ self.logger.warning(f"Gemini API attempt {attempt + 1} failed: {str(e)}. Retrying...")
801
+ time.sleep(2 ** attempt) # Exponential backoff
802
+ continue
803
+ else:
804
+ raise
805
+
806
+ except GepaLLMError:
807
+ raise
808
+ except Exception as e:
809
+ self.logger.error(f"Unexpected error with Gemini API: {str(e)}")
810
+ raise GepaLLMError(
811
+ f"Gemini API error: {str(e)}",
812
+ ErrorType.API_ERROR
813
+ ) from e
src/gepa_optimizer/models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Models module for GEPA Optimizer
3
+ """
4
+
5
+ from .config import ModelConfig, OptimizationConfig
6
+ from .dataset import DatasetItem
7
+ from .result import OptimizationResult, OptimizedResult
8
+
9
+ __all__ = [
10
+ "ModelConfig",
11
+ "OptimizationConfig",
12
+ "DatasetItem",
13
+ "OptimizationResult",
14
+ "OptimizedResult"
15
+ ]
src/gepa_optimizer/models/config.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration models for GEPA Optimizer
3
+ """
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Dict, Any, Union, Tuple
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ """Configuration for any LLM provider"""
12
+ provider: str # Required: "openai", "anthropic", "huggingface", "vllm", etc.
13
+ model_name: str # Required: actual model name
14
+ api_key: str # Required: API key for the provider
15
+ base_url: Optional[str] = None # Optional: custom endpoint URL
16
+ temperature: float = 0.7
17
+ max_tokens: int = 2048
18
+ top_p: float = 1.0
19
+ frequency_penalty: float = 0.0
20
+ presence_penalty: float = 0.0
21
+
22
+ def __post_init__(self):
23
+ """Validate required fields after initialization"""
24
+ if not self.provider:
25
+ raise ValueError("Provider is required (e.g., 'openai', 'anthropic', 'huggingface')")
26
+ if not self.model_name:
27
+ raise ValueError("Model name is required (e.g., 'gpt-4', 'claude-3-opus')")
28
+ if not self.api_key:
29
+ raise ValueError(f"API key is required for {self.provider} provider")
30
+
31
+ @classmethod
32
+ def from_string(cls, model_string: str) -> 'ModelConfig':
33
+ """Create ModelConfig from string like 'openai/gpt-4' or 'gpt-4'"""
34
+ if "/" in model_string:
35
+ provider, model_name = model_string.split("/", 1)
36
+ else:
37
+ # Default to OpenAI if no provider specified
38
+ provider = "openai"
39
+ model_name = model_string
40
+
41
+ # Get API key from environment
42
+ api_key = cls._get_api_key_for_provider(provider)
43
+ if not api_key:
44
+ raise ValueError(
45
+ f"No API key found for {provider}. Please set {provider.upper()}_API_KEY environment variable"
46
+ )
47
+
48
+ return cls(
49
+ provider=provider,
50
+ model_name=model_name,
51
+ api_key=api_key
52
+ )
53
+
54
+ @classmethod
55
+ def from_dict(cls, config_dict: dict) -> 'ModelConfig':
56
+ """Create ModelConfig from dictionary"""
57
+ return cls(**config_dict)
58
+
59
+ def to_dict(self) -> dict:
60
+ """Convert ModelConfig to dictionary"""
61
+ return {
62
+ 'provider': self.provider,
63
+ 'model_name': self.model_name,
64
+ 'api_key': self.api_key,
65
+ 'base_url': self.base_url,
66
+ 'temperature': self.temperature,
67
+ 'max_tokens': self.max_tokens,
68
+ 'top_p': self.top_p,
69
+ 'frequency_penalty': self.frequency_penalty,
70
+ 'presence_penalty': self.presence_penalty
71
+ }
72
+
73
+ @staticmethod
74
+ def _get_api_key_for_provider(provider: str) -> Optional[str]:
75
+ """Get API key for provider from environment variables"""
76
+ env_var_map = {
77
+ "openai": "OPENAI_API_KEY",
78
+ "anthropic": "ANTHROPIC_API_KEY",
79
+ "huggingface": "HUGGINGFACE_API_KEY",
80
+ "cohere": "COHERE_API_KEY",
81
+ "ai21": "AI21_API_KEY",
82
+ "together": "TOGETHER_API_KEY",
83
+ "replicate": "REPLICATE_API_TOKEN",
84
+ "groq": "GROQ_API_KEY",
85
+ "ollama": "OLLAMA_API_KEY"
86
+ }
87
+
88
+ env_var = env_var_map.get(provider.lower())
89
+ if env_var:
90
+ return os.getenv(env_var)
91
+
92
+ # Fallback: try generic pattern
93
+ return os.getenv(f"{provider.upper()}_API_KEY")
94
+
95
+ @dataclass
96
+ class DataSplitConfig:
97
+ """Configuration for dataset splitting into train/val/test sets
98
+
99
+ 🔥 ADAPTIVE SPLITTING: Automatically adjusts ratios based on dataset size for optimal results.
100
+ - Small datasets (< 15): Prioritizes validation set (70/25/5) for reliable candidate ranking
101
+ - Medium datasets (15-50): Balanced split (60/20/20)
102
+ - Large datasets (50+): More training data (70/15/15)
103
+ """
104
+
105
+ # Split ratios (must sum to 1.0) - used as defaults, but adaptive strategy overrides for small datasets
106
+ train_ratio: float = 0.6 # 60% for training (Dfeedback - reflection examples)
107
+ val_ratio: float = 0.2 # 20% for validation (Dpareto - Pareto selection)
108
+ test_ratio: float = 0.2 # 20% for test (held-out final evaluation)
109
+
110
+ # Minimum samples per split
111
+ min_train_samples: int = 3
112
+ min_val_samples: int = 3 # 🔥 INCREASED from 2 to 3 for more reliable validation scores
113
+ min_test_samples: int = 1 # 🔥 REDUCED from 2 to 1 (test set less critical, only used once)
114
+
115
+ # Strategy for handling small datasets
116
+ small_dataset_strategy: str = 'adaptive' # 🔥 DEFAULT: 'adaptive', 'duplicate_val', 'no_test', 'error'
117
+
118
+ def __post_init__(self):
119
+ """Validate split configuration"""
120
+ total = self.train_ratio + self.val_ratio + self.test_ratio
121
+ if not (0.99 <= total <= 1.01): # Allow small floating point errors
122
+ raise ValueError(
123
+ f"Split ratios must sum to 1.0, got {total:.3f} "
124
+ f"(train={self.train_ratio}, val={self.val_ratio}, test={self.test_ratio})"
125
+ )
126
+
127
+ if self.train_ratio <= 0 or self.val_ratio <= 0 or self.test_ratio < 0:
128
+ raise ValueError("Split ratios must be positive (test_ratio can be 0 to disable)")
129
+
130
+ if self.small_dataset_strategy not in {'adaptive', 'duplicate_val', 'no_test', 'error'}:
131
+ raise ValueError(
132
+ f"Invalid small_dataset_strategy: {self.small_dataset_strategy}. "
133
+ f"Must be 'adaptive', 'duplicate_val', 'no_test', or 'error'"
134
+ )
135
+
136
+ def get_adaptive_ratios(self, dataset_size: int) -> Tuple[float, float, float]:
137
+ """
138
+ 🔥 NEW: Get adaptive split ratios based on dataset size.
139
+
140
+ For prompt optimization:
141
+ - Small datasets (< 15): Prioritize validation (70/25/5) for reliable candidate ranking
142
+ - Medium (15-50): Balanced (60/20/20)
143
+ - Large (50+): More training (70/15/15)
144
+
145
+ Args:
146
+ dataset_size: Total number of samples in dataset
147
+
148
+ Returns:
149
+ Tuple of (train_ratio, val_ratio, test_ratio)
150
+ """
151
+ if dataset_size < 15:
152
+ # Small dataset: Prioritize validation for reliable candidate ranking
153
+ # Validation set is CRITICAL - used for every candidate evaluation
154
+ return (0.70, 0.25, 0.05) # 70% train, 25% val, 5% test
155
+ elif dataset_size < 50:
156
+ # Medium dataset: Balanced split
157
+ return (0.60, 0.20, 0.20) # 60% train, 20% val, 20% test
158
+ else:
159
+ # Large dataset: More training data, can reduce validation/test
160
+ return (0.70, 0.15, 0.15) # 70% train, 15% val, 15% test
161
+
162
+ def get_split_indices(self, dataset_size: int) -> Tuple[int, int, int, int]:
163
+ """
164
+ Calculate split indices for a dataset with adaptive ratios.
165
+
166
+ 🔥 ADAPTIVE SPLITTING: Automatically adjusts ratios based on dataset size.
167
+ This ensures optimal allocation:
168
+ - Small datasets: More validation samples for reliable ranking
169
+ - Medium datasets: Balanced split
170
+ - Large datasets: More training data
171
+
172
+ Args:
173
+ dataset_size: Total number of samples in dataset
174
+
175
+ Returns:
176
+ Tuple of (train_end, val_end, test_end, dataset_size) indices
177
+
178
+ Raises:
179
+ ValueError: If dataset is too small for configured splits
180
+ """
181
+ # 🔥 NEW: Use adaptive ratios if strategy is 'adaptive'
182
+ if self.small_dataset_strategy == 'adaptive':
183
+ train_ratio, val_ratio, test_ratio = self.get_adaptive_ratios(dataset_size)
184
+ else:
185
+ train_ratio, val_ratio, test_ratio = self.train_ratio, self.val_ratio, self.test_ratio
186
+
187
+ if dataset_size < self.min_train_samples + self.min_val_samples:
188
+ if self.small_dataset_strategy == 'error':
189
+ raise ValueError(
190
+ f"Dataset too small ({dataset_size} samples). "
191
+ f"Need at least {self.min_train_samples + self.min_val_samples} samples."
192
+ )
193
+
194
+ # Calculate ideal split points with adaptive ratios
195
+ train_end = max(self.min_train_samples, int(dataset_size * train_ratio))
196
+ val_end = train_end + max(self.min_val_samples, int(dataset_size * val_ratio))
197
+
198
+ # Adjust for small datasets
199
+ if val_end >= dataset_size:
200
+ if self.small_dataset_strategy in {'adaptive', 'duplicate_val'}:
201
+ # Ensure minimum validation samples, use remainder for test
202
+ val_end = min(dataset_size, train_end + self.min_val_samples)
203
+ test_end = dataset_size
204
+ elif self.small_dataset_strategy == 'no_test':
205
+ # No test set for small datasets
206
+ val_end = dataset_size
207
+ test_end = dataset_size
208
+ else: # error
209
+ raise ValueError(
210
+ f"Dataset too small ({dataset_size} samples) for train/val/test split. "
211
+ f"Need at least {self.min_train_samples + self.min_val_samples + self.min_test_samples} samples."
212
+ )
213
+ else:
214
+ test_end = dataset_size
215
+
216
+ return train_end, val_end, test_end, dataset_size
217
+
218
+ @dataclass
219
+ class OptimizationConfig:
220
+ """Configuration class for GEPA optimization process"""
221
+
222
+ # Core models - REQUIRED by user
223
+ model: Union[str, ModelConfig] # No default - user must specify
224
+ reflection_model: Union[str, ModelConfig] # No default - user must specify
225
+
226
+ # Optimization parameters - REQUIRED by user
227
+ max_iterations: int # No default - user decides their budget
228
+ max_metric_calls: int # No default - user sets their budget
229
+ batch_size: int # No default - user decides based on memory
230
+
231
+ # Dataset splitting configuration
232
+ data_split: DataSplitConfig = field(default_factory=DataSplitConfig)
233
+
234
+ # Reflection settings (separate from evaluation batch_size)
235
+ reflection_examples: int = 3 # Number of examples for each reflection (small!)
236
+
237
+ # Optional optimization settings with sensible fallbacks
238
+ early_stopping: bool = True
239
+ learning_rate: float = 0.01
240
+
241
+ # Multi-objective optimization
242
+ multi_objective: bool = False
243
+ objectives: List[str] = field(default_factory=lambda: ["accuracy"])
244
+
245
+ # Advanced settings
246
+ custom_metrics: Optional[Dict[str, Any]] = None
247
+ use_cache: bool = True
248
+ parallel_evaluation: bool = False
249
+
250
+ # Backwards compatibility (deprecated)
251
+ train_split_ratio: Optional[float] = None # Use data_split instead
252
+ min_dataset_size: int = 2
253
+
254
+ # Cost and budget - user controlled
255
+ max_cost_usd: Optional[float] = None
256
+ timeout_seconds: Optional[int] = None
257
+
258
+ # GEPA-specific optimization parameters (based on actual GEPA library)
259
+ candidate_selection_strategy: str = 'pareto' # Use Pareto selection strategy
260
+ skip_perfect_score: bool = False # Don't skip perfect scores (set to True for early stopping)
261
+ reflection_minibatch_size: Optional[int] = None # Will use reflection_examples if None
262
+ perfect_score: float = 1.0 # Perfect score threshold
263
+ module_selector: str = 'round_robin' # Component selection strategy
264
+ verbose: bool = True # Enable detailed GEPA logging
265
+
266
+ # Test set evaluation
267
+ evaluate_on_test: bool = True # Evaluate final prompt on held-out test set
268
+
269
+ # 🆕 LLEGO Genetic Operator Parameters (Optional - for faster convergence)
270
+ # Based on ICLR 2025 paper: "Decision Tree Induction Through LLMs via Semantically-Aware Evolution"
271
+ # Optimized for small datasets (6-10 samples)
272
+ use_llego_operators: bool = False # Enable LLEGO genetic operators
273
+
274
+ # 🔥 HYBRID MODE: Combine GEPA Reflection + LLEGO Operators
275
+ # When both enabled, candidates are generated from BOTH sources for maximum diversity
276
+ enable_gepa_reflection_with_llego: bool = False # Enable hybrid GEPA+LLEGO mode
277
+ num_gepa_reflection_candidates: int = 3 # Number of GEPA reflection candidates per iteration (default: 3 for better exploration, range: 2-5)
278
+
279
+ # Fitness-guided crossover parameters (FIX #3: Conservative alpha)
280
+ alpha: float = 0.05 # FIX #3: Fitness extrapolation (0.05 = 5% above best parent, realistic for prompt optimization)
281
+ n_crossover: int = 2 # Number of offspring from crossover per iteration
282
+
283
+ # Diversity-guided mutation parameters
284
+ tau: float = 8.0 # Diversity temperature (8.0 = moderate diversity, balanced exploration/exploitation)
285
+ nu: int = 3 # Parent arity (3 parents optimal for small populations ~6 samples)
286
+ n_mutation: int = 2 # Number of offspring from mutation per iteration (total 4 offspring with crossover)
287
+
288
+ # Population management (for genetic operators)
289
+ population_size: int = 8 # Size of prompt population (small but diverse for 6-sample dataset)
290
+
291
+ # 🆕 LLM-as-Judge configuration (Phase 2)
292
+ use_llm_as_judge: bool = True # Enable LLM-as-Judge feedback for detailed, actionable analysis
293
+ llm_as_judge_threshold: float = 0.8 # Use LLM-as-Judge for scores below this threshold
294
+ llm_as_judge_model: Optional[ModelConfig] = None # Optional: use different model (defaults to reflection_model)
295
+
296
+ # 🆕 Logging configuration (Phase 3)
297
+ log_level: str = "INFO" # Logging level: "DEBUG", "INFO", "WARNING", "ERROR"
298
+
299
+ def __post_init__(self):
300
+ """Validate and process configuration after initialization"""
301
+ # Handle backwards compatibility for train_split_ratio
302
+ if self.train_split_ratio is not None and self.train_split_ratio != 0.8:
303
+ import warnings
304
+ warnings.warn(
305
+ "train_split_ratio is deprecated. Use data_split=DataSplitConfig(...) instead. "
306
+ "Converting to 3-way split with your ratio.",
307
+ DeprecationWarning,
308
+ stacklevel=2
309
+ )
310
+ # Convert 2-way split to 3-way: use train_ratio, split remainder between val/test
311
+ remainder = 1.0 - self.train_split_ratio
312
+ self.data_split = DataSplitConfig(
313
+ train_ratio=self.train_split_ratio,
314
+ val_ratio=remainder * 0.5,
315
+ test_ratio=remainder * 0.5
316
+ )
317
+
318
+ # Convert string models to ModelConfig objects
319
+ self.model = self._parse_model_config(self.model, "model")
320
+ self.reflection_model = self._parse_model_config(self.reflection_model, "reflection_model")
321
+
322
+ # Set reflection_minibatch_size default
323
+ if self.reflection_minibatch_size is None:
324
+ self.reflection_minibatch_size = self.reflection_examples
325
+
326
+ # Validate required parameters
327
+ self._validate_required_params()
328
+
329
+ # Validate ranges
330
+ self._validate_ranges()
331
+
332
+ def _parse_model_config(self, model: Union[str, ModelConfig], field_name: str) -> ModelConfig:
333
+ """Parse string model specification into ModelConfig"""
334
+ if isinstance(model, ModelConfig):
335
+ return model
336
+
337
+ if isinstance(model, str):
338
+ # Parse "provider/model-name" format
339
+ if "/" in model:
340
+ provider, model_name = model.split("/", 1)
341
+ else:
342
+ # Default to openai if no provider specified
343
+ provider = "openai"
344
+ model_name = model
345
+
346
+ # Try to get API key from environment
347
+ api_key = self._get_api_key_for_provider(provider)
348
+ if not api_key:
349
+ raise ValueError(
350
+ f"No API key found for {provider}. Please set environment variable "
351
+ f"or provide ModelConfig with api_key for {field_name}"
352
+ )
353
+
354
+ return ModelConfig(
355
+ provider=provider,
356
+ model_name=model_name,
357
+ api_key=api_key
358
+ )
359
+
360
+ raise ValueError(f"{field_name} must be either a string or ModelConfig object")
361
+
362
+ def _get_api_key_for_provider(self, provider: str) -> Optional[str]:
363
+ """Get API key for provider from environment variables"""
364
+ return ModelConfig._get_api_key_for_provider(provider)
365
+
366
+ def _validate_required_params(self):
367
+ """Validate that all required parameters are provided"""
368
+ required_fields = {
369
+ "max_iterations": self.max_iterations,
370
+ "max_metric_calls": self.max_metric_calls,
371
+ "batch_size": self.batch_size,
372
+ }
373
+
374
+ for field_name, value in required_fields.items():
375
+ if value is None:
376
+ raise ValueError(f"{field_name} is required and must be specified by user")
377
+
378
+ def _validate_ranges(self):
379
+ """Validate parameter ranges"""
380
+ if self.max_iterations <= 0:
381
+ raise ValueError("max_iterations must be positive")
382
+
383
+ if self.max_metric_calls <= 0:
384
+ raise ValueError("max_metric_calls must be positive")
385
+
386
+ if self.batch_size <= 0:
387
+ raise ValueError("batch_size must be positive")
388
+
389
+ if self.reflection_examples <= 0 or self.reflection_examples > 10:
390
+ raise ValueError("reflection_examples must be between 1 and 10 (recommended: 2-5)")
391
+
392
+ if self.reflection_minibatch_size <= 0:
393
+ raise ValueError("reflection_minibatch_size must be positive")
394
+
395
+ if hasattr(self.model, 'max_tokens') and self.model.max_tokens <= 0:
396
+ raise ValueError("model.max_tokens must be a positive integer")
397
+
398
+ # Validate hybrid mode parameters
399
+ if self.enable_gepa_reflection_with_llego and not self.use_llego_operators:
400
+ raise ValueError("enable_gepa_reflection_with_llego requires use_llego_operators=True")
401
+
402
+ if self.num_gepa_reflection_candidates <= 0 or self.num_gepa_reflection_candidates > 5:
403
+ raise ValueError("num_gepa_reflection_candidates must be between 1 and 5 (recommended: 3 for balanced exploration)")
404
+
405
+ # Validate log_level
406
+ valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR"]
407
+ if self.log_level.upper() not in valid_log_levels:
408
+ raise ValueError(f"log_level must be one of {valid_log_levels}, got: {self.log_level}")
409
+
410
+ def validate_api_connectivity(self) -> Dict[str, bool]:
411
+ """Test API connectivity for both models"""
412
+ results = {}
413
+
414
+ for model_name, model_config in [("model", self.model), ("reflection_model", self.reflection_model)]:
415
+ try:
416
+ # This would be implemented to actually test the API
417
+ # For now, just check if we have the required info
418
+ if model_config.api_key and model_config.provider and model_config.model_name:
419
+ results[model_name] = True
420
+ else:
421
+ results[model_name] = False
422
+ except Exception:
423
+ results[model_name] = False
424
+
425
+ return results
426
+
427
+ def get_estimated_cost(self) -> Dict[str, Any]:
428
+ """Estimate cost based on configuration"""
429
+ # This would calculate estimated costs based on:
430
+ # - max_metric_calls
431
+ # - model pricing
432
+ # - expected tokens per call
433
+ return {
434
+ "max_calls": self.max_metric_calls,
435
+ "estimated_cost_range": "To be calculated based on provider pricing",
436
+ "cost_factors": {
437
+ "model_calls": self.max_metric_calls,
438
+ "reflection_calls": self.max_iterations,
439
+ "batch_size": self.batch_size
440
+ }
441
+ }
442
+
443
+ @classmethod
444
+ def create_example_config(cls, provider: str = "openai") -> str:
445
+ """Generate example configuration code for users"""
446
+ examples = {
447
+ "openai": '''
448
+ # Example OpenAI Configuration
449
+ config = OptimizationConfig(
450
+ model="openai/gpt-4-turbo", # or ModelConfig(...)
451
+ reflection_model="openai/gpt-4-turbo",
452
+ max_iterations=50, # Your choice based on budget
453
+ max_metric_calls=300, # Your choice based on budget
454
+ batch_size=8, # Your choice based on memory
455
+ early_stopping=True,
456
+ learning_rate=0.01
457
+ )
458
+ ''',
459
+ "anthropic": '''
460
+ # Example Anthropic Configuration
461
+ config = OptimizationConfig(
462
+ model=ModelConfig(
463
+ provider="anthropic",
464
+ model_name="claude-3-opus-20240229",
465
+ api_key="your-anthropic-key",
466
+ temperature=0.7
467
+ ),
468
+ reflection_model="anthropic/claude-3-sonnet-20240229",
469
+ max_iterations=30,
470
+ max_metric_calls=200,
471
+ batch_size=4
472
+ )
473
+ ''',
474
+ "mixed": '''
475
+ # Example Mixed Providers Configuration
476
+ config = OptimizationConfig(
477
+ model="openai/gpt-4-turbo", # Main model
478
+ reflection_model="anthropic/claude-3-opus", # Reflection model
479
+ max_iterations=25,
480
+ max_metric_calls=250,
481
+ batch_size=6,
482
+ max_cost_usd=100.0, # Budget limit
483
+ timeout_seconds=3600 # 1 hour limit
484
+ )
485
+ '''
486
+ }
487
+
488
+ return examples.get(provider, examples["openai"])
src/gepa_optimizer/models/dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset models for GEPA Optimizer
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, List, Optional
7
+ import uuid
8
+
9
+ @dataclass
10
+ class DatasetItem:
11
+ """Single item in a dataset"""
12
+
13
+ # Identifiers
14
+ item_id: str = field(default_factory=lambda: str(uuid.uuid4()))
15
+
16
+ # Core data
17
+ input_data: Any = ""
18
+ expected_output: Optional[str] = None
19
+ image_base64: Optional[str] = None
20
+
21
+ # Metadata
22
+ metadata: Dict[str, Any] = field(default_factory=dict)
23
+ tags: List[str] = field(default_factory=list)
24
+
25
+ # File references
26
+ file_paths: List[str] = field(default_factory=list)
27
+
28
+ # Quality indicators
29
+ quality_score: float = 1.0
30
+ is_validated: bool = False
31
+ validation_notes: List[str] = field(default_factory=list)
32
+
33
+ def __post_init__(self):
34
+ """Validate item after initialization"""
35
+ if self.quality_score < 0 or self.quality_score > 1:
36
+ raise ValueError("quality_score must be between 0 and 1")
37
+
38
+ def add_tag(self, tag: str):
39
+ """Add a tag to this item"""
40
+ if tag not in self.tags:
41
+ self.tags.append(tag)
42
+
43
+ def mark_validated(self, notes: Optional[List[str]] = None):
44
+ """Mark item as validated"""
45
+ self.is_validated = True
46
+ if notes:
47
+ self.validation_notes.extend(notes)
48
+
49
+ @dataclass
50
+ class ProcessedDataset:
51
+ """Dataset after processing for GEPA optimization"""
52
+
53
+ # Identifiers
54
+ dataset_id: str = field(default_factory=lambda: str(uuid.uuid4()))
55
+ name: str = "Untitled Dataset"
56
+
57
+ # Data
58
+ items: List[DatasetItem] = field(default_factory=list)
59
+ train_split: List[DatasetItem] = field(default_factory=list)
60
+ val_split: List[DatasetItem] = field(default_factory=list)
61
+
62
+ # Metadata
63
+ source_info: Dict[str, Any] = field(default_factory=dict)
64
+ processing_stats: Dict[str, Any] = field(default_factory=dict)
65
+
66
+ # Quality metrics
67
+ total_items: int = 0
68
+ validated_items: int = 0
69
+ avg_quality_score: float = 0.0
70
+
71
+ def __post_init__(self):
72
+ """Calculate derived fields"""
73
+ self.total_items = len(self.items)
74
+
75
+ if self.items:
76
+ self.validated_items = sum(1 for item in self.items if item.is_validated)
77
+ self.avg_quality_score = sum(item.quality_score for item in self.items) / len(self.items)
78
+
79
+ def get_stats(self) -> Dict[str, Any]:
80
+ """Get dataset statistics"""
81
+ return {
82
+ 'total_items': self.total_items,
83
+ 'validated_items': self.validated_items,
84
+ 'validation_rate': self.validated_items / self.total_items if self.total_items > 0 else 0,
85
+ 'avg_quality_score': self.avg_quality_score,
86
+ 'train_size': len(self.train_split),
87
+ 'val_size': len(self.val_split),
88
+ 'has_expected_outputs': sum(1 for item in self.items if item.expected_output),
89
+ }
src/gepa_optimizer/models/result.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Result models for GEPA Optimizer
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import Dict, Any, Optional, List
8
+ import uuid
9
+
10
+ @dataclass
11
+ class OptimizationResult:
12
+ """Complete optimization result with all metadata"""
13
+
14
+ # Identifiers
15
+ session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
16
+
17
+ # Core results
18
+ original_prompt: str = ""
19
+ optimized_prompt: str = ""
20
+
21
+ # Performance metrics
22
+ improvement_data: Dict[str, Any] = field(default_factory=dict)
23
+ baseline_metrics: Dict[str, float] = field(default_factory=dict)
24
+ final_metrics: Dict[str, float] = field(default_factory=dict)
25
+
26
+ # Process metadata
27
+ optimization_time: float = 0.0
28
+ dataset_size: int = 0
29
+ total_iterations: int = 0
30
+
31
+ # Status and error handling
32
+ status: str = "pending" # pending, running, completed, failed
33
+ error_message: Optional[str] = None
34
+
35
+ # Timestamps
36
+ created_at: datetime = field(default_factory=datetime.now)
37
+ completed_at: Optional[datetime] = None
38
+
39
+ # Reflection history
40
+ reflection_history: List[Dict[str, Any]] = field(default_factory=list)
41
+
42
+ # Cost and resource usage
43
+ estimated_cost: Optional[float] = None
44
+ api_calls_made: int = 0
45
+
46
+ def mark_completed(self):
47
+ """Mark optimization as completed"""
48
+ self.status = "completed"
49
+ self.completed_at = datetime.now()
50
+
51
+ def mark_failed(self, error: str):
52
+ """Mark optimization as failed"""
53
+ self.status = "failed"
54
+ self.error_message = error
55
+ self.completed_at = datetime.now()
56
+
57
+ class OptimizedResult:
58
+ """
59
+ User-facing result class that provides clean interface
60
+ """
61
+
62
+ def __init__(self,
63
+ original_prompt: str = "",
64
+ optimized_prompt: str = "",
65
+ improvement_data: Dict[str, Any] = None,
66
+ optimization_time: float = 0.0,
67
+ dataset_size: int = 0,
68
+ total_iterations: int = 0,
69
+ status: str = "pending",
70
+ error_message: Optional[str] = None,
71
+ detailed_result: Optional[OptimizationResult] = None,
72
+ session_id: Optional[str] = None):
73
+ """
74
+ Initialize OptimizedResult with individual parameters
75
+
76
+ Args:
77
+ original_prompt: Original seed prompt
78
+ optimized_prompt: Optimized prompt
79
+ improvement_data: Performance improvement data
80
+ optimization_time: Time taken for optimization
81
+ dataset_size: Size of dataset used
82
+ total_iterations: Number of optimization iterations
83
+ status: Optimization status
84
+ error_message: Error message if failed
85
+ detailed_result: Optional detailed OptimizationResult
86
+ session_id: Optional session ID
87
+ """
88
+ if improvement_data is None:
89
+ improvement_data = {}
90
+
91
+ # Create internal OptimizationResult
92
+ self._result = OptimizationResult(
93
+ session_id=session_id or str(uuid.uuid4()),
94
+ original_prompt=original_prompt,
95
+ optimized_prompt=optimized_prompt,
96
+ improvement_data=improvement_data,
97
+ optimization_time=optimization_time,
98
+ dataset_size=dataset_size,
99
+ total_iterations=total_iterations,
100
+ status=status,
101
+ error_message=error_message
102
+ )
103
+
104
+ # If detailed_result is provided, use it instead
105
+ if detailed_result is not None:
106
+ self._result = detailed_result
107
+
108
+ @property
109
+ def prompt(self) -> str:
110
+ """The optimized prompt ready for production use"""
111
+ return self._result.optimized_prompt
112
+
113
+ @property
114
+ def original_prompt(self) -> str:
115
+ """The original seed prompt for reference"""
116
+ return self._result.original_prompt
117
+
118
+ @property
119
+ def session_id(self) -> str:
120
+ """Unique session identifier"""
121
+ return self._result.session_id
122
+
123
+ @property
124
+ def improvement_data(self) -> Dict[str, Any]:
125
+ """Performance improvement data"""
126
+ return self._result.improvement_data
127
+
128
+ @property
129
+ def status(self) -> str:
130
+ """Optimization status"""
131
+ return self._result.status
132
+
133
+ @property
134
+ def error_message(self) -> Optional[str]:
135
+ """Error message if optimization failed"""
136
+ return self._result.error_message
137
+
138
+ @property
139
+ def is_successful(self) -> bool:
140
+ """Whether optimization completed successfully"""
141
+ return (
142
+ self._result.status == "completed" and
143
+ self._result.error_message is None
144
+ )
145
+
146
+ @property
147
+ def optimization_time(self) -> float:
148
+ """Time taken for optimization in seconds"""
149
+ return self._result.optimization_time
150
+
151
+ @property
152
+ def dataset_size(self) -> int:
153
+ """Size of dataset used for optimization"""
154
+ return self._result.dataset_size
155
+
156
+ @property
157
+ def total_iterations(self) -> int:
158
+ """Total optimization iterations performed"""
159
+ return self._result.total_iterations
160
+
161
+ @property
162
+ def estimated_cost(self) -> Optional[float]:
163
+ """Estimated cost in USD"""
164
+ return self._result.estimated_cost
165
+
166
+ def get_improvement_summary(self) -> Dict[str, Any]:
167
+ """Get summary of improvements made"""
168
+ summary = {
169
+ 'has_improvement': bool(self._result.improvement_data),
170
+ 'optimization_time': self.optimization_time,
171
+ 'iterations': self.total_iterations,
172
+ 'dataset_size': self.dataset_size
173
+ }
174
+
175
+ # Add improvement percentage if available
176
+ if 'improvement_percent' in self._result.improvement_data:
177
+ summary['improvement_percent'] = self._result.improvement_data['improvement_percent']
178
+
179
+ return summary
180
+
181
+ def get_reflection_summary(self) -> Dict[str, Any]:
182
+ """Get summary of reflection process"""
183
+ if not self._result.reflection_history:
184
+ return {'total_reflections': 0}
185
+
186
+ return {
187
+ 'total_reflections': len(self._result.reflection_history),
188
+ 'reflection_points': [
189
+ r.get('summary', 'No summary')
190
+ for r in self._result.reflection_history[:3] # First 3
191
+ ]
192
+ }
193
+
194
+ def get_detailed_result(self) -> OptimizationResult:
195
+ """Get the full detailed result for advanced users"""
196
+ return self._result
197
+
198
+ def __str__(self) -> str:
199
+ """String representation"""
200
+ status_emoji = "✅" if self.is_successful else "❌" if self.status == "failed" else "⏳"
201
+ return f"OptimizedResult({status_emoji} {self.status}, time={self.optimization_time:.2f}s)"
202
+
203
+ def __repr__(self) -> str:
204
+ return self.__str__()
src/gepa_optimizer/operators/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLEGO Genetic Operators for GEPA.
3
+
4
+ This module provides genetic operators for prompt optimization:
5
+ - FitnessGuidedCrossover: Combines high-performing prompts
6
+ - DiversityGuidedMutation: Explores diverse variations
7
+ - LLEGOIntegrationLayer: Manages the genetic algorithm workflow
8
+
9
+ Based on: Decision Tree Induction Through LLMs via Semantically-Aware Evolution (ICLR 2025)
10
+ """
11
+
12
+ # Base interfaces (SOLID: Interface Segregation)
13
+ from .base_operator import (
14
+ BaseGeneticOperator,
15
+ BaseCrossoverOperator,
16
+ BaseMutationOperator,
17
+ )
18
+
19
+ # Data models
20
+ from .models import (
21
+ PromptCandidate,
22
+ PromptMetadata,
23
+ )
24
+
25
+ # Concrete operators (SOLID: Single Responsibility)
26
+ from .crossover import FitnessGuidedCrossover
27
+ from .mutation import DiversityGuidedMutation
28
+
29
+ # Integration layer
30
+ from .llego_operators import LLEGOIntegrationLayer
31
+
32
+ __all__ = [
33
+ # Base interfaces
34
+ 'BaseGeneticOperator',
35
+ 'BaseCrossoverOperator',
36
+ 'BaseMutationOperator',
37
+ # Data models
38
+ 'PromptCandidate',
39
+ 'PromptMetadata',
40
+ # Operators
41
+ 'FitnessGuidedCrossover',
42
+ 'DiversityGuidedMutation',
43
+ # Integration
44
+ 'LLEGOIntegrationLayer',
45
+ ]
src/gepa_optimizer/operators/base_operator.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Genetic Operator Interface.
3
+
4
+ Defines the abstract interface for all genetic operators following
5
+ the Interface Segregation Principle (ISP) of SOLID.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import List, Callable
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class BaseGeneticOperator(ABC):
16
+ """
17
+ Abstract base class for genetic operators.
18
+
19
+ All genetic operators (crossover, mutation, etc.) should inherit from this
20
+ class and implement the __call__ method.
21
+
22
+ Design Principles:
23
+ - Single Responsibility: Each operator does one thing
24
+ - Open/Closed: Extend via inheritance, don't modify
25
+ - Liskov Substitution: Any operator works where base is expected
26
+ - Interface Segregation: Minimal required interface
27
+ - Dependency Inversion: Depend on abstractions (LLM callable)
28
+ """
29
+
30
+ @abstractmethod
31
+ def __call__(self, *args, **kwargs) -> str:
32
+ """
33
+ Execute the genetic operation.
34
+
35
+ Returns:
36
+ str: New prompt generated by the operation
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def _build_prompt(self, *args, **kwargs) -> str:
42
+ """
43
+ Build the LLM prompt for this operation.
44
+
45
+ Returns:
46
+ str: Prompt to send to the LLM
47
+ """
48
+ pass
49
+
50
+
51
+ class BaseCrossoverOperator(BaseGeneticOperator):
52
+ """
53
+ Abstract base class for crossover operators.
54
+
55
+ Crossover combines multiple parent prompts to create offspring
56
+ that inherit good traits from both parents.
57
+ """
58
+
59
+ @abstractmethod
60
+ def __call__(
61
+ self,
62
+ parents: List, # List[PromptCandidate]
63
+ target_fitness: float,
64
+ llm: Callable[[str], str]
65
+ ) -> str:
66
+ """
67
+ Combine parent prompts to create offspring.
68
+
69
+ Args:
70
+ parents: List of parent PromptCandidate objects
71
+ target_fitness: Desired fitness for offspring
72
+ llm: Language model callable
73
+
74
+ Returns:
75
+ str: Offspring prompt
76
+ """
77
+ pass
78
+
79
+
80
+ class BaseMutationOperator(BaseGeneticOperator):
81
+ """
82
+ Abstract base class for mutation operators.
83
+
84
+ Mutation creates variations of a parent prompt to explore
85
+ new regions of the search space.
86
+ """
87
+
88
+ @abstractmethod
89
+ def __call__(
90
+ self,
91
+ parent, # PromptCandidate
92
+ population: List, # List[PromptCandidate]
93
+ llm: Callable[[str], str]
94
+ ) -> str:
95
+ """
96
+ Mutate a parent prompt to create a variation.
97
+
98
+ Args:
99
+ parent: Parent PromptCandidate to mutate
100
+ population: Current population for diversity guidance
101
+ llm: Language model callable
102
+
103
+ Returns:
104
+ str: Mutated prompt
105
+ """
106
+ pass
107
+
src/gepa_optimizer/operators/crossover.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fitness-Guided Crossover Operator.
3
+
4
+ Adapts LLEGO's fitness-guided crossover for text prompts.
5
+ Based on: Decision Tree Induction Through LLMs via Semantically-Aware Evolution (ICLR 2025)
6
+ """
7
+
8
+ from typing import List, Callable, TYPE_CHECKING
9
+ import logging
10
+
11
+ from .base_operator import BaseCrossoverOperator
12
+
13
+ if TYPE_CHECKING:
14
+ from .models import PromptCandidate
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class FitnessGuidedCrossover(BaseCrossoverOperator):
20
+ """
21
+ Fitness-guided crossover for text prompts.
22
+
23
+ Combines high-performing parent prompts to generate offspring
24
+ that target specific fitness levels using LLM semantic understanding.
25
+
26
+ From LLEGO paper:
27
+ "Fitness-guided crossover exploits high-performing regions of the search space
28
+ by combining parent trees targeting a desired fitness level f* = f_max + α(f_max - f_min)"
29
+
30
+ Reference: https://github.com/nicolashuynh/LLEGO
31
+ """
32
+
33
+ def __init__(self, alpha: float = 0.1):
34
+ """
35
+ Initialize crossover operator.
36
+
37
+ Args:
38
+ alpha: Fitness extrapolation parameter.
39
+ Higher α = target higher fitness than parents.
40
+ Default 0.1 from LLEGO paper (target 10% above best parent).
41
+ """
42
+ self.alpha = alpha
43
+ logger.debug(f"FitnessGuidedCrossover initialized with α={alpha}")
44
+
45
+ def __call__(
46
+ self,
47
+ parents: List["PromptCandidate"],
48
+ target_fitness: float,
49
+ llm: Callable[[str], str]
50
+ ) -> str:
51
+ """
52
+ Combine parent prompts targeting specific fitness.
53
+
54
+ Args:
55
+ parents: List of PromptCandidate objects (2+ parents)
56
+ target_fitness: Desired fitness for offspring
57
+ llm: Language model callable
58
+
59
+ Returns:
60
+ str: Offspring prompt
61
+
62
+ Raises:
63
+ ValueError: If fewer than 2 parents provided
64
+ """
65
+ if len(parents) < 2:
66
+ raise ValueError("Crossover requires at least 2 parents")
67
+
68
+ # Sort parents by fitness (best first)
69
+ sorted_parents = sorted(parents, key=lambda p: p.fitness, reverse=True)
70
+
71
+ logger.debug(f"Crossover: {len(parents)} parents, target fitness={target_fitness:.3f}")
72
+
73
+ # Build crossover prompt and call LLM
74
+ crossover_prompt = self._build_prompt(sorted_parents, target_fitness)
75
+ new_prompt = llm(crossover_prompt)
76
+
77
+ return new_prompt
78
+
79
+ def _build_prompt(
80
+ self,
81
+ parents: List["PromptCandidate"],
82
+ target_fitness: float
83
+ ) -> str:
84
+ """
85
+ Build LLM prompt for crossover operation.
86
+
87
+ Args:
88
+ parents: Sorted list of parent candidates (best first)
89
+ target_fitness: Target fitness for offspring
90
+
91
+ Returns:
92
+ str: Prompt for LLM
93
+ """
94
+ # Truncate parents to prevent safety filter issues
95
+ MAX_PARENT_LENGTH = 350
96
+
97
+ # Build parent descriptions (limit to top 2)
98
+ parent_descriptions = []
99
+ for i, parent in enumerate(parents[:2]):
100
+ truncated = parent.prompt[:MAX_PARENT_LENGTH]
101
+ if len(parent.prompt) > MAX_PARENT_LENGTH:
102
+ truncated += "..."
103
+ parent_descriptions.append(
104
+ f"P{i+1} (f={parent.fitness:.2f}): {truncated}\n"
105
+ )
106
+
107
+ prompt = f"""Combine these prompts into ONE improved version (target fitness: {target_fitness:.2f}).
108
+
109
+ {' '.join(parent_descriptions)}
110
+ Instructions:
111
+ 1. Merge the best rules/principles from both parents
112
+ 2. Organize logic clearly (e.g., "For X tasks: do Y", "If Z: then A")
113
+ 3. Add structure to handle different cases systematically
114
+ 4. Keep output format (Element: X, Description:, Reason:)
115
+ 5. Max 600 chars
116
+
117
+ Output ONLY the combined prompt:"""
118
+
119
+ return prompt
120
+