Spaces:
Sleeping
Sleeping
Commit
·
cacd4d0
0
Parent(s):
Deploy Universal Prompt Optimizer to HF Spaces (clean)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +27 -0
- README.md +44 -0
- app.py +1563 -0
- requirements.txt +23 -0
- src/gepa_optimizer.egg-info/PKG-INFO +439 -0
- src/gepa_optimizer.egg-info/SOURCES.txt +65 -0
- src/gepa_optimizer.egg-info/dependency_links.txt +1 -0
- src/gepa_optimizer.egg-info/entry_points.txt +2 -0
- src/gepa_optimizer.egg-info/requires.txt +29 -0
- src/gepa_optimizer.egg-info/top_level.txt +1 -0
- src/gepa_optimizer/__init__.py +295 -0
- src/gepa_optimizer/cli.py +239 -0
- src/gepa_optimizer/core/__init__.py +8 -0
- src/gepa_optimizer/core/base_adapter.py +85 -0
- src/gepa_optimizer/core/custom_adapter.py +389 -0
- src/gepa_optimizer/core/optimizer.py +1279 -0
- src/gepa_optimizer/core/result.py +180 -0
- src/gepa_optimizer/core/universal_adapter.py +0 -0
- src/gepa_optimizer/data/__init__.py +27 -0
- src/gepa_optimizer/data/converters.py +265 -0
- src/gepa_optimizer/data/index_caching_loader.py +278 -0
- src/gepa_optimizer/data/loaders.py +237 -0
- src/gepa_optimizer/data/scroll_dataset_loader.py +334 -0
- src/gepa_optimizer/data/validation_dataset_loader.py +376 -0
- src/gepa_optimizer/data/validators.py +207 -0
- src/gepa_optimizer/evaluation/__init__.py +28 -0
- src/gepa_optimizer/evaluation/base_evaluator.py +51 -0
- src/gepa_optimizer/evaluation/index_caching_evaluator.py +357 -0
- src/gepa_optimizer/evaluation/scroll_evaluator.py +251 -0
- src/gepa_optimizer/evaluation/ui_evaluator.py +297 -0
- src/gepa_optimizer/evaluation/universal_evaluator.py +911 -0
- src/gepa_optimizer/evaluation/validation_evaluator.py +495 -0
- src/gepa_optimizer/infrastructure/__init__.py +15 -0
- src/gepa_optimizer/infrastructure/logging/__init__.py +43 -0
- src/gepa_optimizer/infrastructure/logging/context.py +257 -0
- src/gepa_optimizer/infrastructure/logging/formatters.py +259 -0
- src/gepa_optimizer/infrastructure/logging/logger.py +260 -0
- src/gepa_optimizer/llms/__init__.py +10 -0
- src/gepa_optimizer/llms/base_llm.py +56 -0
- src/gepa_optimizer/llms/batch_llm.py +712 -0
- src/gepa_optimizer/llms/llego_enhanced_llm.py +1625 -0
- src/gepa_optimizer/llms/vision_llm.py +813 -0
- src/gepa_optimizer/models/__init__.py +15 -0
- src/gepa_optimizer/models/config.py +488 -0
- src/gepa_optimizer/models/dataset.py +89 -0
- src/gepa_optimizer/models/result.py +204 -0
- src/gepa_optimizer/operators/__init__.py +45 -0
- src/gepa_optimizer/operators/base_operator.py +107 -0
- src/gepa_optimizer/operators/crossover.py +120 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
|
| 8 |
+
# Virtual environments
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
| 11 |
+
ENV/
|
| 12 |
+
|
| 13 |
+
# IDE
|
| 14 |
+
.vscode/
|
| 15 |
+
.idea/
|
| 16 |
+
*.swp
|
| 17 |
+
*.swo
|
| 18 |
+
|
| 19 |
+
# OS
|
| 20 |
+
.DS_Store
|
| 21 |
+
Thumbs.db
|
| 22 |
+
|
| 23 |
+
# Build artifacts
|
| 24 |
+
*.egg-info/
|
| 25 |
+
dist/
|
| 26 |
+
build/
|
| 27 |
+
|
README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
title: Universal Prompt Optimizer
|
| 5 |
+
emoji: 🧬
|
| 6 |
+
colorFrom: blue
|
| 7 |
+
colorTo: cyan
|
| 8 |
+
sdk: gradio
|
| 9 |
+
sdk_version: 4.0.0
|
| 10 |
+
app_file: app.py
|
| 11 |
+
pinned: false
|
| 12 |
+
license: mit
|
| 13 |
+
---
|
| 14 |
+
# Universal Prompt Optimizer
|
| 15 |
+
|
| 16 |
+
A powerful genetic evolutionary prompt optimization tool built with GEPA (Genetic Evolutionary Prompt Agent). Optimize your prompts using genetic algorithms with optional LLEGO crossover for faster convergence.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- 🧬 **Genetic Algorithm Optimization**: Evolve prompts through multiple iterations
|
| 21 |
+
- 🎯 **Multi-Model Support**: Works with OpenAI, Anthropic, Google, and custom models
|
| 22 |
+
- 📊 **Real-time Metrics**: Track optimization progress and improvements
|
| 23 |
+
- 🖼️ **Multi-modal Support**: Include images in your training examples
|
| 24 |
+
- ⚡ **LLEGO Crossover**: Advanced genetic operations for faster convergence
|
| 25 |
+
|
| 26 |
+
## How to Use
|
| 27 |
+
|
| 28 |
+
1. **Select Model**: Choose your target LLM (GPT-4, Claude, Gemini, or custom)
|
| 29 |
+
2. **Enter Seed Prompt**: Describe your task, constraints, and desired output format
|
| 30 |
+
3. **Add Training Examples**: Provide input/output pairs (images optional)
|
| 31 |
+
4. **Configure Optimization**: Set evolution rounds, batch size, and enable LLEGO
|
| 32 |
+
5. **Start Optimization**: Watch as the genetic algorithm evolves your prompt
|
| 33 |
+
|
| 34 |
+
## API Keys
|
| 35 |
+
|
| 36 |
+
API keys are stored in-session only and never logged. You can provide them in the UI or set them as environment variables:
|
| 37 |
+
|
| 38 |
+
- `OPENAI_API_KEY`
|
| 39 |
+
- `ANTHROPIC_API_KEY`
|
| 40 |
+
- `GOOGLE_API_KEY`
|
| 41 |
+
|
| 42 |
+
## License
|
| 43 |
+
|
| 44 |
+
MIT License
|
app.py
ADDED
|
@@ -0,0 +1,1563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
🚀 Universal Prompt Optimizer - Enhanced Production UI v8.0
|
| 3 |
+
Principal Engineer Edition: Linear/Vercel-style Dark Mode with Premium UX
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
# Add src directory to Python path for gepa_optimizer imports
|
| 9 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import json
|
| 13 |
+
import base64
|
| 14 |
+
import io
|
| 15 |
+
import os
|
| 16 |
+
import logging
|
| 17 |
+
import traceback
|
| 18 |
+
import html
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image as PILImage
|
| 21 |
+
from typing import List, Dict, Optional, Any, Tuple
|
| 22 |
+
import threading
|
| 23 |
+
from collections import deque
|
| 24 |
+
|
| 25 |
+
# Optional import for URL image downloads
|
| 26 |
+
try:
|
| 27 |
+
import requests
|
| 28 |
+
REQUESTS_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
REQUESTS_AVAILABLE = False
|
| 31 |
+
|
| 32 |
+
# ==========================================
|
| 33 |
+
# 0. LOGGING & BACKEND UTILS
|
| 34 |
+
# ==========================================
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=logging.INFO,
|
| 37 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# Global Candidates Store (Thread-safe)
|
| 42 |
+
_candidates_store = {
|
| 43 |
+
'candidates': deque(maxlen=100),
|
| 44 |
+
'lock': threading.Lock(),
|
| 45 |
+
'iteration': 0
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def add_candidate_to_store(candidate: Dict[str, Any]):
|
| 49 |
+
with _candidates_store['lock']:
|
| 50 |
+
_candidates_store['candidates'].append({
|
| 51 |
+
'iteration': _candidates_store['iteration'],
|
| 52 |
+
'source': candidate.get('source', 'unknown'),
|
| 53 |
+
'prompt': candidate.get('prompt', ''),
|
| 54 |
+
'timestamp': candidate.get('timestamp', ''),
|
| 55 |
+
'index': len(_candidates_store['candidates']) + 1
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
def get_candidates_from_store() -> List[Dict[str, Any]]:
|
| 59 |
+
with _candidates_store['lock']:
|
| 60 |
+
return list(_candidates_store['candidates'])
|
| 61 |
+
|
| 62 |
+
def clear_candidates_store():
|
| 63 |
+
with _candidates_store['lock']:
|
| 64 |
+
_candidates_store['candidates'].clear()
|
| 65 |
+
_candidates_store['iteration'] = 0
|
| 66 |
+
|
| 67 |
+
def increment_iteration():
|
| 68 |
+
with _candidates_store['lock']:
|
| 69 |
+
_candidates_store['iteration'] += 1
|
| 70 |
+
|
| 71 |
+
# ==========================================
|
| 72 |
+
# 1. MOCK BACKEND (Kept as provided)
|
| 73 |
+
# ==========================================
|
| 74 |
+
try:
|
| 75 |
+
from gepa_optimizer import quick_optimize_sync, OptimizedResult
|
| 76 |
+
BACKEND_AVAILABLE = True
|
| 77 |
+
except ImportError:
|
| 78 |
+
BACKEND_AVAILABLE = False
|
| 79 |
+
from dataclasses import dataclass
|
| 80 |
+
|
| 81 |
+
@dataclass
|
| 82 |
+
class OptimizedResult:
|
| 83 |
+
optimized_prompt: str
|
| 84 |
+
improvement_metrics: dict
|
| 85 |
+
iteration_history: list
|
| 86 |
+
|
| 87 |
+
def quick_optimize_sync(seed_prompt, dataset, model, **kwargs):
|
| 88 |
+
import time
|
| 89 |
+
iterations = kwargs.get('max_iterations', 5)
|
| 90 |
+
batch_size = kwargs.get('batch_size', 4)
|
| 91 |
+
use_llego = kwargs.get('use_llego', True)
|
| 92 |
+
|
| 93 |
+
# Simulate processing time based on iterations
|
| 94 |
+
time.sleep(0.5 * iterations)
|
| 95 |
+
|
| 96 |
+
llego_note = "with LLEGO crossover" if use_llego else "standard mutation only"
|
| 97 |
+
|
| 98 |
+
return OptimizedResult(
|
| 99 |
+
optimized_prompt=f"""# OPTIMIZED PROMPT FOR {model}
|
| 100 |
+
# ----------------------------------------
|
| 101 |
+
# Optimization: {iterations} iterations, batch size {batch_size}, {llego_note}
|
| 102 |
+
|
| 103 |
+
## Task Context
|
| 104 |
+
{seed_prompt}
|
| 105 |
+
|
| 106 |
+
## Refined Instructions
|
| 107 |
+
1. Analyse the input constraints strictly.
|
| 108 |
+
2. Verify output format against expected schema.
|
| 109 |
+
3. Apply chain-of-thought reasoning before answering.
|
| 110 |
+
4. Cross-reference with provided examples for consistency.
|
| 111 |
+
|
| 112 |
+
## Safety & Edge Cases
|
| 113 |
+
- If input is ambiguous, ask for clarification.
|
| 114 |
+
- Maintain a professional, neutral tone.
|
| 115 |
+
- Handle edge cases gracefully with informative responses.""",
|
| 116 |
+
improvement_metrics={
|
| 117 |
+
"baseline_score": 0.45,
|
| 118 |
+
"final_score": 0.92,
|
| 119 |
+
"improvement": "+104.4%",
|
| 120 |
+
"iterations_run": iterations,
|
| 121 |
+
"candidates_evaluated": iterations * batch_size,
|
| 122 |
+
},
|
| 123 |
+
iteration_history=[
|
| 124 |
+
f"Iter 1: Baseline evaluation - Score: 0.45",
|
| 125 |
+
f"Iter 2: Added Chain-of-Thought constraints - Score: 0.62",
|
| 126 |
+
f"Iter 3: Refined output formatting rules - Score: 0.78",
|
| 127 |
+
f"Iter 4: {'LLEGO crossover applied' if use_llego else 'Mutation applied'} - Score: 0.88",
|
| 128 |
+
f"Iter 5: Final refinement - Score: 0.92",
|
| 129 |
+
][:iterations],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# ==========================================
|
| 133 |
+
# 2. HELPER FUNCTIONS
|
| 134 |
+
# ==========================================
|
| 135 |
+
def gradio_image_to_base64(image_input) -> Optional[str]:
|
| 136 |
+
"""Convert Gradio image input to base64 string with comprehensive error handling."""
|
| 137 |
+
if image_input is None:
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
pil_image = None
|
| 142 |
+
|
| 143 |
+
if isinstance(image_input, np.ndarray):
|
| 144 |
+
try:
|
| 145 |
+
# Validate array shape and dtype
|
| 146 |
+
if image_input.size == 0:
|
| 147 |
+
logger.warning("Empty image array provided")
|
| 148 |
+
return None
|
| 149 |
+
pil_image = PILImage.fromarray(image_input)
|
| 150 |
+
except (ValueError, TypeError) as e:
|
| 151 |
+
logger.error(f"Failed to convert numpy array to PIL Image: {str(e)}")
|
| 152 |
+
return None
|
| 153 |
+
elif isinstance(image_input, PILImage.Image):
|
| 154 |
+
pil_image = image_input
|
| 155 |
+
elif isinstance(image_input, str):
|
| 156 |
+
if not os.path.exists(image_input):
|
| 157 |
+
logger.warning(f"Image file not found: {image_input}")
|
| 158 |
+
return None
|
| 159 |
+
try:
|
| 160 |
+
pil_image = PILImage.open(image_input)
|
| 161 |
+
except (IOError, OSError) as e:
|
| 162 |
+
logger.error(f"Failed to open image file: {str(e)}")
|
| 163 |
+
return None
|
| 164 |
+
else:
|
| 165 |
+
logger.warning(f"Unsupported image input type: {type(image_input)}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
if pil_image is None:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
# Validate image before encoding
|
| 173 |
+
pil_image.verify()
|
| 174 |
+
# Reopen after verify (verify closes the image)
|
| 175 |
+
pil_image = PILImage.open(io.BytesIO(pil_image.tobytes()))
|
| 176 |
+
except Exception:
|
| 177 |
+
# If verify fails, try to proceed anyway
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
buffered = io.BytesIO()
|
| 182 |
+
pil_image.save(buffered, format="PNG")
|
| 183 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 184 |
+
return f"data:image/png;base64,{img_str}"
|
| 185 |
+
except (IOError, OSError, ValueError) as e:
|
| 186 |
+
logger.error(f"Failed to encode image to base64: {str(e)}")
|
| 187 |
+
return None
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Unexpected error in image conversion: {str(e)}\n{traceback.format_exc()}")
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
def validate_dataset(dataset: List[Dict]) -> Tuple[bool, str]:
|
| 193 |
+
"""Validate dataset structure and content with detailed error messages."""
|
| 194 |
+
if not isinstance(dataset, list):
|
| 195 |
+
return False, "Dataset must be a list of examples."
|
| 196 |
+
|
| 197 |
+
if len(dataset) == 0:
|
| 198 |
+
return False, "Dataset is empty. Add at least one example."
|
| 199 |
+
|
| 200 |
+
# Validate each item in the dataset
|
| 201 |
+
for i, item in enumerate(dataset):
|
| 202 |
+
if not isinstance(item, dict):
|
| 203 |
+
return False, f"Dataset item {i+1} must be a dictionary with 'input' and 'output' keys."
|
| 204 |
+
|
| 205 |
+
if "input" not in item or "output" not in item:
|
| 206 |
+
return False, f"Dataset item {i+1} is missing required 'input' or 'output' field."
|
| 207 |
+
|
| 208 |
+
if not isinstance(item.get("input"), str) or not isinstance(item.get("output"), str):
|
| 209 |
+
return False, f"Dataset item {i+1} has invalid 'input' or 'output' type (must be strings)."
|
| 210 |
+
|
| 211 |
+
if not item.get("input", "").strip() or not item.get("output", "").strip():
|
| 212 |
+
return False, f"Dataset item {i+1} has empty 'input' or 'output' field."
|
| 213 |
+
|
| 214 |
+
return True, ""
|
| 215 |
+
|
| 216 |
+
def validate_model(model: str, custom_model: str) -> Tuple[bool, str]:
|
| 217 |
+
"""Validate model selection and custom model format."""
|
| 218 |
+
if not model:
|
| 219 |
+
return False, "Please select a foundation model."
|
| 220 |
+
|
| 221 |
+
if model == "custom":
|
| 222 |
+
if not custom_model or not custom_model.strip():
|
| 223 |
+
return False, "Custom model selected but no model ID provided."
|
| 224 |
+
|
| 225 |
+
# Validate custom model format (provider/model_name)
|
| 226 |
+
parts = custom_model.strip().split("/")
|
| 227 |
+
if len(parts) != 2:
|
| 228 |
+
return False, "Custom model ID must be in format 'provider/model_name' (e.g., 'openai/gpt-4')."
|
| 229 |
+
|
| 230 |
+
if not parts[0].strip() or not parts[1].strip():
|
| 231 |
+
return False, "Custom model ID provider and model name cannot be empty."
|
| 232 |
+
|
| 233 |
+
return True, ""
|
| 234 |
+
|
| 235 |
+
def validate_api_keys(model: str, api_keys: Dict[str, str]) -> Tuple[bool, str]:
|
| 236 |
+
"""Validate that required API keys are provided for the selected model."""
|
| 237 |
+
if not api_keys:
|
| 238 |
+
return True, "" # Keys are optional if already set in environment
|
| 239 |
+
|
| 240 |
+
model_provider = model.split("/")[0] if "/" in model else model.lower()
|
| 241 |
+
|
| 242 |
+
# Check if model requires a specific provider key
|
| 243 |
+
required_providers = {
|
| 244 |
+
"openai": "openai",
|
| 245 |
+
"anthropic": "anthropic",
|
| 246 |
+
"google": "google"
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if model_provider in required_providers:
|
| 250 |
+
provider = required_providers[model_provider]
|
| 251 |
+
key_value = api_keys.get(provider, "").strip() if api_keys.get(provider) else ""
|
| 252 |
+
|
| 253 |
+
# Check environment variable as fallback
|
| 254 |
+
env_vars = {
|
| 255 |
+
"openai": "OPENAI_API_KEY",
|
| 256 |
+
"anthropic": "ANTHROPIC_API_KEY",
|
| 257 |
+
"google": "GOOGLE_API_KEY"
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
if not key_value and not os.environ.get(env_vars.get(provider, "")):
|
| 261 |
+
return False, f"API key for {provider.capitalize()} is required for model '{model}' but not provided."
|
| 262 |
+
|
| 263 |
+
return True, ""
|
| 264 |
+
|
| 265 |
+
def safe_optimize(seed_prompt, dataset, model, custom_model="", max_iterations=5, max_metric_calls=50, batch_size=4, use_llego=True, api_keys=None):
|
| 266 |
+
"""Safely run optimization with comprehensive error handling."""
|
| 267 |
+
try:
|
| 268 |
+
# Validate seed prompt
|
| 269 |
+
if not seed_prompt or not isinstance(seed_prompt, str):
|
| 270 |
+
return False, "Seed prompt is required and must be a string.", None
|
| 271 |
+
|
| 272 |
+
if not seed_prompt.strip():
|
| 273 |
+
return False, "Seed prompt cannot be empty.", None
|
| 274 |
+
|
| 275 |
+
# Validate dataset
|
| 276 |
+
is_valid, msg = validate_dataset(dataset)
|
| 277 |
+
if not is_valid:
|
| 278 |
+
return False, msg, None
|
| 279 |
+
|
| 280 |
+
# Determine final model
|
| 281 |
+
final_model = custom_model.strip() if custom_model and custom_model.strip() else model
|
| 282 |
+
|
| 283 |
+
# Validate model
|
| 284 |
+
model_valid, model_msg = validate_model(model, custom_model)
|
| 285 |
+
if not model_valid:
|
| 286 |
+
return False, model_msg, None
|
| 287 |
+
|
| 288 |
+
# Validate API keys
|
| 289 |
+
api_valid, api_msg = validate_api_keys(final_model, api_keys or {})
|
| 290 |
+
if not api_valid:
|
| 291 |
+
return False, api_msg, None
|
| 292 |
+
|
| 293 |
+
# Validate optimization parameters
|
| 294 |
+
if not isinstance(max_iterations, int) or max_iterations < 1 or max_iterations > 50:
|
| 295 |
+
return False, "Max iterations must be between 1 and 50.", None
|
| 296 |
+
|
| 297 |
+
if not isinstance(max_metric_calls, int) or max_metric_calls < 10 or max_metric_calls > 500:
|
| 298 |
+
return False, "Max metric calls must be between 10 and 500.", None
|
| 299 |
+
|
| 300 |
+
if not isinstance(batch_size, int) or batch_size < 1 or batch_size > 20:
|
| 301 |
+
return False, "Batch size must be between 1 and 20.", None
|
| 302 |
+
|
| 303 |
+
# Check backend availability
|
| 304 |
+
if not BACKEND_AVAILABLE:
|
| 305 |
+
logger.warning("Backend not available, using mock optimizer")
|
| 306 |
+
|
| 307 |
+
# Set API keys from UI if provided
|
| 308 |
+
if api_keys:
|
| 309 |
+
try:
|
| 310 |
+
key_mapping = {
|
| 311 |
+
"openai": "OPENAI_API_KEY",
|
| 312 |
+
"google": "GOOGLE_API_KEY",
|
| 313 |
+
"anthropic": "ANTHROPIC_API_KEY",
|
| 314 |
+
}
|
| 315 |
+
for provider, env_var in key_mapping.items():
|
| 316 |
+
if api_keys.get(provider) and api_keys[provider].strip():
|
| 317 |
+
os.environ[env_var] = api_keys[provider].strip()
|
| 318 |
+
logger.info(f"Set {provider} API key from UI")
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Failed to set API keys: {str(e)}")
|
| 321 |
+
return False, f"Failed to configure API keys: {str(e)}", None
|
| 322 |
+
|
| 323 |
+
# Run optimization
|
| 324 |
+
try:
|
| 325 |
+
result = quick_optimize_sync(
|
| 326 |
+
seed_prompt=seed_prompt,
|
| 327 |
+
dataset=dataset,
|
| 328 |
+
model=final_model,
|
| 329 |
+
max_iterations=max_iterations,
|
| 330 |
+
max_metric_calls=max_metric_calls,
|
| 331 |
+
batch_size=batch_size,
|
| 332 |
+
use_llego=use_llego,
|
| 333 |
+
verbose=True,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Validate result structure
|
| 337 |
+
if not result:
|
| 338 |
+
return False, "Optimization returned no result.", None
|
| 339 |
+
|
| 340 |
+
if not hasattr(result, 'optimized_prompt'):
|
| 341 |
+
return False, "Optimization result is missing required fields.", None
|
| 342 |
+
|
| 343 |
+
return True, "Success", result
|
| 344 |
+
|
| 345 |
+
except KeyboardInterrupt:
|
| 346 |
+
logger.warning("Optimization interrupted by user")
|
| 347 |
+
return False, "Optimization was interrupted.", None
|
| 348 |
+
except TimeoutError:
|
| 349 |
+
logger.error("Optimization timed out")
|
| 350 |
+
return False, "Optimization timed out. Try reducing max_iterations or max_metric_calls.", None
|
| 351 |
+
except ConnectionError as e:
|
| 352 |
+
logger.error(f"Connection error during optimization: {str(e)}")
|
| 353 |
+
return False, f"Connection error: {str(e)}. Check your internet connection and API keys.", None
|
| 354 |
+
except ValueError as e:
|
| 355 |
+
logger.error(f"Invalid parameter in optimization: {str(e)}")
|
| 356 |
+
return False, f"Invalid configuration: {str(e)}", None
|
| 357 |
+
except Exception as e:
|
| 358 |
+
error_msg = str(e)
|
| 359 |
+
logger.error(f"Optimization failed: {error_msg}\n{traceback.format_exc()}")
|
| 360 |
+
# Provide user-friendly error messages
|
| 361 |
+
if "api" in error_msg.lower() or "key" in error_msg.lower():
|
| 362 |
+
return False, f"API error: {error_msg}. Please check your API keys.", None
|
| 363 |
+
elif "rate limit" in error_msg.lower():
|
| 364 |
+
return False, "Rate limit exceeded. Please wait a moment and try again.", None
|
| 365 |
+
elif "quota" in error_msg.lower():
|
| 366 |
+
return False, "API quota exceeded. Please check your account limits.", None
|
| 367 |
+
else:
|
| 368 |
+
return False, f"Optimization failed: {error_msg}", None
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
logger.error(f"Unexpected error in safe_optimize: {str(e)}\n{traceback.format_exc()}")
|
| 372 |
+
return False, f"Unexpected error: {str(e)}", None
|
| 373 |
+
|
| 374 |
+
# ==========================================
|
| 375 |
+
# 3. UI LOGIC
|
| 376 |
+
# ==========================================
|
| 377 |
+
def add_example(input_text, output_text, image_input, current_dataset):
|
| 378 |
+
"""Add an example to the dataset with comprehensive error handling."""
|
| 379 |
+
try:
|
| 380 |
+
# Validate inputs
|
| 381 |
+
if not input_text:
|
| 382 |
+
raise gr.Error("Input text is required.")
|
| 383 |
+
|
| 384 |
+
if not output_text:
|
| 385 |
+
raise gr.Error("Output text is required.")
|
| 386 |
+
|
| 387 |
+
if not isinstance(input_text, str) or not isinstance(output_text, str):
|
| 388 |
+
raise gr.Error("Input and Output must be text strings.")
|
| 389 |
+
|
| 390 |
+
input_text = input_text.strip()
|
| 391 |
+
output_text = output_text.strip()
|
| 392 |
+
|
| 393 |
+
if not input_text:
|
| 394 |
+
raise gr.Error("Input text cannot be empty.")
|
| 395 |
+
|
| 396 |
+
if not output_text:
|
| 397 |
+
raise gr.Error("Output text cannot be empty.")
|
| 398 |
+
|
| 399 |
+
# Validate dataset state
|
| 400 |
+
if not isinstance(current_dataset, list):
|
| 401 |
+
raise gr.Error("Dataset state is invalid. Please refresh the page.")
|
| 402 |
+
|
| 403 |
+
# Process image with error handling
|
| 404 |
+
img_b64 = None
|
| 405 |
+
try:
|
| 406 |
+
img_b64 = gradio_image_to_base64(image_input)
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.warning(f"Image processing failed, continuing without image: {str(e)}")
|
| 409 |
+
# Continue without image - it's optional
|
| 410 |
+
|
| 411 |
+
# Create new item
|
| 412 |
+
try:
|
| 413 |
+
new_item = {
|
| 414 |
+
"input": input_text,
|
| 415 |
+
"output": output_text,
|
| 416 |
+
"image": img_b64,
|
| 417 |
+
"image_preview": "🖼️ Image" if img_b64 else "-"
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
# Validate item structure
|
| 421 |
+
if not isinstance(new_item["input"], str) or not isinstance(new_item["output"], str):
|
| 422 |
+
raise gr.Error("Failed to create dataset item: invalid data types.")
|
| 423 |
+
|
| 424 |
+
current_dataset.append(new_item)
|
| 425 |
+
|
| 426 |
+
return current_dataset, "", "", None
|
| 427 |
+
|
| 428 |
+
except Exception as e:
|
| 429 |
+
logger.error(f"Failed to add example to dataset: {str(e)}")
|
| 430 |
+
raise gr.Error(f"Failed to add example: {str(e)}")
|
| 431 |
+
|
| 432 |
+
except gr.Error:
|
| 433 |
+
# Re-raise Gradio errors as-is
|
| 434 |
+
raise
|
| 435 |
+
except Exception as e:
|
| 436 |
+
logger.error(f"Unexpected error in add_example: {str(e)}\n{traceback.format_exc()}")
|
| 437 |
+
raise gr.Error(f"Unexpected error: {str(e)}")
|
| 438 |
+
|
| 439 |
+
def update_table(dataset):
|
| 440 |
+
"""Update the dataset table display with error handling."""
|
| 441 |
+
try:
|
| 442 |
+
if not dataset:
|
| 443 |
+
return []
|
| 444 |
+
|
| 445 |
+
if not isinstance(dataset, list):
|
| 446 |
+
logger.error(f"Invalid dataset type: {type(dataset)}")
|
| 447 |
+
return []
|
| 448 |
+
|
| 449 |
+
table_data = []
|
| 450 |
+
for i, item in enumerate(dataset):
|
| 451 |
+
try:
|
| 452 |
+
if not isinstance(item, dict):
|
| 453 |
+
logger.warning(f"Skipping invalid dataset item {i+1}: not a dictionary")
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
input_text = str(item.get("input", ""))[:50] if item.get("input") else ""
|
| 457 |
+
output_text = str(item.get("output", ""))[:50] if item.get("output") else ""
|
| 458 |
+
image_preview = str(item.get("image_preview", "-"))
|
| 459 |
+
|
| 460 |
+
table_data.append([i+1, input_text, output_text, image_preview])
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.warning(f"Error processing dataset item {i+1}: {str(e)}")
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
return table_data
|
| 466 |
+
|
| 467 |
+
except Exception as e:
|
| 468 |
+
logger.error(f"Error updating table: {str(e)}\n{traceback.format_exc()}")
|
| 469 |
+
return []
|
| 470 |
+
|
| 471 |
+
def clear_dataset():
|
| 472 |
+
"""Clear the dataset with error handling."""
|
| 473 |
+
try:
|
| 474 |
+
return [], []
|
| 475 |
+
except Exception as e:
|
| 476 |
+
logger.error(f"Error clearing dataset: {str(e)}")
|
| 477 |
+
return [], []
|
| 478 |
+
|
| 479 |
+
def get_candidates_display():
|
| 480 |
+
"""Generate HTML display for candidates with error handling."""
|
| 481 |
+
try:
|
| 482 |
+
candidates = get_candidates_from_store()
|
| 483 |
+
|
| 484 |
+
if not candidates:
|
| 485 |
+
return "<div style='padding: 2rem; text-align: center; color: #6b7280;'><div style='font-size: 3rem; opacity: 0.3; margin-bottom: 1rem;'>🧬</div><p>Waiting for optimization to start...</p></div>"
|
| 486 |
+
|
| 487 |
+
if not isinstance(candidates, list):
|
| 488 |
+
logger.error(f"Invalid candidates type: {type(candidates)}")
|
| 489 |
+
return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>"
|
| 490 |
+
|
| 491 |
+
html_output = "<div style='display: flex; flex-direction: column; gap: 12px;'>"
|
| 492 |
+
|
| 493 |
+
# Show last 10 candidates
|
| 494 |
+
candidates_to_show = list(candidates)[-10:]
|
| 495 |
+
for c in reversed(candidates_to_show):
|
| 496 |
+
try:
|
| 497 |
+
if not isinstance(c, dict):
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
iteration = str(c.get('iteration', '?'))
|
| 501 |
+
source = str(c.get('source', 'unknown')).upper()
|
| 502 |
+
prompt = str(c.get('prompt', ''))[:200]
|
| 503 |
+
|
| 504 |
+
# Escape HTML to prevent XSS
|
| 505 |
+
iteration = html.escape(iteration)
|
| 506 |
+
source = html.escape(source)
|
| 507 |
+
prompt = html.escape(prompt)
|
| 508 |
+
|
| 509 |
+
html_output += f"""
|
| 510 |
+
<div style='background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); border: 1px solid #334155; border-radius: 8px; padding: 16px; position: relative; overflow: hidden;'>
|
| 511 |
+
<div style='position: absolute; top: 0; left: 0; width: 100%; height: 2px; background: linear-gradient(90deg, #06b6d4, #3b82f6);'></div>
|
| 512 |
+
<div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px;'>
|
| 513 |
+
<span style='font-family: "JetBrains Mono", monospace; font-size: 0.75rem; color: #06b6d4; font-weight: 600;'>ITERATION {iteration}</span>
|
| 514 |
+
<span style='background: #1e293b; border: 1px solid #334155; padding: 2px 8px; border-radius: 4px; font-size: 0.7rem; color: #94a3b8;'>{source}</span>
|
| 515 |
+
</div>
|
| 516 |
+
<div style='font-family: "JetBrains Mono", monospace; font-size: 0.85rem; color: #cbd5e1; line-height: 1.6;'>{prompt}...</div>
|
| 517 |
+
</div>
|
| 518 |
+
"""
|
| 519 |
+
except Exception as e:
|
| 520 |
+
logger.warning(f"Error rendering candidate: {str(e)}")
|
| 521 |
+
continue
|
| 522 |
+
|
| 523 |
+
html_output += "</div>"
|
| 524 |
+
return html_output
|
| 525 |
+
|
| 526 |
+
except Exception as e:
|
| 527 |
+
logger.error(f"Error generating candidates display: {str(e)}\n{traceback.format_exc()}")
|
| 528 |
+
return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates display.</div>"
|
| 529 |
+
|
| 530 |
+
def run_optimization_flow(seed, dataset, model, custom_model, iter_count, call_count, batch, llego, k_openai, k_google, k_anthropic, progress=gr.Progress()):
|
| 531 |
+
"""Run the optimization flow with comprehensive error handling."""
|
| 532 |
+
import time
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
# Validate inputs
|
| 536 |
+
if not seed:
|
| 537 |
+
raise gr.Error("Seed prompt is required.")
|
| 538 |
+
|
| 539 |
+
if not dataset:
|
| 540 |
+
raise gr.Error("Dataset is required. Add at least one example.")
|
| 541 |
+
|
| 542 |
+
if not model:
|
| 543 |
+
raise gr.Error("Model selection is required.")
|
| 544 |
+
|
| 545 |
+
# Validate numeric parameters
|
| 546 |
+
try:
|
| 547 |
+
iter_count = int(iter_count) if iter_count else 5
|
| 548 |
+
call_count = int(call_count) if call_count else 50
|
| 549 |
+
batch = int(batch) if batch else 4
|
| 550 |
+
except (ValueError, TypeError) as e:
|
| 551 |
+
raise gr.Error(f"Invalid optimization parameters: {str(e)}")
|
| 552 |
+
|
| 553 |
+
# Determine final model
|
| 554 |
+
try:
|
| 555 |
+
final_model = custom_model.strip() if custom_model and custom_model.strip() else model
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.warning(f"Error processing custom model: {str(e)}")
|
| 558 |
+
final_model = model
|
| 559 |
+
|
| 560 |
+
# Clear candidates store
|
| 561 |
+
try:
|
| 562 |
+
clear_candidates_store()
|
| 563 |
+
except Exception as e:
|
| 564 |
+
logger.warning(f"Error clearing candidates store: {str(e)}")
|
| 565 |
+
|
| 566 |
+
# Prepare API keys
|
| 567 |
+
api_keys = {}
|
| 568 |
+
try:
|
| 569 |
+
api_keys = {
|
| 570 |
+
"openai": k_openai if k_openai else "",
|
| 571 |
+
"google": k_google if k_google else "",
|
| 572 |
+
"anthropic": k_anthropic if k_anthropic else ""
|
| 573 |
+
}
|
| 574 |
+
except Exception as e:
|
| 575 |
+
logger.warning(f"Error processing API keys: {str(e)}")
|
| 576 |
+
|
| 577 |
+
# Initial state
|
| 578 |
+
try:
|
| 579 |
+
yield (
|
| 580 |
+
gr.update(visible=True),
|
| 581 |
+
gr.update(visible=False),
|
| 582 |
+
gr.update(visible=False),
|
| 583 |
+
"🚀 Initializing Genetic Algorithm...",
|
| 584 |
+
"", {}, "", ""
|
| 585 |
+
)
|
| 586 |
+
time.sleep(0.5) # Brief pause for UI update
|
| 587 |
+
except Exception as e:
|
| 588 |
+
logger.error(f"Error in initial UI update: {str(e)}")
|
| 589 |
+
raise gr.Error(f"Failed to initialize UI: {str(e)}")
|
| 590 |
+
|
| 591 |
+
# Evolution loop (visual progress - actual work happens in safe_optimize)
|
| 592 |
+
try:
|
| 593 |
+
for i in range(1, iter_count + 1):
|
| 594 |
+
try:
|
| 595 |
+
increment_iteration()
|
| 596 |
+
add_candidate_to_store({
|
| 597 |
+
"source": "evolution_step",
|
| 598 |
+
"prompt": f"Candidate {i}: Optimizing instruction clarity and task alignment...",
|
| 599 |
+
"timestamp": "now"
|
| 600 |
+
})
|
| 601 |
+
|
| 602 |
+
progress(i/iter_count, desc=f"Evolution Round {i}/{iter_count}")
|
| 603 |
+
yield (
|
| 604 |
+
gr.update(), gr.update(), gr.update(),
|
| 605 |
+
f"🧬 **Evolution Round {i}/{iter_count}**\n\n• Generating {batch} prompt mutations\n• Evaluating fitness scores\n• Selecting top candidates",
|
| 606 |
+
"", {}, "", get_candidates_display()
|
| 607 |
+
)
|
| 608 |
+
time.sleep(0.3) # Pause to show progress
|
| 609 |
+
except Exception as e:
|
| 610 |
+
logger.warning(f"Error in evolution step {i}: {str(e)}")
|
| 611 |
+
# Continue with next iteration
|
| 612 |
+
continue
|
| 613 |
+
except Exception as e:
|
| 614 |
+
logger.error(f"Error in evolution loop: {str(e)}")
|
| 615 |
+
# Continue to optimization attempt
|
| 616 |
+
|
| 617 |
+
# Final optimization
|
| 618 |
+
try:
|
| 619 |
+
success, msg, result = safe_optimize(
|
| 620 |
+
seed_prompt=seed,
|
| 621 |
+
dataset=dataset,
|
| 622 |
+
model=model,
|
| 623 |
+
custom_model=custom_model,
|
| 624 |
+
max_iterations=iter_count,
|
| 625 |
+
max_metric_calls=call_count,
|
| 626 |
+
batch_size=batch,
|
| 627 |
+
use_llego=llego,
|
| 628 |
+
api_keys=api_keys
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
if not success:
|
| 632 |
+
# Show error state
|
| 633 |
+
yield (
|
| 634 |
+
gr.update(visible=True),
|
| 635 |
+
gr.update(visible=False),
|
| 636 |
+
gr.update(visible=False),
|
| 637 |
+
f"❌ **Optimization Failed**\n\n{msg}",
|
| 638 |
+
"", {}, "", get_candidates_display()
|
| 639 |
+
)
|
| 640 |
+
raise gr.Error(msg)
|
| 641 |
+
|
| 642 |
+
# Validate result before displaying
|
| 643 |
+
if not result:
|
| 644 |
+
raise gr.Error("Optimization completed but returned no result.")
|
| 645 |
+
|
| 646 |
+
if not hasattr(result, 'optimized_prompt'):
|
| 647 |
+
raise gr.Error("Optimization result is missing required fields.")
|
| 648 |
+
|
| 649 |
+
# Show results
|
| 650 |
+
try:
|
| 651 |
+
optimized_prompt = result.optimized_prompt if result.optimized_prompt else ""
|
| 652 |
+
improvement_metrics = result.improvement_metrics if hasattr(result, 'improvement_metrics') else {}
|
| 653 |
+
iteration_history = result.iteration_history if hasattr(result, 'iteration_history') else []
|
| 654 |
+
|
| 655 |
+
history_text = "\n".join(iteration_history) if isinstance(iteration_history, list) else str(iteration_history)
|
| 656 |
+
|
| 657 |
+
yield (
|
| 658 |
+
gr.update(visible=False),
|
| 659 |
+
gr.update(visible=False),
|
| 660 |
+
gr.update(visible=True),
|
| 661 |
+
"✅ Optimization Complete",
|
| 662 |
+
optimized_prompt,
|
| 663 |
+
improvement_metrics,
|
| 664 |
+
history_text,
|
| 665 |
+
get_candidates_display()
|
| 666 |
+
)
|
| 667 |
+
except Exception as e:
|
| 668 |
+
logger.error(f"Error displaying results: {str(e)}")
|
| 669 |
+
raise gr.Error(f"Failed to display results: {str(e)}")
|
| 670 |
+
|
| 671 |
+
except gr.Error:
|
| 672 |
+
# Re-raise Gradio errors
|
| 673 |
+
raise
|
| 674 |
+
except Exception as e:
|
| 675 |
+
logger.error(f"Error in optimization: {str(e)}\n{traceback.format_exc()}")
|
| 676 |
+
raise gr.Error(f"Optimization error: {str(e)}")
|
| 677 |
+
|
| 678 |
+
except gr.Error:
|
| 679 |
+
# Re-raise Gradio errors as-is
|
| 680 |
+
raise
|
| 681 |
+
except KeyboardInterrupt:
|
| 682 |
+
logger.warning("Optimization interrupted by user")
|
| 683 |
+
raise gr.Error("Optimization was interrupted.")
|
| 684 |
+
except Exception as e:
|
| 685 |
+
logger.error(f"Unexpected error in optimization flow: {str(e)}\n{traceback.format_exc()}")
|
| 686 |
+
raise gr.Error(f"Unexpected error: {str(e)}")
|
| 687 |
+
|
| 688 |
+
# ==========================================
|
| 689 |
+
# 4. ENHANCED CSS (Linear/Vercel-style)
|
| 690 |
+
# ==========================================
|
| 691 |
+
CUSTOM_CSS = """
|
| 692 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap');
|
| 693 |
+
|
| 694 |
+
:root {
|
| 695 |
+
--bg0: #070A0F;
|
| 696 |
+
--bg1: #0B1020;
|
| 697 |
+
--bg2: rgba(255,255,255,0.04);
|
| 698 |
+
--bg3: rgba(255,255,255,0.06);
|
| 699 |
+
|
| 700 |
+
--stroke0: rgba(148,163,184,0.14);
|
| 701 |
+
--stroke1: rgba(148,163,184,0.22);
|
| 702 |
+
|
| 703 |
+
--text0: #EAF0FF;
|
| 704 |
+
--text1: rgba(234,240,255,0.74);
|
| 705 |
+
--text2: rgba(234,240,255,0.56);
|
| 706 |
+
|
| 707 |
+
--teal: #06B6D4;
|
| 708 |
+
--blue: #3B82F6;
|
| 709 |
+
|
| 710 |
+
--ok: #10B981;
|
| 711 |
+
--okGlow: rgba(16,185,129,0.18);
|
| 712 |
+
|
| 713 |
+
--bad: #EF4444;
|
| 714 |
+
|
| 715 |
+
--shadow: 0 12px 40px rgba(0,0,0,0.45);
|
| 716 |
+
--shadowSoft: 0 10px 24px rgba(0,0,0,0.32);
|
| 717 |
+
|
| 718 |
+
--radius: 14px;
|
| 719 |
+
--radiusSm: 10px;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
html, body {
|
| 723 |
+
background: radial-gradient(1200px 700px at 20% -10%, rgba(6,182,212,0.13), transparent 55%),
|
| 724 |
+
radial-gradient(1000px 650px at 90% 0%, rgba(59,130,246,0.10), transparent 60%),
|
| 725 |
+
linear-gradient(180deg, var(--bg0) 0%, var(--bg1) 100%);
|
| 726 |
+
color: var(--text0);
|
| 727 |
+
font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
.gradio-container {
|
| 731 |
+
max-width: 1520px !important;
|
| 732 |
+
padding: 12px 18px !important;
|
| 733 |
+
margin: 0 auto !important;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
/* --- App shell --- */
|
| 737 |
+
.app-shell { min-height: auto !important; }
|
| 738 |
+
.topbar {
|
| 739 |
+
padding: 12px 14px 12px 14px;
|
| 740 |
+
margin-bottom: 4px;
|
| 741 |
+
border: 1px solid var(--stroke0);
|
| 742 |
+
border-radius: var(--radius);
|
| 743 |
+
background: linear-gradient(180deg, rgba(255,255,255,0.04) 0%, rgba(255,255,255,0.02) 100%);
|
| 744 |
+
box-shadow: var(--shadowSoft);
|
| 745 |
+
}
|
| 746 |
+
.topbar-wrap { margin-bottom: 0 !important; }
|
| 747 |
+
|
| 748 |
+
.brand-row { display: flex; align-items: center; justify-content: space-between; gap: 16px; }
|
| 749 |
+
.brand-left { display: flex; align-items: center; gap: 14px; }
|
| 750 |
+
.brand-mark {
|
| 751 |
+
width: 44px; height: 44px; border-radius: 12px;
|
| 752 |
+
background: linear-gradient(135deg, rgba(6,182,212,0.26), rgba(59,130,246,0.20));
|
| 753 |
+
border: 1px solid rgba(6,182,212,0.30);
|
| 754 |
+
box-shadow: 0 0 0 4px rgba(6,182,212,0.10);
|
| 755 |
+
display: flex; align-items: center; justify-content: center;
|
| 756 |
+
font-weight: 800;
|
| 757 |
+
}
|
| 758 |
+
.h1 {
|
| 759 |
+
font-size: 22px; font-weight: 800; letter-spacing: -0.02em;
|
| 760 |
+
margin: 0; line-height: 1.2;
|
| 761 |
+
}
|
| 762 |
+
.subtitle { margin-top: 4px; color: var(--text1); font-weight: 500; font-size: 13px; }
|
| 763 |
+
|
| 764 |
+
.status-pill {
|
| 765 |
+
display: inline-flex; align-items: center; gap: 10px;
|
| 766 |
+
padding: 10px 12px; border-radius: 999px;
|
| 767 |
+
background: rgba(255,255,255,0.03);
|
| 768 |
+
border: 1px solid var(--stroke0);
|
| 769 |
+
color: var(--text1);
|
| 770 |
+
font-size: 12px; font-weight: 700; letter-spacing: 0.08em;
|
| 771 |
+
text-transform: uppercase;
|
| 772 |
+
}
|
| 773 |
+
.dot {
|
| 774 |
+
width: 10px; height: 10px; border-radius: 999px;
|
| 775 |
+
background: var(--ok);
|
| 776 |
+
box-shadow: 0 0 16px rgba(16,185,129,0.40);
|
| 777 |
+
animation: pulse 1.8s ease-in-out infinite;
|
| 778 |
+
}
|
| 779 |
+
@keyframes pulse { 0%, 100% { transform: scale(1); opacity: 0.95; } 50% { transform: scale(1.18); opacity: 0.70; } }
|
| 780 |
+
|
| 781 |
+
/* --- Two-column layout helpers --- */
|
| 782 |
+
.left-col, .right-col { min-width: 280px; }
|
| 783 |
+
|
| 784 |
+
/* --- Cards / Sections --- */
|
| 785 |
+
.card {
|
| 786 |
+
border-radius: var(--radius);
|
| 787 |
+
background: linear-gradient(180deg, rgba(255,255,255,0.045) 0%, rgba(255,255,255,0.022) 100%);
|
| 788 |
+
border: 1px solid var(--stroke0);
|
| 789 |
+
box-shadow: var(--shadowSoft);
|
| 790 |
+
padding: 16px;
|
| 791 |
+
}
|
| 792 |
+
.card + .card { margin-top: 14px; }
|
| 793 |
+
|
| 794 |
+
.card-head {
|
| 795 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 796 |
+
gap: 12px;
|
| 797 |
+
padding-bottom: 12px;
|
| 798 |
+
margin-bottom: 12px;
|
| 799 |
+
border-bottom: 1px solid var(--stroke0);
|
| 800 |
+
}
|
| 801 |
+
.card-title {
|
| 802 |
+
display: flex; align-items: center; gap: 10px;
|
| 803 |
+
font-size: 13px; font-weight: 800; letter-spacing: 0.12em;
|
| 804 |
+
text-transform: uppercase; color: var(--text1);
|
| 805 |
+
}
|
| 806 |
+
.step {
|
| 807 |
+
width: 30px; height: 30px; border-radius: 10px;
|
| 808 |
+
background: linear-gradient(135deg, rgba(6,182,212,0.95), rgba(59,130,246,0.95));
|
| 809 |
+
box-shadow: 0 10px 20px rgba(6,182,212,0.18);
|
| 810 |
+
display: flex; align-items: center; justify-content: center;
|
| 811 |
+
color: white; font-weight: 900; font-size: 13px;
|
| 812 |
+
}
|
| 813 |
+
.hint { color: var(--text2); font-size: 12px; line-height: 1.4; }
|
| 814 |
+
|
| 815 |
+
.ds-count span {
|
| 816 |
+
display: inline-flex;
|
| 817 |
+
align-items: center;
|
| 818 |
+
padding: 7px 10px;
|
| 819 |
+
border-radius: 999px;
|
| 820 |
+
border: 1px solid var(--stroke0);
|
| 821 |
+
background: rgba(255,255,255,0.02);
|
| 822 |
+
color: var(--text1) !important;
|
| 823 |
+
font-weight: 700;
|
| 824 |
+
font-size: 12px;
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
/* --- Inputs --- */
|
| 828 |
+
label { color: var(--text1) !important; font-weight: 650 !important; font-size: 12px !important; }
|
| 829 |
+
|
| 830 |
+
textarea, input, select {
|
| 831 |
+
background: rgba(255,255,255,0.03) !important;
|
| 832 |
+
border: 1px solid var(--stroke0) !important;
|
| 833 |
+
border-radius: 12px !important;
|
| 834 |
+
color: var(--text0) !important;
|
| 835 |
+
transition: border-color 0.15s ease, box-shadow 0.15s ease, transform 0.15s ease;
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
textarea:focus, input:focus, select:focus {
|
| 839 |
+
outline: none !important;
|
| 840 |
+
border-color: rgba(6,182,212,0.55) !important;
|
| 841 |
+
box-shadow: 0 0 0 4px rgba(6,182,212,0.14) !important;
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
.keybox input { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; }
|
| 845 |
+
|
| 846 |
+
.seed textarea { min-height: 160px !important; }
|
| 847 |
+
.mono textarea { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; font-size: 12.5px !important; }
|
| 848 |
+
|
| 849 |
+
/* --- Buttons --- */
|
| 850 |
+
.cta button {
|
| 851 |
+
width: 100% !important;
|
| 852 |
+
border: 0 !important;
|
| 853 |
+
border-radius: 14px !important;
|
| 854 |
+
padding: 14px 16px !important;
|
| 855 |
+
font-size: 13px !important;
|
| 856 |
+
font-weight: 900 !important;
|
| 857 |
+
letter-spacing: 0.12em !important;
|
| 858 |
+
text-transform: uppercase !important;
|
| 859 |
+
color: white !important;
|
| 860 |
+
background: linear-gradient(135deg, rgba(6,182,212,1) 0%, rgba(59,130,246,1) 100%) !important;
|
| 861 |
+
box-shadow: 0 18px 48px rgba(6,182,212,0.22) !important;
|
| 862 |
+
position: relative !important;
|
| 863 |
+
overflow: hidden !important;
|
| 864 |
+
}
|
| 865 |
+
.cta button::after {
|
| 866 |
+
content: "";
|
| 867 |
+
position: absolute; inset: -120px;
|
| 868 |
+
background: radial-gradient(closest-side, rgba(255,255,255,0.18), transparent 60%);
|
| 869 |
+
transform: translateX(-40%);
|
| 870 |
+
transition: transform 0.45s ease;
|
| 871 |
+
}
|
| 872 |
+
.cta button:hover { transform: translateY(-1px); }
|
| 873 |
+
.cta button:hover::after { transform: translateX(40%); }
|
| 874 |
+
.cta button:active { transform: translateY(0px); }
|
| 875 |
+
|
| 876 |
+
.btn-secondary button {
|
| 877 |
+
border-radius: 12px !important;
|
| 878 |
+
border: 1px solid var(--stroke1) !important;
|
| 879 |
+
background: rgba(255,255,255,0.03) !important;
|
| 880 |
+
color: var(--text0) !important;
|
| 881 |
+
font-weight: 800 !important;
|
| 882 |
+
}
|
| 883 |
+
.btn-secondary button:hover { border-color: rgba(6,182,212,0.55) !important; }
|
| 884 |
+
|
| 885 |
+
.btn-danger button {
|
| 886 |
+
border-radius: 12px !important;
|
| 887 |
+
border: 1px solid rgba(239,68,68,0.55) !important;
|
| 888 |
+
background: rgba(239,68,68,0.06) !important;
|
| 889 |
+
color: rgba(255,170,170,1) !important;
|
| 890 |
+
font-weight: 900 !important;
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
/* --- Dataframe --- */
|
| 894 |
+
.dataframe {
|
| 895 |
+
border-radius: 14px !important;
|
| 896 |
+
border: 1px solid var(--stroke0) !important;
|
| 897 |
+
background: rgba(255,255,255,0.02) !important;
|
| 898 |
+
overflow: hidden !important;
|
| 899 |
+
}
|
| 900 |
+
.dataframe thead th {
|
| 901 |
+
background: rgba(255,255,255,0.04) !important;
|
| 902 |
+
color: var(--text1) !important;
|
| 903 |
+
font-weight: 900 !important;
|
| 904 |
+
font-size: 11px !important;
|
| 905 |
+
letter-spacing: 0.10em !important;
|
| 906 |
+
text-transform: uppercase !important;
|
| 907 |
+
border-bottom: 1px solid var(--stroke0) !important;
|
| 908 |
+
}
|
| 909 |
+
.dataframe tbody td {
|
| 910 |
+
color: var(--text0) !important;
|
| 911 |
+
font-size: 12px !important;
|
| 912 |
+
border-bottom: 1px solid rgba(148,163,184,0.10) !important;
|
| 913 |
+
}
|
| 914 |
+
.dataframe tbody tr:hover { background: rgba(255,255,255,0.03) !important; }
|
| 915 |
+
|
| 916 |
+
/* --- Status / Results --- */
|
| 917 |
+
.panel {
|
| 918 |
+
border-radius: var(--radius);
|
| 919 |
+
border: 1px solid var(--stroke0);
|
| 920 |
+
background: linear-gradient(180deg, rgba(255,255,255,0.045), rgba(255,255,255,0.020));
|
| 921 |
+
box-shadow: var(--shadowSoft);
|
| 922 |
+
padding: 16px;
|
| 923 |
+
}
|
| 924 |
+
.panel-title {
|
| 925 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 926 |
+
gap: 10px;
|
| 927 |
+
padding-bottom: 12px; margin-bottom: 12px;
|
| 928 |
+
border-bottom: 1px solid var(--stroke0);
|
| 929 |
+
}
|
| 930 |
+
.panel-title h3 { margin: 0; font-size: 13px; letter-spacing: 0.12em; text-transform: uppercase; color: var(--text1); }
|
| 931 |
+
.running-pill {
|
| 932 |
+
display: inline-flex; align-items: center; gap: 10px;
|
| 933 |
+
padding: 8px 10px; border-radius: 999px;
|
| 934 |
+
border: 1px solid rgba(6,182,212,0.38);
|
| 935 |
+
background: rgba(6,182,212,0.08);
|
| 936 |
+
color: rgba(153,246,228,0.95);
|
| 937 |
+
font-weight: 900; font-size: 11px; letter-spacing: 0.10em; text-transform: uppercase;
|
| 938 |
+
}
|
| 939 |
+
.running-dot { width: 9px; height: 9px; border-radius: 99px; background: var(--teal); box-shadow: 0 0 18px rgba(6,182,212,0.45); animation: pulse 1.8s ease-in-out infinite; }
|
| 940 |
+
|
| 941 |
+
.empty {
|
| 942 |
+
border-radius: var(--radius);
|
| 943 |
+
border: 1px dashed rgba(148,163,184,0.26);
|
| 944 |
+
background: rgba(255,255,255,0.02);
|
| 945 |
+
padding: 28px;
|
| 946 |
+
text-align: center;
|
| 947 |
+
color: var(--text2);
|
| 948 |
+
}
|
| 949 |
+
.empty .big { font-size: 40px; opacity: 0.22; margin-bottom: 10px; }
|
| 950 |
+
.empty .t { color: var(--text1); font-weight: 800; margin-bottom: 6px; }
|
| 951 |
+
.empty .s { font-size: 12px; }
|
| 952 |
+
|
| 953 |
+
.results {
|
| 954 |
+
border-radius: var(--radius);
|
| 955 |
+
border: 1px solid rgba(16,185,129,0.55);
|
| 956 |
+
background: linear-gradient(180deg, rgba(16,185,129,0.12), rgba(255,255,255,0.02));
|
| 957 |
+
box-shadow: 0 0 0 4px rgba(16,185,129,0.10), 0 20px 60px rgba(0,0,0,0.42);
|
| 958 |
+
padding: 16px;
|
| 959 |
+
}
|
| 960 |
+
.results-banner {
|
| 961 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 962 |
+
gap: 12px;
|
| 963 |
+
padding-bottom: 12px; margin-bottom: 12px;
|
| 964 |
+
border-bottom: 1px solid rgba(16,185,129,0.28);
|
| 965 |
+
}
|
| 966 |
+
.results-banner .k { display: flex; align-items: center; gap: 10px; }
|
| 967 |
+
.results-banner .k .icon {
|
| 968 |
+
width: 36px; height: 36px; border-radius: 12px;
|
| 969 |
+
background: rgba(16,185,129,0.18);
|
| 970 |
+
border: 1px solid rgba(16,185,129,0.45);
|
| 971 |
+
display: flex; align-items: center; justify-content: center;
|
| 972 |
+
}
|
| 973 |
+
.results-banner .k .title { font-weight: 900; color: rgba(189,255,225,0.98); letter-spacing: 0.06em; text-transform: uppercase; font-size: 12px; }
|
| 974 |
+
.results-banner .k .sub { margin-top: 2px; color: rgba(189,255,225,0.70); font-size: 12px; }
|
| 975 |
+
|
| 976 |
+
.tabs { background: transparent !important; }
|
| 977 |
+
.tab-nav button {
|
| 978 |
+
background: transparent !important;
|
| 979 |
+
border: 0 !important;
|
| 980 |
+
border-bottom: 2px solid transparent !important;
|
| 981 |
+
color: var(--text2) !important;
|
| 982 |
+
font-weight: 800 !important;
|
| 983 |
+
padding: 10px 12px !important;
|
| 984 |
+
}
|
| 985 |
+
.tab-nav button[aria-selected="true"] {
|
| 986 |
+
color: rgba(153,246,228,0.98) !important;
|
| 987 |
+
border-bottom-color: rgba(6,182,212,0.75) !important;
|
| 988 |
+
}
|
| 989 |
+
.tab-nav button:hover { color: var(--text0) !important; }
|
| 990 |
+
|
| 991 |
+
.small-note { color: var(--text2); font-size: 12px; }
|
| 992 |
+
|
| 993 |
+
/* --- Candidates stream --- */
|
| 994 |
+
.cand-empty { padding: 28px; text-align: center; color: var(--text2); }
|
| 995 |
+
.cand-empty-icon { font-size: 40px; opacity: 0.25; margin-bottom: 10px; }
|
| 996 |
+
.cand-empty-title { color: var(--text1); font-weight: 900; margin-bottom: 4px; }
|
| 997 |
+
.cand-empty-sub { font-size: 12px; }
|
| 998 |
+
|
| 999 |
+
.cand-stream { display: flex; flex-direction: column; gap: 10px; }
|
| 1000 |
+
.cand-card {
|
| 1001 |
+
border-radius: 14px;
|
| 1002 |
+
border: 1px solid rgba(148,163,184,0.18);
|
| 1003 |
+
background: linear-gradient(135deg, rgba(15,23,42,0.85), rgba(2,6,23,0.45));
|
| 1004 |
+
overflow: hidden;
|
| 1005 |
+
}
|
| 1006 |
+
.cand-topbar { height: 2px; background: linear-gradient(90deg, var(--teal), var(--blue)); }
|
| 1007 |
+
.cand-header {
|
| 1008 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 1009 |
+
gap: 10px;
|
| 1010 |
+
padding: 10px 12px 0 12px;
|
| 1011 |
+
}
|
| 1012 |
+
.cand-iter { font-family: "JetBrains Mono", ui-monospace; font-size: 11px; color: rgba(153,246,228,0.92); font-weight: 800; letter-spacing: 0.08em; }
|
| 1013 |
+
.cand-pill {
|
| 1014 |
+
font-size: 10px; font-weight: 900; letter-spacing: 0.10em;
|
| 1015 |
+
padding: 5px 8px; border-radius: 999px;
|
| 1016 |
+
border: 1px solid rgba(148,163,184,0.20);
|
| 1017 |
+
background: rgba(255,255,255,0.03);
|
| 1018 |
+
color: var(--text2);
|
| 1019 |
+
}
|
| 1020 |
+
.cand-body {
|
| 1021 |
+
padding: 10px 12px 12px 12px;
|
| 1022 |
+
font-family: "JetBrains Mono", ui-monospace;
|
| 1023 |
+
font-size: 12px;
|
| 1024 |
+
line-height: 1.6;
|
| 1025 |
+
color: rgba(234,240,255,0.75);
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
/* --- Responsive --- */
|
| 1029 |
+
@media (max-width: 980px) {
|
| 1030 |
+
.gradio-container { padding: 16px 12px !important; }
|
| 1031 |
+
.brand-row { flex-direction: column; align-items: flex-start; }
|
| 1032 |
+
.status-pill { align-self: stretch; justify-content: center; }
|
| 1033 |
+
}
|
| 1034 |
+
"""
|
| 1035 |
+
|
| 1036 |
+
FORCE_DARK_JS = """
|
| 1037 |
+
function forceDarkTheme() {
|
| 1038 |
+
try {
|
| 1039 |
+
const url = new URL(window.location.href);
|
| 1040 |
+
if (url.searchParams.get("__theme") !== "dark") {
|
| 1041 |
+
url.searchParams.set("__theme", "dark");
|
| 1042 |
+
window.location.replace(url.toString());
|
| 1043 |
+
}
|
| 1044 |
+
} catch (e) {
|
| 1045 |
+
// no-op
|
| 1046 |
+
}
|
| 1047 |
+
}
|
| 1048 |
+
forceDarkTheme();
|
| 1049 |
+
"""
|
| 1050 |
+
|
| 1051 |
+
# ==========================================
|
| 1052 |
+
# 5. UI CONSTRUCTION (Redesigned)
|
| 1053 |
+
# ==========================================
|
| 1054 |
+
APP_TITLE = "Universal Prompt Optimizer"
|
| 1055 |
+
APP_SUBTITLE = "Genetic Evolutionary Prompt Agent (GEPA)"
|
| 1056 |
+
STATUS_READY = "System Ready"
|
| 1057 |
+
|
| 1058 |
+
with gr.Blocks(
|
| 1059 |
+
title="Universal Prompt Optimizer",
|
| 1060 |
+
theme=gr.themes.Base()
|
| 1061 |
+
) as app:
|
| 1062 |
+
dataset_state = gr.State([])
|
| 1063 |
+
|
| 1064 |
+
# TOP BAR
|
| 1065 |
+
gr.HTML(
|
| 1066 |
+
f"""
|
| 1067 |
+
<div class="topbar">
|
| 1068 |
+
<div class="brand-row">
|
| 1069 |
+
<div class="brand-left">
|
| 1070 |
+
<div class="brand-mark">GE</div>
|
| 1071 |
+
<div>
|
| 1072 |
+
<div class="h1">{APP_TITLE}</div>
|
| 1073 |
+
<div class="subtitle">{APP_SUBTITLE}</div>
|
| 1074 |
+
</div>
|
| 1075 |
+
</div>
|
| 1076 |
+
<div class="status-pill"><span class="dot"></span> {STATUS_READY}</div>
|
| 1077 |
+
</div>
|
| 1078 |
+
</div>
|
| 1079 |
+
""",
|
| 1080 |
+
elem_classes=["topbar-wrap"]
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# MAIN LAYOUT
|
| 1084 |
+
with gr.Row():
|
| 1085 |
+
|
| 1086 |
+
# LEFT COLUMN: Configuration
|
| 1087 |
+
with gr.Column(scale=5):
|
| 1088 |
+
|
| 1089 |
+
# Step 1
|
| 1090 |
+
with gr.Group(elem_classes=["card"]):
|
| 1091 |
+
gr.HTML(
|
| 1092 |
+
"""
|
| 1093 |
+
<div class="card-head">
|
| 1094 |
+
<div class="card-title"><div class="step">1</div> Model & Credentials</div>
|
| 1095 |
+
<div class="hint">Select a target model, then provide keys (stored in-session only).</div>
|
| 1096 |
+
</div>
|
| 1097 |
+
"""
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
with gr.Row():
|
| 1101 |
+
model_select = gr.Dropdown(
|
| 1102 |
+
label="Foundation Model",
|
| 1103 |
+
choices=[
|
| 1104 |
+
"openai/gpt-4o",
|
| 1105 |
+
"openai/gpt-4-turbo",
|
| 1106 |
+
"anthropic/claude-3-5-sonnet",
|
| 1107 |
+
"google/gemini-1.5-pro",
|
| 1108 |
+
"custom"
|
| 1109 |
+
],
|
| 1110 |
+
value="openai/gpt-4o",
|
| 1111 |
+
scale=2
|
| 1112 |
+
)
|
| 1113 |
+
custom_model_input = gr.Textbox(
|
| 1114 |
+
label="Custom Model ID",
|
| 1115 |
+
placeholder="provider/model_name",
|
| 1116 |
+
scale=1
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
gr.HTML('<div class="subsection-title">API Access Keys</div>')
|
| 1120 |
+
gr.Markdown("*Keys are stored in-session only and never logged*", elem_classes=["text-xs"])
|
| 1121 |
+
|
| 1122 |
+
with gr.Row():
|
| 1123 |
+
key_openai = gr.Textbox(
|
| 1124 |
+
label="OpenAI API Key",
|
| 1125 |
+
type="password",
|
| 1126 |
+
placeholder="sk-...",
|
| 1127 |
+
scale=1
|
| 1128 |
+
)
|
| 1129 |
+
key_google = gr.Textbox(
|
| 1130 |
+
label="Google API Key",
|
| 1131 |
+
type="password",
|
| 1132 |
+
placeholder="AIza...",
|
| 1133 |
+
scale=1
|
| 1134 |
+
)
|
| 1135 |
+
key_anthropic = gr.Textbox(
|
| 1136 |
+
label="Anthropic API Key",
|
| 1137 |
+
type="password",
|
| 1138 |
+
placeholder="sk-ant...",
|
| 1139 |
+
scale=1
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
# Step 2
|
| 1143 |
+
with gr.Group(elem_classes=["card"]):
|
| 1144 |
+
gr.HTML(
|
| 1145 |
+
"""
|
| 1146 |
+
<div class="card-head">
|
| 1147 |
+
<div class="card-title"><div class="step">2</div> Seed Prompt</div>
|
| 1148 |
+
<div class="hint">Describe the task, constraints, output format, and tone.</div>
|
| 1149 |
+
</div>
|
| 1150 |
+
"""
|
| 1151 |
+
)
|
| 1152 |
+
seed_input = gr.Textbox(
|
| 1153 |
+
label="Task Description",
|
| 1154 |
+
placeholder="Example: You are a code reviewer that identifies security vulnerabilities in Python code. Return a JSON report with severity and fixes...",
|
| 1155 |
+
lines=7,
|
| 1156 |
+
max_lines=14,
|
| 1157 |
+
elem_classes=["seed", "mono"]
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
# Step 3
|
| 1161 |
+
with gr.Group(elem_classes=["card"]):
|
| 1162 |
+
gr.HTML(
|
| 1163 |
+
"""
|
| 1164 |
+
<div class="card-head">
|
| 1165 |
+
<div class="card-title"><div class="step">3</div> Training Examples</div>
|
| 1166 |
+
<div class="hint">Add a few high-quality I/O pairs (images optional) to shape the optimizer.</div>
|
| 1167 |
+
</div>
|
| 1168 |
+
"""
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
with gr.Tabs():
|
| 1172 |
+
with gr.Tab("Manual Entry"):
|
| 1173 |
+
with gr.Row():
|
| 1174 |
+
with gr.Column(scale=2):
|
| 1175 |
+
d_in = gr.Textbox(
|
| 1176 |
+
label="Input / User Prompt",
|
| 1177 |
+
placeholder="Example user input...",
|
| 1178 |
+
lines=3
|
| 1179 |
+
)
|
| 1180 |
+
d_out = gr.Textbox(
|
| 1181 |
+
label="Ideal Output",
|
| 1182 |
+
placeholder="Expected AI response...",
|
| 1183 |
+
lines=3
|
| 1184 |
+
)
|
| 1185 |
+
with gr.Column(scale=1):
|
| 1186 |
+
d_img = gr.Image(
|
| 1187 |
+
label="Attach Image (Optional)",
|
| 1188 |
+
type="numpy",
|
| 1189 |
+
height=170
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
btn_add = gr.Button(
|
| 1193 |
+
"Add Example",
|
| 1194 |
+
elem_classes=["btn-secondary"]
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
with gr.Tab("Bulk Import (JSON)"):
|
| 1198 |
+
gr.Markdown(
|
| 1199 |
+
"Paste a JSON array like: `[{\"input\": \"...\", \"output\": \"...\"}]`",
|
| 1200 |
+
elem_classes=["small-note"]
|
| 1201 |
+
)
|
| 1202 |
+
bulk_json = gr.Textbox(
|
| 1203 |
+
show_label=False,
|
| 1204 |
+
placeholder='[{"input": "...", "output": "..."}]',
|
| 1205 |
+
lines=6
|
| 1206 |
+
)
|
| 1207 |
+
btn_import = gr.Button(
|
| 1208 |
+
"Import JSON",
|
| 1209 |
+
elem_classes=["btn-secondary"]
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
with gr.Row():
|
| 1213 |
+
gr.HTML("<div class='hint'>Current dataset</div>")
|
| 1214 |
+
ds_count = gr.HTML(
|
| 1215 |
+
"<span style='color: var(--text-secondary);'>0 examples loaded</span>",
|
| 1216 |
+
elem_classes=["ds-count"]
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
ds_table = gr.Dataframe(
|
| 1220 |
+
headers=["ID", "Input", "Output", "Media"],
|
| 1221 |
+
datatype=["number", "str", "str", "str"],
|
| 1222 |
+
row_count=6,
|
| 1223 |
+
column_count=(4, "fixed"),
|
| 1224 |
+
interactive=False
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
with gr.Row():
|
| 1228 |
+
btn_clear = gr.Button(
|
| 1229 |
+
"Clear All",
|
| 1230 |
+
elem_classes=["btn-danger"],
|
| 1231 |
+
size="sm"
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# Step 4 (Prominent, not buried)
|
| 1235 |
+
with gr.Group(elem_classes=["card"]):
|
| 1236 |
+
gr.HTML(
|
| 1237 |
+
"""
|
| 1238 |
+
<div class="card-head">
|
| 1239 |
+
<div class="card-title"><div class="step">4</div> Optimization Controls</div>
|
| 1240 |
+
<div class="hint">Tune evolution budget. Defaults are safe for quick runs.</div>
|
| 1241 |
+
</div>
|
| 1242 |
+
"""
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
with gr.Row():
|
| 1246 |
+
slider_iter = gr.Slider(
|
| 1247 |
+
minimum=1,
|
| 1248 |
+
maximum=20,
|
| 1249 |
+
value=5,
|
| 1250 |
+
step=1,
|
| 1251 |
+
label="Evolution Rounds",
|
| 1252 |
+
info="Number of genetic iterations"
|
| 1253 |
+
)
|
| 1254 |
+
slider_calls = gr.Slider(
|
| 1255 |
+
minimum=10,
|
| 1256 |
+
maximum=200,
|
| 1257 |
+
value=50,
|
| 1258 |
+
step=10,
|
| 1259 |
+
label="Max LLM Calls",
|
| 1260 |
+
info="Total API call budget"
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
with gr.Row():
|
| 1264 |
+
slider_batch = gr.Slider(
|
| 1265 |
+
minimum=1,
|
| 1266 |
+
maximum=10,
|
| 1267 |
+
value=4,
|
| 1268 |
+
step=1,
|
| 1269 |
+
label="Batch Size",
|
| 1270 |
+
info="Candidates per iteration"
|
| 1271 |
+
)
|
| 1272 |
+
check_llego = gr.Checkbox(
|
| 1273 |
+
value=True,
|
| 1274 |
+
label="Enable LLEGO Crossover",
|
| 1275 |
+
info="Use advanced genetic operations"
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
btn_optimize = gr.Button(
|
| 1279 |
+
"Start Optimization",
|
| 1280 |
+
elem_classes=["cta", "mt-6"]
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
# RIGHT: STATUS + RESULTS
|
| 1284 |
+
with gr.Column(scale=5, elem_classes=["right-col"]):
|
| 1285 |
+
# STATUS PANEL (Hidden by default)
|
| 1286 |
+
status_panel = gr.Group(visible=False, elem_classes=["panel"])
|
| 1287 |
+
with status_panel:
|
| 1288 |
+
gr.HTML(
|
| 1289 |
+
"""
|
| 1290 |
+
<div class="panel-title">
|
| 1291 |
+
<h3>Optimization status</h3>
|
| 1292 |
+
<div class="running-pill"><span class="running-dot"></span> Running</div>
|
| 1293 |
+
</div>
|
| 1294 |
+
"""
|
| 1295 |
+
)
|
| 1296 |
+
txt_status = gr.Markdown("Initializing genetic algorithm...")
|
| 1297 |
+
|
| 1298 |
+
# EMPTY STATE
|
| 1299 |
+
empty_state = gr.HTML(
|
| 1300 |
+
"""
|
| 1301 |
+
<div class="empty">
|
| 1302 |
+
<div class="big">🧬</div>
|
| 1303 |
+
<div class="t">Ready to optimize</div>
|
| 1304 |
+
<div class="s">Fill Steps 1–3, then click <b>Start Optimization</b> to begin prompt evolution.</div>
|
| 1305 |
+
</div>
|
| 1306 |
+
""",
|
| 1307 |
+
visible=True
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
# RESULTS PANEL (Hidden by default)
|
| 1311 |
+
results_panel = gr.Group(visible=False, elem_classes=["results"])
|
| 1312 |
+
with results_panel:
|
| 1313 |
+
gr.HTML(
|
| 1314 |
+
"""
|
| 1315 |
+
<div class="results-banner">
|
| 1316 |
+
<div class="k">
|
| 1317 |
+
<div class="icon">✓</div>
|
| 1318 |
+
<div>
|
| 1319 |
+
<div class="title">Optimization successful</div>
|
| 1320 |
+
<div class="sub">Review the optimized prompt, metrics, and evolution traces.</div>
|
| 1321 |
+
</div>
|
| 1322 |
+
</div>
|
| 1323 |
+
</div>
|
| 1324 |
+
"""
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
with gr.Tabs():
|
| 1328 |
+
with gr.Tab("Optimized Prompt"):
|
| 1329 |
+
res_prompt = gr.Textbox(
|
| 1330 |
+
label="Optimized Prompt",
|
| 1331 |
+
lines=18,
|
| 1332 |
+
max_lines=28,
|
| 1333 |
+
interactive=False,
|
| 1334 |
+
show_label=True,
|
| 1335 |
+
elem_classes=["mono"]
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
with gr.Tab("Metrics & Log"):
|
| 1339 |
+
res_metrics = gr.JSON(label="Performance Gains")
|
| 1340 |
+
res_history = gr.TextArea(
|
| 1341 |
+
label="Evolution Log",
|
| 1342 |
+
interactive=False,
|
| 1343 |
+
lines=10
|
| 1344 |
+
)
|
| 1345 |
+
|
| 1346 |
+
with gr.Tab("🧬 Live Candidates"):
|
| 1347 |
+
gr.Markdown("Real-time stream of generated prompt candidates during optimization:")
|
| 1348 |
+
live_candidates = gr.HTML()
|
| 1349 |
+
btn_refresh_cand = gr.Button(
|
| 1350 |
+
"🔄 Refresh Stream",
|
| 1351 |
+
elem_classes=["secondary-btn"],
|
| 1352 |
+
size="sm"
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
# ==========================================
|
| 1356 |
+
# 6. EVENT HANDLERS
|
| 1357 |
+
# ==========================================
|
| 1358 |
+
|
| 1359 |
+
# Dataset Management
|
| 1360 |
+
def update_dataset_count(dataset):
|
| 1361 |
+
"""Update dataset count display with error handling."""
|
| 1362 |
+
try:
|
| 1363 |
+
if not isinstance(dataset, list):
|
| 1364 |
+
return "<span style='color: var(--text-secondary);'>0 examples loaded</span>"
|
| 1365 |
+
count = len(dataset)
|
| 1366 |
+
return f"<span style='color: var(--text-secondary);'>{count} example{'s' if count != 1 else ''} loaded</span>"
|
| 1367 |
+
except Exception as e:
|
| 1368 |
+
logger.error(f"Error updating dataset count: {str(e)}")
|
| 1369 |
+
return "<span style='color: var(--text-secondary);'>Error</span>"
|
| 1370 |
+
|
| 1371 |
+
# Wrap event handlers with error handling
|
| 1372 |
+
def safe_add_example(*args):
|
| 1373 |
+
"""Wrapper for add_example with error handling."""
|
| 1374 |
+
try:
|
| 1375 |
+
return add_example(*args)
|
| 1376 |
+
except gr.Error:
|
| 1377 |
+
raise
|
| 1378 |
+
except Exception as e:
|
| 1379 |
+
logger.error(f"Unexpected error in add_example: {str(e)}")
|
| 1380 |
+
raise gr.Error(f"Failed to add example: {str(e)}")
|
| 1381 |
+
|
| 1382 |
+
def safe_update_table(dataset):
|
| 1383 |
+
"""Wrapper for update_table with error handling."""
|
| 1384 |
+
try:
|
| 1385 |
+
return update_table(dataset)
|
| 1386 |
+
except Exception as e:
|
| 1387 |
+
logger.error(f"Error updating table: {str(e)}")
|
| 1388 |
+
return []
|
| 1389 |
+
|
| 1390 |
+
def safe_clear_dataset():
|
| 1391 |
+
"""Wrapper for clear_dataset with error handling."""
|
| 1392 |
+
try:
|
| 1393 |
+
return clear_dataset()
|
| 1394 |
+
except Exception as e:
|
| 1395 |
+
logger.error(f"Error clearing dataset: {str(e)}")
|
| 1396 |
+
return [], []
|
| 1397 |
+
|
| 1398 |
+
btn_add.click(
|
| 1399 |
+
safe_add_example,
|
| 1400 |
+
inputs=[d_in, d_out, d_img, dataset_state],
|
| 1401 |
+
outputs=[dataset_state, d_in, d_out, d_img]
|
| 1402 |
+
).then(
|
| 1403 |
+
safe_update_table,
|
| 1404 |
+
inputs=[dataset_state],
|
| 1405 |
+
outputs=[ds_table]
|
| 1406 |
+
).then(
|
| 1407 |
+
update_dataset_count,
|
| 1408 |
+
inputs=[dataset_state],
|
| 1409 |
+
outputs=[ds_count]
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
btn_clear.click(
|
| 1413 |
+
safe_clear_dataset,
|
| 1414 |
+
outputs=[dataset_state, ds_table]
|
| 1415 |
+
).then(
|
| 1416 |
+
lambda: "<span style='color: var(--text-secondary);'>0 examples loaded</span>",
|
| 1417 |
+
outputs=[ds_count]
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
# Bulk Import
|
| 1421 |
+
def import_bulk_json(json_text, current_dataset):
|
| 1422 |
+
"""Import examples from JSON with comprehensive error handling."""
|
| 1423 |
+
try:
|
| 1424 |
+
# Validate inputs
|
| 1425 |
+
if not json_text or not json_text.strip():
|
| 1426 |
+
raise gr.Error("JSON input is empty. Please provide a JSON array.")
|
| 1427 |
+
|
| 1428 |
+
if not isinstance(current_dataset, list):
|
| 1429 |
+
raise gr.Error("Dataset state is invalid. Please refresh the page.")
|
| 1430 |
+
|
| 1431 |
+
# Parse JSON
|
| 1432 |
+
try:
|
| 1433 |
+
data = json.loads(json_text.strip())
|
| 1434 |
+
except json.JSONDecodeError as e:
|
| 1435 |
+
raise gr.Error(f"Invalid JSON format: {str(e)}. Please check your JSON syntax.")
|
| 1436 |
+
|
| 1437 |
+
# Validate structure
|
| 1438 |
+
if not isinstance(data, list):
|
| 1439 |
+
raise gr.Error("JSON must be an array of objects. Example: [{\"input\": \"...\", \"output\": \"...\"}]")
|
| 1440 |
+
|
| 1441 |
+
if len(data) == 0:
|
| 1442 |
+
raise gr.Error("JSON array is empty. Add at least one example object.")
|
| 1443 |
+
|
| 1444 |
+
# Validate and import items
|
| 1445 |
+
imported_count = 0
|
| 1446 |
+
errors = []
|
| 1447 |
+
|
| 1448 |
+
for i, item in enumerate(data):
|
| 1449 |
+
try:
|
| 1450 |
+
if not isinstance(item, dict):
|
| 1451 |
+
errors.append(f"Item {i+1}: not a dictionary")
|
| 1452 |
+
continue
|
| 1453 |
+
|
| 1454 |
+
if "input" not in item or "output" not in item:
|
| 1455 |
+
errors.append(f"Item {i+1}: missing 'input' or 'output' field")
|
| 1456 |
+
continue
|
| 1457 |
+
|
| 1458 |
+
input_val = item["input"]
|
| 1459 |
+
output_val = item["output"]
|
| 1460 |
+
|
| 1461 |
+
if not isinstance(input_val, str) or not isinstance(output_val, str):
|
| 1462 |
+
errors.append(f"Item {i+1}: 'input' and 'output' must be strings")
|
| 1463 |
+
continue
|
| 1464 |
+
|
| 1465 |
+
if not input_val.strip() or not output_val.strip():
|
| 1466 |
+
errors.append(f"Item {i+1}: 'input' and 'output' cannot be empty")
|
| 1467 |
+
continue
|
| 1468 |
+
|
| 1469 |
+
# Add valid item
|
| 1470 |
+
current_dataset.append({
|
| 1471 |
+
"input": input_val.strip(),
|
| 1472 |
+
"output": output_val.strip(),
|
| 1473 |
+
"image": item.get("image"), # Optional
|
| 1474 |
+
"image_preview": "🖼️ Image" if item.get("image") else "-"
|
| 1475 |
+
})
|
| 1476 |
+
imported_count += 1
|
| 1477 |
+
|
| 1478 |
+
except Exception as e:
|
| 1479 |
+
errors.append(f"Item {i+1}: {str(e)}")
|
| 1480 |
+
logger.warning(f"Error importing item {i+1}: {str(e)}")
|
| 1481 |
+
continue
|
| 1482 |
+
|
| 1483 |
+
# Report results
|
| 1484 |
+
if imported_count == 0:
|
| 1485 |
+
error_msg = "No valid examples imported. "
|
| 1486 |
+
if errors:
|
| 1487 |
+
error_msg += "Errors: " + "; ".join(errors[:3])
|
| 1488 |
+
if len(errors) > 3:
|
| 1489 |
+
error_msg += f" (and {len(errors) - 3} more)"
|
| 1490 |
+
raise gr.Error(error_msg)
|
| 1491 |
+
|
| 1492 |
+
if errors:
|
| 1493 |
+
warning_msg = f"Imported {imported_count} example(s). "
|
| 1494 |
+
if len(errors) <= 3:
|
| 1495 |
+
warning_msg += f"Warnings: {'; '.join(errors)}"
|
| 1496 |
+
else:
|
| 1497 |
+
warning_msg += f"{len(errors)} items had errors."
|
| 1498 |
+
logger.warning(warning_msg)
|
| 1499 |
+
|
| 1500 |
+
return current_dataset, ""
|
| 1501 |
+
|
| 1502 |
+
except gr.Error:
|
| 1503 |
+
# Re-raise Gradio errors
|
| 1504 |
+
raise
|
| 1505 |
+
except Exception as e:
|
| 1506 |
+
logger.error(f"Unexpected error in import_bulk_json: {str(e)}\n{traceback.format_exc()}")
|
| 1507 |
+
raise gr.Error(f"Failed to import JSON: {str(e)}")
|
| 1508 |
+
|
| 1509 |
+
btn_import.click(
|
| 1510 |
+
import_bulk_json,
|
| 1511 |
+
inputs=[bulk_json, dataset_state],
|
| 1512 |
+
outputs=[dataset_state, bulk_json]
|
| 1513 |
+
).then(
|
| 1514 |
+
safe_update_table,
|
| 1515 |
+
inputs=[dataset_state],
|
| 1516 |
+
outputs=[ds_table]
|
| 1517 |
+
).then(
|
| 1518 |
+
update_dataset_count,
|
| 1519 |
+
inputs=[dataset_state],
|
| 1520 |
+
outputs=[ds_count]
|
| 1521 |
+
)
|
| 1522 |
+
|
| 1523 |
+
# Main Optimization Flow
|
| 1524 |
+
btn_optimize.click(
|
| 1525 |
+
run_optimization_flow,
|
| 1526 |
+
inputs=[
|
| 1527 |
+
seed_input, dataset_state, model_select, custom_model_input,
|
| 1528 |
+
slider_iter, slider_calls, slider_batch, check_llego,
|
| 1529 |
+
key_openai, key_google, key_anthropic
|
| 1530 |
+
],
|
| 1531 |
+
outputs=[
|
| 1532 |
+
status_panel, empty_state, results_panel,
|
| 1533 |
+
txt_status, res_prompt, res_metrics, res_history, live_candidates
|
| 1534 |
+
]
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
# Refresh Candidates
|
| 1538 |
+
def safe_get_candidates_display():
|
| 1539 |
+
"""Wrapper for get_candidates_display with error handling."""
|
| 1540 |
+
try:
|
| 1541 |
+
return get_candidates_display()
|
| 1542 |
+
except Exception as e:
|
| 1543 |
+
logger.error(f"Error refreshing candidates: {str(e)}")
|
| 1544 |
+
return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>"
|
| 1545 |
+
|
| 1546 |
+
btn_refresh_cand.click(
|
| 1547 |
+
safe_get_candidates_display,
|
| 1548 |
+
outputs=[live_candidates]
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
# ==========================================
|
| 1552 |
+
# 7. LAUNCH
|
| 1553 |
+
# ==========================================
|
| 1554 |
+
if __name__ == "__main__":
|
| 1555 |
+
app.queue().launch(
|
| 1556 |
+
server_name="0.0.0.0",
|
| 1557 |
+
server_port=7860,
|
| 1558 |
+
share=False, # Set to False for HF Spaces
|
| 1559 |
+
show_error=True,
|
| 1560 |
+
css=CUSTOM_CSS,
|
| 1561 |
+
js=FORCE_DARK_JS
|
| 1562 |
+
)
|
| 1563 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies - gepa from git
|
| 2 |
+
git+https://github.com/gepa-ai/gepa.git
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
pandas>=1.5.0
|
| 5 |
+
pydantic>=2.0.0
|
| 6 |
+
python-dotenv>=1.0.0
|
| 7 |
+
|
| 8 |
+
# HTTP/API clients
|
| 9 |
+
requests>=2.31.0
|
| 10 |
+
aiohttp>=3.8.0
|
| 11 |
+
asyncio-throttle>=1.0.0
|
| 12 |
+
|
| 13 |
+
# LLM Provider SDKs
|
| 14 |
+
openai>=1.0.0
|
| 15 |
+
anthropic>=0.18.0
|
| 16 |
+
google-generativeai>=0.3.0
|
| 17 |
+
google-genai>=0.2.0
|
| 18 |
+
|
| 19 |
+
# Image processing
|
| 20 |
+
Pillow>=9.0.0
|
| 21 |
+
|
| 22 |
+
# Gradio UI (version will be set by README.md sdk_version)
|
| 23 |
+
gradio>=4.0.0
|
src/gepa_optimizer.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: gepa-optimizer
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Universal prompt optimization framework based on GEPA
|
| 5 |
+
Home-page: https://github.com/suhasb-dev/Prompt-Optimizer
|
| 6 |
+
Author: Suhas
|
| 7 |
+
Author-email: Suhas <s8hasgrylls@gmail.com>
|
| 8 |
+
License: MIT
|
| 9 |
+
Project-URL: Homepage, https://github.com/suhasb-dev/Prompt-Optimizer
|
| 10 |
+
Project-URL: Repository, https://github.com/suhasb-dev/Prompt-Optimizer
|
| 11 |
+
Project-URL: Documentation, https://suhasb-dev.gitbook.io/gepa-universal-prompt-optimizer/
|
| 12 |
+
Project-URL: Bug Reports, https://github.com/suhasb-dev/Prompt-Optimizer/issues
|
| 13 |
+
Keywords: prompt-optimization,llm,gepa,ai,machine-learning,ui-tree-extraction
|
| 14 |
+
Classifier: Development Status :: 3 - Alpha
|
| 15 |
+
Classifier: Intended Audience :: Developers
|
| 16 |
+
Classifier: Intended Audience :: Science/Research
|
| 17 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 18 |
+
Classifier: Programming Language :: Python :: 3
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 23 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 24 |
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
| 25 |
+
Requires-Python: >=3.8
|
| 26 |
+
Description-Content-Type: text/markdown
|
| 27 |
+
License-File: LICENSE
|
| 28 |
+
Requires-Dist: gepa>=0.0.12
|
| 29 |
+
Requires-Dist: pandas>=1.5.0
|
| 30 |
+
Requires-Dist: pydantic>=2.0.0
|
| 31 |
+
Requires-Dist: python-dotenv>=1.0.0
|
| 32 |
+
Requires-Dist: requests>=2.31.0
|
| 33 |
+
Requires-Dist: aiohttp>=3.8.0
|
| 34 |
+
Requires-Dist: asyncio-throttle>=1.0.0
|
| 35 |
+
Requires-Dist: google-generativeai>=0.3.0
|
| 36 |
+
Requires-Dist: Pillow>=9.0.0
|
| 37 |
+
Provides-Extra: dev
|
| 38 |
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
| 39 |
+
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
| 40 |
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
| 41 |
+
Requires-Dist: flake8>=6.0.0; extra == "dev"
|
| 42 |
+
Requires-Dist: mypy>=1.0.0; extra == "dev"
|
| 43 |
+
Provides-Extra: docs
|
| 44 |
+
Requires-Dist: sphinx>=5.0.0; extra == "docs"
|
| 45 |
+
Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == "docs"
|
| 46 |
+
Provides-Extra: all
|
| 47 |
+
Requires-Dist: pytest>=7.0.0; extra == "all"
|
| 48 |
+
Requires-Dist: pytest-asyncio>=0.21.0; extra == "all"
|
| 49 |
+
Requires-Dist: black>=23.0.0; extra == "all"
|
| 50 |
+
Requires-Dist: flake8>=6.0.0; extra == "all"
|
| 51 |
+
Requires-Dist: mypy>=1.0.0; extra == "all"
|
| 52 |
+
Requires-Dist: sphinx>=5.0.0; extra == "all"
|
| 53 |
+
Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == "all"
|
| 54 |
+
Dynamic: author
|
| 55 |
+
Dynamic: home-page
|
| 56 |
+
Dynamic: license-file
|
| 57 |
+
Dynamic: requires-python
|
| 58 |
+
|
| 59 |
+
# GEPA Optimizer
|
| 60 |
+
|
| 61 |
+
[](https://badge.fury.io/py/gepa-optimizer)
|
| 62 |
+
[](https://www.python.org/downloads/)
|
| 63 |
+
[](https://opensource.org/licenses/MIT)
|
| 64 |
+
|
| 65 |
+
A universal prompt optimization framework built on [GEPA](https://arxiv.org/abs/2507.19457) with optional [LLEGO](https://arxiv.org/abs/2503.14217) genetic operators for accelerated convergence.
|
| 66 |
+
|
| 67 |
+
## Overview
|
| 68 |
+
|
| 69 |
+
GEPA Optimizer provides a modular architecture for optimizing prompts through reflective evolution. It requires custom evaluators and LLM clients, enabling domain-specific optimization for any use case.
|
| 70 |
+
|
| 71 |
+
**Key capabilities:**
|
| 72 |
+
- Multi-modal support (text + vision models)
|
| 73 |
+
- Hybrid GEPA + LLEGO optimization modes
|
| 74 |
+
- Configurable train/val/test data splitting
|
| 75 |
+
- Batch API support for cost reduction
|
| 76 |
+
- Async-first architecture
|
| 77 |
+
|
| 78 |
+
## Installation
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
pip install gepa-optimizer
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
**From source:**
|
| 85 |
+
```bash
|
| 86 |
+
git clone https://github.com/suhasb-dev/Prompt-Optimizer.git
|
| 87 |
+
cd Prompt-Optimizer
|
| 88 |
+
pip install -e .
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Quick Start
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
import asyncio
|
| 95 |
+
from gepa_optimizer import (
|
| 96 |
+
GepaOptimizer,
|
| 97 |
+
OptimizationConfig,
|
| 98 |
+
BaseEvaluator,
|
| 99 |
+
BaseLLMClient
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Define custom evaluator
|
| 103 |
+
class MyEvaluator(BaseEvaluator):
|
| 104 |
+
def evaluate(self, predicted: str, expected: str) -> dict:
|
| 105 |
+
score = 1.0 if predicted.strip() == expected.strip() else 0.0
|
| 106 |
+
return {"accuracy": score, "composite_score": score}
|
| 107 |
+
|
| 108 |
+
# Define custom LLM client
|
| 109 |
+
class MyLLMClient(BaseLLMClient):
|
| 110 |
+
def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> dict:
|
| 111 |
+
# Your LLM integration here
|
| 112 |
+
return {"content": "response"}
|
| 113 |
+
|
| 114 |
+
async def main():
|
| 115 |
+
config = OptimizationConfig(
|
| 116 |
+
model="openai/gpt-4o",
|
| 117 |
+
reflection_model="openai/gpt-4o",
|
| 118 |
+
max_iterations=5,
|
| 119 |
+
max_metric_calls=50,
|
| 120 |
+
batch_size=8
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
optimizer = GepaOptimizer(
|
| 124 |
+
config=config,
|
| 125 |
+
llm_client=MyLLMClient("openai", "gpt-4o"),
|
| 126 |
+
evaluator=MyEvaluator()
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
result = await optimizer.train(
|
| 130 |
+
seed_prompt="Your initial prompt",
|
| 131 |
+
dataset=your_dataset
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
print(f"Optimized: {result.prompt}")
|
| 135 |
+
print(f"Score: {result.improvement_data}")
|
| 136 |
+
|
| 137 |
+
asyncio.run(main())
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Project Structure
|
| 141 |
+
|
| 142 |
+
```
|
| 143 |
+
src/gepa_optimizer/
|
| 144 |
+
├── core/ # Core optimization logic
|
| 145 |
+
│ ├── optimizer.py # GepaOptimizer main class
|
| 146 |
+
�� ├── base_adapter.py # BaseGepaAdapter interface
|
| 147 |
+
│ └── universal_adapter.py
|
| 148 |
+
├── evaluation/ # Evaluator implementations
|
| 149 |
+
│ ├── base_evaluator.py # BaseEvaluator abstract class
|
| 150 |
+
│ ├── scroll_evaluator.py
|
| 151 |
+
│ ├── validation_evaluator.py
|
| 152 |
+
│ └── index_caching_evaluator.py
|
| 153 |
+
├── llms/ # LLM client implementations
|
| 154 |
+
│ ├── base_llm.py # BaseLLMClient abstract class
|
| 155 |
+
│ ├── vision_llm.py # VisionLLMClient (OpenAI, Google, Anthropic)
|
| 156 |
+
│ └── batch_llm.py # BatchLLMClient (50% cost savings)
|
| 157 |
+
├── operators/ # LLEGO genetic operators
|
| 158 |
+
│ └── llego_operators.py # FitnessGuidedCrossover, DiversityGuidedMutation
|
| 159 |
+
├── data/ # Dataset loaders and converters
|
| 160 |
+
├── models/ # Configuration and result models
|
| 161 |
+
└── utils/ # Utilities and helpers
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
## Configuration
|
| 165 |
+
|
| 166 |
+
### Basic Configuration
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
from gepa_optimizer import OptimizationConfig, ModelConfig
|
| 170 |
+
|
| 171 |
+
config = OptimizationConfig(
|
| 172 |
+
# Required parameters
|
| 173 |
+
model="openai/gpt-4o", # or ModelConfig instance
|
| 174 |
+
reflection_model="openai/gpt-4o",
|
| 175 |
+
max_iterations=10,
|
| 176 |
+
max_metric_calls=100,
|
| 177 |
+
batch_size=8,
|
| 178 |
+
|
| 179 |
+
# Data splitting (train/val/test)
|
| 180 |
+
data_split=DataSplitConfig(
|
| 181 |
+
train_ratio=0.6,
|
| 182 |
+
val_ratio=0.2,
|
| 183 |
+
test_ratio=0.2
|
| 184 |
+
),
|
| 185 |
+
|
| 186 |
+
# Optional settings
|
| 187 |
+
reflection_examples=3, # Examples per reflection (2-5 recommended)
|
| 188 |
+
evaluate_on_test=True, # Final evaluation on held-out test set
|
| 189 |
+
log_level="INFO" # DEBUG, INFO, WARNING, ERROR
|
| 190 |
+
)
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### LLEGO Genetic Operators
|
| 194 |
+
|
| 195 |
+
Enable LLEGO for faster convergence through fitness-guided crossover and diversity-guided mutation:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
config = OptimizationConfig(
|
| 199 |
+
model="openai/gpt-4o",
|
| 200 |
+
reflection_model="openai/gpt-4o",
|
| 201 |
+
max_iterations=5,
|
| 202 |
+
max_metric_calls=50,
|
| 203 |
+
batch_size=8,
|
| 204 |
+
|
| 205 |
+
# Enable LLEGO
|
| 206 |
+
use_llego_operators=True,
|
| 207 |
+
alpha=0.15, # Fitness extrapolation factor
|
| 208 |
+
tau=10.0, # Diversity temperature
|
| 209 |
+
nu=4, # Parent arity
|
| 210 |
+
n_crossover=2, # Crossover offspring per iteration
|
| 211 |
+
n_mutation=3, # Mutation offspring per iteration
|
| 212 |
+
population_size=15
|
| 213 |
+
)
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
### Hybrid Mode (GEPA + LLEGO)
|
| 217 |
+
|
| 218 |
+
Combine GEPA's semantic reflection with LLEGO's structural diversity:
|
| 219 |
+
|
| 220 |
+
```python
|
| 221 |
+
config = OptimizationConfig(
|
| 222 |
+
model="openai/gpt-4o",
|
| 223 |
+
reflection_model="openai/gpt-4o",
|
| 224 |
+
max_iterations=6,
|
| 225 |
+
max_metric_calls=200,
|
| 226 |
+
batch_size=10,
|
| 227 |
+
|
| 228 |
+
# Hybrid mode
|
| 229 |
+
use_llego_operators=True,
|
| 230 |
+
enable_gepa_reflection_with_llego=True,
|
| 231 |
+
num_gepa_reflection_candidates=3,
|
| 232 |
+
n_crossover=3,
|
| 233 |
+
n_mutation=3
|
| 234 |
+
# Total: 9 candidates per iteration (3 GEPA + 3 crossover + 3 mutation)
|
| 235 |
+
)
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
### Batch API (Cost Optimization)
|
| 239 |
+
|
| 240 |
+
Use batch processing for 50% cost reduction:
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
from gepa_optimizer.llms import BatchLLMClient
|
| 244 |
+
|
| 245 |
+
llm_client = BatchLLMClient(
|
| 246 |
+
provider="google",
|
| 247 |
+
model_name="gemini-2.5-flash",
|
| 248 |
+
batch_size=20,
|
| 249 |
+
polling_interval=30
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
optimizer = GepaOptimizer(
|
| 253 |
+
config=config,
|
| 254 |
+
llm_client=llm_client,
|
| 255 |
+
evaluator=evaluator
|
| 256 |
+
)
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
## Built-in Components
|
| 260 |
+
|
| 261 |
+
### LLM Clients
|
| 262 |
+
|
| 263 |
+
| Client | Description | Use Case |
|
| 264 |
+
|--------|-------------|----------|
|
| 265 |
+
| `VisionLLMClient` | Multi-modal client for OpenAI, Google, Anthropic | Real-time requests |
|
| 266 |
+
| `BatchLLMClient` | Batch processing client | Cost-sensitive workloads |
|
| 267 |
+
|
| 268 |
+
### Evaluators
|
| 269 |
+
|
| 270 |
+
| Evaluator | Description |
|
| 271 |
+
|-----------|-------------|
|
| 272 |
+
| `ScrollElementEvaluator` | UI element detection scoring |
|
| 273 |
+
| `ValidationEvaluator` | Screen validation tasks |
|
| 274 |
+
| `IndexCachingEvaluator` | Index-based element selection |
|
| 275 |
+
| `UITreeEvaluator` | UI tree extraction |
|
| 276 |
+
|
| 277 |
+
### Dataset Loaders
|
| 278 |
+
|
| 279 |
+
| Loader | Description |
|
| 280 |
+
|--------|-------------|
|
| 281 |
+
| `load_scroll_dataset()` | Load scroll detection datasets |
|
| 282 |
+
| `load_validation_split()` | Load validation datasets with splits |
|
| 283 |
+
| `load_index_caching_split()` | Load index caching datasets |
|
| 284 |
+
|
| 285 |
+
## Creating Custom Components
|
| 286 |
+
|
| 287 |
+
### Custom Evaluator
|
| 288 |
+
|
| 289 |
+
```python
|
| 290 |
+
from gepa_optimizer import BaseEvaluator
|
| 291 |
+
|
| 292 |
+
class CustomEvaluator(BaseEvaluator):
|
| 293 |
+
def __init__(self):
|
| 294 |
+
super().__init__(metric_weights={
|
| 295 |
+
"accuracy": 0.5,
|
| 296 |
+
"completeness": 0.3,
|
| 297 |
+
"format": 0.2
|
| 298 |
+
})
|
| 299 |
+
|
| 300 |
+
def evaluate(self, predicted: str, expected: str) -> dict:
|
| 301 |
+
accuracy = self._compute_accuracy(predicted, expected)
|
| 302 |
+
completeness = self._compute_completeness(predicted, expected)
|
| 303 |
+
format_score = self._compute_format(predicted)
|
| 304 |
+
|
| 305 |
+
composite = (
|
| 306 |
+
accuracy * 0.5 +
|
| 307 |
+
completeness * 0.3 +
|
| 308 |
+
format_score * 0.2
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return {
|
| 312 |
+
"accuracy": accuracy,
|
| 313 |
+
"completeness": completeness,
|
| 314 |
+
"format": format_score,
|
| 315 |
+
"composite_score": composite # Required key
|
| 316 |
+
}
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
### Custom LLM Client
|
| 320 |
+
|
| 321 |
+
```python
|
| 322 |
+
from gepa_optimizer import BaseLLMClient
|
| 323 |
+
|
| 324 |
+
class CustomLLMClient(BaseLLMClient):
|
| 325 |
+
def __init__(self, api_key: str):
|
| 326 |
+
super().__init__(provider="custom", model_name="my-model")
|
| 327 |
+
self.api_key = api_key
|
| 328 |
+
|
| 329 |
+
def generate(
|
| 330 |
+
self,
|
| 331 |
+
system_prompt: str,
|
| 332 |
+
user_prompt: str,
|
| 333 |
+
image_base64: str = None,
|
| 334 |
+
**kwargs
|
| 335 |
+
) -> dict:
|
| 336 |
+
# Your API call here
|
| 337 |
+
response = call_your_api(system_prompt, user_prompt, image_base64)
|
| 338 |
+
return {"content": response}
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
## Examples
|
| 342 |
+
|
| 343 |
+
| File | Description |
|
| 344 |
+
|------|-------------|
|
| 345 |
+
| [`examples/basic_usage.py`](examples/basic_usage.py) | Basic optimization workflow |
|
| 346 |
+
| [`examples/advanced_usage.py`](examples/advanced_usage.py) | Advanced configuration |
|
| 347 |
+
| [`examples/batch_api_example.py`](examples/batch_api_example.py) | Batch API usage |
|
| 348 |
+
| [`examples/gemini_usage.py`](examples/gemini_usage.py) | Google Gemini integration |
|
| 349 |
+
|
| 350 |
+
**Run examples:**
|
| 351 |
+
```bash
|
| 352 |
+
python examples/basic_usage.py
|
| 353 |
+
```
|
| 354 |
+
|
| 355 |
+
## Testing
|
| 356 |
+
|
| 357 |
+
```bash
|
| 358 |
+
# Run all tests
|
| 359 |
+
pytest tests/
|
| 360 |
+
|
| 361 |
+
# Run unit tests only
|
| 362 |
+
pytest tests/unit/
|
| 363 |
+
|
| 364 |
+
# Run integration tests
|
| 365 |
+
pytest tests/integration/
|
| 366 |
+
```
|
| 367 |
+
|
| 368 |
+
## API Reference
|
| 369 |
+
|
| 370 |
+
### GepaOptimizer
|
| 371 |
+
|
| 372 |
+
```python
|
| 373 |
+
class GepaOptimizer:
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
config: OptimizationConfig,
|
| 377 |
+
llm_client: BaseLLMClient,
|
| 378 |
+
evaluator: BaseEvaluator,
|
| 379 |
+
adapter_type: str = "universal"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
async def train(
|
| 383 |
+
self,
|
| 384 |
+
seed_prompt: str,
|
| 385 |
+
dataset: Union[List, Dict],
|
| 386 |
+
**kwargs
|
| 387 |
+
) -> OptimizedResult
|
| 388 |
+
```
|
| 389 |
+
|
| 390 |
+
### OptimizationConfig
|
| 391 |
+
|
| 392 |
+
| Parameter | Type | Default | Description |
|
| 393 |
+
|-----------|------|---------|-------------|
|
| 394 |
+
| `model` | `str \| ModelConfig` | Required | Target model |
|
| 395 |
+
| `reflection_model` | `str \| ModelConfig` | Required | Reflection model |
|
| 396 |
+
| `max_iterations` | `int` | Required | Maximum optimization iterations |
|
| 397 |
+
| `max_metric_calls` | `int` | Required | Maximum evaluation calls |
|
| 398 |
+
| `batch_size` | `int` | Required | Samples per evaluation batch |
|
| 399 |
+
| `use_llego_operators` | `bool` | `False` | Enable LLEGO genetic operators |
|
| 400 |
+
| `enable_gepa_reflection_with_llego` | `bool` | `False` | Enable hybrid mode |
|
| 401 |
+
| `use_llm_as_judge` | `bool` | `True` | Enable LLM-as-Judge feedback |
|
| 402 |
+
| `log_level` | `str` | `"INFO"` | Logging verbosity |
|
| 403 |
+
|
| 404 |
+
### OptimizedResult
|
| 405 |
+
|
| 406 |
+
| Attribute | Type | Description |
|
| 407 |
+
|-----------|------|-------------|
|
| 408 |
+
| `prompt` | `str` | Optimized prompt |
|
| 409 |
+
| `original_prompt` | `str` | Initial seed prompt |
|
| 410 |
+
| `improvement_data` | `dict` | Score improvements |
|
| 411 |
+
| `optimization_time` | `float` | Total time in seconds |
|
| 412 |
+
| `is_successful` | `bool` | Optimization success status |
|
| 413 |
+
|
| 414 |
+
## Environment Variables
|
| 415 |
+
|
| 416 |
+
| Variable | Description |
|
| 417 |
+
|----------|-------------|
|
| 418 |
+
| `OPENAI_API_KEY` | OpenAI API key |
|
| 419 |
+
| `ANTHROPIC_API_KEY` | Anthropic API key |
|
| 420 |
+
| `GOOGLE_API_KEY` | Google AI API key |
|
| 421 |
+
|
| 422 |
+
## References
|
| 423 |
+
|
| 424 |
+
- **GEPA Paper:** [Reflective Prompt Evolution Can Outperform Reinforcement Learning](https://arxiv.org/abs/2507.19457)
|
| 425 |
+
- **LLEGO Paper:** [Decision Tree Induction Through LLMs via Semantically-Aware Evolution](https://arxiv.org/abs/2503.14217)
|
| 426 |
+
- **GEPA Library:** [github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa)
|
| 427 |
+
|
| 428 |
+
## License
|
| 429 |
+
|
| 430 |
+
MIT License - see [LICENSE](LICENSE) for details.
|
| 431 |
+
|
| 432 |
+
## Contributing
|
| 433 |
+
|
| 434 |
+
Contributions welcome. Please open an issue or submit a pull request.
|
| 435 |
+
|
| 436 |
+
## Support
|
| 437 |
+
|
| 438 |
+
- **Issues:** [GitHub Issues](https://github.com/suhasb-dev/Prompt-Optimizer/issues)
|
| 439 |
+
- **Documentation:** [GitBook](https://suhasb-dev.gitbook.io/gepa-universal-prompt-optimizer/)
|
src/gepa_optimizer.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
setup.py
|
| 5 |
+
src/gepa_optimizer/__init__.py
|
| 6 |
+
src/gepa_optimizer/cli.py
|
| 7 |
+
src/gepa_optimizer/types.py
|
| 8 |
+
src/gepa_optimizer/version.py
|
| 9 |
+
src/gepa_optimizer.egg-info/PKG-INFO
|
| 10 |
+
src/gepa_optimizer.egg-info/SOURCES.txt
|
| 11 |
+
src/gepa_optimizer.egg-info/dependency_links.txt
|
| 12 |
+
src/gepa_optimizer.egg-info/entry_points.txt
|
| 13 |
+
src/gepa_optimizer.egg-info/requires.txt
|
| 14 |
+
src/gepa_optimizer.egg-info/top_level.txt
|
| 15 |
+
src/gepa_optimizer/core/__init__.py
|
| 16 |
+
src/gepa_optimizer/core/base_adapter.py
|
| 17 |
+
src/gepa_optimizer/core/custom_adapter.py
|
| 18 |
+
src/gepa_optimizer/core/optimizer.py
|
| 19 |
+
src/gepa_optimizer/core/result.py
|
| 20 |
+
src/gepa_optimizer/core/universal_adapter.py
|
| 21 |
+
src/gepa_optimizer/data/__init__.py
|
| 22 |
+
src/gepa_optimizer/data/converters.py
|
| 23 |
+
src/gepa_optimizer/data/index_caching_loader.py
|
| 24 |
+
src/gepa_optimizer/data/loaders.py
|
| 25 |
+
src/gepa_optimizer/data/scroll_dataset_loader.py
|
| 26 |
+
src/gepa_optimizer/data/validation_dataset_loader.py
|
| 27 |
+
src/gepa_optimizer/data/validators.py
|
| 28 |
+
src/gepa_optimizer/evaluation/__init__.py
|
| 29 |
+
src/gepa_optimizer/evaluation/base_evaluator.py
|
| 30 |
+
src/gepa_optimizer/evaluation/index_caching_evaluator.py
|
| 31 |
+
src/gepa_optimizer/evaluation/scroll_evaluator.py
|
| 32 |
+
src/gepa_optimizer/evaluation/ui_evaluator.py
|
| 33 |
+
src/gepa_optimizer/evaluation/universal_evaluator.py
|
| 34 |
+
src/gepa_optimizer/evaluation/validation_evaluator.py
|
| 35 |
+
src/gepa_optimizer/infrastructure/__init__.py
|
| 36 |
+
src/gepa_optimizer/infrastructure/logging/__init__.py
|
| 37 |
+
src/gepa_optimizer/infrastructure/logging/context.py
|
| 38 |
+
src/gepa_optimizer/infrastructure/logging/formatters.py
|
| 39 |
+
src/gepa_optimizer/infrastructure/logging/logger.py
|
| 40 |
+
src/gepa_optimizer/llms/__init__.py
|
| 41 |
+
src/gepa_optimizer/llms/base_llm.py
|
| 42 |
+
src/gepa_optimizer/llms/batch_llm.py
|
| 43 |
+
src/gepa_optimizer/llms/llego_enhanced_llm.py
|
| 44 |
+
src/gepa_optimizer/llms/vision_llm.py
|
| 45 |
+
src/gepa_optimizer/models/__init__.py
|
| 46 |
+
src/gepa_optimizer/models/config.py
|
| 47 |
+
src/gepa_optimizer/models/dataset.py
|
| 48 |
+
src/gepa_optimizer/models/result.py
|
| 49 |
+
src/gepa_optimizer/operators/__init__.py
|
| 50 |
+
src/gepa_optimizer/operators/base_operator.py
|
| 51 |
+
src/gepa_optimizer/operators/crossover.py
|
| 52 |
+
src/gepa_optimizer/operators/llego_operators.py
|
| 53 |
+
src/gepa_optimizer/operators/models.py
|
| 54 |
+
src/gepa_optimizer/operators/mutation.py
|
| 55 |
+
src/gepa_optimizer/utils/__init__.py
|
| 56 |
+
src/gepa_optimizer/utils/api_keys.py
|
| 57 |
+
src/gepa_optimizer/utils/candidate_collector.py
|
| 58 |
+
src/gepa_optimizer/utils/clean_logger.py
|
| 59 |
+
src/gepa_optimizer/utils/exceptions.py
|
| 60 |
+
src/gepa_optimizer/utils/helpers.py
|
| 61 |
+
src/gepa_optimizer/utils/llm_judge_prompt.py
|
| 62 |
+
src/gepa_optimizer/utils/log_parser.py
|
| 63 |
+
src/gepa_optimizer/utils/logging.py
|
| 64 |
+
src/gepa_optimizer/utils/metrics.py
|
| 65 |
+
src/gepa_optimizer/utils/pareto_logger.py
|
src/gepa_optimizer.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/gepa_optimizer.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
gepa-optimize = gepa_optimizer.cli:main
|
src/gepa_optimizer.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gepa>=0.0.12
|
| 2 |
+
pandas>=1.5.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
python-dotenv>=1.0.0
|
| 5 |
+
requests>=2.31.0
|
| 6 |
+
aiohttp>=3.8.0
|
| 7 |
+
asyncio-throttle>=1.0.0
|
| 8 |
+
google-generativeai>=0.3.0
|
| 9 |
+
Pillow>=9.0.0
|
| 10 |
+
|
| 11 |
+
[all]
|
| 12 |
+
pytest>=7.0.0
|
| 13 |
+
pytest-asyncio>=0.21.0
|
| 14 |
+
black>=23.0.0
|
| 15 |
+
flake8>=6.0.0
|
| 16 |
+
mypy>=1.0.0
|
| 17 |
+
sphinx>=5.0.0
|
| 18 |
+
sphinx-rtd-theme>=1.2.0
|
| 19 |
+
|
| 20 |
+
[dev]
|
| 21 |
+
pytest>=7.0.0
|
| 22 |
+
pytest-asyncio>=0.21.0
|
| 23 |
+
black>=23.0.0
|
| 24 |
+
flake8>=6.0.0
|
| 25 |
+
mypy>=1.0.0
|
| 26 |
+
|
| 27 |
+
[docs]
|
| 28 |
+
sphinx>=5.0.0
|
| 29 |
+
sphinx-rtd-theme>=1.2.0
|
src/gepa_optimizer.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gepa_optimizer
|
src/gepa_optimizer/__init__.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GEPA Universal Prompt Optimizer
|
| 3 |
+
|
| 4 |
+
A modern, modular Python library for universal prompt optimization powered by GEPA.
|
| 5 |
+
|
| 6 |
+
Quick Start (No custom evaluator needed!):
|
| 7 |
+
|
| 8 |
+
from gepa_optimizer import quick_optimize
|
| 9 |
+
|
| 10 |
+
result = await quick_optimize(
|
| 11 |
+
seed_prompt="Your initial prompt",
|
| 12 |
+
dataset=[
|
| 13 |
+
{"input": "task1", "output": "expected1"},
|
| 14 |
+
{"input": "task2", "output": "expected2"},
|
| 15 |
+
],
|
| 16 |
+
model="openai/gpt-4o" # or any: "google/gemini-1.5-pro", "anthropic/claude-3-5-sonnet-20241022"
|
| 17 |
+
)
|
| 18 |
+
print(result.optimized_prompt)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Core functionality
|
| 22 |
+
from .core import GepaOptimizer
|
| 23 |
+
from .core.base_adapter import BaseGepaAdapter
|
| 24 |
+
from .core.universal_adapter import UniversalGepaAdapter
|
| 25 |
+
|
| 26 |
+
# Configuration and models
|
| 27 |
+
from .models import OptimizationConfig, OptimizationResult, OptimizedResult, ModelConfig
|
| 28 |
+
|
| 29 |
+
# Data processing
|
| 30 |
+
from .data import UniversalConverter, DataLoader, DataValidator
|
| 31 |
+
from .data.scroll_dataset_loader import ScrollDatasetLoader, load_scroll_dataset
|
| 32 |
+
from .data.validation_dataset_loader import ValidationDatasetLoader, load_validation_dataset, load_validation_split
|
| 33 |
+
from .data.index_caching_loader import IndexCachingDatasetLoader, load_index_caching_dataset, load_index_caching_split
|
| 34 |
+
|
| 35 |
+
# LLM clients
|
| 36 |
+
from .llms import VisionLLMClient
|
| 37 |
+
from .llms.base_llm import BaseLLMClient
|
| 38 |
+
from .llms.batch_llm import BatchLLMClient
|
| 39 |
+
|
| 40 |
+
# Evaluators - including Universal Semantic Evaluator (works for ANY task!)
|
| 41 |
+
from .evaluation import (
|
| 42 |
+
BaseEvaluator,
|
| 43 |
+
UniversalSemanticEvaluator,
|
| 44 |
+
create_universal_evaluator,
|
| 45 |
+
UITreeEvaluator,
|
| 46 |
+
ScrollElementEvaluator,
|
| 47 |
+
ValidationEvaluator,
|
| 48 |
+
IndexCachingEvaluator
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# LLEGO Genetic Operators
|
| 52 |
+
from .operators import (
|
| 53 |
+
# Base interfaces
|
| 54 |
+
BaseGeneticOperator,
|
| 55 |
+
BaseCrossoverOperator,
|
| 56 |
+
BaseMutationOperator,
|
| 57 |
+
# Concrete operators
|
| 58 |
+
FitnessGuidedCrossover,
|
| 59 |
+
DiversityGuidedMutation,
|
| 60 |
+
LLEGOIntegrationLayer,
|
| 61 |
+
# Data models
|
| 62 |
+
PromptCandidate,
|
| 63 |
+
PromptMetadata
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Utilities
|
| 67 |
+
from .utils import setup_logging, calculate_metrics, sanitize_prompt, APIKeyManager
|
| 68 |
+
from .utils.exceptions import GepaOptimizerError, GepaDependencyError, InvalidInputError, DatasetError
|
| 69 |
+
|
| 70 |
+
# Logging infrastructure
|
| 71 |
+
from .infrastructure.logging import get_logger, configure_logging, LogContext
|
| 72 |
+
|
| 73 |
+
# Type definitions (for type hints in user code)
|
| 74 |
+
from .types import (
|
| 75 |
+
DatasetItem,
|
| 76 |
+
EvaluationResult,
|
| 77 |
+
LLMResponse,
|
| 78 |
+
CandidateDict,
|
| 79 |
+
LLMClientProtocol,
|
| 80 |
+
EvaluatorProtocol,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
__version__ = "0.1.0"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 87 |
+
# CONVENIENCE FUNCTION: quick_optimize
|
| 88 |
+
# No evaluator needed - uses Universal Semantic Evaluator automatically
|
| 89 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 90 |
+
|
| 91 |
+
async def quick_optimize(
|
| 92 |
+
seed_prompt: str,
|
| 93 |
+
dataset: list,
|
| 94 |
+
model: str,
|
| 95 |
+
max_iterations: int = 5,
|
| 96 |
+
max_metric_calls: int = 50,
|
| 97 |
+
batch_size: int = 4,
|
| 98 |
+
use_llego: bool = True,
|
| 99 |
+
verbose: bool = True
|
| 100 |
+
) -> OptimizedResult:
|
| 101 |
+
"""
|
| 102 |
+
🚀 Quick prompt optimization - no custom evaluator needed!
|
| 103 |
+
|
| 104 |
+
Uses Universal Semantic Evaluator that works for ANY task.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
seed_prompt: Your initial prompt to optimize
|
| 108 |
+
dataset: List of dicts with 'input' and 'output' (expected) keys
|
| 109 |
+
Can also include 'image' key for multi-modal tasks
|
| 110 |
+
model: LLM model to use in format "provider/model-name" (REQUIRED)
|
| 111 |
+
Examples:
|
| 112 |
+
- "google/gemini-1.5-pro"
|
| 113 |
+
- "google/gemini-2.5-flash-preview-05-20"
|
| 114 |
+
- "openai/gpt-4o"
|
| 115 |
+
- "openai/gpt-4-turbo"
|
| 116 |
+
- "anthropic/claude-3-5-sonnet-20241022"
|
| 117 |
+
max_iterations: Maximum optimization iterations (default: 5)
|
| 118 |
+
max_metric_calls: Maximum evaluation calls (default: 50)
|
| 119 |
+
batch_size: Samples per evaluation batch (default: 4)
|
| 120 |
+
use_llego: Enable LLEGO genetic operators (default: True)
|
| 121 |
+
verbose: Show progress logs (default: True)
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
OptimizedResult with optimized prompt and improvement metrics
|
| 125 |
+
|
| 126 |
+
Example:
|
| 127 |
+
>>> result = await quick_optimize(
|
| 128 |
+
... seed_prompt="Count the objects in the image",
|
| 129 |
+
... dataset=[
|
| 130 |
+
... {"input": "image1.jpg", "output": "5 objects", "image": "base64..."},
|
| 131 |
+
... {"input": "image2.jpg", "output": "3 objects", "image": "base64..."},
|
| 132 |
+
... ],
|
| 133 |
+
... model="openai/gpt-4o", # or "google/gemini-1.5-pro", etc.
|
| 134 |
+
... max_iterations=3
|
| 135 |
+
... )
|
| 136 |
+
>>> print(result.optimized_prompt)
|
| 137 |
+
"""
|
| 138 |
+
import logging
|
| 139 |
+
|
| 140 |
+
if verbose:
|
| 141 |
+
logging.basicConfig(level=logging.INFO)
|
| 142 |
+
|
| 143 |
+
# Create LLM client
|
| 144 |
+
llm_client = VisionLLMClient.from_model_string(model)
|
| 145 |
+
|
| 146 |
+
# Create Universal Semantic Evaluator (uses same LLM for analysis)
|
| 147 |
+
evaluator = UniversalSemanticEvaluator(
|
| 148 |
+
llm_client=llm_client,
|
| 149 |
+
use_llm_analysis=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Create configuration
|
| 153 |
+
config = OptimizationConfig(
|
| 154 |
+
model=model,
|
| 155 |
+
reflection_model=model,
|
| 156 |
+
max_iterations=max_iterations,
|
| 157 |
+
max_metric_calls=max_metric_calls,
|
| 158 |
+
batch_size=batch_size,
|
| 159 |
+
use_llego_operators=use_llego,
|
| 160 |
+
enable_gepa_reflection_with_llego=use_llego,
|
| 161 |
+
num_gepa_reflection_candidates=3,
|
| 162 |
+
n_crossover=2,
|
| 163 |
+
n_mutation=2,
|
| 164 |
+
verbose=verbose
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Create optimizer
|
| 168 |
+
optimizer = GepaOptimizer(
|
| 169 |
+
config=config,
|
| 170 |
+
llm_client=llm_client,
|
| 171 |
+
evaluator=evaluator
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Run optimization
|
| 175 |
+
result = await optimizer.train(
|
| 176 |
+
seed_prompt=seed_prompt,
|
| 177 |
+
dataset=dataset
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return result
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def quick_optimize_sync(
|
| 184 |
+
seed_prompt: str,
|
| 185 |
+
dataset: list,
|
| 186 |
+
model: str,
|
| 187 |
+
max_iterations: int = 5,
|
| 188 |
+
max_metric_calls: int = 50,
|
| 189 |
+
batch_size: int = 4,
|
| 190 |
+
use_llego: bool = True,
|
| 191 |
+
verbose: bool = True
|
| 192 |
+
) -> OptimizedResult:
|
| 193 |
+
"""
|
| 194 |
+
🚀 Synchronous version of quick_optimize.
|
| 195 |
+
|
| 196 |
+
Same as quick_optimize but runs synchronously (blocks until complete).
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
model: LLM model to use in format "provider/model-name" (REQUIRED)
|
| 200 |
+
Examples: "openai/gpt-4o", "google/gemini-1.5-pro", "anthropic/claude-3-5-sonnet-20241022"
|
| 201 |
+
|
| 202 |
+
See quick_optimize for full documentation.
|
| 203 |
+
"""
|
| 204 |
+
import asyncio
|
| 205 |
+
return asyncio.run(quick_optimize(
|
| 206 |
+
seed_prompt=seed_prompt,
|
| 207 |
+
dataset=dataset,
|
| 208 |
+
model=model,
|
| 209 |
+
max_iterations=max_iterations,
|
| 210 |
+
max_metric_calls=max_metric_calls,
|
| 211 |
+
batch_size=batch_size,
|
| 212 |
+
use_llego=use_llego,
|
| 213 |
+
verbose=verbose
|
| 214 |
+
))
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
__all__ = [
|
| 218 |
+
# 🚀 Quick Start (recommended for new users)
|
| 219 |
+
"quick_optimize",
|
| 220 |
+
"quick_optimize_sync",
|
| 221 |
+
|
| 222 |
+
# Core functionality
|
| 223 |
+
"GepaOptimizer",
|
| 224 |
+
"BaseGepaAdapter",
|
| 225 |
+
"UniversalGepaAdapter",
|
| 226 |
+
|
| 227 |
+
# Configuration
|
| 228 |
+
"OptimizationConfig",
|
| 229 |
+
"OptimizationResult",
|
| 230 |
+
"OptimizedResult",
|
| 231 |
+
"ModelConfig",
|
| 232 |
+
|
| 233 |
+
# Data processing
|
| 234 |
+
"UniversalConverter",
|
| 235 |
+
"DataLoader",
|
| 236 |
+
"DataValidator",
|
| 237 |
+
|
| 238 |
+
# Dataset loaders
|
| 239 |
+
"ScrollDatasetLoader",
|
| 240 |
+
"load_scroll_dataset",
|
| 241 |
+
"ValidationDatasetLoader",
|
| 242 |
+
"load_validation_dataset",
|
| 243 |
+
"load_validation_split",
|
| 244 |
+
"IndexCachingDatasetLoader",
|
| 245 |
+
"load_index_caching_dataset",
|
| 246 |
+
"load_index_caching_split",
|
| 247 |
+
|
| 248 |
+
# LLM clients
|
| 249 |
+
"VisionLLMClient",
|
| 250 |
+
"BaseLLMClient",
|
| 251 |
+
"BatchLLMClient",
|
| 252 |
+
|
| 253 |
+
# Evaluators (Universal recommended for general use)
|
| 254 |
+
"UniversalSemanticEvaluator",
|
| 255 |
+
"create_universal_evaluator",
|
| 256 |
+
"BaseEvaluator",
|
| 257 |
+
"UITreeEvaluator",
|
| 258 |
+
"ScrollElementEvaluator",
|
| 259 |
+
"ValidationEvaluator",
|
| 260 |
+
"IndexCachingEvaluator",
|
| 261 |
+
|
| 262 |
+
# LLEGO Genetic Operators - Base interfaces
|
| 263 |
+
"BaseGeneticOperator",
|
| 264 |
+
"BaseCrossoverOperator",
|
| 265 |
+
"BaseMutationOperator",
|
| 266 |
+
# LLEGO Genetic Operators - Concrete implementations
|
| 267 |
+
"FitnessGuidedCrossover",
|
| 268 |
+
"DiversityGuidedMutation",
|
| 269 |
+
"LLEGOIntegrationLayer",
|
| 270 |
+
"PromptCandidate",
|
| 271 |
+
"PromptMetadata",
|
| 272 |
+
|
| 273 |
+
# Utilities
|
| 274 |
+
"APIKeyManager",
|
| 275 |
+
"GepaOptimizerError",
|
| 276 |
+
"GepaDependencyError",
|
| 277 |
+
"InvalidInputError",
|
| 278 |
+
"DatasetError",
|
| 279 |
+
"setup_logging",
|
| 280 |
+
"calculate_metrics",
|
| 281 |
+
"sanitize_prompt",
|
| 282 |
+
|
| 283 |
+
# Logging infrastructure
|
| 284 |
+
"get_logger",
|
| 285 |
+
"configure_logging",
|
| 286 |
+
"LogContext",
|
| 287 |
+
|
| 288 |
+
# Type definitions
|
| 289 |
+
"DatasetItem",
|
| 290 |
+
"EvaluationResult",
|
| 291 |
+
"LLMResponse",
|
| 292 |
+
"CandidateDict",
|
| 293 |
+
"LLMClientProtocol",
|
| 294 |
+
"EvaluatorProtocol",
|
| 295 |
+
]
|
src/gepa_optimizer/cli.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Command Line Interface for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import asyncio
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from .core import GepaOptimizer
|
| 13 |
+
from .models import OptimizationConfig, ModelConfig
|
| 14 |
+
from .utils import setup_logging, APIKeyManager
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
"""Main CLI entry point"""
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
description="GEPA Universal Prompt Optimizer CLI",
|
| 21 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 22 |
+
epilog="""
|
| 23 |
+
Examples:
|
| 24 |
+
gepa-optimize --model openai/gpt-4-turbo --prompt "Extract UI elements" --dataset data.json
|
| 25 |
+
gepa-optimize --config config.json --prompt "Analyze interface" --dataset images/
|
| 26 |
+
"""
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Required arguments
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--prompt",
|
| 32 |
+
required=True,
|
| 33 |
+
help="Initial seed prompt to optimize"
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--dataset",
|
| 37 |
+
required=True,
|
| 38 |
+
help="Path to dataset file or directory"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Model configuration
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--model",
|
| 44 |
+
help="Model specification (e.g., 'openai/gpt-4-turbo')"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--reflection-model",
|
| 48 |
+
help="Reflection model specification"
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--config",
|
| 52 |
+
help="Path to configuration JSON file"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Optimization parameters
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--max-iterations",
|
| 58 |
+
type=int,
|
| 59 |
+
default=10,
|
| 60 |
+
help="Maximum optimization iterations (default: 10)"
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--max-metric-calls",
|
| 64 |
+
type=int,
|
| 65 |
+
default=100,
|
| 66 |
+
help="Maximum metric evaluation calls (default: 100)"
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--batch-size",
|
| 70 |
+
type=int,
|
| 71 |
+
default=4,
|
| 72 |
+
help="Batch size for evaluation (default: 4)"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# GEPA-specific parameters
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--candidate-selection-strategy",
|
| 78 |
+
type=str,
|
| 79 |
+
default="pareto",
|
| 80 |
+
choices=["pareto", "best"],
|
| 81 |
+
help="Strategy for selecting candidates (default: pareto)"
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--skip-perfect-score",
|
| 85 |
+
action="store_true",
|
| 86 |
+
help="Skip updating candidates with perfect scores"
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--reflection-minibatch-size",
|
| 90 |
+
type=int,
|
| 91 |
+
default=None,
|
| 92 |
+
help="Number of examples to use for reflection (default: use batch_size)"
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--perfect-score",
|
| 96 |
+
type=float,
|
| 97 |
+
default=1.0,
|
| 98 |
+
help="Perfect score threshold (default: 1.0)"
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--module-selector",
|
| 102 |
+
type=str,
|
| 103 |
+
default="round_robin",
|
| 104 |
+
choices=["round_robin", "all"],
|
| 105 |
+
help="Component selection strategy (default: round_robin)"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Output options
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--output",
|
| 111 |
+
help="Output file path for results (default: stdout)"
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--verbose", "-v",
|
| 115 |
+
action="store_true",
|
| 116 |
+
help="Enable verbose logging"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
args = parser.parse_args()
|
| 120 |
+
|
| 121 |
+
# Setup logging
|
| 122 |
+
setup_logging(level="DEBUG" if args.verbose else "INFO")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
# Load configuration
|
| 126 |
+
if args.config:
|
| 127 |
+
config = load_config_from_file(args.config)
|
| 128 |
+
else:
|
| 129 |
+
config = create_config_from_args(args)
|
| 130 |
+
|
| 131 |
+
# Validate API keys
|
| 132 |
+
validate_api_keys(config)
|
| 133 |
+
|
| 134 |
+
# Create optimizer
|
| 135 |
+
optimizer = GepaOptimizer(config=config)
|
| 136 |
+
|
| 137 |
+
# Run optimization (async)
|
| 138 |
+
print(f"🚀 Starting optimization with model: {config.model.model_name}")
|
| 139 |
+
result = asyncio.run(optimizer.train(
|
| 140 |
+
seed_prompt=args.prompt,
|
| 141 |
+
dataset=args.dataset
|
| 142 |
+
))
|
| 143 |
+
|
| 144 |
+
# Output results
|
| 145 |
+
output_results(result, args.output)
|
| 146 |
+
|
| 147 |
+
print("✅ Optimization completed successfully!")
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"❌ Error: {str(e)}", file=sys.stderr)
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def load_config_from_file(config_path: str) -> OptimizationConfig:
|
| 155 |
+
"""Load configuration from JSON file"""
|
| 156 |
+
path = Path(config_path)
|
| 157 |
+
if not path.exists():
|
| 158 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 159 |
+
|
| 160 |
+
with open(path, 'r') as f:
|
| 161 |
+
config_data = json.load(f)
|
| 162 |
+
|
| 163 |
+
# Convert model configs
|
| 164 |
+
if 'model' in config_data and isinstance(config_data['model'], dict):
|
| 165 |
+
config_data['model'] = ModelConfig(**config_data['model'])
|
| 166 |
+
|
| 167 |
+
if 'reflection_model' in config_data and isinstance(config_data['reflection_model'], dict):
|
| 168 |
+
config_data['reflection_model'] = ModelConfig(**config_data['reflection_model'])
|
| 169 |
+
|
| 170 |
+
return OptimizationConfig(**config_data)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_config_from_args(args) -> OptimizationConfig:
|
| 174 |
+
"""Create configuration from command line arguments"""
|
| 175 |
+
if not args.model:
|
| 176 |
+
raise ValueError("Either --model or --config must be specified")
|
| 177 |
+
|
| 178 |
+
# Parse model specification
|
| 179 |
+
model_config = ModelConfig.from_string(args.model)
|
| 180 |
+
|
| 181 |
+
reflection_model_config = None
|
| 182 |
+
if args.reflection_model:
|
| 183 |
+
reflection_model_config = ModelConfig.from_string(args.reflection_model)
|
| 184 |
+
|
| 185 |
+
return OptimizationConfig(
|
| 186 |
+
model=model_config,
|
| 187 |
+
reflection_model=reflection_model_config,
|
| 188 |
+
max_iterations=args.max_iterations,
|
| 189 |
+
max_metric_calls=args.max_metric_calls,
|
| 190 |
+
batch_size=args.batch_size
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def validate_api_keys(config: OptimizationConfig):
|
| 195 |
+
"""Validate that required API keys are available"""
|
| 196 |
+
api_manager = APIKeyManager()
|
| 197 |
+
|
| 198 |
+
providers = [config.model.provider]
|
| 199 |
+
if config.reflection_model:
|
| 200 |
+
providers.append(config.reflection_model.provider)
|
| 201 |
+
|
| 202 |
+
missing_keys = api_manager.get_missing_keys(providers)
|
| 203 |
+
|
| 204 |
+
if missing_keys:
|
| 205 |
+
print("❌ Missing API keys for the following providers:")
|
| 206 |
+
for provider in missing_keys:
|
| 207 |
+
print(f" - {provider.upper()}_API_KEY")
|
| 208 |
+
print("\nPlease set the required environment variables or use a .env file")
|
| 209 |
+
sys.exit(1)
|
| 210 |
+
|
| 211 |
+
def output_results(result, output_path: Optional[str]):
|
| 212 |
+
"""Output optimization results"""
|
| 213 |
+
output_data = {
|
| 214 |
+
"optimized_prompt": result.prompt,
|
| 215 |
+
"original_prompt": result.original_prompt,
|
| 216 |
+
"improvement_metrics": result.improvement_data,
|
| 217 |
+
"optimization_time": result.optimization_time,
|
| 218 |
+
"status": result.status,
|
| 219 |
+
"session_id": result.session_id
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if output_path:
|
| 223 |
+
with open(output_path, 'w') as f:
|
| 224 |
+
json.dump(output_data, f, indent=2)
|
| 225 |
+
print(f"📄 Results saved to: {output_path}")
|
| 226 |
+
else:
|
| 227 |
+
print("\n📊 Optimization Results:")
|
| 228 |
+
print(f"Session ID: {result.session_id}")
|
| 229 |
+
print(f"Status: {result.status}")
|
| 230 |
+
print(f"Time: {result.optimization_time:.2f}s")
|
| 231 |
+
print(f"\nOriginal Prompt:\n{result.original_prompt}")
|
| 232 |
+
print(f"\nOptimized Prompt:\n{result.prompt}")
|
| 233 |
+
|
| 234 |
+
if 'improvement_percent' in result.improvement_data:
|
| 235 |
+
print(f"\nImprovement: {result.improvement_data['improvement_percent']:.2f}%")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
main()
|
src/gepa_optimizer/core/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core functionality for GEPA Universal Prompt Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .optimizer import GepaOptimizer
|
| 6 |
+
from .result import ResultProcessor
|
| 7 |
+
|
| 8 |
+
__all__ = ["GepaOptimizer", "ResultProcessor"]
|
src/gepa_optimizer/core/base_adapter.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base adapter class for all GEPA adapters.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
import logging
|
| 8 |
+
from gepa.core.adapter import GEPAAdapter, EvaluationBatch
|
| 9 |
+
|
| 10 |
+
from ..llms.base_llm import BaseLLMClient
|
| 11 |
+
from ..evaluation.base_evaluator import BaseEvaluator
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class BaseGepaAdapter(GEPAAdapter, ABC):
|
| 16 |
+
"""
|
| 17 |
+
Abstract base class for GEPA adapters.
|
| 18 |
+
|
| 19 |
+
Provides the foundation for creating task-specific adapters while
|
| 20 |
+
maintaining compatibility with the GEPA framework.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, llm_client: BaseLLMClient, evaluator: BaseEvaluator):
|
| 24 |
+
"""
|
| 25 |
+
Initialize adapter with LLM client and evaluator.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
llm_client: LLM client for generating responses
|
| 29 |
+
evaluator: Evaluator for scoring predictions
|
| 30 |
+
"""
|
| 31 |
+
if not isinstance(llm_client, BaseLLMClient):
|
| 32 |
+
raise TypeError("llm_client must be an instance of BaseLLMClient")
|
| 33 |
+
if not isinstance(evaluator, BaseEvaluator):
|
| 34 |
+
raise TypeError("evaluator must be an instance of BaseEvaluator")
|
| 35 |
+
|
| 36 |
+
self.llm_client = llm_client
|
| 37 |
+
self.evaluator = evaluator
|
| 38 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 39 |
+
|
| 40 |
+
# Performance tracking
|
| 41 |
+
self._evaluation_count = 0
|
| 42 |
+
self._best_score = 0.0
|
| 43 |
+
self._best_candidate = None
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def evaluate(self, batch: List[Dict[str, Any]], candidate: Dict[str, str],
|
| 47 |
+
capture_traces: bool = False) -> EvaluationBatch:
|
| 48 |
+
"""
|
| 49 |
+
Evaluate candidate on a batch of data.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
batch: List of data items to evaluate
|
| 53 |
+
candidate: Prompt candidate to evaluate
|
| 54 |
+
capture_traces: Whether to capture detailed traces
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
EvaluationBatch with outputs, scores, and optional trajectories
|
| 58 |
+
"""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def make_reflective_dataset(self, candidate: Dict[str, str],
|
| 63 |
+
eval_batch: EvaluationBatch,
|
| 64 |
+
components_to_update: List[str]) -> Dict[str, List[Dict[str, Any]]]:
|
| 65 |
+
"""
|
| 66 |
+
Create reflective dataset for GEPA's reflection process.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
candidate: Current prompt candidate
|
| 70 |
+
eval_batch: Results from evaluation
|
| 71 |
+
components_to_update: List of components to update
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Dictionary mapping components to reflection data
|
| 75 |
+
"""
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def get_performance_stats(self) -> Dict[str, Any]:
|
| 79 |
+
"""Get performance statistics for monitoring"""
|
| 80 |
+
return {
|
| 81 |
+
'evaluation_count': self._evaluation_count,
|
| 82 |
+
'best_score': self._best_score,
|
| 83 |
+
'model_info': self.llm_client.get_model_info(),
|
| 84 |
+
'evaluator_class': self.evaluator.__class__.__name__
|
| 85 |
+
}
|
src/gepa_optimizer/core/custom_adapter.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom GEPA Adapter for the GEPA Universal Prompt Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
# Import ModelConfig
|
| 11 |
+
from ..models import ModelConfig
|
| 12 |
+
|
| 13 |
+
from gepa.core.adapter import GEPAAdapter, EvaluationBatch
|
| 14 |
+
from ..llms.vision_llm import VisionLLMClient
|
| 15 |
+
from ..evaluation.ui_evaluator import UITreeEvaluator
|
| 16 |
+
from .base_adapter import BaseGepaAdapter
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class CustomGepaAdapter(BaseGepaAdapter):
|
| 21 |
+
"""
|
| 22 |
+
Custom adapter for the GEPA Universal Prompt Optimizer.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, model_config: 'ModelConfig', metric_weights: Optional[Dict[str, float]] = None):
|
| 26 |
+
"""Initialize the custom GEPA adapter with model configuration."""
|
| 27 |
+
# Convert string model to ModelConfig if needed
|
| 28 |
+
if not isinstance(model_config, ModelConfig):
|
| 29 |
+
model_config = ModelConfig(
|
| 30 |
+
provider='openai',
|
| 31 |
+
model_name=str(model_config),
|
| 32 |
+
api_key=None
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Initialize components
|
| 36 |
+
llm_client = VisionLLMClient(
|
| 37 |
+
provider=model_config.provider,
|
| 38 |
+
model_name=model_config.model_name,
|
| 39 |
+
api_key=model_config.api_key,
|
| 40 |
+
base_url=model_config.base_url,
|
| 41 |
+
temperature=model_config.temperature,
|
| 42 |
+
max_tokens=model_config.max_tokens,
|
| 43 |
+
top_p=model_config.top_p,
|
| 44 |
+
frequency_penalty=model_config.frequency_penalty,
|
| 45 |
+
presence_penalty=model_config.presence_penalty
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
evaluator = UITreeEvaluator(metric_weights=metric_weights)
|
| 49 |
+
|
| 50 |
+
# Initialize parent class
|
| 51 |
+
super().__init__(llm_client, evaluator)
|
| 52 |
+
|
| 53 |
+
# Track candidates for logging
|
| 54 |
+
self._last_candidate = None
|
| 55 |
+
self._evaluation_count = 0
|
| 56 |
+
|
| 57 |
+
self.logger.info(f"🚀 Initialized UI Tree adapter with {model_config.provider}/{model_config.model_name}")
|
| 58 |
+
|
| 59 |
+
def _parse_json_safely(self, json_str: str) -> Dict[str, Any]:
|
| 60 |
+
"""Safely parse JSON string to dictionary with enhanced parsing and repair."""
|
| 61 |
+
if not json_str or not isinstance(json_str, str):
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
# Try direct parsing first
|
| 65 |
+
try:
|
| 66 |
+
return json.loads(json_str)
|
| 67 |
+
except json.JSONDecodeError:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
# Try to extract JSON from markdown code blocks
|
| 71 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', json_str, re.DOTALL)
|
| 72 |
+
if json_match:
|
| 73 |
+
try:
|
| 74 |
+
return json.loads(json_match.group(1))
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
# Try to find JSON object in the string
|
| 79 |
+
json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
|
| 80 |
+
if json_match:
|
| 81 |
+
try:
|
| 82 |
+
return json.loads(json_match.group(0))
|
| 83 |
+
except json.JSONDecodeError:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
# Try repair and parse
|
| 87 |
+
repaired_json = self._repair_json(json_str)
|
| 88 |
+
if repaired_json:
|
| 89 |
+
try:
|
| 90 |
+
return json.loads(repaired_json)
|
| 91 |
+
except json.JSONDecodeError:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
self.logger.warning(f"Failed to parse JSON: {json_str[:100]}...")
|
| 95 |
+
return {}
|
| 96 |
+
|
| 97 |
+
def _repair_json(self, json_str: str) -> str:
|
| 98 |
+
"""Attempt to repair common JSON issues."""
|
| 99 |
+
try:
|
| 100 |
+
# Remove markdown formatting
|
| 101 |
+
json_str = re.sub(r'```(?:json)?\s*', '', json_str)
|
| 102 |
+
json_str = re.sub(r'```\s*$', '', json_str)
|
| 103 |
+
|
| 104 |
+
# Remove extra text before/after JSON
|
| 105 |
+
json_match = re.search(r'\{.*\}', json_str, re.DOTALL)
|
| 106 |
+
if json_match:
|
| 107 |
+
json_str = json_match.group(0)
|
| 108 |
+
|
| 109 |
+
# Fix common issues
|
| 110 |
+
json_str = re.sub(r',\s*}', '}', json_str) # Remove trailing commas
|
| 111 |
+
json_str = re.sub(r',\s*]', ']', json_str) # Remove trailing commas in arrays
|
| 112 |
+
json_str = re.sub(r'([{,]\s*)(\w+):', r'\1"\2":', json_str) # Quote unquoted keys
|
| 113 |
+
|
| 114 |
+
return json_str
|
| 115 |
+
except Exception as e:
|
| 116 |
+
self.logger.warning(f"🔧 JSON repair failed: {e}")
|
| 117 |
+
return ""
|
| 118 |
+
|
| 119 |
+
def evaluate(
|
| 120 |
+
self,
|
| 121 |
+
batch: List[Dict[str, Any]],
|
| 122 |
+
candidate: Dict[str, str],
|
| 123 |
+
capture_traces: bool = False,
|
| 124 |
+
) -> EvaluationBatch:
|
| 125 |
+
"""Evaluate the candidate on a batch of data."""
|
| 126 |
+
outputs = []
|
| 127 |
+
scores = []
|
| 128 |
+
trajectories = [] if capture_traces else None
|
| 129 |
+
|
| 130 |
+
system_prompt = candidate.get('system_prompt', '')
|
| 131 |
+
|
| 132 |
+
# Check if this is a new candidate (different from last one)
|
| 133 |
+
if self._last_candidate != system_prompt:
|
| 134 |
+
self._evaluation_count += 1
|
| 135 |
+
self.log_proposed_candidate(candidate, self._evaluation_count)
|
| 136 |
+
self._last_candidate = system_prompt
|
| 137 |
+
|
| 138 |
+
self.logger.info(f"📊 Evaluating {len(batch)} samples with prompt: '{system_prompt[:50]}...'")
|
| 139 |
+
|
| 140 |
+
for i, item in enumerate(batch):
|
| 141 |
+
input_text = item.get('input', '')
|
| 142 |
+
image_base64 = item.get('image', '')
|
| 143 |
+
ground_truth_json = item.get('output', '')
|
| 144 |
+
|
| 145 |
+
# Call the LLM client
|
| 146 |
+
llm_response = self.llm_client.generate(system_prompt, input_text, image_base64=image_base64)
|
| 147 |
+
|
| 148 |
+
# Extract content from the response dictionary
|
| 149 |
+
if isinstance(llm_response, dict):
|
| 150 |
+
llm_output_json_str = llm_response.get("content", "")
|
| 151 |
+
if not llm_output_json_str:
|
| 152 |
+
llm_output_json_str = str(llm_response)
|
| 153 |
+
else:
|
| 154 |
+
llm_output_json_str = str(llm_response) if llm_response else ""
|
| 155 |
+
|
| 156 |
+
# 🔍 DEBUG: Log essential info only (removed verbose JSON content)
|
| 157 |
+
self.logger.debug(f"🔍 Sample {i+1} - LLM Response Type: {type(llm_response)}")
|
| 158 |
+
self.logger.debug(f"🔍 Sample {i+1} - Response Length: {len(llm_output_json_str)} chars")
|
| 159 |
+
|
| 160 |
+
outputs.append(llm_output_json_str)
|
| 161 |
+
|
| 162 |
+
# Parse JSON strings to dictionaries for evaluation
|
| 163 |
+
llm_output_dict = self._parse_json_safely(llm_output_json_str)
|
| 164 |
+
ground_truth_dict = self._parse_json_safely(ground_truth_json)
|
| 165 |
+
|
| 166 |
+
# Initialize evaluation_results with default values
|
| 167 |
+
evaluation_results = {
|
| 168 |
+
"composite_score": 0.0,
|
| 169 |
+
"element_completeness": 0.0,
|
| 170 |
+
"element_type_accuracy": 0.0,
|
| 171 |
+
"text_content_accuracy": 0.0,
|
| 172 |
+
"hierarchy_accuracy": 0.0,
|
| 173 |
+
"style_accuracy": 0.0
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# Calculate composite score and evaluation results
|
| 177 |
+
if not llm_output_dict and not ground_truth_dict:
|
| 178 |
+
composite_score = 0.1
|
| 179 |
+
evaluation_results = {k: 0.1 for k in evaluation_results.keys()}
|
| 180 |
+
self.logger.warning(f"⚠️ Sample {i+1}: Empty results - using default score: {composite_score}")
|
| 181 |
+
elif not llm_output_dict or not ground_truth_dict:
|
| 182 |
+
composite_score = 0.05
|
| 183 |
+
evaluation_results = {k: 0.05 for k in evaluation_results.keys()}
|
| 184 |
+
self.logger.warning(f"⚠️ Sample {i+1}: Incomplete results - using low score: {composite_score}")
|
| 185 |
+
else:
|
| 186 |
+
# Calculate score using evaluator with parsed dictionaries
|
| 187 |
+
evaluation_results = self.evaluator.evaluate(llm_output_dict, ground_truth_dict)
|
| 188 |
+
composite_score = evaluation_results["composite_score"]
|
| 189 |
+
|
| 190 |
+
# Clean, readable logging (removed verbose JSON dumps)
|
| 191 |
+
llm_children = len(llm_output_dict.get('children', []))
|
| 192 |
+
gt_children = len(ground_truth_dict.get('children', []))
|
| 193 |
+
|
| 194 |
+
if composite_score < 0.1:
|
| 195 |
+
self.logger.warning(f"⚠️ Sample {i+1}: Low score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
|
| 196 |
+
self.logger.debug(f" Score breakdown: {evaluation_results}")
|
| 197 |
+
else:
|
| 198 |
+
self.logger.info(f"✅ Sample {i+1}: Score {composite_score:.4f} - LLM: {llm_children} elements, GT: {gt_children} elements")
|
| 199 |
+
|
| 200 |
+
scores.append(composite_score)
|
| 201 |
+
|
| 202 |
+
if capture_traces:
|
| 203 |
+
trajectories.append({
|
| 204 |
+
'input_text': input_text,
|
| 205 |
+
'image_base64': image_base64,
|
| 206 |
+
'ground_truth_json': ground_truth_json,
|
| 207 |
+
'llm_output_json': llm_output_json_str,
|
| 208 |
+
'evaluation_results': evaluation_results
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
| 212 |
+
|
| 213 |
+
# Update performance tracking (handled by parent class)
|
| 214 |
+
if avg_score > self._best_score:
|
| 215 |
+
self._best_score = avg_score
|
| 216 |
+
self._best_candidate = candidate.copy()
|
| 217 |
+
self.logger.info(f"🎯 New best candidate found with score: {avg_score:.4f}")
|
| 218 |
+
|
| 219 |
+
self.logger.info(f"📈 Batch evaluation complete - Average score: {avg_score:.4f}")
|
| 220 |
+
|
| 221 |
+
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
|
| 222 |
+
|
| 223 |
+
def make_reflective_dataset(
|
| 224 |
+
self,
|
| 225 |
+
candidate: Dict[str, str],
|
| 226 |
+
eval_batch: EvaluationBatch,
|
| 227 |
+
components_to_update: List[str],
|
| 228 |
+
) -> Dict[str, List[Dict[str, Any]]]:
|
| 229 |
+
"""Create a reflective dataset from the evaluation results."""
|
| 230 |
+
reflective_dataset = {}
|
| 231 |
+
system_prompt = candidate.get('system_prompt', '')
|
| 232 |
+
|
| 233 |
+
# 🎯 NEW: Log the proposed new prompt being evaluated
|
| 234 |
+
self.logger.info(f"📝 Creating reflection dataset for prompt: '{system_prompt[:100]}...'")
|
| 235 |
+
|
| 236 |
+
# Pretty print reflection dataset creation
|
| 237 |
+
self._log_reflection_dataset_creation(candidate, eval_batch, components_to_update)
|
| 238 |
+
|
| 239 |
+
for component in components_to_update:
|
| 240 |
+
reflective_dataset[component] = []
|
| 241 |
+
for i, trace in enumerate(eval_batch.trajectories):
|
| 242 |
+
feedback = self._generate_feedback(trace['evaluation_results'])
|
| 243 |
+
reflective_dataset[component].append({
|
| 244 |
+
"current_prompt": system_prompt,
|
| 245 |
+
"input_text": trace['input_text'],
|
| 246 |
+
"image_base64": trace['image_base64'],
|
| 247 |
+
"generated_json": trace['llm_output_json'],
|
| 248 |
+
"ground_truth_json": trace['ground_truth_json'],
|
| 249 |
+
"score": trace['evaluation_results']["composite_score"],
|
| 250 |
+
"feedback": feedback,
|
| 251 |
+
"detailed_scores": trace['evaluation_results']
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
# 🎯 NEW: Log reflection dataset summary
|
| 255 |
+
total_samples = sum(len(data) for data in reflective_dataset.values())
|
| 256 |
+
avg_score = sum(trace['score'] for data in reflective_dataset.values() for trace in data) / total_samples if total_samples > 0 else 0.0
|
| 257 |
+
self.logger.info(f"📝 Reflection dataset created - {total_samples} samples, avg score: {avg_score:.4f}")
|
| 258 |
+
|
| 259 |
+
return reflective_dataset
|
| 260 |
+
|
| 261 |
+
def _generate_feedback(self, evaluation_results: Dict[str, float]) -> str:
|
| 262 |
+
"""Generate textual feedback based on evaluation results."""
|
| 263 |
+
composite_score = evaluation_results.get("composite_score", 0.0)
|
| 264 |
+
|
| 265 |
+
feedback_parts = []
|
| 266 |
+
|
| 267 |
+
# Overall quality assessment
|
| 268 |
+
if composite_score >= 0.8:
|
| 269 |
+
feedback_parts.append("The overall quality is good.")
|
| 270 |
+
elif composite_score >= 0.5:
|
| 271 |
+
feedback_parts.append("The overall quality is moderate.")
|
| 272 |
+
else:
|
| 273 |
+
feedback_parts.append("The overall quality is low. Focus on fundamental accuracy.")
|
| 274 |
+
|
| 275 |
+
# Specific metric feedback
|
| 276 |
+
if evaluation_results.get("element_completeness", 0.0) < 0.7:
|
| 277 |
+
feedback_parts.append("Element completeness is low. Ensure all UI elements are captured.")
|
| 278 |
+
|
| 279 |
+
if evaluation_results.get("element_type_accuracy", 0.0) < 0.7:
|
| 280 |
+
feedback_parts.append("Element type accuracy is low. Verify correct UI element identification (Button, Text, Image, etc.).")
|
| 281 |
+
|
| 282 |
+
if evaluation_results.get("text_content_accuracy", 0.0) < 0.7:
|
| 283 |
+
feedback_parts.append("Text content accuracy is low. Improve text extraction fidelity.")
|
| 284 |
+
|
| 285 |
+
if evaluation_results.get("hierarchy_accuracy", 0.0) < 0.7:
|
| 286 |
+
feedback_parts.append("Hierarchy accuracy is low. Ensure correct parent-child relationships.")
|
| 287 |
+
|
| 288 |
+
if evaluation_results.get("style_accuracy", 0.0) < 0.7:
|
| 289 |
+
feedback_parts.append("Style accuracy is low. Capture more styling properties (colors, sizes, positioning).")
|
| 290 |
+
|
| 291 |
+
return " ".join(feedback_parts)
|
| 292 |
+
|
| 293 |
+
def get_best_candidate(self) -> Optional[Dict[str, str]]:
|
| 294 |
+
"""Get the best candidate found so far."""
|
| 295 |
+
return self._best_candidate
|
| 296 |
+
|
| 297 |
+
def get_best_score(self) -> float:
|
| 298 |
+
"""Get the best score found so far."""
|
| 299 |
+
return self._best_score
|
| 300 |
+
|
| 301 |
+
def log_proposed_candidate(self, candidate: Dict[str, str], iteration: int = 0):
|
| 302 |
+
"""
|
| 303 |
+
Log the new proposed candidate prompt.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
candidate: The new candidate prompt from GEPA
|
| 307 |
+
iteration: Current optimization iteration
|
| 308 |
+
"""
|
| 309 |
+
system_prompt = candidate.get('system_prompt', '')
|
| 310 |
+
|
| 311 |
+
logger.info("="*80)
|
| 312 |
+
logger.info(f"NEW PROPOSED CANDIDATE (Iteration {iteration})")
|
| 313 |
+
logger.info("="*80)
|
| 314 |
+
logger.info(f"PROPOSED PROMPT:")
|
| 315 |
+
logger.info("-" * 40)
|
| 316 |
+
logger.debug(f'"{system_prompt}"')
|
| 317 |
+
logger.info("-" * 40)
|
| 318 |
+
logger.info(f"Prompt Length: {len(system_prompt)} characters")
|
| 319 |
+
logger.info(f"Word Count: {len(system_prompt.split())} words")
|
| 320 |
+
logger.info("="*80)
|
| 321 |
+
|
| 322 |
+
def _log_reflection_dataset_creation(self, candidate: Dict[str, str], eval_batch: EvaluationBatch,
|
| 323 |
+
components_to_update: List[str]):
|
| 324 |
+
"""
|
| 325 |
+
Log the reflection dataset creation process.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
candidate: Current candidate being evaluated
|
| 329 |
+
eval_batch: Evaluation results
|
| 330 |
+
components_to_update: Components being updated
|
| 331 |
+
"""
|
| 332 |
+
system_prompt = candidate.get('system_prompt', '')
|
| 333 |
+
|
| 334 |
+
logger.info("="*80)
|
| 335 |
+
logger.info("REFLECTION DATASET CREATION")
|
| 336 |
+
logger.info("="*80)
|
| 337 |
+
|
| 338 |
+
logger.info(f"CURRENT PROMPT BEING ANALYZED:")
|
| 339 |
+
logger.info("-" * 40)
|
| 340 |
+
logger.debug(f'"{system_prompt}"')
|
| 341 |
+
logger.info("-" * 40)
|
| 342 |
+
|
| 343 |
+
logger.info(f"EVALUATION SUMMARY:")
|
| 344 |
+
logger.info("-" * 40)
|
| 345 |
+
if eval_batch.scores:
|
| 346 |
+
avg_score = sum(eval_batch.scores) / len(eval_batch.scores)
|
| 347 |
+
min_score = min(eval_batch.scores)
|
| 348 |
+
max_score = max(eval_batch.scores)
|
| 349 |
+
logger.info(f" Average Score: {avg_score:.4f}")
|
| 350 |
+
logger.info(f" Min Score: {min_score:.4f}")
|
| 351 |
+
logger.info(f" Max Score: {max_score:.4f}")
|
| 352 |
+
logger.info(f" Total Samples: {len(eval_batch.scores)}")
|
| 353 |
+
|
| 354 |
+
logger.info(f"COMPONENTS TO UPDATE:")
|
| 355 |
+
logger.info("-" * 40)
|
| 356 |
+
for i, component in enumerate(components_to_update, 1):
|
| 357 |
+
logger.info(f" {i}. {component}")
|
| 358 |
+
|
| 359 |
+
if eval_batch.trajectories:
|
| 360 |
+
logger.debug(f"DETAILED ANALYSIS:")
|
| 361 |
+
logger.debug("-" * 40)
|
| 362 |
+
for i, trace in enumerate(eval_batch.trajectories[:3], 1): # Show first 3 samples
|
| 363 |
+
evaluation_results = trace['evaluation_results']
|
| 364 |
+
composite_score = evaluation_results.get("composite_score", 0.0)
|
| 365 |
+
|
| 366 |
+
logger.debug(f" Sample {i} (Score: {composite_score:.4f}):")
|
| 367 |
+
|
| 368 |
+
# Show input data (truncated)
|
| 369 |
+
input_text = trace['input_text'][:100] + "..." if len(trace['input_text']) > 100 else trace['input_text']
|
| 370 |
+
logger.debug(f" Input: \"{input_text}\"")
|
| 371 |
+
|
| 372 |
+
# Show predicted output (truncated)
|
| 373 |
+
predicted_output = trace['llm_output_json'][:100] + "..." if len(trace['llm_output_json']) > 100 else trace['llm_output_json']
|
| 374 |
+
logger.debug(f" Output: \"{predicted_output}\"")
|
| 375 |
+
|
| 376 |
+
# Show detailed scores
|
| 377 |
+
logger.debug(f" Detailed Scores:")
|
| 378 |
+
for metric, score in evaluation_results.items():
|
| 379 |
+
if metric != "composite_score":
|
| 380 |
+
logger.debug(f" {metric.replace('_', ' ').title()}: {score:.4f}")
|
| 381 |
+
|
| 382 |
+
# Show generated feedback
|
| 383 |
+
feedback = self._generate_feedback(evaluation_results)
|
| 384 |
+
logger.debug(f" Feedback: \"{feedback}\"")
|
| 385 |
+
|
| 386 |
+
if len(eval_batch.trajectories) > 3:
|
| 387 |
+
logger.debug(f" ... and {len(eval_batch.trajectories) - 3} more samples")
|
| 388 |
+
|
| 389 |
+
logger.info("="*80)
|
src/gepa_optimizer/core/optimizer.py
ADDED
|
@@ -0,0 +1,1279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main GepaOptimizer class - the heart of the optimization system
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Any, Dict, List, Optional, Union
|
| 8 |
+
import asyncio
|
| 9 |
+
import io
|
| 10 |
+
import sys
|
| 11 |
+
from contextlib import redirect_stdout, redirect_stderr
|
| 12 |
+
|
| 13 |
+
import gepa
|
| 14 |
+
from ..utils.api_keys import APIKeyManager
|
| 15 |
+
from .result import ResultProcessor
|
| 16 |
+
from ..data.converters import UniversalConverter
|
| 17 |
+
from ..models.result import OptimizationResult, OptimizedResult
|
| 18 |
+
from ..models.config import OptimizationConfig, ModelConfig
|
| 19 |
+
from ..utils.helpers import sanitize_prompt
|
| 20 |
+
from ..utils.exceptions import GepaDependencyError, InvalidInputError, DatasetError, GepaOptimizerError
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
class GepaOptimizer:
|
| 25 |
+
"""
|
| 26 |
+
Main class for prompt optimization using GEPA
|
| 27 |
+
|
| 28 |
+
This is the primary interface that users interact with.
|
| 29 |
+
Provides both simple and advanced optimization capabilities.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: Optional[OptimizationConfig] = None,
|
| 33 |
+
adapter_type: str = "universal",
|
| 34 |
+
custom_adapter: Optional[Any] = None,
|
| 35 |
+
llm_model_name: Optional[str] = None,
|
| 36 |
+
metric_weights: Optional[Dict[str, float]] = None,
|
| 37 |
+
**kwargs):
|
| 38 |
+
"""
|
| 39 |
+
Initialize the optimizer
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
config: Optimization configuration (required)
|
| 43 |
+
adapter_type: Type of adapter to use ("universal" only - fully configurable)
|
| 44 |
+
custom_adapter: Custom adapter instance (overrides adapter_type)
|
| 45 |
+
llm_model_name: [Deprecated] Use config.model instead. Will be removed in future versions.
|
| 46 |
+
metric_weights: [Deprecated] Not used - evaluator handles metrics. Will be removed in future versions.
|
| 47 |
+
**kwargs: Additional parameters for universal adapter (llm_client, evaluator, etc.)
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
ValueError: If required configuration is missing
|
| 51 |
+
GepaDependencyError: If GEPA library is not available
|
| 52 |
+
"""
|
| 53 |
+
if config is None:
|
| 54 |
+
raise ValueError("config parameter is required. Use OptimizationConfig to configure the optimizer.")
|
| 55 |
+
|
| 56 |
+
# Initialize logger first
|
| 57 |
+
self.logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
self.config = config
|
| 60 |
+
self.converter = UniversalConverter(data_split_config=config.data_split)
|
| 61 |
+
self.api_manager = APIKeyManager()
|
| 62 |
+
self.result_processor = ResultProcessor()
|
| 63 |
+
|
| 64 |
+
# Initialize adapter based on configuration
|
| 65 |
+
if custom_adapter:
|
| 66 |
+
# User provided custom adapter
|
| 67 |
+
from .base_adapter import BaseGepaAdapter
|
| 68 |
+
if not isinstance(custom_adapter, BaseGepaAdapter):
|
| 69 |
+
raise TypeError("custom_adapter must be an instance of BaseGepaAdapter")
|
| 70 |
+
self.adapter = custom_adapter
|
| 71 |
+
self.logger.info("Using user-provided custom adapter")
|
| 72 |
+
elif adapter_type == "universal":
|
| 73 |
+
# Universal adapter requires user to provide components
|
| 74 |
+
llm_client = kwargs.get('llm_client')
|
| 75 |
+
evaluator = kwargs.get('evaluator')
|
| 76 |
+
|
| 77 |
+
if not llm_client or not evaluator:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"llm_client and evaluator are required for universal adapter. "
|
| 80 |
+
"Example: GepaOptimizer(config=config, adapter_type='universal', "
|
| 81 |
+
"llm_client=llm_client, evaluator=evaluator)"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
from .universal_adapter import UniversalGepaAdapter
|
| 85 |
+
self.adapter = UniversalGepaAdapter(
|
| 86 |
+
llm_client=llm_client,
|
| 87 |
+
evaluator=evaluator,
|
| 88 |
+
data_converter=kwargs.get('data_converter')
|
| 89 |
+
)
|
| 90 |
+
self.logger.info("Using universal adapter")
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"Unknown adapter_type: {adapter_type}. "
|
| 94 |
+
f"Only 'universal' is supported. "
|
| 95 |
+
f"Provide llm_client and evaluator when using universal adapter."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Keep backward compatibility
|
| 99 |
+
self.custom_adapter = self.adapter
|
| 100 |
+
|
| 101 |
+
# Log model configuration
|
| 102 |
+
model_info = self.adapter.get_performance_stats()
|
| 103 |
+
self.logger.info(f"Initialized adapter: {model_info}")
|
| 104 |
+
|
| 105 |
+
# Set up logging
|
| 106 |
+
logging.basicConfig(
|
| 107 |
+
level=logging.INFO,
|
| 108 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Validate GEPA availability
|
| 112 |
+
if gepa is None:
|
| 113 |
+
raise GepaDependencyError("GEPA library is not available. Please install it with: pip install gepa")
|
| 114 |
+
|
| 115 |
+
async def train(self,
|
| 116 |
+
seed_prompt: str,
|
| 117 |
+
dataset: Union[List[Any], str, Dict, Any],
|
| 118 |
+
**kwargs) -> OptimizedResult:
|
| 119 |
+
"""
|
| 120 |
+
Main training method for prompt optimization
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
seed_prompt: Initial prompt to optimize
|
| 124 |
+
dataset: Training data in any format
|
| 125 |
+
**kwargs: Additional parameters that can override config
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
OptimizedResult: Optimization result with improved prompt
|
| 129 |
+
|
| 130 |
+
Raises:
|
| 131 |
+
InvalidInputError: For invalid input parameters
|
| 132 |
+
DatasetError: For issues with dataset processing
|
| 133 |
+
GepaOptimizerError: For optimization failures
|
| 134 |
+
"""
|
| 135 |
+
start_time = time.time()
|
| 136 |
+
session_id = f"opt_{int(start_time)}_{id(self)}"
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
self.logger.info(f"Starting optimization session: {session_id}")
|
| 140 |
+
self.logger.info(f"Using model: {self.config.model.model_name} (provider: {self.config.model.provider})")
|
| 141 |
+
|
| 142 |
+
# #region agent log
|
| 143 |
+
import json as _json_debug
|
| 144 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 145 |
+
with open(_debug_log_path, "a") as _f:
|
| 146 |
+
_f.write(_json_debug.dumps({"hypothesisId": "E", "location": "optimizer.py:train_start", "message": "Optimization train() started", "data": {"session_id": session_id, "max_iterations": self.config.max_iterations}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 147 |
+
# #endregion
|
| 148 |
+
|
| 149 |
+
# 🔥 FIX E: Reset Pareto logger at start of each optimization run
|
| 150 |
+
from ..utils.pareto_logger import reset_pareto_logger
|
| 151 |
+
reset_pareto_logger()
|
| 152 |
+
self.logger.info("✅ Reset Pareto logger for new optimization run")
|
| 153 |
+
|
| 154 |
+
# Update config with any overrides from kwargs
|
| 155 |
+
self._update_config_from_kwargs(kwargs)
|
| 156 |
+
|
| 157 |
+
# Step 1: Validate inputs
|
| 158 |
+
self._validate_inputs(seed_prompt)
|
| 159 |
+
|
| 160 |
+
# Step 2: Convert dataset to GEPA format with 3-way split
|
| 161 |
+
# 🔥 FIX: Support pre-split datasets (user-provided train/val/test)
|
| 162 |
+
if isinstance(dataset, dict) and all(k in dataset for k in ['train', 'val', 'test']):
|
| 163 |
+
# User provided pre-split dataset - use it directly
|
| 164 |
+
self.logger.info("✅ Detected pre-split dataset - using user's split (no re-splitting)")
|
| 165 |
+
trainset_raw = dataset.get('train', [])
|
| 166 |
+
valset_raw = dataset.get('val', [])
|
| 167 |
+
testset_raw = dataset.get('test', [])
|
| 168 |
+
|
| 169 |
+
# Still need to standardize the format (convert to GEPA format)
|
| 170 |
+
trainset = self.converter._standardize(trainset_raw)
|
| 171 |
+
valset = self.converter._standardize(valset_raw)
|
| 172 |
+
testset = self.converter._standardize(testset_raw) if testset_raw else []
|
| 173 |
+
|
| 174 |
+
self.logger.info(
|
| 175 |
+
f"Using pre-split dataset: {len(trainset)} train (Dfeedback), "
|
| 176 |
+
f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)"
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
# Standard path: convert and split automatically
|
| 180 |
+
self.logger.info("Converting dataset to GEPA format with 3-way split...")
|
| 181 |
+
trainset, valset, testset = self.converter.convert(
|
| 182 |
+
dataset,
|
| 183 |
+
split_config=self.config.data_split
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Log split with adaptive strategy info
|
| 187 |
+
split_strategy = self.config.data_split.small_dataset_strategy
|
| 188 |
+
strategy_note = ""
|
| 189 |
+
if split_strategy == 'adaptive':
|
| 190 |
+
total_size = len(trainset) + len(valset) + len(testset)
|
| 191 |
+
train_ratio, val_ratio, test_ratio = self.config.data_split.get_adaptive_ratios(total_size)
|
| 192 |
+
strategy_note = f" (adaptive: {train_ratio*100:.0f}%/{val_ratio*100:.0f}%/{test_ratio*100:.0f}% ratios)"
|
| 193 |
+
self.logger.info(
|
| 194 |
+
f"Dataset split{strategy_note}: {len(trainset)} train (Dfeedback), "
|
| 195 |
+
f"{len(valset)} val (Dpareto), {len(testset)} test (held-out)"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if not trainset:
|
| 199 |
+
raise DatasetError("Dataset appears to be empty after conversion")
|
| 200 |
+
|
| 201 |
+
# Step 3: Create seed candidate
|
| 202 |
+
seed_candidate = self._create_seed_candidate(seed_prompt)
|
| 203 |
+
|
| 204 |
+
# 🔥 CRITICAL: Set valset info in adapter BEFORE baseline evaluation
|
| 205 |
+
# This ensures adapter correctly detects 'dpareto' dataset type
|
| 206 |
+
# Use direct assignment (don't rely on hasattr) to ensure attributes are set
|
| 207 |
+
try:
|
| 208 |
+
self.adapter._valset_size = len(valset) if valset else 0
|
| 209 |
+
self.logger.info(f"✅ Set valset_size in adapter: {len(valset) if valset else 0} for Dpareto detection")
|
| 210 |
+
except AttributeError:
|
| 211 |
+
self.logger.warning("⚠️ Could not set _valset_size in adapter - attribute not supported")
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
self.adapter._valset = valset
|
| 215 |
+
self.logger.info(f"✅ Stored valset in adapter ({len(valset) if valset else 0} samples)")
|
| 216 |
+
except AttributeError:
|
| 217 |
+
self.logger.warning("⚠️ Could not set _valset in adapter - attribute not supported")
|
| 218 |
+
|
| 219 |
+
# Step 3.5: Calculate baseline score on VALIDATION set (not test set)
|
| 220 |
+
# This ensures fair comparison since optimization uses validation set for Pareto selection
|
| 221 |
+
baseline_val_score = None
|
| 222 |
+
if valset:
|
| 223 |
+
self.logger.info("📊 Evaluating seed prompt on validation set for baseline...")
|
| 224 |
+
# Set baseline flag so adapter knows this is baseline, not optimization
|
| 225 |
+
# Use direct assignment to ensure the flag is set
|
| 226 |
+
try:
|
| 227 |
+
self.adapter._is_baseline_evaluation = True
|
| 228 |
+
self.logger.info("✅ Set baseline evaluation flag in adapter")
|
| 229 |
+
except AttributeError:
|
| 230 |
+
self.logger.warning("⚠️ Could not set _is_baseline_evaluation in adapter")
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
# Evaluate on validation set (same as what GEPA will use for Pareto selection)
|
| 234 |
+
eval_result = self.adapter.evaluate(
|
| 235 |
+
batch=valset,
|
| 236 |
+
candidate=seed_candidate,
|
| 237 |
+
capture_traces=False
|
| 238 |
+
)
|
| 239 |
+
baseline_val_score = sum(eval_result.scores) / len(eval_result.scores) if eval_result.scores else 0.0
|
| 240 |
+
self.logger.info(f"📊 Baseline validation score: {baseline_val_score:.4f} (on {len(valset)} samples)")
|
| 241 |
+
|
| 242 |
+
# Store baseline in adapter for later use
|
| 243 |
+
if hasattr(self.adapter, '_baseline_score'):
|
| 244 |
+
self.adapter._baseline_score = baseline_val_score
|
| 245 |
+
|
| 246 |
+
# 🔥 CRITICAL FIX: Also set baseline in Pareto logger
|
| 247 |
+
# This ensures candidates can be properly evaluated against baseline
|
| 248 |
+
from ..utils.pareto_logger import get_pareto_logger
|
| 249 |
+
pareto_log = get_pareto_logger()
|
| 250 |
+
pareto_log.set_baseline(baseline_val_score)
|
| 251 |
+
self.logger.info(f"✅ Baseline set in Pareto logger: {baseline_val_score:.4f}")
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
self.logger.warning(f"Baseline evaluation failed: {e}")
|
| 255 |
+
import traceback
|
| 256 |
+
self.logger.debug(f"Baseline evaluation error: {traceback.format_exc()}")
|
| 257 |
+
finally:
|
| 258 |
+
try:
|
| 259 |
+
self.adapter._is_baseline_evaluation = False
|
| 260 |
+
self.logger.debug("✅ Reset baseline evaluation flag - optimization can begin")
|
| 261 |
+
except AttributeError:
|
| 262 |
+
pass # Ignore if attribute not supported
|
| 263 |
+
|
| 264 |
+
# Step 4: Run GEPA optimization
|
| 265 |
+
self.logger.info("Starting GEPA optimization...")
|
| 266 |
+
gepa_result, actual_iterations = await self._run_gepa_optimization(
|
| 267 |
+
adapter=self.adapter,
|
| 268 |
+
seed_candidate=seed_candidate,
|
| 269 |
+
trainset=trainset,
|
| 270 |
+
valset=valset,
|
| 271 |
+
**kwargs
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Step 5: Extract best candidate
|
| 275 |
+
best_candidate = self._extract_best_candidate(gepa_result)
|
| 276 |
+
|
| 277 |
+
# 🔥 CRITICAL: Extract optimized prompt from best_candidate
|
| 278 |
+
# This is the actual optimized prompt that GEPA found
|
| 279 |
+
self.logger.info(f"\n{'═'*80}")
|
| 280 |
+
self.logger.info(f"📝 EXTRACTING OPTIMIZED PROMPT FROM GEPA RESULT")
|
| 281 |
+
self.logger.info(f"{'═'*80}")
|
| 282 |
+
self.logger.info(f"best_candidate keys: {list(best_candidate.keys()) if isinstance(best_candidate, dict) else 'N/A'}")
|
| 283 |
+
|
| 284 |
+
optimized_prompt = best_candidate.get('system_prompt', seed_prompt)
|
| 285 |
+
if not optimized_prompt or optimized_prompt.strip() == '':
|
| 286 |
+
# Fallback: try other keys or use seed prompt
|
| 287 |
+
optimized_prompt = best_candidate.get('prompt', best_candidate.get('text', seed_prompt))
|
| 288 |
+
|
| 289 |
+
# Get fitness score if available
|
| 290 |
+
best_fitness = best_candidate.get('fitness') or self.adapter.get_best_score() if hasattr(self.adapter, 'get_best_score') else None
|
| 291 |
+
candidate_source = best_candidate.get('source', 'unknown')
|
| 292 |
+
|
| 293 |
+
self.logger.info(f"\n✅ EXTRACTED OPTIMIZED PROMPT:")
|
| 294 |
+
self.logger.info(f" Source: {candidate_source}")
|
| 295 |
+
if best_fitness is not None:
|
| 296 |
+
self.logger.info(f" Fitness: f={best_fitness:.4f}")
|
| 297 |
+
self.logger.info(f" Length: {len(optimized_prompt)} characters")
|
| 298 |
+
self.logger.info(f" Words: {len(optimized_prompt.split())} words")
|
| 299 |
+
self.logger.info(f"\n📝 FULL OPTIMIZED PROMPT TEXT:")
|
| 300 |
+
self.logger.info(f"{'─'*80}")
|
| 301 |
+
self.logger.info(optimized_prompt)
|
| 302 |
+
self.logger.info(f"{'─'*80}")
|
| 303 |
+
|
| 304 |
+
if optimized_prompt != seed_prompt:
|
| 305 |
+
self.logger.info(f"\n✅ SUCCESS: Prompt WAS OPTIMIZED!")
|
| 306 |
+
self.logger.info(f" Seed length: {len(seed_prompt)} chars")
|
| 307 |
+
self.logger.info(f" Optimized length: {len(optimized_prompt)} chars")
|
| 308 |
+
self.logger.info(f" Difference: {len(optimized_prompt) - len(seed_prompt):+d} chars")
|
| 309 |
+
if best_fitness is not None:
|
| 310 |
+
baseline_fitness = 0.5 # Default baseline, could be improved
|
| 311 |
+
improvement = best_fitness - baseline_fitness
|
| 312 |
+
improvement_pct = (improvement / baseline_fitness * 100) if baseline_fitness > 0 else 0
|
| 313 |
+
self.logger.info(f" Fitness: f={best_fitness:.4f} (improvement: {improvement:+.4f} ({improvement_pct:+.1f}%))")
|
| 314 |
+
else:
|
| 315 |
+
self.logger.warning(f"\n⚠️ WARNING: Optimized prompt is IDENTICAL to seed prompt")
|
| 316 |
+
self.logger.warning(f" This means GEPA didn't modify the prompt during optimization")
|
| 317 |
+
if best_fitness is not None:
|
| 318 |
+
self.logger.warning(f" Best fitness found: f={best_fitness:.4f}")
|
| 319 |
+
self.logger.warning(f" 💡 Check if LLEGO best candidate is being properly extracted")
|
| 320 |
+
|
| 321 |
+
self.logger.info(f"{'═'*80}\n")
|
| 322 |
+
|
| 323 |
+
# Step 5.5: Calculate improvement metrics (validation vs validation)
|
| 324 |
+
optimized_test_score = None
|
| 325 |
+
improvement_data = {}
|
| 326 |
+
|
| 327 |
+
# 🔥 FIX: Calculate improvement based on VALIDATION scores (fair comparison)
|
| 328 |
+
# Compare optimized VALIDATION score vs validation baseline (both on Dpareto)
|
| 329 |
+
# This ensures fair comparison - both evaluated on the same validation set
|
| 330 |
+
optimized_val_score = best_fitness # Best candidate's fitness is from validation set (Dpareto)
|
| 331 |
+
|
| 332 |
+
if baseline_val_score is not None and optimized_val_score is not None:
|
| 333 |
+
absolute_improvement = optimized_val_score - baseline_val_score
|
| 334 |
+
relative_improvement = (
|
| 335 |
+
(absolute_improvement / baseline_val_score * 100)
|
| 336 |
+
if baseline_val_score > 0 else 0
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
improvement_data = {
|
| 340 |
+
'baseline_val_score': baseline_val_score,
|
| 341 |
+
'optimized_val_score': optimized_val_score,
|
| 342 |
+
'absolute_improvement': absolute_improvement,
|
| 343 |
+
'relative_improvement_percent': relative_improvement
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
self.logger.info(
|
| 347 |
+
f"📈 Validation improvement: {relative_improvement:+.2f}% "
|
| 348 |
+
f"(baseline val: {baseline_val_score:.4f} → optimized val: {optimized_val_score:.4f})"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Step 5.6: Evaluate optimized prompt on test set (if available) for final reporting
|
| 352 |
+
if testset and self.config.evaluate_on_test:
|
| 353 |
+
self.logger.info("📊 Evaluating optimized prompt on test set...")
|
| 354 |
+
|
| 355 |
+
# 🔥 CRITICAL FIX: Clear LLEGO candidate queue before test evaluation
|
| 356 |
+
# This prevents the LLEGO wrapper from intercepting test evaluation calls
|
| 357 |
+
# and returning wrong candidates instead of actually running the optimized prompt
|
| 358 |
+
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
|
| 359 |
+
if hasattr(self.adapter, 'llm_client') and isinstance(self.adapter.llm_client, LLEGOEnhancedLLMClient):
|
| 360 |
+
if hasattr(self.adapter.llm_client, '_adapter_generated_candidates'):
|
| 361 |
+
self.adapter.llm_client._adapter_generated_candidates = []
|
| 362 |
+
self.logger.info("✅ Cleared LLEGO candidate queue for clean test evaluation")
|
| 363 |
+
if hasattr(self.adapter.llm_client, '_candidate_queue'):
|
| 364 |
+
self.adapter.llm_client._candidate_queue = []
|
| 365 |
+
self.logger.info("✅ Cleared LLEGO hybrid candidate queue for clean test evaluation")
|
| 366 |
+
|
| 367 |
+
# Evaluate on test set for final reporting (but improvement is based on validation)
|
| 368 |
+
try:
|
| 369 |
+
optimized_test_score = self._evaluate_candidate_on_testset(
|
| 370 |
+
best_candidate,
|
| 371 |
+
testset
|
| 372 |
+
)
|
| 373 |
+
self.logger.info(f"📊 Optimized test score: {optimized_test_score:.4f}")
|
| 374 |
+
|
| 375 |
+
# Add test score to improvement_data for reference (but improvement is based on validation)
|
| 376 |
+
improvement_data['optimized_test_score'] = optimized_test_score
|
| 377 |
+
|
| 378 |
+
if baseline_val_score is not None:
|
| 379 |
+
test_vs_baseline = (
|
| 380 |
+
((optimized_test_score - baseline_val_score) / baseline_val_score * 100)
|
| 381 |
+
if baseline_val_score > 0 else 0
|
| 382 |
+
)
|
| 383 |
+
self.logger.info(
|
| 384 |
+
f"📊 Test set vs validation baseline: {test_vs_baseline:+.2f}% "
|
| 385 |
+
f"(baseline val: {baseline_val_score:.4f} → optimized test: {optimized_test_score:.4f})"
|
| 386 |
+
)
|
| 387 |
+
except Exception as e:
|
| 388 |
+
self.logger.warning(f"Test evaluation failed: {e}")
|
| 389 |
+
|
| 390 |
+
# Step 6: Process results
|
| 391 |
+
optimization_time = time.time() - start_time
|
| 392 |
+
|
| 393 |
+
processed_result = self.result_processor.process_full_result(
|
| 394 |
+
result=gepa_result,
|
| 395 |
+
original_prompt=seed_prompt,
|
| 396 |
+
optimization_time=optimization_time,
|
| 397 |
+
actual_iterations=actual_iterations,
|
| 398 |
+
test_metrics=improvement_data # Add test metrics
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Merge improvement data
|
| 402 |
+
final_improvement_data = {**processed_result.get('improvement_data', {}), **improvement_data}
|
| 403 |
+
|
| 404 |
+
# Step 7: Create result objects
|
| 405 |
+
# 🔥 CRITICAL: Use extracted optimized_prompt instead of processed_result
|
| 406 |
+
result = OptimizedResult(
|
| 407 |
+
original_prompt=seed_prompt,
|
| 408 |
+
optimized_prompt=optimized_prompt, # Use extracted prompt, not processed_result!
|
| 409 |
+
improvement_data=final_improvement_data,
|
| 410 |
+
optimization_time=optimization_time,
|
| 411 |
+
dataset_size=len(trainset) + len(valset) + len(testset),
|
| 412 |
+
total_iterations=processed_result.get('total_iterations', 0),
|
| 413 |
+
status=processed_result.get('status', 'completed'),
|
| 414 |
+
error_message=processed_result.get('error_message'),
|
| 415 |
+
detailed_result=OptimizationResult(
|
| 416 |
+
session_id=session_id,
|
| 417 |
+
original_prompt=seed_prompt,
|
| 418 |
+
optimized_prompt=optimized_prompt, # Use extracted prompt!
|
| 419 |
+
improvement_data=final_improvement_data,
|
| 420 |
+
optimization_time=optimization_time,
|
| 421 |
+
dataset_size=len(trainset) + len(valset) + len(testset),
|
| 422 |
+
total_iterations=processed_result.get('total_iterations', 0),
|
| 423 |
+
status=processed_result.get('status', 'completed'),
|
| 424 |
+
error_message=processed_result.get('error_message')
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
self.logger.info(f"✅ Optimization completed in {optimization_time:.2f}s")
|
| 429 |
+
return result
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
optimization_time = time.time() - start_time
|
| 433 |
+
error_msg = f"Optimization failed: {str(e)}"
|
| 434 |
+
self.logger.error(error_msg)
|
| 435 |
+
|
| 436 |
+
# Return failed result
|
| 437 |
+
return OptimizedResult(
|
| 438 |
+
original_prompt=seed_prompt,
|
| 439 |
+
optimized_prompt=seed_prompt, # Return original on failure
|
| 440 |
+
improvement_data={'error': error_msg},
|
| 441 |
+
optimization_time=optimization_time,
|
| 442 |
+
dataset_size=0,
|
| 443 |
+
total_iterations=0,
|
| 444 |
+
status='failed',
|
| 445 |
+
error_message=error_msg
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
def _update_config_from_kwargs(self, kwargs: Dict[str, Any]) -> None:
|
| 449 |
+
"""Update configuration with runtime overrides from kwargs."""
|
| 450 |
+
updated_params = []
|
| 451 |
+
|
| 452 |
+
for key, value in kwargs.items():
|
| 453 |
+
if hasattr(self.config, key):
|
| 454 |
+
setattr(self.config, key, value)
|
| 455 |
+
updated_params.append(f"{key}={value}")
|
| 456 |
+
else:
|
| 457 |
+
self.logger.warning(f"Unknown parameter '{key}' ignored")
|
| 458 |
+
|
| 459 |
+
if updated_params:
|
| 460 |
+
self.logger.info(f"Updated config parameters: {', '.join(updated_params)}")
|
| 461 |
+
|
| 462 |
+
def _validate_inputs(self, seed_prompt: str) -> None:
|
| 463 |
+
"""
|
| 464 |
+
Validate input parameters for optimization
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
seed_prompt: The seed prompt to validate
|
| 468 |
+
|
| 469 |
+
Raises:
|
| 470 |
+
InvalidInputError: If validation fails
|
| 471 |
+
"""
|
| 472 |
+
if not seed_prompt or not isinstance(seed_prompt, str):
|
| 473 |
+
raise InvalidInputError("Seed prompt must be a non-empty string")
|
| 474 |
+
|
| 475 |
+
if len(seed_prompt.strip()) < 10:
|
| 476 |
+
raise InvalidInputError("Seed prompt is too short (minimum 10 characters)")
|
| 477 |
+
|
| 478 |
+
# Validate model configuration
|
| 479 |
+
model_config = self.config.model
|
| 480 |
+
if not hasattr(model_config, 'model_name') or not model_config.model_name:
|
| 481 |
+
raise InvalidInputError("Model name is required")
|
| 482 |
+
|
| 483 |
+
reflection_config = self.config.reflection_model
|
| 484 |
+
if not hasattr(reflection_config, 'model_name') or not reflection_config.model_name:
|
| 485 |
+
raise InvalidInputError("Reflection model name is required")
|
| 486 |
+
|
| 487 |
+
def _clean_reflection_prompt(self, prompt: str, max_length: int = 50000) -> str:
|
| 488 |
+
"""
|
| 489 |
+
Clean reflection prompt by removing base64 images and truncating if too long.
|
| 490 |
+
|
| 491 |
+
🔥 CRITICAL: GEPA's reflective dataset includes base64 images which create
|
| 492 |
+
massive prompts (7MB+) that exceed token limits. This function:
|
| 493 |
+
1. Strips all base64 image data
|
| 494 |
+
2. Removes excessive detailed_scores entries
|
| 495 |
+
3. Truncates to reasonable size
|
| 496 |
+
4. Preserves essential feedback information
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
prompt: Original prompt from GEPA (may contain base64)
|
| 500 |
+
max_length: Maximum length after cleaning (default: 50K chars)
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
Cleaned prompt without base64, within size limits
|
| 504 |
+
"""
|
| 505 |
+
import re
|
| 506 |
+
|
| 507 |
+
# Step 1: Remove base64 image strings (typically very long alphanumeric strings)
|
| 508 |
+
# Base64 images are usually 50K+ characters of A-Za-z0-9+/= pattern
|
| 509 |
+
# Look for very long base64-like sequences
|
| 510 |
+
base64_pattern = r'[A-Za-z0-9+/=]{5000,}' # Sequences of 5000+ base64 chars
|
| 511 |
+
cleaned = re.sub(base64_pattern, '[IMAGE_DATA_REMOVED]', prompt)
|
| 512 |
+
|
| 513 |
+
# Step 2: Remove detailed_scores sections that might contain base64 references
|
| 514 |
+
# These are usually in markdown format: "### detailed_scores\n...base64..."
|
| 515 |
+
detailed_scores_pattern = r'### detailed_scores[^\n]*\n[^#]*(?:image_base64|base64)[^\n]*(?:\n[^#]*)*'
|
| 516 |
+
cleaned = re.sub(detailed_scores_pattern, '### detailed_scores: [REMOVED_FOR_BREVITY]', cleaned, flags=re.IGNORECASE | re.MULTILINE)
|
| 517 |
+
|
| 518 |
+
# Step 3: Remove any remaining image_base64 references
|
| 519 |
+
cleaned = re.sub(r'image_base64[^\n]*', 'image_base64: [REMOVED]', cleaned, flags=re.IGNORECASE)
|
| 520 |
+
cleaned = re.sub(r'"[A-Za-z0-9+/=]{10000,}"', '[LARGE_DATA_STRING_REMOVED]', cleaned) # Very long strings likely base64
|
| 521 |
+
|
| 522 |
+
# Step 4: Truncate if still too long (keep the beginning which usually has the most important info)
|
| 523 |
+
if len(cleaned) > max_length:
|
| 524 |
+
# Keep first part (usually contains prompt and key feedback)
|
| 525 |
+
# Add truncation notice
|
| 526 |
+
truncated_size = len(cleaned) - max_length
|
| 527 |
+
cleaned = cleaned[:max_length] + f"\n\n[TRUNCATED {truncated_size} characters of detailed evaluation data]"
|
| 528 |
+
self.logger.warning(f"⚠️ Prompt truncated: {len(prompt)} → {len(cleaned)} chars")
|
| 529 |
+
|
| 530 |
+
return cleaned
|
| 531 |
+
|
| 532 |
+
def _validate_models(self, task_lm, reflection_lm):
|
| 533 |
+
"""
|
| 534 |
+
Validate if specified models are supported.
|
| 535 |
+
|
| 536 |
+
Note: No hardcoded restrictions - the API provider will validate model existence.
|
| 537 |
+
This method is kept for potential future validation logic but doesn't restrict users.
|
| 538 |
+
"""
|
| 539 |
+
# No hardcoded model restrictions - users can specify any model
|
| 540 |
+
# The API provider will handle validation and return errors if model doesn't exist
|
| 541 |
+
self.logger.debug(f"Using task model: {task_lm}, reflection model: {reflection_lm}")
|
| 542 |
+
|
| 543 |
+
def _create_seed_candidate(self, seed_prompt: str) -> Dict[str, str]:
|
| 544 |
+
"""Create a seed candidate from the input prompt."""
|
| 545 |
+
sanitized_prompt = sanitize_prompt(seed_prompt)
|
| 546 |
+
return {'system_prompt': sanitized_prompt}
|
| 547 |
+
|
| 548 |
+
async def _run_gepa_optimization(self, adapter, seed_candidate: Any, trainset: List[Any], valset: List[Any], **kwargs) -> tuple: # Return tuple
|
| 549 |
+
"""
|
| 550 |
+
Run GEPA optimization with the given adapter and data
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
adapter: Custom adapter for GEPA
|
| 554 |
+
seed_candidate: Initial prompt candidate
|
| 555 |
+
trainset: Training dataset
|
| 556 |
+
valset: Validation dataset
|
| 557 |
+
**kwargs: Additional optimization parameters that can override config
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
Dict with optimization results
|
| 561 |
+
|
| 562 |
+
Raises:
|
| 563 |
+
GepaOptimizerError: If optimization fails
|
| 564 |
+
|
| 565 |
+
Note:
|
| 566 |
+
The following parameters are required in the config:
|
| 567 |
+
- max_metric_calls: Maximum number of metric evaluations
|
| 568 |
+
- batch_size: Batch size for evaluation
|
| 569 |
+
- max_iterations: Maximum number of optimization iterations
|
| 570 |
+
"""
|
| 571 |
+
try:
|
| 572 |
+
# Get optimization parameters from config (these are required fields)
|
| 573 |
+
max_metric_calls = self.config.max_metric_calls
|
| 574 |
+
batch_size = self.config.batch_size
|
| 575 |
+
max_iterations = self.config.max_iterations
|
| 576 |
+
|
| 577 |
+
# Create reflection model client
|
| 578 |
+
from ..llms.vision_llm import VisionLLMClient
|
| 579 |
+
base_reflection_lm_client = VisionLLMClient(
|
| 580 |
+
provider=self.config.reflection_model.provider,
|
| 581 |
+
model_name=self.config.reflection_model.model_name,
|
| 582 |
+
api_key=self.config.reflection_model.api_key,
|
| 583 |
+
base_url=self.config.reflection_model.base_url,
|
| 584 |
+
temperature=self.config.reflection_model.temperature,
|
| 585 |
+
max_tokens=self.config.reflection_model.max_tokens,
|
| 586 |
+
top_p=self.config.reflection_model.top_p,
|
| 587 |
+
frequency_penalty=self.config.reflection_model.frequency_penalty,
|
| 588 |
+
presence_penalty=self.config.reflection_model.presence_penalty
|
| 589 |
+
)
|
| 590 |
+
# reflection_lm_client will be set below (may be wrapped with LLEGO)
|
| 591 |
+
reflection_lm_client = base_reflection_lm_client
|
| 592 |
+
|
| 593 |
+
# 🆕 LLEGO Integration: Create enhanced reflection callable
|
| 594 |
+
if self.config.use_llego_operators:
|
| 595 |
+
self.logger.info("🧬 LLEGO genetic operators ENABLED")
|
| 596 |
+
self.logger.info(f" α={self.config.alpha}, τ={self.config.tau}, ν={self.config.nu}")
|
| 597 |
+
self.logger.info(f" Crossover offspring: {self.config.n_crossover}, Mutation offspring: {self.config.n_mutation}")
|
| 598 |
+
|
| 599 |
+
# Import LLEGO operators
|
| 600 |
+
from ..operators.llego_operators import LLEGOIntegrationLayer, PromptCandidate
|
| 601 |
+
|
| 602 |
+
# Initialize LLEGO integration layer
|
| 603 |
+
llego = LLEGOIntegrationLayer(
|
| 604 |
+
alpha=self.config.alpha,
|
| 605 |
+
tau=self.config.tau,
|
| 606 |
+
nu=self.config.nu,
|
| 607 |
+
population_size=self.config.population_size,
|
| 608 |
+
n_crossover=self.config.n_crossover,
|
| 609 |
+
n_mutation=self.config.n_mutation
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# Initialize with seed prompt
|
| 613 |
+
llego.initialize_population(
|
| 614 |
+
seed_prompt=seed_candidate.get('system_prompt', ''),
|
| 615 |
+
initial_fitness=0.5
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# 🔥 HYBRID MODE FIX: Wrap reflection_lm_client with LLEGO for hybrid mode
|
| 619 |
+
# This ensures reflection calls go through LLEGO wrapper for candidate generation
|
| 620 |
+
if self.config.enable_gepa_reflection_with_llego:
|
| 621 |
+
self.logger.info("🔥 HYBRID MODE: Wrapping reflection_lm_client with LLEGO")
|
| 622 |
+
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
|
| 623 |
+
|
| 624 |
+
# Wrap reflection_lm_client with LLEGO so hybrid generation is triggered
|
| 625 |
+
reflection_lm_client = LLEGOEnhancedLLMClient(
|
| 626 |
+
base_llm=base_reflection_lm_client,
|
| 627 |
+
llego_layer=llego,
|
| 628 |
+
config=self.config, # Pass config for hybrid mode!
|
| 629 |
+
verbose=True
|
| 630 |
+
)
|
| 631 |
+
self.logger.info("✅ reflection_lm_client wrapped with LLEGO (hybrid mode enabled)")
|
| 632 |
+
|
| 633 |
+
# 🔥 CRITICAL: Store reflection_lm_client reference in adapter so it can set context
|
| 634 |
+
# This allows make_reflective_dataset to set reflection context on BOTH clients
|
| 635 |
+
if hasattr(adapter, 'reflection_lm_client'):
|
| 636 |
+
adapter.reflection_lm_client = reflection_lm_client
|
| 637 |
+
self.logger.info("✅ Stored reflection_lm_client reference in adapter")
|
| 638 |
+
else:
|
| 639 |
+
# Add reflection_lm_client attribute to adapter
|
| 640 |
+
adapter.reflection_lm_client = reflection_lm_client
|
| 641 |
+
self.logger.info("✅ Added reflection_lm_client attribute to adapter")
|
| 642 |
+
|
| 643 |
+
# 🔥 NEW: Also store config and reflection_lm_client for adapter-level generation
|
| 644 |
+
if hasattr(adapter, '_config'):
|
| 645 |
+
adapter._config = self.config
|
| 646 |
+
self.logger.info("✅ Stored config in adapter for hybrid mode")
|
| 647 |
+
else:
|
| 648 |
+
adapter._config = self.config
|
| 649 |
+
self.logger.info("✅ Added _config attribute to adapter")
|
| 650 |
+
|
| 651 |
+
if hasattr(adapter, '_reflection_lm_client'):
|
| 652 |
+
adapter._reflection_lm_client = reflection_lm_client
|
| 653 |
+
self.logger.info("✅ Stored _reflection_lm_client in adapter for hybrid mode")
|
| 654 |
+
else:
|
| 655 |
+
adapter._reflection_lm_client = reflection_lm_client
|
| 656 |
+
self.logger.info("✅ Added _reflection_lm_client attribute to adapter")
|
| 657 |
+
|
| 658 |
+
# 🔥 CRITICAL FIX: Ensure LLEGO layer is stored in adapter
|
| 659 |
+
# Without this, adapter.llego will be None and population updates are skipped!
|
| 660 |
+
if hasattr(adapter, 'llego'):
|
| 661 |
+
if adapter.llego is None:
|
| 662 |
+
adapter.llego = llego
|
| 663 |
+
self.logger.info("✅ CRITICAL: Set LLEGO layer in adapter (was None)")
|
| 664 |
+
else:
|
| 665 |
+
self.logger.debug("✅ LLEGO layer already set in adapter")
|
| 666 |
+
else:
|
| 667 |
+
# Add llego attribute if it doesn't exist
|
| 668 |
+
adapter.llego = llego
|
| 669 |
+
self.logger.info("✅ CRITICAL: Added LLEGO layer to adapter")
|
| 670 |
+
|
| 671 |
+
# 🔥 CRITICAL: Always set _reflection_lm_client in adapter (even without hybrid mode)
|
| 672 |
+
# This is required for propose_new_texts() to work
|
| 673 |
+
if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None:
|
| 674 |
+
adapter._reflection_lm_client = reflection_lm_client
|
| 675 |
+
self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)")
|
| 676 |
+
|
| 677 |
+
# 🔥 HYBRID MODE FIX: Inject config into LLEGO wrapper for hybrid mode
|
| 678 |
+
# The adapter already has LLEGO wrapper, we just need to update its config
|
| 679 |
+
if self.config.enable_gepa_reflection_with_llego:
|
| 680 |
+
# HYBRID MODE: Update the LLEGO wrapper's config
|
| 681 |
+
self.logger.info("🔥 HYBRID MODE: Enabling hybrid candidate generation in LLEGO wrapper")
|
| 682 |
+
|
| 683 |
+
# Get the LLM client (may already be wrapped)
|
| 684 |
+
llm_client = self.adapter.llm_client
|
| 685 |
+
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
|
| 686 |
+
|
| 687 |
+
if isinstance(llm_client, LLEGOEnhancedLLMClient):
|
| 688 |
+
# Already wrapped, just update config
|
| 689 |
+
llm_client.config = self.config
|
| 690 |
+
self.logger.info("✅ Updated LLEGO wrapper with hybrid mode config")
|
| 691 |
+
else:
|
| 692 |
+
# Not wrapped yet, wrap it now with config
|
| 693 |
+
llego_wrapped_llm = LLEGOEnhancedLLMClient(
|
| 694 |
+
base_llm=llm_client,
|
| 695 |
+
llego_layer=llego,
|
| 696 |
+
config=self.config, # ← Pass config for hybrid mode!
|
| 697 |
+
verbose=True
|
| 698 |
+
)
|
| 699 |
+
# Update adapter's LLM client
|
| 700 |
+
self.adapter.llm_client = llego_wrapped_llm
|
| 701 |
+
self.logger.info("✅ Wrapped LLM client with LLEGO (hybrid mode enabled)")
|
| 702 |
+
|
| 703 |
+
adapter = self.adapter
|
| 704 |
+
else:
|
| 705 |
+
# LLEGO-ONLY MODE: Wrap adapter with LLEGO layer (no hybrid)
|
| 706 |
+
self.logger.info("🧬 LLEGO-ONLY MODE: Recreating adapter with LLEGO integration...")
|
| 707 |
+
if hasattr(self, 'adapter') and self.adapter:
|
| 708 |
+
from .universal_adapter import UniversalGepaAdapter
|
| 709 |
+
|
| 710 |
+
# Get original LLM client and evaluator from current adapter
|
| 711 |
+
original_llm = self.adapter.llm_client
|
| 712 |
+
# If it's already wrapped, unwrap it
|
| 713 |
+
if hasattr(original_llm, 'base_llm'):
|
| 714 |
+
original_llm = original_llm.base_llm
|
| 715 |
+
|
| 716 |
+
evaluator = self.adapter.evaluator
|
| 717 |
+
data_converter = self.adapter.data_converter
|
| 718 |
+
|
| 719 |
+
# Recreate adapter with LLEGO (no hybrid mode config)
|
| 720 |
+
from ..llms.llego_enhanced_llm import LLEGOEnhancedLLMClient
|
| 721 |
+
llego_wrapped_llm = LLEGOEnhancedLLMClient(
|
| 722 |
+
base_llm=original_llm,
|
| 723 |
+
llego_layer=llego,
|
| 724 |
+
config=None, # No hybrid mode
|
| 725 |
+
verbose=True
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
adapter = UniversalGepaAdapter(
|
| 729 |
+
llm_client=llego_wrapped_llm,
|
| 730 |
+
evaluator=evaluator,
|
| 731 |
+
data_converter=data_converter,
|
| 732 |
+
llego_layer=llego
|
| 733 |
+
)
|
| 734 |
+
self.logger.info("✅ Adapter recreated with LLEGO-enhanced LLM client")
|
| 735 |
+
else:
|
| 736 |
+
adapter = self.adapter
|
| 737 |
+
|
| 738 |
+
# Create LLEGO-enhanced reflection callable
|
| 739 |
+
# When hybrid mode is enabled, reflection_lm_client is wrapped with LLEGO
|
| 740 |
+
# The wrapper will automatically generate hybrid candidates when called
|
| 741 |
+
def reflection_lm_callable(prompt: str) -> str:
|
| 742 |
+
"""
|
| 743 |
+
Reflection callable that delegates to LLEGO-wrapped client.
|
| 744 |
+
In hybrid mode, the wrapper generates candidates from both GEPA and LLEGO.
|
| 745 |
+
|
| 746 |
+
🔥 CRITICAL: Clean the prompt to remove base64 images and truncate if too long.
|
| 747 |
+
"""
|
| 748 |
+
# 🔥 FIX: Clean prompt to remove base64 images and truncate excessive data
|
| 749 |
+
cleaned_prompt = self._clean_reflection_prompt(prompt)
|
| 750 |
+
|
| 751 |
+
self.logger.info(f"\n{'🔥'*40}")
|
| 752 |
+
self.logger.info(f"🔥 reflection_lm_callable CALLED (delegating to LLEGO wrapper)")
|
| 753 |
+
self.logger.info(f"🔥 Original prompt length: {len(prompt)} chars")
|
| 754 |
+
self.logger.info(f"🔥 Cleaned prompt length: {len(cleaned_prompt)} chars")
|
| 755 |
+
self.logger.info(f"🔥 Truncation: {len(prompt) - len(cleaned_prompt)} chars removed")
|
| 756 |
+
self.logger.info(f"🔥 First 200 chars (cleaned): {cleaned_prompt[:200]}...")
|
| 757 |
+
self.logger.info(f"{'🔥'*40}\n")
|
| 758 |
+
|
| 759 |
+
try:
|
| 760 |
+
# 🔥 CRITICAL: Set reflection context BEFORE generating
|
| 761 |
+
# This signals to the LLEGO wrapper that we're in reflection mode
|
| 762 |
+
if isinstance(reflection_lm_client, LLEGOEnhancedLLMClient):
|
| 763 |
+
reflection_lm_client.set_reflection_context(
|
| 764 |
+
current_prompt=cleaned_prompt, # Use cleaned prompt
|
| 765 |
+
feedback=None,
|
| 766 |
+
in_reflection=True # Enable reflection mode
|
| 767 |
+
)
|
| 768 |
+
self.logger.info("✅ Reflection context set on reflection_lm_client")
|
| 769 |
+
|
| 770 |
+
# 🔥 HYBRID MODE: If reflection_lm_client is wrapped with LLEGO,
|
| 771 |
+
# calling generate() will trigger hybrid candidate generation
|
| 772 |
+
# The wrapper handles queuing and returns candidates one by one
|
| 773 |
+
|
| 774 |
+
# 🔥 CRITICAL: System prompt must instruct LLM to generate improved prompt, not feedback
|
| 775 |
+
optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization.
|
| 776 |
+
|
| 777 |
+
Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues.
|
| 778 |
+
|
| 779 |
+
Core Requirements:
|
| 780 |
+
1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary)
|
| 781 |
+
2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening)
|
| 782 |
+
3. PRESERVE the core task domain and output format requirements
|
| 783 |
+
4. INTEGRATE improvements from feedback naturally into the prompt structure
|
| 784 |
+
5. MAINTAIN clarity, specificity, and actionability
|
| 785 |
+
|
| 786 |
+
Quality Standards:
|
| 787 |
+
- Be specific and concrete (avoid vague instructions)
|
| 788 |
+
- Use clear, imperative language for task instructions
|
| 789 |
+
- Include edge case handling if feedback identifies confusion
|
| 790 |
+
- Ensure the prompt is self-contained and unambiguous
|
| 791 |
+
|
| 792 |
+
DO NOT include:
|
| 793 |
+
- Analysis of what went wrong
|
| 794 |
+
- Explanations of your changes
|
| 795 |
+
- Meta-text like "Here's an improved version..." or "Based on feedback..."
|
| 796 |
+
- Recommendations or suggestions (those are already in the feedback)
|
| 797 |
+
|
| 798 |
+
Output the improved prompt directly and only the prompt."""
|
| 799 |
+
|
| 800 |
+
result = reflection_lm_client.generate(
|
| 801 |
+
system_prompt=optimization_system_prompt,
|
| 802 |
+
user_prompt=cleaned_prompt, # Use cleaned prompt (no base64, truncated)
|
| 803 |
+
image_base64=""
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Extract content from result
|
| 807 |
+
if isinstance(result, dict):
|
| 808 |
+
candidate = result.get("content", str(result))
|
| 809 |
+
source = result.get("source", "unknown")
|
| 810 |
+
self.logger.info(f"✅ Candidate from {source} (FULL TEXT):")
|
| 811 |
+
self.logger.info(f" '{candidate}'")
|
| 812 |
+
return candidate
|
| 813 |
+
else:
|
| 814 |
+
candidate = str(result)
|
| 815 |
+
self.logger.info(f"✅ Candidate generated (FULL TEXT):")
|
| 816 |
+
self.logger.info(f" '{candidate}'")
|
| 817 |
+
return candidate
|
| 818 |
+
|
| 819 |
+
except Exception as e:
|
| 820 |
+
self.logger.error(f"❌ Error in reflection_lm_callable: {e}")
|
| 821 |
+
import traceback
|
| 822 |
+
self.logger.error(traceback.format_exc())
|
| 823 |
+
# Fallback: return prompt as-is
|
| 824 |
+
return prompt
|
| 825 |
+
|
| 826 |
+
# Set up reflection context for LLEGO wrapper
|
| 827 |
+
if self.config.enable_gepa_reflection_with_llego and isinstance(reflection_lm_client, LLEGOEnhancedLLMClient):
|
| 828 |
+
# Store current prompt in reflection context for LLEGO operators
|
| 829 |
+
reflection_lm_client.set_reflection_context(
|
| 830 |
+
current_prompt=seed_candidate.get('system_prompt', ''),
|
| 831 |
+
feedback=None,
|
| 832 |
+
in_reflection=True
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
else:
|
| 836 |
+
# Standard GEPA reflection (no LLEGO)
|
| 837 |
+
adapter = self.adapter # Use the original adapter
|
| 838 |
+
|
| 839 |
+
# 🔥 CRITICAL: Always set _reflection_lm_client in adapter (even without LLEGO)
|
| 840 |
+
# This is required for propose_new_texts() to work
|
| 841 |
+
if not hasattr(adapter, '_reflection_lm_client') or adapter._reflection_lm_client is None:
|
| 842 |
+
adapter._reflection_lm_client = reflection_lm_client
|
| 843 |
+
self.logger.info("✅ Set _reflection_lm_client in adapter (required for propose_new_texts)")
|
| 844 |
+
|
| 845 |
+
# Define standard reflection callable (no LLEGO enhancement)
|
| 846 |
+
def reflection_lm_callable(prompt: str) -> str:
|
| 847 |
+
"""Standard callable wrapper for reflection model that GEPA expects"""
|
| 848 |
+
try:
|
| 849 |
+
# 🔥 CRITICAL: System prompt must instruct LLM to generate improved prompt, not feedback
|
| 850 |
+
optimization_system_prompt = """You are an expert prompt engineer specializing in iterative prompt optimization.
|
| 851 |
+
|
| 852 |
+
Your task: Given the CURRENT PROMPT and its EVALUATION FEEDBACK, generate an IMPROVED version of the prompt that addresses all identified issues.
|
| 853 |
+
|
| 854 |
+
Core Requirements:
|
| 855 |
+
1. OUTPUT ONLY the improved prompt text (no explanations, no analysis, no meta-commentary)
|
| 856 |
+
2. START directly with the prompt (e.g., "You are a mobile GUI agent..." or similar task-appropriate opening)
|
| 857 |
+
3. PRESERVE the core task domain and output format requirements
|
| 858 |
+
4. INTEGRATE improvements from feedback naturally into the prompt structure
|
| 859 |
+
5. MAINTAIN clarity, specificity, and actionability
|
| 860 |
+
|
| 861 |
+
Quality Standards:
|
| 862 |
+
- Be specific and concrete (avoid vague instructions)
|
| 863 |
+
- Use clear, imperative language for task instructions
|
| 864 |
+
- Include edge case handling if feedback identifies confusion
|
| 865 |
+
- Ensure the prompt is self-contained and unambiguous
|
| 866 |
+
|
| 867 |
+
DO NOT include:
|
| 868 |
+
- Analysis of what went wrong
|
| 869 |
+
- Explanations of your changes
|
| 870 |
+
- Meta-text like "Here's an improved version..." or "Based on feedback..."
|
| 871 |
+
- Recommendations or suggestions (those are already in the feedback)
|
| 872 |
+
|
| 873 |
+
Output the improved prompt directly and only the prompt."""
|
| 874 |
+
|
| 875 |
+
# For reflection, we only need text generation (no images)
|
| 876 |
+
result = reflection_lm_client.generate(
|
| 877 |
+
system_prompt=optimization_system_prompt,
|
| 878 |
+
user_prompt=prompt,
|
| 879 |
+
image_base64="" # No image for reflection
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
# Extract string content from the result dictionary
|
| 883 |
+
if isinstance(result, dict):
|
| 884 |
+
return result.get("content", str(result))
|
| 885 |
+
else:
|
| 886 |
+
return str(result)
|
| 887 |
+
|
| 888 |
+
except Exception as e:
|
| 889 |
+
self.logger.error(f"Reflection model error: {e}")
|
| 890 |
+
return prompt # Return original prompt on error
|
| 891 |
+
self.logger.info(
|
| 892 |
+
f"Starting GEPA optimization with {max_iterations} iterations, "
|
| 893 |
+
f"batch size {batch_size}, max metric calls: {max_metric_calls}"
|
| 894 |
+
)
|
| 895 |
+
self.logger.info(
|
| 896 |
+
f"GEPA parameters: candidate_selection_strategy=pareto, "
|
| 897 |
+
f"reflection_minibatch_size={batch_size}, "
|
| 898 |
+
f"skip_perfect_score=False, "
|
| 899 |
+
f"module_selector=round_robin"
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
# Prepare optimization parameters with ONLY valid GEPA parameters
|
| 903 |
+
# Note: 'adapter' variable is set above (either LLEGO-enhanced or standard)
|
| 904 |
+
# 🔥 REMOVED: Excessive diagnostic warnings - moved to DEBUG level
|
| 905 |
+
reflection_lm_passed = reflection_lm_callable if self.config.use_llego_operators else None
|
| 906 |
+
if reflection_lm_passed:
|
| 907 |
+
self.logger.debug(f"reflection_lm_callable passed to GEPA (may be ignored in adapter mode)")
|
| 908 |
+
|
| 909 |
+
# #region agent log
|
| 910 |
+
import json as _json_debug
|
| 911 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 912 |
+
with open(_debug_log_path, "a") as _f:
|
| 913 |
+
_f.write(_json_debug.dumps({"hypothesisId": "A", "location": "optimizer.py:gepa_params", "message": "GEPA params construction", "data": {"max_iterations_from_config": max_iterations, "max_metric_calls": max_metric_calls, "batch_size": batch_size}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 914 |
+
# #endregion
|
| 915 |
+
|
| 916 |
+
gepa_params = {
|
| 917 |
+
'adapter': adapter, # Use the adapter created above (with or without LLEGO)
|
| 918 |
+
'seed_candidate': seed_candidate,
|
| 919 |
+
'trainset': trainset,
|
| 920 |
+
'valset': valset,
|
| 921 |
+
'max_metric_calls': max_metric_calls,
|
| 922 |
+
# NOTE: GEPA does NOT have num_iterations - it uses max_metric_calls to control iterations
|
| 923 |
+
|
| 924 |
+
# 🔥 CRITICAL: When using an adapter, GEPA expects:
|
| 925 |
+
# - adapter.make_reflective_dataset() to create feedback data
|
| 926 |
+
# - GEPA's internal proposer to generate candidates from that data
|
| 927 |
+
# - task_lm and reflection_lm must be None (GEPA will use model from adapter)
|
| 928 |
+
'task_lm': None, # Don't pass - adapter handles this
|
| 929 |
+
'reflection_lm': reflection_lm_passed, # Pass LLEGO-enhanced reflection (may be ignored!)
|
| 930 |
+
|
| 931 |
+
# Valid GEPA parameters based on actual library
|
| 932 |
+
'candidate_selection_strategy': 'pareto', # Use Pareto selection
|
| 933 |
+
'skip_perfect_score': False, # Don't skip perfect scores
|
| 934 |
+
'reflection_minibatch_size': batch_size, # Use batch size for reflection
|
| 935 |
+
'perfect_score': 1.0, # Perfect score threshold
|
| 936 |
+
'module_selector': 'round_robin', # Cycle through components
|
| 937 |
+
'display_progress_bar': self.config.verbose, # Show progress if verbose
|
| 938 |
+
'raise_on_exception': True, # Raise exceptions for debugging
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
# 🔥 CRITICAL FIX: Filter kwargs to only include valid GEPA parameters
|
| 942 |
+
# GEPA does NOT accept num_iterations, max_iterations, or other non-GEPA params
|
| 943 |
+
VALID_GEPA_PARAMS = {
|
| 944 |
+
'seed_candidate', 'trainset', 'valset', 'adapter', 'task_lm', 'reflection_lm',
|
| 945 |
+
'candidate_selection_strategy', 'skip_perfect_score', 'batch_sampler',
|
| 946 |
+
'reflection_minibatch_size', 'perfect_score', 'reflection_prompt_template',
|
| 947 |
+
'module_selector', 'use_merge', 'max_merge_invocations', 'merge_val_overlap_floor',
|
| 948 |
+
'max_metric_calls', 'stop_callbacks', 'logger', 'run_dir', 'use_wandb',
|
| 949 |
+
'wandb_api_key', 'wandb_init_kwargs', 'use_mlflow', 'mlflow_tracking_uri',
|
| 950 |
+
'mlflow_experiment_name', 'track_best_outputs', 'display_progress_bar',
|
| 951 |
+
'use_cloudpickle', 'seed', 'raise_on_exception', 'val_evaluation_policy'
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
# Only add valid kwargs that aren't already in gepa_params
|
| 955 |
+
for key, value in kwargs.items():
|
| 956 |
+
if key in VALID_GEPA_PARAMS and key not in gepa_params:
|
| 957 |
+
gepa_params[key] = value
|
| 958 |
+
elif key not in VALID_GEPA_PARAMS:
|
| 959 |
+
self.logger.debug(f"⚠️ Filtering out invalid GEPA parameter: {key}")
|
| 960 |
+
|
| 961 |
+
# #region agent log
|
| 962 |
+
with open(_debug_log_path, "a") as _f:
|
| 963 |
+
_f.write(_json_debug.dumps({"hypothesisId": "A", "location": "optimizer.py:gepa_params_final", "message": "Final GEPA params keys", "data": {"params_keys": list(gepa_params.keys()), "max_metric_calls": gepa_params.get('max_metric_calls', 'NOT_PASSED')}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 964 |
+
# #endregion
|
| 965 |
+
|
| 966 |
+
# 🎯 NEW: Capture GEPA's internal logging for pareto front information
|
| 967 |
+
gepa_output = io.StringIO()
|
| 968 |
+
|
| 969 |
+
# Log iteration start
|
| 970 |
+
from ..utils.clean_logger import get_clean_logger
|
| 971 |
+
clean_log = get_clean_logger()
|
| 972 |
+
clean_log.log_iteration_start(1, seed_prompt=seed_candidate.get('system_prompt', ''))
|
| 973 |
+
|
| 974 |
+
# 🔥 CRITICAL: Pass valset size to adapter for better dataset type detection
|
| 975 |
+
if hasattr(adapter, '_valset_size'):
|
| 976 |
+
adapter._valset_size = len(valset)
|
| 977 |
+
self.logger.debug(f"✅ Set valset_size in adapter: {len(valset)} for Dpareto detection")
|
| 978 |
+
|
| 979 |
+
# 🔥 CRITICAL FIX: Store valset in adapter so we can evaluate generated candidates on it
|
| 980 |
+
# This ensures generated candidates are evaluated on Dpareto for Pareto selection
|
| 981 |
+
if hasattr(adapter, '_valset'):
|
| 982 |
+
adapter._valset = valset
|
| 983 |
+
self.logger.debug(f"✅ Stored valset in adapter ({len(valset)} samples) for Dpareto evaluation of generated candidates")
|
| 984 |
+
else:
|
| 985 |
+
# Add _valset attribute if it doesn't exist
|
| 986 |
+
adapter._valset = valset
|
| 987 |
+
self.logger.debug(f"✅ Added _valset attribute to adapter ({len(valset)} samples)")
|
| 988 |
+
|
| 989 |
+
# Run GEPA optimization (synchronous call wrapped in async)
|
| 990 |
+
result = await asyncio.get_event_loop().run_in_executor(
|
| 991 |
+
None,
|
| 992 |
+
lambda: self._run_gepa_with_logging(gepa_params, gepa_output)
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# 🎯 NEW: Process and log pareto front information, extract iteration count
|
| 996 |
+
gepa_logs = gepa_output.getvalue()
|
| 997 |
+
actual_iterations = self._log_pareto_front_info(gepa_logs) # Get iteration count
|
| 998 |
+
|
| 999 |
+
return result, actual_iterations # Return both result and iteration count
|
| 1000 |
+
except Exception as e:
|
| 1001 |
+
# Try to extract partial results before failing
|
| 1002 |
+
self.logger.warning(f"GEPA optimization failed: {e}")
|
| 1003 |
+
|
| 1004 |
+
# Check if we have any cached results from the adapter
|
| 1005 |
+
best_candidate = adapter.get_best_candidate()
|
| 1006 |
+
best_score = adapter.get_best_score()
|
| 1007 |
+
|
| 1008 |
+
if best_candidate and best_score > 0:
|
| 1009 |
+
self.logger.info(f"🎯 Using cached best result with score: {best_score:.4f}")
|
| 1010 |
+
|
| 1011 |
+
# Create a mock GEPA result with the best candidate found
|
| 1012 |
+
return {
|
| 1013 |
+
'best_candidate': best_candidate,
|
| 1014 |
+
'best_score': best_score,
|
| 1015 |
+
'partial_result': True,
|
| 1016 |
+
'error': f'GEPA failed but returning best result found: {str(e)}'
|
| 1017 |
+
}
|
| 1018 |
+
else:
|
| 1019 |
+
# If no cached results, re-raise the error
|
| 1020 |
+
raise GepaOptimizerError(f"GEPA optimization failed: {str(e)}")
|
| 1021 |
+
|
| 1022 |
+
def _run_gepa_with_logging(self, gepa_params: Dict[str, Any], output_buffer: io.StringIO) -> Any:
|
| 1023 |
+
"""Run GEPA optimization while capturing its output."""
|
| 1024 |
+
# Capture GEPA's print statements and logging
|
| 1025 |
+
with redirect_stdout(output_buffer), redirect_stderr(output_buffer):
|
| 1026 |
+
return gepa.optimize(**gepa_params)
|
| 1027 |
+
|
| 1028 |
+
def _log_pareto_front_info(self, gepa_logs: str) -> int: # Return int instead of None
|
| 1029 |
+
"""Extract and log pareto front information from GEPA logs. Returns max iteration count."""
|
| 1030 |
+
lines = gepa_logs.split('\n')
|
| 1031 |
+
current_iteration = 0
|
| 1032 |
+
max_iteration = 0 # Track max iteration
|
| 1033 |
+
|
| 1034 |
+
for line in lines:
|
| 1035 |
+
# Look for iteration information
|
| 1036 |
+
if 'iteration' in line.lower():
|
| 1037 |
+
# Try to extract iteration number
|
| 1038 |
+
import re
|
| 1039 |
+
iteration_match = re.search(r'iteration\s+(\d+)', line.lower())
|
| 1040 |
+
if iteration_match:
|
| 1041 |
+
current_iteration = int(iteration_match.group(1))
|
| 1042 |
+
max_iteration = max(max_iteration, current_iteration) # Track max
|
| 1043 |
+
# Log iteration change
|
| 1044 |
+
from ..utils.clean_logger import get_clean_logger
|
| 1045 |
+
clean_log = get_clean_logger()
|
| 1046 |
+
if current_iteration > clean_log.current_iteration:
|
| 1047 |
+
clean_log.current_iteration = current_iteration
|
| 1048 |
+
|
| 1049 |
+
# Look for pareto front information
|
| 1050 |
+
if 'pareto front' in line.lower() or 'new program' in line.lower():
|
| 1051 |
+
self.logger.info(f"GEPA Pareto Update: {line.strip()}")
|
| 1052 |
+
elif 'iteration' in line.lower() and ('score' in line.lower() or 'program' in line.lower()):
|
| 1053 |
+
self.logger.debug(f"{line.strip()}")
|
| 1054 |
+
elif 'best' in line.lower() and 'score' in line.lower():
|
| 1055 |
+
self.logger.info(f"{line.strip()}")
|
| 1056 |
+
|
| 1057 |
+
# Look for evaluation information
|
| 1058 |
+
if 'evaluating' in line.lower() and 'candidate' in line.lower():
|
| 1059 |
+
self.logger.debug(f"{line.strip()}")
|
| 1060 |
+
|
| 1061 |
+
self.logger.info(f"GEPA Optimization Complete: {max_iteration} iterations")
|
| 1062 |
+
|
| 1063 |
+
# #region agent log
|
| 1064 |
+
import json as _json_debug
|
| 1065 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 1066 |
+
with open(_debug_log_path, "a") as _f:
|
| 1067 |
+
_f.write(_json_debug.dumps({"hypothesisId": "F", "location": "optimizer.py:gepa_complete", "message": "GEPA optimization complete - iteration count", "data": {"max_iteration_from_logs": max_iteration, "expected_iterations": self.config.max_iterations, "off_by_one": max_iteration != self.config.max_iterations, "gepa_logs_length": len(gepa_logs)}, "timestamp": int(time.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 1068 |
+
# #endregion
|
| 1069 |
+
|
| 1070 |
+
return max_iteration # Return the max iteration count
|
| 1071 |
+
|
| 1072 |
+
def _extract_best_candidate(self, gepa_result: Any) -> Dict[str, str]:
|
| 1073 |
+
"""
|
| 1074 |
+
Extract the best candidate from GEPA Pareto front (single source of truth).
|
| 1075 |
+
|
| 1076 |
+
GEPA Pareto front is the single source of truth because:
|
| 1077 |
+
- All candidates (GEPA reflection, LLEGO crossover, LLEGO mutation) are evaluated on Dpareto
|
| 1078 |
+
- All non-dominated candidates are added to GEPA Pareto front
|
| 1079 |
+
- Therefore, the best candidate MUST be in GEPA Pareto front
|
| 1080 |
+
|
| 1081 |
+
Args:
|
| 1082 |
+
gepa_result: Raw result from gepa.optimize() (used only as fallback edge case)
|
| 1083 |
+
|
| 1084 |
+
Returns:
|
| 1085 |
+
Best candidate dictionary with prompt components from GEPA Pareto front
|
| 1086 |
+
"""
|
| 1087 |
+
try:
|
| 1088 |
+
self.logger.info(f"\n{'═'*80}")
|
| 1089 |
+
self.logger.info(f"🔍 EXTRACTING BEST CANDIDATE FROM GEPA PARETO FRONT")
|
| 1090 |
+
self.logger.info(f"{'═'*80}")
|
| 1091 |
+
|
| 1092 |
+
# ========================================================================
|
| 1093 |
+
# PRIMARY: Get best candidate from GEPA Pareto front (single source of truth)
|
| 1094 |
+
# ========================================================================
|
| 1095 |
+
from ..utils.pareto_logger import get_pareto_logger
|
| 1096 |
+
pareto_log = get_pareto_logger()
|
| 1097 |
+
|
| 1098 |
+
if pareto_log.pareto_front:
|
| 1099 |
+
try:
|
| 1100 |
+
# Get best candidate from GEPA Pareto front (highest score = best)
|
| 1101 |
+
gepa_pareto_best = max(pareto_log.pareto_front, key=lambda x: x['score'])
|
| 1102 |
+
gepa_pareto_fitness = gepa_pareto_best['score']
|
| 1103 |
+
gepa_pareto_prompt = gepa_pareto_best['prompt']
|
| 1104 |
+
gepa_pareto_type = gepa_pareto_best.get('type', 'unknown')
|
| 1105 |
+
gepa_pareto_notation = gepa_pareto_best.get('notation', 'S')
|
| 1106 |
+
|
| 1107 |
+
best_candidate = {
|
| 1108 |
+
'system_prompt': gepa_pareto_prompt,
|
| 1109 |
+
'fitness': gepa_pareto_fitness,
|
| 1110 |
+
'source': 'gepa_pareto_front',
|
| 1111 |
+
'candidate_type': gepa_pareto_type,
|
| 1112 |
+
'notation': gepa_pareto_notation
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
self.logger.info(f"✅ SELECTED: Best candidate from GEPA Pareto front")
|
| 1116 |
+
self.logger.info(f" Notation: {gepa_pareto_notation}")
|
| 1117 |
+
self.logger.info(f" Fitness: f({gepa_pareto_notation})={gepa_pareto_fitness:.4f}")
|
| 1118 |
+
self.logger.info(f" Type: {gepa_pareto_type}")
|
| 1119 |
+
self.logger.info(f" Prompt length: {len(gepa_pareto_prompt)} chars")
|
| 1120 |
+
self.logger.info(f" 💡 GEPA Pareto front is single source of truth (all candidates evaluated on Dpareto)")
|
| 1121 |
+
|
| 1122 |
+
return best_candidate
|
| 1123 |
+
|
| 1124 |
+
except Exception as e:
|
| 1125 |
+
self.logger.error(f"❌ Failed to extract from GEPA Pareto front: {e}")
|
| 1126 |
+
import traceback
|
| 1127 |
+
self.logger.error(traceback.format_exc())
|
| 1128 |
+
|
| 1129 |
+
# ========================================================================
|
| 1130 |
+
# EDGE CASE FALLBACK: Pareto front empty (shouldn't happen, but handle gracefully)
|
| 1131 |
+
# ========================================================================
|
| 1132 |
+
self.logger.warning(f"⚠️ GEPA Pareto front is empty - using gepa_result as fallback")
|
| 1133 |
+
self.logger.warning(f" This should not happen if all candidates are evaluated on Dpareto")
|
| 1134 |
+
|
| 1135 |
+
# Try to extract from gepa_result (last resort)
|
| 1136 |
+
if hasattr(gepa_result, 'best_candidate'):
|
| 1137 |
+
gepa_candidate = gepa_result.best_candidate
|
| 1138 |
+
gepa_prompt = gepa_candidate.get('system_prompt') if isinstance(gepa_candidate, dict) else str(gepa_candidate)
|
| 1139 |
+
gepa_fitness = getattr(gepa_result, 'best_score', None)
|
| 1140 |
+
|
| 1141 |
+
if gepa_prompt:
|
| 1142 |
+
self.logger.info(f"✅ Using gepa_result.best_candidate as fallback")
|
| 1143 |
+
return {
|
| 1144 |
+
'system_prompt': gepa_prompt,
|
| 1145 |
+
'fitness': float(gepa_fitness) if gepa_fitness is not None else None,
|
| 1146 |
+
'source': 'gepa_result_fallback',
|
| 1147 |
+
'candidate_type': 'unknown',
|
| 1148 |
+
'notation': 'S'
|
| 1149 |
+
}
|
| 1150 |
+
|
| 1151 |
+
# Last resort: return empty prompt
|
| 1152 |
+
self.logger.error(f"❌ No candidates found anywhere - returning empty prompt")
|
| 1153 |
+
return {'system_prompt': ''}
|
| 1154 |
+
|
| 1155 |
+
except Exception as e:
|
| 1156 |
+
self.logger.error(f"❌ Error extracting best candidate: {e}")
|
| 1157 |
+
import traceback
|
| 1158 |
+
self.logger.error(traceback.format_exc())
|
| 1159 |
+
return {'system_prompt': ''}
|
| 1160 |
+
|
| 1161 |
+
def _evaluate_candidate_on_testset(
|
| 1162 |
+
self,
|
| 1163 |
+
candidate: Dict[str, str],
|
| 1164 |
+
testset: List[Dict]
|
| 1165 |
+
) -> float:
|
| 1166 |
+
"""
|
| 1167 |
+
Evaluate a candidate prompt on the held-out test set.
|
| 1168 |
+
|
| 1169 |
+
Args:
|
| 1170 |
+
candidate: Prompt candidate to evaluate
|
| 1171 |
+
testset: Test dataset (not used during optimization)
|
| 1172 |
+
|
| 1173 |
+
Returns:
|
| 1174 |
+
Average composite score on test set
|
| 1175 |
+
|
| 1176 |
+
Raises:
|
| 1177 |
+
TestSetEvaluationError: If evaluation fails
|
| 1178 |
+
"""
|
| 1179 |
+
from ..utils.exceptions import TestSetEvaluationError
|
| 1180 |
+
|
| 1181 |
+
try:
|
| 1182 |
+
# Evaluate using the adapter (same as GEPA does internally)
|
| 1183 |
+
eval_result = self.adapter.evaluate(
|
| 1184 |
+
batch=testset,
|
| 1185 |
+
candidate=candidate,
|
| 1186 |
+
capture_traces=False # Don't need detailed traces for test
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
if not eval_result.scores:
|
| 1190 |
+
raise TestSetEvaluationError("No scores returned from test evaluation")
|
| 1191 |
+
|
| 1192 |
+
# Calculate average score
|
| 1193 |
+
avg_score = sum(eval_result.scores) / len(eval_result.scores)
|
| 1194 |
+
|
| 1195 |
+
self.logger.debug(
|
| 1196 |
+
f"Test set evaluation: {len(eval_result.scores)} samples, "
|
| 1197 |
+
f"scores: {eval_result.scores}, avg: {avg_score:.4f}"
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
return avg_score
|
| 1201 |
+
|
| 1202 |
+
except Exception as e:
|
| 1203 |
+
raise TestSetEvaluationError(f"Failed to evaluate on test set: {str(e)}")
|
| 1204 |
+
|
| 1205 |
+
def optimize_sync(self,
|
| 1206 |
+
model: str,
|
| 1207 |
+
seed_prompt: str,
|
| 1208 |
+
dataset: Any,
|
| 1209 |
+
reflection_lm: str,
|
| 1210 |
+
max_metric_calls: int = 150,
|
| 1211 |
+
**kwargs) -> OptimizedResult:
|
| 1212 |
+
"""
|
| 1213 |
+
Synchronous version of the optimization method
|
| 1214 |
+
|
| 1215 |
+
Args:
|
| 1216 |
+
model: Target model to optimize for
|
| 1217 |
+
seed_prompt: Initial prompt to optimize
|
| 1218 |
+
dataset: Training data in any format
|
| 1219 |
+
reflection_lm: Model for reflection
|
| 1220 |
+
max_metric_calls: Budget for optimization attempts
|
| 1221 |
+
**kwargs: Additional optimization parameters
|
| 1222 |
+
|
| 1223 |
+
Returns:
|
| 1224 |
+
OptimizedResult: Optimization result
|
| 1225 |
+
"""
|
| 1226 |
+
# Run the async method in a new event loop
|
| 1227 |
+
loop = asyncio.new_event_loop()
|
| 1228 |
+
asyncio.set_event_loop(loop)
|
| 1229 |
+
|
| 1230 |
+
try:
|
| 1231 |
+
result = loop.run_until_complete(
|
| 1232 |
+
self.train(model, seed_prompt, dataset, reflection_lm, max_metric_calls, **kwargs)
|
| 1233 |
+
)
|
| 1234 |
+
return result
|
| 1235 |
+
finally:
|
| 1236 |
+
loop.close()
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
# Convenience function for quick optimization
|
| 1240 |
+
def optimize_prompt(
|
| 1241 |
+
model: Union[str, ModelConfig],
|
| 1242 |
+
seed_prompt: str,
|
| 1243 |
+
dataset: Any,
|
| 1244 |
+
reflection_model: Optional[Union[str, ModelConfig]] = None,
|
| 1245 |
+
**kwargs
|
| 1246 |
+
) -> OptimizedResult:
|
| 1247 |
+
"""
|
| 1248 |
+
Convenience function for quick prompt optimization without creating optimizer instance
|
| 1249 |
+
|
| 1250 |
+
Args:
|
| 1251 |
+
model: Target model configuration
|
| 1252 |
+
seed_prompt: Initial prompt to optimize
|
| 1253 |
+
dataset: Training data
|
| 1254 |
+
reflection_model: Model for reflection (optional)
|
| 1255 |
+
**kwargs: Additional optimization parameters
|
| 1256 |
+
|
| 1257 |
+
Returns:
|
| 1258 |
+
OptimizedResult: Optimization result
|
| 1259 |
+
"""
|
| 1260 |
+
# Create default config if not provided
|
| 1261 |
+
if reflection_model is None:
|
| 1262 |
+
reflection_model = model
|
| 1263 |
+
|
| 1264 |
+
config = OptimizationConfig(
|
| 1265 |
+
model=model,
|
| 1266 |
+
reflection_model=reflection_model,
|
| 1267 |
+
max_iterations=kwargs.get('max_iterations', 10),
|
| 1268 |
+
max_metric_calls=kwargs.get('max_metric_calls', 50),
|
| 1269 |
+
batch_size=kwargs.get('batch_size', 4)
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
optimizer = GepaOptimizer(config=config)
|
| 1273 |
+
return asyncio.run(optimizer.train(seed_prompt, dataset, **kwargs))
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
|
| 1279 |
+
|
src/gepa_optimizer/core/result.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Result processing for GEPA Optimizer
|
| 3 |
+
Handles extraction and processing of GEPA optimization results
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class ResultProcessor:
|
| 12 |
+
"""
|
| 13 |
+
Processes raw GEPA optimization results into clean, usable formats
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def extract_optimized_prompt(result: Any) -> str:
|
| 18 |
+
"""
|
| 19 |
+
Extract the optimized prompt from GEPA result object
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
result: Raw GEPA optimization result
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
str: The optimized prompt text
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
# Try multiple possible result structures
|
| 29 |
+
if hasattr(result, 'best_candidate'):
|
| 30 |
+
candidate = result.best_candidate
|
| 31 |
+
|
| 32 |
+
if isinstance(candidate, dict):
|
| 33 |
+
# Try common prompt keys
|
| 34 |
+
for key in ['system_prompt', 'prompt', 'text']:
|
| 35 |
+
if key in candidate:
|
| 36 |
+
return str(candidate[key])
|
| 37 |
+
|
| 38 |
+
# If no standard key found, return string representation
|
| 39 |
+
return str(candidate)
|
| 40 |
+
else:
|
| 41 |
+
return str(candidate)
|
| 42 |
+
|
| 43 |
+
# Fallback - convert entire result to string
|
| 44 |
+
return str(result)
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.warning(f"Failed to extract optimized prompt: {e}")
|
| 48 |
+
return "Optimization completed (prompt extraction failed)"
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def extract_metrics(result: Any) -> Dict[str, Any]:
|
| 52 |
+
"""
|
| 53 |
+
Extract performance metrics from GEPA result
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
result: Raw GEPA optimization result
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Dict[str, Any]: Extracted metrics
|
| 60 |
+
"""
|
| 61 |
+
metrics = {}
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Extract common metrics
|
| 65 |
+
if hasattr(result, 'best_score'):
|
| 66 |
+
metrics['best_score'] = float(result.best_score)
|
| 67 |
+
|
| 68 |
+
if hasattr(result, 'baseline_score'):
|
| 69 |
+
metrics['baseline_score'] = float(result.baseline_score)
|
| 70 |
+
|
| 71 |
+
if hasattr(result, 'improvement'):
|
| 72 |
+
metrics['improvement'] = float(result.improvement)
|
| 73 |
+
|
| 74 |
+
if hasattr(result, 'iterations'):
|
| 75 |
+
metrics['iterations'] = int(result.iterations)
|
| 76 |
+
|
| 77 |
+
# Calculate improvement percentage if we have both scores
|
| 78 |
+
if 'best_score' in metrics and 'baseline_score' in metrics:
|
| 79 |
+
baseline = metrics['baseline_score']
|
| 80 |
+
if baseline > 0:
|
| 81 |
+
improvement_percent = ((metrics['best_score'] - baseline) / baseline) * 100
|
| 82 |
+
metrics['improvement_percent'] = round(improvement_percent, 2)
|
| 83 |
+
|
| 84 |
+
# Extract additional metadata
|
| 85 |
+
if hasattr(result, 'metadata'):
|
| 86 |
+
metrics['metadata'] = result.metadata
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.warning(f"Failed to extract metrics: {e}")
|
| 90 |
+
|
| 91 |
+
return metrics
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def extract_reflection_history(result: Any) -> list:
|
| 95 |
+
"""
|
| 96 |
+
Extract reflection/optimization history from GEPA result
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
result: Raw GEPA optimization result
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
list: List of reflection iterations
|
| 103 |
+
"""
|
| 104 |
+
history = []
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
if hasattr(result, 'optimization_history'):
|
| 108 |
+
for i, iteration in enumerate(result.optimization_history):
|
| 109 |
+
history_item = {
|
| 110 |
+
'iteration': i,
|
| 111 |
+
'score': iteration.get('score', 0.0),
|
| 112 |
+
'candidate': iteration.get('candidate', {}),
|
| 113 |
+
'feedback': iteration.get('feedback', ''),
|
| 114 |
+
'improvement': iteration.get('improvement', 0.0)
|
| 115 |
+
}
|
| 116 |
+
history.append(history_item)
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.warning(f"Failed to extract reflection history: {e}")
|
| 120 |
+
|
| 121 |
+
return history
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def process_full_result(
|
| 125 |
+
result: Any,
|
| 126 |
+
original_prompt: str,
|
| 127 |
+
optimization_time: float,
|
| 128 |
+
actual_iterations: Optional[int] = None,
|
| 129 |
+
test_metrics: Optional[Dict[str, Any]] = None
|
| 130 |
+
) -> Dict[str, Any]:
|
| 131 |
+
"""
|
| 132 |
+
Process complete GEPA result into structured format.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
result: Raw GEPA optimization result
|
| 136 |
+
original_prompt: Original seed prompt
|
| 137 |
+
optimization_time: Time taken for optimization
|
| 138 |
+
actual_iterations: Actual number of iterations from GEPA logs (optional)
|
| 139 |
+
test_metrics: Metrics from test set evaluation (optional)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Dict[str, Any]: Complete processed result
|
| 143 |
+
"""
|
| 144 |
+
# Extract metrics first
|
| 145 |
+
metrics = ResultProcessor.extract_metrics(result)
|
| 146 |
+
|
| 147 |
+
# Extract iterations from GEPA result
|
| 148 |
+
total_iterations = 0
|
| 149 |
+
try:
|
| 150 |
+
# First priority: use actual_iterations if provided (from logs)
|
| 151 |
+
if actual_iterations is not None:
|
| 152 |
+
total_iterations = actual_iterations
|
| 153 |
+
elif hasattr(result, 'iterations'):
|
| 154 |
+
total_iterations = int(result.iterations)
|
| 155 |
+
elif hasattr(result, 'num_iterations'):
|
| 156 |
+
total_iterations = int(result.num_iterations)
|
| 157 |
+
elif hasattr(result, 'optimization_history'):
|
| 158 |
+
total_iterations = len(result.optimization_history)
|
| 159 |
+
# Check if it's in metrics
|
| 160 |
+
elif 'iterations' in metrics:
|
| 161 |
+
total_iterations = metrics['iterations']
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.warning(f"Failed to extract iterations: {e}")
|
| 164 |
+
|
| 165 |
+
# Merge test metrics into improvement_data
|
| 166 |
+
improvement_data = {}
|
| 167 |
+
if test_metrics:
|
| 168 |
+
improvement_data.update(test_metrics)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'original_prompt': original_prompt,
|
| 172 |
+
'optimized_prompt': ResultProcessor.extract_optimized_prompt(result),
|
| 173 |
+
'metrics': metrics,
|
| 174 |
+
'improvement_data': improvement_data,
|
| 175 |
+
'reflection_history': ResultProcessor.extract_reflection_history(result),
|
| 176 |
+
'optimization_time': optimization_time,
|
| 177 |
+
'total_iterations': total_iterations,
|
| 178 |
+
'status': 'completed',
|
| 179 |
+
'raw_result': result # Keep raw result for advanced users
|
| 180 |
+
}
|
src/gepa_optimizer/core/universal_adapter.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/gepa_optimizer/data/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data module for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .converters import UniversalConverter
|
| 6 |
+
from .loaders import DataLoader
|
| 7 |
+
from .validators import DataValidator
|
| 8 |
+
from .scroll_dataset_loader import ScrollDatasetLoader, load_scroll_dataset
|
| 9 |
+
from .validation_dataset_loader import ValidationDatasetLoader, load_validation_dataset, load_validation_split
|
| 10 |
+
from .index_caching_loader import IndexCachingDatasetLoader, load_index_caching_dataset, load_index_caching_split
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"UniversalConverter",
|
| 14 |
+
"DataLoader",
|
| 15 |
+
"DataValidator",
|
| 16 |
+
# Scroll dataset
|
| 17 |
+
"ScrollDatasetLoader",
|
| 18 |
+
"load_scroll_dataset",
|
| 19 |
+
# Validation dataset
|
| 20 |
+
"ValidationDatasetLoader",
|
| 21 |
+
"load_validation_dataset",
|
| 22 |
+
"load_validation_split",
|
| 23 |
+
# Index caching dataset
|
| 24 |
+
"IndexCachingDatasetLoader",
|
| 25 |
+
"load_index_caching_dataset",
|
| 26 |
+
"load_index_caching_split",
|
| 27 |
+
]
|
src/gepa_optimizer/data/converters.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Universal converter for dataset to GEPA format with 3-way split (train/val/test)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, List, Tuple, Union, Dict, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from .loaders import DataLoader
|
| 13 |
+
from ..utils.exceptions import DatasetError
|
| 14 |
+
from ..models.config import DataSplitConfig
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class UniversalConverter:
|
| 19 |
+
"""
|
| 20 |
+
Universal converter for datasets to GEPA format.
|
| 21 |
+
|
| 22 |
+
Handles 3-way splitting (train/val/test) with configurable ratios and
|
| 23 |
+
graceful handling of small datasets.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, data_split_config: Optional[DataSplitConfig] = None):
|
| 27 |
+
"""
|
| 28 |
+
Initialize converter with optional split configuration.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
data_split_config: Configuration for train/val/test splits.
|
| 32 |
+
If None, uses default 60/20/20 split.
|
| 33 |
+
"""
|
| 34 |
+
self.supported_extensions = [
|
| 35 |
+
'.csv', '.json', '.jsonl', '.txt', '.md',
|
| 36 |
+
'.png', '.jpg', '.jpeg'
|
| 37 |
+
]
|
| 38 |
+
self.loader = DataLoader()
|
| 39 |
+
self.data_split_config = data_split_config or DataSplitConfig()
|
| 40 |
+
|
| 41 |
+
def convert(
|
| 42 |
+
self,
|
| 43 |
+
dataset: Union[List[Any], str, Any, Dict[str, Any]],
|
| 44 |
+
split_config: Optional[DataSplitConfig] = None
|
| 45 |
+
) -> Tuple[List[dict], List[dict], List[dict]]:
|
| 46 |
+
"""
|
| 47 |
+
Convert any dataset to GEPA format with 3-way split (train/val/test).
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
dataset: Input dataset in any supported format
|
| 51 |
+
split_config: Optional split configuration (overrides instance config)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Tuple of (trainset, valset, testset) where:
|
| 55 |
+
- trainset: Used for reflection/feedback (Dfeedback in GEPA paper)
|
| 56 |
+
- valset: Used for Pareto selection (Dpareto in GEPA paper)
|
| 57 |
+
- testset: Held-out for final evaluation (not passed to GEPA)
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
DatasetError: If dataset cannot be converted or is too small
|
| 61 |
+
"""
|
| 62 |
+
try:
|
| 63 |
+
# Use provided split config or instance default
|
| 64 |
+
config = split_config or self.data_split_config
|
| 65 |
+
|
| 66 |
+
# Handle UI tree dataset format
|
| 67 |
+
if isinstance(dataset, dict) and 'type' in dataset and dataset['type'] == 'ui_tree_dataset':
|
| 68 |
+
return self.convert_ui_tree_dataset(
|
| 69 |
+
dataset.get('json_dir', 'json_tree'),
|
| 70 |
+
dataset.get('screenshots_dir', 'screenshots'),
|
| 71 |
+
split_config=config
|
| 72 |
+
)
|
| 73 |
+
elif isinstance(dataset, str):
|
| 74 |
+
data = self._load_from_path(dataset)
|
| 75 |
+
elif hasattr(dataset, 'to_dict'): # pandas DataFrame
|
| 76 |
+
data = dataset.to_dict(orient='records')
|
| 77 |
+
elif isinstance(dataset, list):
|
| 78 |
+
data = dataset
|
| 79 |
+
else:
|
| 80 |
+
data = [dataset]
|
| 81 |
+
|
| 82 |
+
logger.info(f"Normalized data length: {len(data)}")
|
| 83 |
+
standardized = self._standardize(data)
|
| 84 |
+
train, val, test = self._split_three_way(standardized, config)
|
| 85 |
+
return train, val, test
|
| 86 |
+
except (FileNotFoundError, ValueError, TypeError) as e:
|
| 87 |
+
raise DatasetError(f"Failed to convert dataset: {str(e)}")
|
| 88 |
+
|
| 89 |
+
def _load_from_path(self, path: str) -> List[Any]:
|
| 90 |
+
"""Load data from file path"""
|
| 91 |
+
p = Path(path)
|
| 92 |
+
if not p.exists():
|
| 93 |
+
raise FileNotFoundError(f"File not found: {path}")
|
| 94 |
+
|
| 95 |
+
ext = p.suffix.lower()
|
| 96 |
+
if ext in self.supported_extensions:
|
| 97 |
+
return [self.loader.load(p)]
|
| 98 |
+
else:
|
| 99 |
+
raise DatasetError(f"Unsupported file extension: {ext}")
|
| 100 |
+
|
| 101 |
+
def _standardize(self, data: List[Any]) -> List[dict]:
|
| 102 |
+
"""Standardize data to input/output format
|
| 103 |
+
|
| 104 |
+
Handles both UI tree JSON format and simple text inputs.
|
| 105 |
+
UI tree format should have: {'screenshot': str, 'ui_tree': dict, 'expected_output': str}
|
| 106 |
+
Simple format can be: {'input': str, 'output': str} or {'question': str, 'answer': str} etc.
|
| 107 |
+
"""
|
| 108 |
+
out = []
|
| 109 |
+
for item in data:
|
| 110 |
+
if not isinstance(item, dict):
|
| 111 |
+
item = {'input': str(item)}
|
| 112 |
+
|
| 113 |
+
# Handle UI tree JSON format
|
| 114 |
+
if 'ui_tree' in item and 'screenshot' in item:
|
| 115 |
+
ui_tree = item['ui_tree']
|
| 116 |
+
input_text = ui_tree.get('text', '')
|
| 117 |
+
output_text = item.get('expected_output', '')
|
| 118 |
+
image = item.get('screenshot', '')
|
| 119 |
+
out.append({'input': input_text, 'output': output_text, 'image': image})
|
| 120 |
+
# Handle simple text format
|
| 121 |
+
else:
|
| 122 |
+
inp = self._extract(item, ['input', 'question', 'text', 'prompt']) or ''
|
| 123 |
+
outp = self._extract(item, ['output', 'result', 'response', 'answer', 'expected_output']) or ''
|
| 124 |
+
image = self._extract(item, ['image', 'image_base64', 'screenshot']) or ''
|
| 125 |
+
out.append({'input': inp, 'output': outp, 'image': image})
|
| 126 |
+
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
def _extract(self, d: dict, keys: List[str]) -> Union[str, None]:
|
| 130 |
+
"""Extract value by trying multiple keys"""
|
| 131 |
+
for k in keys:
|
| 132 |
+
if k in d:
|
| 133 |
+
return d[k]
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
def _split_three_way(
|
| 137 |
+
self,
|
| 138 |
+
data: List[dict],
|
| 139 |
+
config: DataSplitConfig
|
| 140 |
+
) -> Tuple[List[dict], List[dict], List[dict]]:
|
| 141 |
+
"""
|
| 142 |
+
Split data into train, validation, and test sets.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
data: Standardized dataset
|
| 146 |
+
config: Split configuration with ratios and strategies
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Tuple of (train, val, test) datasets
|
| 150 |
+
|
| 151 |
+
Raises:
|
| 152 |
+
ValueError: If dataset is too small for configured splits
|
| 153 |
+
"""
|
| 154 |
+
dataset_size = len(data)
|
| 155 |
+
|
| 156 |
+
# 🔥 NEW: Log adaptive strategy if being used
|
| 157 |
+
if config.small_dataset_strategy == 'adaptive':
|
| 158 |
+
train_ratio, val_ratio, test_ratio = config.get_adaptive_ratios(dataset_size)
|
| 159 |
+
logger.info(
|
| 160 |
+
f"📊 Adaptive dataset splitting (strategy: adaptive, size: {dataset_size}): "
|
| 161 |
+
f"ratios = {train_ratio*100:.0f}%/{val_ratio*100:.0f}%/{test_ratio*100:.0f}% "
|
| 162 |
+
f"(prioritizes validation for reliable candidate ranking)"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Get split indices from config
|
| 166 |
+
try:
|
| 167 |
+
train_end, val_end, test_end, _ = config.get_split_indices(dataset_size)
|
| 168 |
+
except ValueError as e:
|
| 169 |
+
logger.error(f"Dataset split error: {e}")
|
| 170 |
+
raise DatasetError(str(e))
|
| 171 |
+
|
| 172 |
+
# Perform the split
|
| 173 |
+
train = data[:train_end]
|
| 174 |
+
val = data[train_end:val_end]
|
| 175 |
+
test = data[val_end:test_end]
|
| 176 |
+
|
| 177 |
+
# Log split information with strategy
|
| 178 |
+
strategy_note = ""
|
| 179 |
+
if config.small_dataset_strategy == 'adaptive':
|
| 180 |
+
strategy_note = " (adaptive)"
|
| 181 |
+
logger.info(
|
| 182 |
+
f"Dataset split{strategy_note}: {len(train)} train ({len(train)/dataset_size*100:.1f}%), "
|
| 183 |
+
f"{len(val)} val ({len(val)/dataset_size*100:.1f}%), "
|
| 184 |
+
f"{len(test)} test ({len(test)/dataset_size*100:.1f}%)"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Validate splits are not empty
|
| 188 |
+
if len(train) == 0:
|
| 189 |
+
raise DatasetError("Training set is empty after split")
|
| 190 |
+
if len(val) == 0:
|
| 191 |
+
logger.warning("Validation set is empty - this may cause issues with Pareto selection")
|
| 192 |
+
val = [train[-1]] # Use last training sample as fallback
|
| 193 |
+
if len(test) == 0:
|
| 194 |
+
logger.warning("Test set is empty - final evaluation will not be performed")
|
| 195 |
+
|
| 196 |
+
return train, val, test
|
| 197 |
+
|
| 198 |
+
def _split(self, data: List[dict], ratio: float = 0.8) -> Tuple[List[dict], List[dict]]:
|
| 199 |
+
"""
|
| 200 |
+
DEPRECATED: Legacy 2-way split for backwards compatibility.
|
| 201 |
+
|
| 202 |
+
Use _split_three_way() instead for production code.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
data: Standardized dataset
|
| 206 |
+
ratio: Train ratio (0.0-1.0)
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Tuple of (train, val) datasets
|
| 210 |
+
"""
|
| 211 |
+
import warnings
|
| 212 |
+
warnings.warn(
|
| 213 |
+
"_split() is deprecated. Use _split_three_way() for 3-way splitting.",
|
| 214 |
+
DeprecationWarning,
|
| 215 |
+
stacklevel=2
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
split = max(1, int(len(data) * ratio))
|
| 219 |
+
train = data[:split]
|
| 220 |
+
val = data[split:] or data[-1:] # Ensure val is not empty
|
| 221 |
+
return train, val
|
| 222 |
+
|
| 223 |
+
def convert_ui_tree_dataset(
|
| 224 |
+
self,
|
| 225 |
+
json_dir: str,
|
| 226 |
+
screenshots_dir: str,
|
| 227 |
+
split_config: Optional[DataSplitConfig] = None
|
| 228 |
+
) -> Tuple[List[dict], List[dict], List[dict]]:
|
| 229 |
+
"""
|
| 230 |
+
Convert UI tree dataset (JSON + screenshots) to GEPA format with 3-way split.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
json_dir: Directory containing JSON files
|
| 234 |
+
screenshots_dir: Directory containing screenshot images
|
| 235 |
+
split_config: Optional split configuration (overrides instance config)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Tuple of (train_data, val_data, test_data) in GEPA format
|
| 239 |
+
|
| 240 |
+
Raises:
|
| 241 |
+
DatasetError: If dataset cannot be loaded or is invalid
|
| 242 |
+
"""
|
| 243 |
+
try:
|
| 244 |
+
# Load paired dataset
|
| 245 |
+
dataset = self.loader.load_ui_tree_dataset(json_dir, screenshots_dir)
|
| 246 |
+
|
| 247 |
+
if not dataset:
|
| 248 |
+
raise DatasetError("No valid image-JSON pairs found")
|
| 249 |
+
|
| 250 |
+
logger.info(f"Loaded {len(dataset)} UI tree samples")
|
| 251 |
+
|
| 252 |
+
# Use provided config or instance default
|
| 253 |
+
config = split_config or self.data_split_config
|
| 254 |
+
|
| 255 |
+
# Split into train/val/test
|
| 256 |
+
train, val, test = self._split_three_way(dataset, config)
|
| 257 |
+
|
| 258 |
+
logger.info(
|
| 259 |
+
f"Split UI tree dataset: {len(train)} train, "
|
| 260 |
+
f"{len(val)} validation, {len(test)} test"
|
| 261 |
+
)
|
| 262 |
+
return train, val, test
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
raise DatasetError(f"Failed to convert UI tree dataset: {str(e)}")
|
src/gepa_optimizer/data/index_caching_loader.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Index Caching Dataset Loader
|
| 3 |
+
|
| 4 |
+
Loads index caching dataset from JSON file (note2_debug.json format) and converts to GEPA-compatible format.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import base64
|
| 10 |
+
import logging
|
| 11 |
+
from typing import List, Dict, Any, Optional
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class IndexCachingDatasetLoader:
|
| 18 |
+
"""
|
| 19 |
+
Loads index caching dataset from JSON file.
|
| 20 |
+
|
| 21 |
+
Expected JSON format:
|
| 22 |
+
[
|
| 23 |
+
{
|
| 24 |
+
"command": "Tap on first option from the suggestion",
|
| 25 |
+
"image": "element_images/QMxgc_14_0_tap_IkALe_element.png",
|
| 26 |
+
"xml": "xml/IkALe__debug.xml",
|
| 27 |
+
"expected": {
|
| 28 |
+
"is_index_based": true,
|
| 29 |
+
"index_value": 1,
|
| 30 |
+
"parent_element_id": "aaaabf",
|
| 31 |
+
"element_id_of_nth_child_of_parent": "aaaabg",
|
| 32 |
+
"selected_element_is_correct": true
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
...
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
Converts to GEPA format:
|
| 39 |
+
- input: command text (seed prompt will be provided in test script)
|
| 40 |
+
- output: JSON string with expected values
|
| 41 |
+
- image_base64: base64 encoded image (TOP LEVEL for UniversalConverter)
|
| 42 |
+
- input: Command + XML content (combined in user prompt)
|
| 43 |
+
- metadata: All original fields plus converted values
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, json_path: Optional[str] = None, base_dir: Optional[str] = None):
|
| 47 |
+
"""
|
| 48 |
+
Initialize index caching dataset loader.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
json_path: Path to JSON file. Default: "./note2_debug.json" or from env var
|
| 52 |
+
base_dir: Base directory for resolving relative paths in JSON.
|
| 53 |
+
Default: Directory containing JSON file
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
FileNotFoundError: If JSON file doesn't exist
|
| 57 |
+
json.JSONDecodeError: If JSON file is invalid
|
| 58 |
+
"""
|
| 59 |
+
# Get JSON path from env or use default
|
| 60 |
+
if json_path is None:
|
| 61 |
+
json_path = os.getenv("INDEX_CACHING_DATASET_PATH", "./note2_debug.json")
|
| 62 |
+
|
| 63 |
+
self.json_path = Path(json_path).resolve()
|
| 64 |
+
|
| 65 |
+
if not self.json_path.exists():
|
| 66 |
+
raise FileNotFoundError(
|
| 67 |
+
f"Dataset file not found: {self.json_path}\n"
|
| 68 |
+
f"Make sure note2_debug.json exists in the project root."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Base directory for resolving relative paths
|
| 72 |
+
if base_dir is None:
|
| 73 |
+
base_dir = self.json_path.parent
|
| 74 |
+
self.base_dir = Path(base_dir).resolve()
|
| 75 |
+
|
| 76 |
+
def load_dataset(self) -> List[Dict[str, Any]]:
|
| 77 |
+
"""
|
| 78 |
+
Load dataset from JSON file and convert to GEPA format.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of dataset items in GEPA format:
|
| 82 |
+
[
|
| 83 |
+
{
|
| 84 |
+
"input": "Tap on first option from the suggestion", # Command only
|
| 85 |
+
"output": '{"is_index_based": true, "index_value": 1, ...}', # Expected JSON
|
| 86 |
+
"image_base64": "<base64_encoded_image>", # TOP LEVEL
|
| 87 |
+
"metadata": {
|
| 88 |
+
"command": "...",
|
| 89 |
+
"image_path": "...",
|
| 90 |
+
"xml_path": "...",
|
| 91 |
+
"expected": {...}
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
...
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
FileNotFoundError: If image or XML file doesn't exist
|
| 99 |
+
json.JSONDecodeError: If JSON file is invalid
|
| 100 |
+
"""
|
| 101 |
+
# Load JSON file
|
| 102 |
+
with open(self.json_path, "r", encoding="utf-8") as f:
|
| 103 |
+
dataset = json.load(f)
|
| 104 |
+
|
| 105 |
+
gepa_dataset = []
|
| 106 |
+
|
| 107 |
+
for idx, entry in enumerate(dataset):
|
| 108 |
+
command = entry.get("command", "")
|
| 109 |
+
image_path = entry.get("image", "")
|
| 110 |
+
xml_path = entry.get("xml", "")
|
| 111 |
+
expected = entry.get("expected", {})
|
| 112 |
+
|
| 113 |
+
# Resolve paths relative to base_dir
|
| 114 |
+
abs_image_path = (self.base_dir / image_path).resolve()
|
| 115 |
+
abs_xml_path = (self.base_dir / xml_path).resolve()
|
| 116 |
+
|
| 117 |
+
# Validate paths
|
| 118 |
+
if not abs_image_path.exists():
|
| 119 |
+
raise FileNotFoundError(
|
| 120 |
+
f"Image file not found: {abs_image_path}\n"
|
| 121 |
+
f"Entry {idx + 1}: {command}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if not abs_xml_path.exists():
|
| 125 |
+
raise FileNotFoundError(
|
| 126 |
+
f"XML file not found: {abs_xml_path}\n"
|
| 127 |
+
f"Entry {idx + 1}: {command}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Load and encode image
|
| 131 |
+
with open(abs_image_path, "rb") as f:
|
| 132 |
+
image_data = f.read()
|
| 133 |
+
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
| 134 |
+
|
| 135 |
+
# Load XML content
|
| 136 |
+
with open(abs_xml_path, "r", encoding="utf-8") as f:
|
| 137 |
+
xml_content = f.read()
|
| 138 |
+
|
| 139 |
+
# Convert expected to JSON string
|
| 140 |
+
expected_json = json.dumps(expected, ensure_ascii=False)
|
| 141 |
+
|
| 142 |
+
# Create user prompt with command + XML content
|
| 143 |
+
# The XML will be included in the user prompt text (as the agent does)
|
| 144 |
+
user_prompt = f"{command}\n\nXML Content:\n\n```xml\n{xml_content}\n```"
|
| 145 |
+
|
| 146 |
+
# For reflection, we don't need full XML - just the command is enough
|
| 147 |
+
# Reflection is about improving the prompt based on evaluation feedback,
|
| 148 |
+
# not analyzing specific XML structures
|
| 149 |
+
reflection_input = command # Just the command, no XML
|
| 150 |
+
|
| 151 |
+
# Create GEPA format item
|
| 152 |
+
gepa_item = {
|
| 153 |
+
"input": user_prompt, # Command + XML content (for evaluation)
|
| 154 |
+
"reflection_input": reflection_input, # Just command (for reflection)
|
| 155 |
+
"output": expected_json, # Expected output as JSON string
|
| 156 |
+
"image_base64": image_base64, # TOP LEVEL for UniversalConverter
|
| 157 |
+
"metadata": {
|
| 158 |
+
"command": command,
|
| 159 |
+
"image_path": str(image_path),
|
| 160 |
+
"xml_path": str(xml_path),
|
| 161 |
+
"abs_image_path": str(abs_image_path),
|
| 162 |
+
"abs_xml_path": str(abs_xml_path),
|
| 163 |
+
"xml_content": xml_content, # Store XML separately in metadata
|
| 164 |
+
"expected": expected,
|
| 165 |
+
"dataset_index": idx
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
gepa_dataset.append(gepa_item)
|
| 170 |
+
|
| 171 |
+
return gepa_dataset
|
| 172 |
+
|
| 173 |
+
def load_split(
|
| 174 |
+
self,
|
| 175 |
+
train_ratio: float = 0.6,
|
| 176 |
+
val_ratio: float = 0.4
|
| 177 |
+
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 178 |
+
"""
|
| 179 |
+
Load dataset and split into train/val sets (no test set).
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
train_ratio: Ratio for training set (default: 0.6)
|
| 183 |
+
val_ratio: Ratio for validation set (default: 0.4)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Tuple of (train_set, val_set)
|
| 187 |
+
|
| 188 |
+
Raises:
|
| 189 |
+
ValueError: If ratios don't sum to 1.0
|
| 190 |
+
"""
|
| 191 |
+
if abs(train_ratio + val_ratio - 1.0) > 0.01:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"Split ratios must sum to 1.0, got {train_ratio + val_ratio:.3f}"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
dataset = self.load_dataset()
|
| 197 |
+
total = len(dataset)
|
| 198 |
+
|
| 199 |
+
train_end = int(total * train_ratio)
|
| 200 |
+
|
| 201 |
+
train_set = dataset[:train_end]
|
| 202 |
+
val_set = dataset[train_end:]
|
| 203 |
+
|
| 204 |
+
return train_set, val_set
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def load_index_caching_dataset(
|
| 208 |
+
json_path: Optional[str] = None,
|
| 209 |
+
base_dir: Optional[str] = None
|
| 210 |
+
) -> List[Dict[str, Any]]:
|
| 211 |
+
"""
|
| 212 |
+
Convenience function to load index caching dataset.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
json_path: Path to JSON file
|
| 216 |
+
base_dir: Base directory for resolving relative paths
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
List of dataset items in GEPA format
|
| 220 |
+
"""
|
| 221 |
+
loader = IndexCachingDatasetLoader(json_path=json_path, base_dir=base_dir)
|
| 222 |
+
return loader.load_dataset()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_index_caching_split(
|
| 226 |
+
json_path: Optional[str] = None,
|
| 227 |
+
base_dir: Optional[str] = None,
|
| 228 |
+
train_ratio: float = 0.6,
|
| 229 |
+
val_ratio: float = 0.4
|
| 230 |
+
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 231 |
+
"""
|
| 232 |
+
Convenience function to load and split index caching dataset.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
json_path: Path to JSON file
|
| 236 |
+
base_dir: Base directory for resolving relative paths
|
| 237 |
+
train_ratio: Ratio for training set
|
| 238 |
+
val_ratio: Ratio for validation set
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Tuple of (train_set, val_set) - no test set
|
| 242 |
+
"""
|
| 243 |
+
loader = IndexCachingDatasetLoader(json_path=json_path, base_dir=base_dir)
|
| 244 |
+
return loader.load_split(train_ratio=train_ratio, val_ratio=val_ratio)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# Example usage
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
print("🚀 Testing Index Caching Dataset Loader...")
|
| 250 |
+
|
| 251 |
+
# Test loading
|
| 252 |
+
try:
|
| 253 |
+
loader = IndexCachingDatasetLoader(json_path="./note2_debug.json")
|
| 254 |
+
dataset = loader.load_dataset()
|
| 255 |
+
|
| 256 |
+
print(f"\n✅ Loaded {len(dataset)} items")
|
| 257 |
+
|
| 258 |
+
# Show sample
|
| 259 |
+
if dataset:
|
| 260 |
+
sample = dataset[0]
|
| 261 |
+
print(f"\n📝 Sample Item:")
|
| 262 |
+
print(f" Command: {sample['input']}")
|
| 263 |
+
print(f" Image path: {sample['metadata']['image_path']}")
|
| 264 |
+
print(f" XML path: {sample['metadata']['xml_path']}")
|
| 265 |
+
print(f" Expected: {sample['output'][:100]}...")
|
| 266 |
+
print(f" Image base64 length: {len(sample['image_base64'])}")
|
| 267 |
+
print(f" XML content length: {len(sample['metadata'].get('xml_content', ''))}")
|
| 268 |
+
|
| 269 |
+
# Test split
|
| 270 |
+
train, val = loader.load_split()
|
| 271 |
+
print(f"\n📊 Dataset Split:")
|
| 272 |
+
print(f" Training: {len(train)} samples")
|
| 273 |
+
print(f" Validation: {len(val)} samples")
|
| 274 |
+
print(f" Test: Not used (no test set)")
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"❌ Error: {e}")
|
| 278 |
+
|
src/gepa_optimizer/data/loaders.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading utilities for various file formats
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import base64
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from typing import Any, Optional, Union, List , Dict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class DataLoader:
|
| 15 |
+
"""
|
| 16 |
+
Utility class for loading data from various sources
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.supported_formats = [
|
| 21 |
+
'.csv', '.json', '.jsonl', '.txt', '.md', '.xlsx',
|
| 22 |
+
'.png', '.jpg', '.jpeg'
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
def load(self, source: Union[str, Path], format_hint: Optional[str] = None) -> Optional[Any]:
|
| 26 |
+
"""
|
| 27 |
+
Load data from any supported source
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
source: File path or data source
|
| 31 |
+
format_hint: Optional format hint to override auto-detection
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Loaded data or None if failed
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
path = Path(source)
|
| 38 |
+
|
| 39 |
+
if not path.exists():
|
| 40 |
+
logger.error(f"File not found: {source}")
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
# Use format hint or detect from extension
|
| 44 |
+
file_format = format_hint or path.suffix.lower()
|
| 45 |
+
|
| 46 |
+
if file_format == '.csv':
|
| 47 |
+
return self.load_csv(path)
|
| 48 |
+
elif file_format == '.json':
|
| 49 |
+
return self.load_json(path)
|
| 50 |
+
elif file_format == '.jsonl':
|
| 51 |
+
return self.load_jsonl(path)
|
| 52 |
+
elif file_format in ['.txt', '.md']:
|
| 53 |
+
return self.load_text(path)
|
| 54 |
+
elif file_format == '.xlsx':
|
| 55 |
+
return self.load_excel(path)
|
| 56 |
+
elif file_format in ['.png', '.jpg', '.jpeg']:
|
| 57 |
+
return self.load_image_base64(path)
|
| 58 |
+
else:
|
| 59 |
+
logger.warning(f"Unsupported format: {file_format}")
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Failed to load data from {source}: {str(e)}")
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
def load_csv(self, path: Union[str, Path]) -> Optional[pd.DataFrame]:
|
| 67 |
+
"""Load CSV file as pandas DataFrame"""
|
| 68 |
+
try:
|
| 69 |
+
df = pd.read_csv(path)
|
| 70 |
+
logger.info(f"Loaded CSV with {len(df)} rows and {len(df.columns)} columns")
|
| 71 |
+
return df
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Failed to load CSV {path}: {str(e)}")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
def load_json(self, path: Union[str, Path]) -> Optional[Any]:
|
| 77 |
+
"""Load JSON file"""
|
| 78 |
+
try:
|
| 79 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 80 |
+
data = json.load(f)
|
| 81 |
+
|
| 82 |
+
if isinstance(data, list):
|
| 83 |
+
logger.info(f"Loaded JSON with {len(data)} items")
|
| 84 |
+
else:
|
| 85 |
+
logger.info("Loaded JSON object")
|
| 86 |
+
|
| 87 |
+
return data
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"Failed to load JSON {path}: {str(e)}")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
def load_jsonl(self, path: Union[str, Path]) -> Optional[List[Dict]]:
|
| 93 |
+
"""Load JSONL (JSON Lines) file"""
|
| 94 |
+
try:
|
| 95 |
+
data = []
|
| 96 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 97 |
+
for line_num, line in enumerate(f, 1):
|
| 98 |
+
line = line.strip()
|
| 99 |
+
if line:
|
| 100 |
+
try:
|
| 101 |
+
data.append(json.loads(line))
|
| 102 |
+
except json.JSONDecodeError as e:
|
| 103 |
+
logger.warning(f"Invalid JSON on line {line_num}: {str(e)}")
|
| 104 |
+
|
| 105 |
+
logger.info(f"Loaded JSONL with {len(data)} items")
|
| 106 |
+
return data
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Failed to load JSONL {path}: {str(e)}")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def load_text(self, path: Union[str, Path]) -> Optional[str]:
|
| 112 |
+
"""Load plain text file"""
|
| 113 |
+
try:
|
| 114 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 115 |
+
content = f.read()
|
| 116 |
+
|
| 117 |
+
logger.info(f"Loaded text file with {len(content)} characters")
|
| 118 |
+
return content
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Failed to load text {path}: {str(e)}")
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
def load_excel(self, path: Union[str, Path]) -> Optional[pd.DataFrame]:
|
| 124 |
+
"""Load Excel file as pandas DataFrame"""
|
| 125 |
+
try:
|
| 126 |
+
df = pd.read_excel(path)
|
| 127 |
+
logger.info(f"Loaded Excel with {len(df)} rows and {len(df.columns)} columns")
|
| 128 |
+
return df
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Failed to load Excel {path}: {str(e)}")
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
def load_image_base64(self, path: Union[str, Path]) -> Optional[str]:
|
| 134 |
+
"""Load image file and encode as Base64 string"""
|
| 135 |
+
try:
|
| 136 |
+
with open(path, 'rb') as f:
|
| 137 |
+
encoded_string = base64.b64encode(f.read()).decode('utf-8')
|
| 138 |
+
logger.info(f"Loaded image {path} and encoded to Base64")
|
| 139 |
+
return encoded_string
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Failed to load image {path}: {str(e)}")
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
def is_supported_format(self, file_path: Union[str, Path]) -> bool:
|
| 145 |
+
"""Check if file format is supported"""
|
| 146 |
+
path = Path(file_path)
|
| 147 |
+
return path.suffix.lower() in self.supported_formats
|
| 148 |
+
|
| 149 |
+
def get_file_info(self, file_path: Union[str, Path]) -> Dict[str, Any]:
|
| 150 |
+
"""Get information about a file"""
|
| 151 |
+
path = Path(file_path)
|
| 152 |
+
|
| 153 |
+
if not path.exists():
|
| 154 |
+
return {'exists': False}
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
'exists': True,
|
| 158 |
+
'size': path.stat().st_size,
|
| 159 |
+
'format': path.suffix.lower(),
|
| 160 |
+
'supported': self.is_supported_format(path),
|
| 161 |
+
'name': path.name,
|
| 162 |
+
'stem': path.stem,
|
| 163 |
+
'parent': str(path.parent)
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
def load_ui_tree_dataset(self, json_dir: str, screenshots_dir: str) -> List[Dict[str, Any]]:
|
| 167 |
+
"""
|
| 168 |
+
Load UI tree dataset by pairing JSON files with corresponding screenshots
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
json_dir: Directory containing JSON files (e.g., "json_tree")
|
| 172 |
+
screenshots_dir: Directory containing screenshot images (e.g., "screenshots")
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
List of dictionaries with 'input', 'output', and 'image' keys
|
| 176 |
+
"""
|
| 177 |
+
json_path = Path(json_dir)
|
| 178 |
+
screenshots_path = Path(screenshots_dir)
|
| 179 |
+
|
| 180 |
+
if not json_path.exists():
|
| 181 |
+
raise FileNotFoundError(f"JSON directory not found: {json_dir}")
|
| 182 |
+
if not screenshots_path.exists():
|
| 183 |
+
raise FileNotFoundError(f"Screenshots directory not found: {screenshots_dir}")
|
| 184 |
+
|
| 185 |
+
dataset = []
|
| 186 |
+
|
| 187 |
+
# Get all JSON files
|
| 188 |
+
json_files = list(json_path.glob("*.json"))
|
| 189 |
+
logger.info(f"Found {len(json_files)} JSON files in {json_dir}")
|
| 190 |
+
|
| 191 |
+
for json_file in json_files:
|
| 192 |
+
# Extract filename without extension (e.g., "2" from "2.json")
|
| 193 |
+
file_stem = json_file.stem
|
| 194 |
+
|
| 195 |
+
# Look for corresponding image file
|
| 196 |
+
image_extensions = ['.jpg', '.jpeg', '.png']
|
| 197 |
+
image_file = None
|
| 198 |
+
|
| 199 |
+
for ext in image_extensions:
|
| 200 |
+
potential_image = screenshots_path / f"{file_stem}{ext}"
|
| 201 |
+
if potential_image.exists():
|
| 202 |
+
image_file = potential_image
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
if not image_file:
|
| 206 |
+
logger.warning(f"No corresponding image found for {json_file.name}")
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
# Load JSON content
|
| 211 |
+
json_data = self.load_json(json_file)
|
| 212 |
+
if not json_data:
|
| 213 |
+
logger.warning(f"Failed to load JSON: {json_file}")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Load image as base64
|
| 217 |
+
image_base64 = self.load_image_base64(image_file)
|
| 218 |
+
if not image_base64:
|
| 219 |
+
logger.warning(f"Failed to load image: {image_file}")
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
# Create dataset entry
|
| 223 |
+
dataset_entry = {
|
| 224 |
+
'input': 'Extract UI elements from this screenshot and provide the complete UI tree structure',
|
| 225 |
+
'output': json.dumps(json_data, indent=2), # Convert JSON to string
|
| 226 |
+
'image': image_base64
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
dataset.append(dataset_entry)
|
| 230 |
+
logger.debug(f"Loaded pair: {json_file.name} + {image_file.name}")
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"Error loading {json_file.name}: {str(e)}")
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
logger.info(f"Successfully loaded {len(dataset)} image-JSON pairs")
|
| 237 |
+
return dataset
|
src/gepa_optimizer/data/scroll_dataset_loader.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Scroll Element Dataset Loader for Drizz Mobile App Testing
|
| 3 |
+
|
| 4 |
+
Loads screenshots with bounding boxes and commands to identify scroll elements.
|
| 5 |
+
Converts to GEPA-compatible format for prompt optimization.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import random
|
| 10 |
+
import logging
|
| 11 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ScrollDatasetLoader:
|
| 18 |
+
"""
|
| 19 |
+
GENERIC dataset loader for image-based tasks.
|
| 20 |
+
|
| 21 |
+
This is a LIBRARY class - NO hardcoded assumptions about:
|
| 22 |
+
- What the task is (OCR, element detection, classification, etc.)
|
| 23 |
+
- Input format (questions, commands, descriptions, etc.)
|
| 24 |
+
- Output format (IDs, text, JSON, etc.)
|
| 25 |
+
|
| 26 |
+
Users define their dataset in the test script and pass it here.
|
| 27 |
+
|
| 28 |
+
Dataset format per item: (image_filename, input_text, expected_output)
|
| 29 |
+
|
| 30 |
+
Example usage (ANY task):
|
| 31 |
+
# Define YOUR dataset in YOUR test script
|
| 32 |
+
my_dataset = [
|
| 33 |
+
("img1.png", "What is the main color?", "blue"),
|
| 34 |
+
("img2.png", "Count the objects", "5"),
|
| 35 |
+
("img3.png", "Describe the scene", "A cat on a sofa"),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# Pass to loader
|
| 39 |
+
loader = ScrollDatasetLoader(
|
| 40 |
+
images_dir="images",
|
| 41 |
+
dataset_config=my_dataset
|
| 42 |
+
)
|
| 43 |
+
data = loader.load_dataset()
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
images_dir: str = "images",
|
| 49 |
+
dataset_config: Optional[List[Tuple[str, str, str]]] = None
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Initialize dataset loader.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
images_dir: Directory containing images
|
| 56 |
+
dataset_config: List of (image_filename, input_text, expected_output) tuples.
|
| 57 |
+
REQUIRED - no hardcoded defaults to keep library generic.
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
FileNotFoundError: If images_dir doesn't exist
|
| 61 |
+
ValueError: If dataset_config is None
|
| 62 |
+
"""
|
| 63 |
+
self.images_dir = Path(images_dir)
|
| 64 |
+
|
| 65 |
+
if not self.images_dir.exists():
|
| 66 |
+
raise FileNotFoundError(f"Images directory not found: {images_dir}")
|
| 67 |
+
|
| 68 |
+
if dataset_config is None:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"dataset_config is required. This is a library class - define your "
|
| 71 |
+
"dataset in the test script:\n"
|
| 72 |
+
" dataset = [('img1.png', 'your input', 'expected output'), ...]\n"
|
| 73 |
+
" loader = ScrollDatasetLoader(images_dir='...', dataset_config=dataset)"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.dataset_config = dataset_config
|
| 77 |
+
|
| 78 |
+
def load_dataset(self) -> List[Dict[str, Any]]:
|
| 79 |
+
"""
|
| 80 |
+
Load complete dataset with images.
|
| 81 |
+
|
| 82 |
+
Phase 1: Includes element_id extraction from expected output.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
List of dataset items in GEPA format:
|
| 86 |
+
[
|
| 87 |
+
{
|
| 88 |
+
"input": "Command: Scroll down by 70%",
|
| 89 |
+
"output": "3",
|
| 90 |
+
"image_base64": "<base64_encoded_image>", # TOP LEVEL
|
| 91 |
+
"metadata": {
|
| 92 |
+
"image_path": "images/5.png",
|
| 93 |
+
"input_text": "Command: Scroll down by 70%",
|
| 94 |
+
"expected_output": "3",
|
| 95 |
+
"image_filename": "5.png",
|
| 96 |
+
"element_id": 3 # Extracted integer (None if extraction fails)
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
...
|
| 100 |
+
]
|
| 101 |
+
"""
|
| 102 |
+
dataset = []
|
| 103 |
+
|
| 104 |
+
# Generic variable names - no assumptions about data type
|
| 105 |
+
for image_filename, input_text, expected_output in self.dataset_config:
|
| 106 |
+
image_path = self.images_dir / image_filename
|
| 107 |
+
|
| 108 |
+
# Validate image exists
|
| 109 |
+
if not image_path.exists():
|
| 110 |
+
logger.warning(f"Image not found: {image_path}")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# Read and encode image
|
| 114 |
+
try:
|
| 115 |
+
image_base64 = self._encode_image(image_path)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.warning(f"Error encoding {image_filename}: {e}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# 🔥 Phase 1: Extract element_id from expected_output for robust evaluation
|
| 121 |
+
element_id = self._extract_element_id(expected_output)
|
| 122 |
+
if element_id is None:
|
| 123 |
+
logger.warning(f"Could not extract element_id from '{expected_output}' in {image_filename}")
|
| 124 |
+
|
| 125 |
+
# Create dataset item - COMPLETELY GENERIC
|
| 126 |
+
# NO assumptions about output format (element IDs, commands, etc.)
|
| 127 |
+
# Just: image + input text + expected output text
|
| 128 |
+
# Library doesn't know or care what the task is!
|
| 129 |
+
# IMPORTANT: Put image_base64 at TOP LEVEL for UniversalConverter to find it
|
| 130 |
+
dataset_item = {
|
| 131 |
+
"input": input_text, # Generic input text (ANY format)
|
| 132 |
+
"output": expected_output, # Generic expected output (ANY format, full reasoning)
|
| 133 |
+
"image_base64": image_base64, # TOP LEVEL for converter
|
| 134 |
+
"metadata": {
|
| 135 |
+
"image_path": str(image_path),
|
| 136 |
+
"input_text": input_text,
|
| 137 |
+
"expected_output": expected_output,
|
| 138 |
+
"image_filename": image_filename,
|
| 139 |
+
"element_id": element_id # NEW: Extracted element ID (int or None)
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
dataset.append(dataset_item)
|
| 144 |
+
|
| 145 |
+
if not dataset:
|
| 146 |
+
raise ValueError("No valid images found in dataset")
|
| 147 |
+
|
| 148 |
+
logger.info(f"Loaded {len(dataset)} scroll element detection samples")
|
| 149 |
+
return dataset
|
| 150 |
+
|
| 151 |
+
def _extract_element_id(self, expected_output: str) -> Optional[int]:
|
| 152 |
+
"""
|
| 153 |
+
Extract element ID from expected output string.
|
| 154 |
+
|
| 155 |
+
Handles multiple formats:
|
| 156 |
+
- "Element: 4"
|
| 157 |
+
- "Element 4"
|
| 158 |
+
- "4" (standalone)
|
| 159 |
+
- "Element: 4, Description: ..." (full reasoning)
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
expected_output: Full expected output string with reasoning
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Element ID as integer, or None if not found
|
| 166 |
+
"""
|
| 167 |
+
import re
|
| 168 |
+
|
| 169 |
+
if not expected_output:
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
# Pattern 1: "Element: X" or "Element X" (case insensitive)
|
| 173 |
+
patterns = [
|
| 174 |
+
r'element[:\s]+(\d+)', # "Element: 4" or "Element 4"
|
| 175 |
+
r'\belement\s+(\d+)\b', # "element 4" (word boundary)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
for pattern in patterns:
|
| 179 |
+
match = re.search(pattern, expected_output, re.IGNORECASE)
|
| 180 |
+
if match:
|
| 181 |
+
try:
|
| 182 |
+
element_id = int(match.group(1))
|
| 183 |
+
# Validate range (reasonable UI element IDs)
|
| 184 |
+
if 1 <= element_id <= 100:
|
| 185 |
+
return element_id
|
| 186 |
+
except (ValueError, IndexError):
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
# Pattern 2: First standalone number (if no "Element:" pattern found)
|
| 190 |
+
# Only use if it's a reasonable element ID (1-100)
|
| 191 |
+
number_match = re.search(r'\b(\d{1,3})\b', expected_output)
|
| 192 |
+
if number_match:
|
| 193 |
+
try:
|
| 194 |
+
element_id = int(number_match.group(1))
|
| 195 |
+
if 1 <= element_id <= 100: # Reasonable range for UI elements
|
| 196 |
+
return element_id
|
| 197 |
+
except ValueError:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
def _encode_image(self, image_path: Path) -> str:
|
| 203 |
+
"""
|
| 204 |
+
Encode image to base64 string.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
image_path: Path to image file
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Base64 encoded image string
|
| 211 |
+
"""
|
| 212 |
+
with open(image_path, "rb") as image_file:
|
| 213 |
+
encoded = base64.b64encode(image_file.read()).decode('utf-8')
|
| 214 |
+
return encoded
|
| 215 |
+
|
| 216 |
+
def split_dataset(
|
| 217 |
+
self,
|
| 218 |
+
dataset: List[Dict[str, Any]],
|
| 219 |
+
train_size: int = 4,
|
| 220 |
+
val_size: int = 1,
|
| 221 |
+
test_size: int = 1,
|
| 222 |
+
shuffle: bool = True,
|
| 223 |
+
seed: Optional[int] = None
|
| 224 |
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 225 |
+
"""
|
| 226 |
+
Split dataset into train, validation, and test sets.
|
| 227 |
+
|
| 228 |
+
🔥 NEW: Added shuffling support to ensure different image distribution
|
| 229 |
+
across splits, preventing hard images from always landing in validation set.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
dataset: Complete dataset
|
| 233 |
+
train_size: Number of samples for training (default: 4)
|
| 234 |
+
val_size: Number of samples for validation (default: 1)
|
| 235 |
+
test_size: Number of samples for test (default: 1)
|
| 236 |
+
shuffle: Whether to shuffle dataset before splitting (default: True)
|
| 237 |
+
seed: Random seed for reproducible shuffling (default: None = random)
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Tuple of (train_set, val_set, test_set)
|
| 241 |
+
"""
|
| 242 |
+
n = len(dataset)
|
| 243 |
+
|
| 244 |
+
# Validate split sizes
|
| 245 |
+
total_size = train_size + val_size + test_size
|
| 246 |
+
if total_size > n:
|
| 247 |
+
logger.warning(f"Requested split ({total_size}) exceeds dataset size ({n}). Adjusting split proportionally...")
|
| 248 |
+
ratio = n / total_size
|
| 249 |
+
train_size = int(train_size * ratio)
|
| 250 |
+
val_size = int(val_size * ratio)
|
| 251 |
+
test_size = n - train_size - val_size
|
| 252 |
+
|
| 253 |
+
# 🔥 CRITICAL: Shuffle dataset to ensure different image distribution
|
| 254 |
+
# This prevents the same hard images from always being in validation set
|
| 255 |
+
dataset_copy = dataset.copy() # Don't modify original
|
| 256 |
+
if shuffle:
|
| 257 |
+
if seed is not None:
|
| 258 |
+
random.seed(seed)
|
| 259 |
+
logger.debug(f"Shuffling dataset with seed={seed} for reproducible splits")
|
| 260 |
+
else:
|
| 261 |
+
logger.debug(f"Shuffling dataset randomly (no seed)")
|
| 262 |
+
random.shuffle(dataset_copy)
|
| 263 |
+
else:
|
| 264 |
+
logger.warning(f"Not shuffling dataset - using original order")
|
| 265 |
+
|
| 266 |
+
# Split shuffled dataset
|
| 267 |
+
train_set = dataset_copy[:train_size]
|
| 268 |
+
val_set = dataset_copy[train_size:train_size + val_size]
|
| 269 |
+
test_set = dataset_copy[train_size + val_size:train_size + val_size + test_size]
|
| 270 |
+
|
| 271 |
+
logger.info(f"Dataset split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test")
|
| 272 |
+
|
| 273 |
+
# Log which images are in each split for debugging
|
| 274 |
+
if shuffle:
|
| 275 |
+
train_images = [item['metadata'].get('image_filename', 'N/A') for item in train_set]
|
| 276 |
+
val_images = [item['metadata'].get('image_filename', 'N/A') for item in val_set]
|
| 277 |
+
test_images = [item['metadata'].get('image_filename', 'N/A') for item in test_set]
|
| 278 |
+
print(f" Train images: {train_images[:5]}{'...' if len(train_images) > 5 else ''}")
|
| 279 |
+
print(f" Val images: {val_images}")
|
| 280 |
+
print(f" Test images: {test_images[:5]}{'...' if len(test_images) > 5 else ''}")
|
| 281 |
+
|
| 282 |
+
return train_set, val_set, test_set
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def load_scroll_dataset(
|
| 286 |
+
images_dir: str = "images",
|
| 287 |
+
dataset_config: List[Tuple[str, str, str]] = None,
|
| 288 |
+
split: bool = True
|
| 289 |
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 290 |
+
"""
|
| 291 |
+
Convenience function to load image-based dataset (GENERIC).
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
images_dir: Directory containing images
|
| 295 |
+
dataset_config: List of (image_filename, input_text, expected_output) tuples
|
| 296 |
+
split: Whether to split into train/val/test
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
If split=True: (train_set, val_set, test_set)
|
| 300 |
+
If split=False: (full_dataset, [], [])
|
| 301 |
+
|
| 302 |
+
Example (works for ANY task):
|
| 303 |
+
dataset_config = [
|
| 304 |
+
("img1.png", "What color is the sky?", "blue"),
|
| 305 |
+
("img2.png", "Count the dogs", "2"),
|
| 306 |
+
]
|
| 307 |
+
train, val, test = load_scroll_dataset(
|
| 308 |
+
images_dir="images",
|
| 309 |
+
dataset_config=dataset_config
|
| 310 |
+
)
|
| 311 |
+
"""
|
| 312 |
+
loader = ScrollDatasetLoader(images_dir, dataset_config=dataset_config)
|
| 313 |
+
dataset = loader.load_dataset()
|
| 314 |
+
|
| 315 |
+
if split:
|
| 316 |
+
return loader.split_dataset(dataset)
|
| 317 |
+
else:
|
| 318 |
+
return dataset, [], []
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# Example usage (for testing the library loader itself)
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
print("🚀 Testing Scroll Dataset Loader...")
|
| 324 |
+
print("⚠️ NOTE: This is a library class. Define your dataset in your test script.")
|
| 325 |
+
print("\nExample:")
|
| 326 |
+
print(" dataset_config = [")
|
| 327 |
+
print(" ('image1.png', 'Scroll down by 50%', '3'),")
|
| 328 |
+
print(" ('image2.png', 'Swipe left', '4'),")
|
| 329 |
+
print(" ]")
|
| 330 |
+
print(" train, val, test = load_scroll_dataset(")
|
| 331 |
+
print(" images_dir='images',")
|
| 332 |
+
print(" dataset_config=dataset_config")
|
| 333 |
+
print(" )")
|
| 334 |
+
|
src/gepa_optimizer/data/validation_dataset_loader.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validation Dataset Loader for UI Validation Use Case
|
| 3 |
+
|
| 4 |
+
Loads validation datapoints from SQLite database and converts to GEPA-compatible format.
|
| 5 |
+
Supports filtering by data_type (trainset/valset/testset) and confirmed status.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sqlite3
|
| 10 |
+
import base64
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Dict, Any, Optional, Literal
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ValidationDatasetLoader:
|
| 19 |
+
"""
|
| 20 |
+
Loads validation dataset from SQLite database.
|
| 21 |
+
|
| 22 |
+
Database schema:
|
| 23 |
+
- validation_data: id, image_id, command, result (0/1), reasoning, data_type, confirmed, created_at
|
| 24 |
+
- images: image_id, mime, bytes (BLOB), created_at
|
| 25 |
+
|
| 26 |
+
Converts to GEPA format:
|
| 27 |
+
- input: command text (seed prompt will be provided in test script)
|
| 28 |
+
- output: "true" or "false" (converted from 0/1)
|
| 29 |
+
- image_base64: base64 encoded image (TOP LEVEL for UniversalConverter)
|
| 30 |
+
- metadata: All original fields plus converted values
|
| 31 |
+
|
| 32 |
+
Note: The seed prompt is NOT stored in database - it will be provided in the test script.
|
| 33 |
+
The input field contains just the command, and the image is at top level.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
db_path: Optional[str] = None,
|
| 39 |
+
confirmed_only: bool = True
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Initialize validation dataset loader.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
db_path: Path to SQLite database file.
|
| 46 |
+
Default: "./validation_data.db" or from VD_DB_PATH env var
|
| 47 |
+
confirmed_only: If True, only load datapoints where confirmed=1.
|
| 48 |
+
Default: True (only manually reviewed data)
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
FileNotFoundError: If database file doesn't exist
|
| 52 |
+
sqlite3.Error: If database connection fails
|
| 53 |
+
"""
|
| 54 |
+
# Get database path from env or use default
|
| 55 |
+
if db_path is None:
|
| 56 |
+
db_path = os.getenv("VD_DB_PATH", "./validation_data.db")
|
| 57 |
+
|
| 58 |
+
self.db_path = Path(db_path).resolve()
|
| 59 |
+
|
| 60 |
+
if not self.db_path.exists():
|
| 61 |
+
raise FileNotFoundError(
|
| 62 |
+
f"Database file not found: {self.db_path}\n"
|
| 63 |
+
f"Make sure validation_data_ui_server_async.py has been run at least once to create the database."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.confirmed_only = confirmed_only
|
| 67 |
+
|
| 68 |
+
def load_dataset(
|
| 69 |
+
self,
|
| 70 |
+
data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
|
| 71 |
+
confirmed_only: Optional[bool] = None
|
| 72 |
+
) -> List[Dict[str, Any]]:
|
| 73 |
+
"""
|
| 74 |
+
Load dataset from database and convert to GEPA format.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
data_type: Filter by data_type. If None, loads all types.
|
| 78 |
+
Options: "trainset", "valset", "testset"
|
| 79 |
+
confirmed_only: Override instance default. If True, only load confirmed datapoints.
|
| 80 |
+
If None, uses instance default (self.confirmed_only)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of dataset items in GEPA format:
|
| 84 |
+
[
|
| 85 |
+
{
|
| 86 |
+
"input": "Validate Submit button is visible", # Command only (seed prompt in test script)
|
| 87 |
+
"output": "true", # or "false" (converted from 0/1)
|
| 88 |
+
"image_base64": "<base64_encoded_image>", # TOP LEVEL (image + command together)
|
| 89 |
+
"metadata": {
|
| 90 |
+
"id": 1,
|
| 91 |
+
"image_id": "abc123...",
|
| 92 |
+
"command": "Validate Submit button is visible",
|
| 93 |
+
"result": True, # Boolean
|
| 94 |
+
"result_int": 1, # Original 0/1
|
| 95 |
+
"reasoning": "Detailed explanation...",
|
| 96 |
+
"data_type": "trainset",
|
| 97 |
+
"confirmed": True,
|
| 98 |
+
"created_at": "2024-01-01 12:00:00"
|
| 99 |
+
}
|
| 100 |
+
},
|
| 101 |
+
...
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
Note: Seed prompt is provided separately in test script, not in database.
|
| 105 |
+
|
| 106 |
+
Raises:
|
| 107 |
+
sqlite3.Error: If database query fails
|
| 108 |
+
ValueError: If no datapoints found matching criteria
|
| 109 |
+
"""
|
| 110 |
+
# Use provided confirmed_only or instance default
|
| 111 |
+
use_confirmed = confirmed_only if confirmed_only is not None else self.confirmed_only
|
| 112 |
+
|
| 113 |
+
conn = sqlite3.connect(str(self.db_path))
|
| 114 |
+
conn.row_factory = sqlite3.Row # Access columns by name
|
| 115 |
+
dataset = []
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
# Build query with filters
|
| 119 |
+
query = """
|
| 120 |
+
SELECT
|
| 121 |
+
v.id,
|
| 122 |
+
v.image_id,
|
| 123 |
+
v.command,
|
| 124 |
+
v.result,
|
| 125 |
+
v.reasoning,
|
| 126 |
+
v.data_type,
|
| 127 |
+
v.confirmed,
|
| 128 |
+
v.created_at,
|
| 129 |
+
i.mime,
|
| 130 |
+
i.bytes
|
| 131 |
+
FROM validation_data v
|
| 132 |
+
INNER JOIN images i ON v.image_id = i.image_id
|
| 133 |
+
WHERE 1=1
|
| 134 |
+
"""
|
| 135 |
+
params = []
|
| 136 |
+
|
| 137 |
+
# Add filters
|
| 138 |
+
if use_confirmed:
|
| 139 |
+
query += " AND v.confirmed = 1"
|
| 140 |
+
|
| 141 |
+
if data_type:
|
| 142 |
+
query += " AND v.data_type = ?"
|
| 143 |
+
params.append(data_type)
|
| 144 |
+
|
| 145 |
+
query += " ORDER BY v.id ASC"
|
| 146 |
+
|
| 147 |
+
# Execute query
|
| 148 |
+
cursor = conn.execute(query, params)
|
| 149 |
+
rows = cursor.fetchall()
|
| 150 |
+
|
| 151 |
+
if not rows:
|
| 152 |
+
filter_msg = []
|
| 153 |
+
if use_confirmed:
|
| 154 |
+
filter_msg.append("confirmed=1")
|
| 155 |
+
if data_type:
|
| 156 |
+
filter_msg.append(f"data_type='{data_type}'")
|
| 157 |
+
|
| 158 |
+
filter_str = " with filters: " + ", ".join(filter_msg) if filter_msg else ""
|
| 159 |
+
raise ValueError(
|
| 160 |
+
f"No datapoints found{filter_str} in database: {self.db_path}\n"
|
| 161 |
+
f"Make sure you have generated and saved datapoints using the validation UI."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Convert rows to GEPA format
|
| 165 |
+
for row in rows:
|
| 166 |
+
# Convert 0/1 to "true"/"false" string for GEPA
|
| 167 |
+
result_str = "true" if row["result"] == 1 else "false"
|
| 168 |
+
|
| 169 |
+
# Encode image bytes to base64
|
| 170 |
+
image_base64 = base64.b64encode(row["bytes"]).decode("utf-8")
|
| 171 |
+
|
| 172 |
+
# Create GEPA format item
|
| 173 |
+
# Input: command (seed prompt will be provided in test script)
|
| 174 |
+
# Image: separate at top level (image_base64)
|
| 175 |
+
# Output: "true" or "false" (converted from 0/1)
|
| 176 |
+
dataset_item = {
|
| 177 |
+
"input": row["command"], # Just the command - seed prompt will be in test script
|
| 178 |
+
"output": result_str, # "true" or "false" (string)
|
| 179 |
+
"image_base64": image_base64, # TOP LEVEL for UniversalConverter (image + command together)
|
| 180 |
+
"metadata": {
|
| 181 |
+
"id": row["id"],
|
| 182 |
+
"image_id": row["image_id"],
|
| 183 |
+
"command": row["command"], # Keep original for reference
|
| 184 |
+
"result": bool(row["result"]), # Boolean for reference
|
| 185 |
+
"result_int": row["result"], # Original 0/1 for reference
|
| 186 |
+
"reasoning": row["reasoning"],
|
| 187 |
+
"data_type": row["data_type"],
|
| 188 |
+
"confirmed": bool(row["confirmed"]),
|
| 189 |
+
"created_at": row["created_at"],
|
| 190 |
+
"mime": row["mime"],
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
dataset.append(dataset_item)
|
| 195 |
+
|
| 196 |
+
# Log summary
|
| 197 |
+
data_type_str = f" ({data_type})" if data_type else ""
|
| 198 |
+
confirmed_str = " (confirmed only)" if use_confirmed else " (all)"
|
| 199 |
+
logger.info(f"Loaded {len(dataset)} validation datapoints{data_type_str}{confirmed_str}")
|
| 200 |
+
|
| 201 |
+
return dataset
|
| 202 |
+
|
| 203 |
+
finally:
|
| 204 |
+
conn.close()
|
| 205 |
+
|
| 206 |
+
def load_split_dataset(
|
| 207 |
+
self,
|
| 208 |
+
confirmed_only: Optional[bool] = None
|
| 209 |
+
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 210 |
+
"""
|
| 211 |
+
Load dataset split by data_type (trainset/valset/testset).
|
| 212 |
+
|
| 213 |
+
Convenience method that loads all three splits at once.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
confirmed_only: Override instance default. If True, only load confirmed datapoints.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Tuple of (train_set, val_set, test_set) in GEPA format
|
| 220 |
+
|
| 221 |
+
Example:
|
| 222 |
+
loader = ValidationDatasetLoader(db_path="./validation_data.db")
|
| 223 |
+
train, val, test = loader.load_split_dataset()
|
| 224 |
+
"""
|
| 225 |
+
train_set = self.load_dataset(data_type="trainset", confirmed_only=confirmed_only)
|
| 226 |
+
val_set = self.load_dataset(data_type="valset", confirmed_only=confirmed_only)
|
| 227 |
+
test_set = self.load_dataset(data_type="testset", confirmed_only=confirmed_only)
|
| 228 |
+
|
| 229 |
+
logger.info(f"Dataset Split Summary: Training={len(train_set)}, Validation={len(val_set)}, Test={len(test_set)}, Total={len(train_set) + len(val_set) + len(test_set)}")
|
| 230 |
+
|
| 231 |
+
return train_set, val_set, test_set
|
| 232 |
+
|
| 233 |
+
def get_dataset_stats(self) -> Dict[str, Any]:
|
| 234 |
+
"""
|
| 235 |
+
Get statistics about the dataset in the database.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Dictionary with dataset statistics:
|
| 239 |
+
{
|
| 240 |
+
"total": 100,
|
| 241 |
+
"confirmed": 95,
|
| 242 |
+
"unconfirmed": 5,
|
| 243 |
+
"by_data_type": {
|
| 244 |
+
"trainset": 70,
|
| 245 |
+
"valset": 15,
|
| 246 |
+
"testset": 15
|
| 247 |
+
},
|
| 248 |
+
"by_result": {
|
| 249 |
+
"true": 50,
|
| 250 |
+
"false": 50
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
"""
|
| 254 |
+
conn = sqlite3.connect(str(self.db_path))
|
| 255 |
+
conn.row_factory = sqlite3.Row
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
stats = {}
|
| 259 |
+
|
| 260 |
+
# Total counts
|
| 261 |
+
total = conn.execute("SELECT COUNT(*) FROM validation_data").fetchone()[0]
|
| 262 |
+
confirmed = conn.execute("SELECT COUNT(*) FROM validation_data WHERE confirmed = 1").fetchone()[0]
|
| 263 |
+
stats["total"] = total
|
| 264 |
+
stats["confirmed"] = confirmed
|
| 265 |
+
stats["unconfirmed"] = total - confirmed
|
| 266 |
+
|
| 267 |
+
# By data_type
|
| 268 |
+
data_type_rows = conn.execute("""
|
| 269 |
+
SELECT data_type, COUNT(*) as count
|
| 270 |
+
FROM validation_data
|
| 271 |
+
GROUP BY data_type
|
| 272 |
+
""").fetchall()
|
| 273 |
+
stats["by_data_type"] = {row["data_type"]: row["count"] for row in data_type_rows}
|
| 274 |
+
|
| 275 |
+
# By result (true/false)
|
| 276 |
+
result_rows = conn.execute("""
|
| 277 |
+
SELECT result, COUNT(*) as count
|
| 278 |
+
FROM validation_data
|
| 279 |
+
GROUP BY result
|
| 280 |
+
""").fetchall()
|
| 281 |
+
stats["by_result"] = {
|
| 282 |
+
"true": sum(row["count"] for row in result_rows if row["result"] == 1),
|
| 283 |
+
"false": sum(row["count"] for row in result_rows if row["result"] == 0)
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
return stats
|
| 287 |
+
|
| 288 |
+
finally:
|
| 289 |
+
conn.close()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def load_validation_dataset(
|
| 293 |
+
db_path: Optional[str] = None,
|
| 294 |
+
data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
|
| 295 |
+
confirmed_only: bool = True
|
| 296 |
+
) -> List[Dict[str, Any]]:
|
| 297 |
+
"""
|
| 298 |
+
Convenience function to load validation dataset.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
db_path: Path to SQLite database file. Default: "./validation_data.db"
|
| 302 |
+
data_type: Filter by data_type. If None, loads all types.
|
| 303 |
+
confirmed_only: If True, only load confirmed datapoints.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
List of dataset items in GEPA format
|
| 307 |
+
|
| 308 |
+
Example:
|
| 309 |
+
# Load all confirmed training data
|
| 310 |
+
train_data = load_validation_dataset(data_type="trainset", confirmed_only=True)
|
| 311 |
+
|
| 312 |
+
# Load all confirmed data
|
| 313 |
+
all_data = load_validation_dataset(confirmed_only=True)
|
| 314 |
+
"""
|
| 315 |
+
loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
|
| 316 |
+
return loader.load_dataset(data_type=data_type, confirmed_only=confirmed_only)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def load_validation_split(
|
| 320 |
+
db_path: Optional[str] = None,
|
| 321 |
+
confirmed_only: bool = True
|
| 322 |
+
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 323 |
+
"""
|
| 324 |
+
Convenience function to load validation dataset split by data_type.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
db_path: Path to SQLite database file. Default: "./validation_data.db"
|
| 328 |
+
confirmed_only: If True, only load confirmed datapoints.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Tuple of (train_set, val_set, test_set) in GEPA format
|
| 332 |
+
|
| 333 |
+
Example:
|
| 334 |
+
train, val, test = load_validation_split(confirmed_only=True)
|
| 335 |
+
"""
|
| 336 |
+
loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
|
| 337 |
+
return loader.load_split_dataset(confirmed_only=confirmed_only)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# Example usage and testing
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
print("🚀 Testing Validation Dataset Loader...")
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
loader = ValidationDatasetLoader()
|
| 346 |
+
|
| 347 |
+
# Get stats
|
| 348 |
+
print("\n📊 Dataset Statistics:")
|
| 349 |
+
stats = loader.get_dataset_stats()
|
| 350 |
+
print(f" Total: {stats['total']}")
|
| 351 |
+
print(f" Confirmed: {stats['confirmed']}")
|
| 352 |
+
print(f" Unconfirmed: {stats['unconfirmed']}")
|
| 353 |
+
print(f" By data_type: {stats['by_data_type']}")
|
| 354 |
+
print(f" By result: {stats['by_result']}")
|
| 355 |
+
|
| 356 |
+
# Load split dataset
|
| 357 |
+
print("\n📦 Loading split dataset...")
|
| 358 |
+
train, val, test = loader.load_split_dataset()
|
| 359 |
+
|
| 360 |
+
# Show sample
|
| 361 |
+
if train:
|
| 362 |
+
sample = train[0]
|
| 363 |
+
print(f"\n📝 Sample Training Item:")
|
| 364 |
+
print(f" Input: {sample['input']}")
|
| 365 |
+
print(f" Output: {sample['output']}")
|
| 366 |
+
print(f" Image ID: {sample['metadata']['image_id'][:8]}...")
|
| 367 |
+
print(f" Data Type: {sample['metadata']['data_type']}")
|
| 368 |
+
print(f" Result: {sample['metadata']['result']} (int: {sample['metadata']['result_int']})")
|
| 369 |
+
|
| 370 |
+
except FileNotFoundError as e:
|
| 371 |
+
print(f"❌ {e}")
|
| 372 |
+
print("\n💡 Make sure validation_data_ui_server_async.py has been run to create the database.")
|
| 373 |
+
except ValueError as e:
|
| 374 |
+
print(f"❌ {e}")
|
| 375 |
+
print("\n💡 Generate and save some datapoints using the validation UI first.")
|
| 376 |
+
|
src/gepa_optimizer/data/validators.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data validation utilities for GEPA optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class DataValidator:
|
| 11 |
+
"""
|
| 12 |
+
Validates datasets for completeness and GEPA compatibility
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.required_fields = ['input', 'output']
|
| 17 |
+
self.optional_fields = ['metadata', 'id', 'tags']
|
| 18 |
+
|
| 19 |
+
def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Tuple[bool, List[str]]:
|
| 20 |
+
"""
|
| 21 |
+
Validate entire dataset
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
dataset: List of data items to validate
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Tuple[bool, List[str]]: (is_valid, list_of_errors)
|
| 28 |
+
"""
|
| 29 |
+
errors = []
|
| 30 |
+
|
| 31 |
+
# Basic dataset checks
|
| 32 |
+
if not dataset:
|
| 33 |
+
errors.append("Dataset is empty")
|
| 34 |
+
return False, errors
|
| 35 |
+
|
| 36 |
+
if not isinstance(dataset, list):
|
| 37 |
+
errors.append("Dataset must be a list")
|
| 38 |
+
return False, errors
|
| 39 |
+
|
| 40 |
+
# Validate each item
|
| 41 |
+
for idx, item in enumerate(dataset):
|
| 42 |
+
item_errors = self.validate_item(item, idx)
|
| 43 |
+
errors.extend(item_errors)
|
| 44 |
+
|
| 45 |
+
# Check for minimum dataset size
|
| 46 |
+
if len(dataset) < 2:
|
| 47 |
+
errors.append("Dataset should have at least 2 items for proper train/val split")
|
| 48 |
+
|
| 49 |
+
# Log validation results
|
| 50 |
+
if errors:
|
| 51 |
+
logger.warning(f"Dataset validation failed with {len(errors)} errors")
|
| 52 |
+
else:
|
| 53 |
+
logger.info(f"Dataset validation passed for {len(dataset)} items")
|
| 54 |
+
|
| 55 |
+
return len(errors) == 0, errors
|
| 56 |
+
|
| 57 |
+
def validate_item(self, item: Dict[str, Any], index: Optional[int] = None) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
Validate a single dataset item
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
item: Single data item to validate
|
| 63 |
+
index: Optional item index for error reporting
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
List[str]: List of validation errors
|
| 67 |
+
"""
|
| 68 |
+
errors = []
|
| 69 |
+
item_ref = f"item {index}" if index is not None else "item"
|
| 70 |
+
|
| 71 |
+
# Check if item is a dictionary
|
| 72 |
+
if not isinstance(item, dict):
|
| 73 |
+
errors.append(f"{item_ref}: Must be a dictionary")
|
| 74 |
+
return errors
|
| 75 |
+
|
| 76 |
+
# Check for required fields
|
| 77 |
+
if 'input' not in item:
|
| 78 |
+
errors.append(f"{item_ref}: Missing required 'input' field")
|
| 79 |
+
elif not isinstance(item['input'], str):
|
| 80 |
+
errors.append(f"{item_ref}: 'input' field must be a string")
|
| 81 |
+
elif not item['input'].strip():
|
| 82 |
+
errors.append(f"{item_ref}: 'input' field cannot be empty")
|
| 83 |
+
|
| 84 |
+
# Check output field (can be empty but should exist for supervised learning)
|
| 85 |
+
if 'output' in item:
|
| 86 |
+
if not isinstance(item['output'], str):
|
| 87 |
+
errors.append(f"{item_ref}: 'output' field must be a string")
|
| 88 |
+
|
| 89 |
+
# Validate metadata if present
|
| 90 |
+
if 'metadata' in item and not isinstance(item['metadata'], dict):
|
| 91 |
+
errors.append(f"{item_ref}: 'metadata' field must be a dictionary")
|
| 92 |
+
|
| 93 |
+
return errors
|
| 94 |
+
|
| 95 |
+
def validate_gepa_format(self, gepa_data: List[Dict[str, Any]]) -> Tuple[bool, List[str]]:
|
| 96 |
+
"""
|
| 97 |
+
Validate data in GEPA format
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
gepa_data: Data in GEPA format
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Tuple[bool, List[str]]: (is_valid, list_of_errors)
|
| 104 |
+
"""
|
| 105 |
+
errors = []
|
| 106 |
+
|
| 107 |
+
if not gepa_data:
|
| 108 |
+
errors.append("GEPA dataset is empty")
|
| 109 |
+
return False, errors
|
| 110 |
+
|
| 111 |
+
for idx, item in enumerate(gepa_data):
|
| 112 |
+
if 'input' not in item:
|
| 113 |
+
errors.append(f"GEPA item {idx}: Missing 'input' field")
|
| 114 |
+
|
| 115 |
+
if 'expected_output' not in item:
|
| 116 |
+
errors.append(f"GEPA item {idx}: Missing 'expected_output' field")
|
| 117 |
+
|
| 118 |
+
if 'metadata' not in item:
|
| 119 |
+
errors.append(f"GEPA item {idx}: Missing 'metadata' field")
|
| 120 |
+
elif not isinstance(item['metadata'], dict):
|
| 121 |
+
errors.append(f"GEPA item {idx}: 'metadata' must be a dictionary")
|
| 122 |
+
|
| 123 |
+
return len(errors) == 0, errors
|
| 124 |
+
|
| 125 |
+
def validate_split(self, trainset: List[Dict], valset: List[Dict]) -> Tuple[bool, List[str]]:
|
| 126 |
+
"""
|
| 127 |
+
Validate train/validation split
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
trainset: Training data
|
| 131 |
+
valset: Validation data
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple[bool, List[str]]: (is_valid, list_of_errors)
|
| 135 |
+
"""
|
| 136 |
+
errors = []
|
| 137 |
+
|
| 138 |
+
if not trainset:
|
| 139 |
+
errors.append("Training set is empty")
|
| 140 |
+
|
| 141 |
+
if not valset:
|
| 142 |
+
errors.append("Validation set is empty")
|
| 143 |
+
|
| 144 |
+
# Check proportions
|
| 145 |
+
total_size = len(trainset) + len(valset)
|
| 146 |
+
if total_size > 0:
|
| 147 |
+
train_ratio = len(trainset) / total_size
|
| 148 |
+
if train_ratio < 0.5:
|
| 149 |
+
errors.append(f"Training set too small: {train_ratio:.2%} of total data")
|
| 150 |
+
elif train_ratio > 0.95:
|
| 151 |
+
errors.append(f"Validation set too small: {1-train_ratio:.2%} of total data")
|
| 152 |
+
|
| 153 |
+
return len(errors) == 0, errors
|
| 154 |
+
|
| 155 |
+
def get_dataset_stats(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 156 |
+
"""
|
| 157 |
+
Get statistics about the dataset
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
dataset: Dataset to analyze
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dict[str, Any]: Dataset statistics
|
| 164 |
+
"""
|
| 165 |
+
if not dataset:
|
| 166 |
+
return {'total_items': 0, 'valid': False}
|
| 167 |
+
|
| 168 |
+
stats = {
|
| 169 |
+
'total_items': len(dataset),
|
| 170 |
+
'has_output': sum(1 for item in dataset if item.get('output')),
|
| 171 |
+
'avg_input_length': 0,
|
| 172 |
+
'avg_output_length': 0,
|
| 173 |
+
'empty_inputs': 0,
|
| 174 |
+
'empty_outputs': 0
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
input_lengths = []
|
| 178 |
+
output_lengths = []
|
| 179 |
+
|
| 180 |
+
for item in dataset:
|
| 181 |
+
if isinstance(item, dict):
|
| 182 |
+
input_text = item.get('input', '')
|
| 183 |
+
output_text = item.get('output', '')
|
| 184 |
+
|
| 185 |
+
if isinstance(input_text, str):
|
| 186 |
+
input_lengths.append(len(input_text))
|
| 187 |
+
if not input_text.strip():
|
| 188 |
+
stats['empty_inputs'] += 1
|
| 189 |
+
|
| 190 |
+
if isinstance(output_text, str):
|
| 191 |
+
output_lengths.append(len(output_text))
|
| 192 |
+
if not output_text.strip():
|
| 193 |
+
stats['empty_outputs'] += 1
|
| 194 |
+
|
| 195 |
+
if input_lengths:
|
| 196 |
+
stats['avg_input_length'] = sum(input_lengths) / len(input_lengths)
|
| 197 |
+
|
| 198 |
+
if output_lengths:
|
| 199 |
+
stats['avg_output_length'] = sum(output_lengths) / len(output_lengths)
|
| 200 |
+
|
| 201 |
+
# Determine if dataset looks valid
|
| 202 |
+
stats['valid'] = (
|
| 203 |
+
stats['total_items'] > 0 and
|
| 204 |
+
stats['empty_inputs'] < stats['total_items'] * 0.5 # Less than 50% empty inputs
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return stats
|
src/gepa_optimizer/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation module for GEPA Optimizer
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- UniversalSemanticEvaluator: Works for ANY task (recommended for general use)
|
| 6 |
+
- BaseEvaluator: Abstract base class for custom evaluators
|
| 7 |
+
- Task-specific evaluators for specialized use cases
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .base_evaluator import BaseEvaluator
|
| 11 |
+
from .universal_evaluator import UniversalSemanticEvaluator, create_universal_evaluator
|
| 12 |
+
from .ui_evaluator import UITreeEvaluator
|
| 13 |
+
from .scroll_evaluator import ScrollElementEvaluator
|
| 14 |
+
from .validation_evaluator import ValidationEvaluator
|
| 15 |
+
from .index_caching_evaluator import IndexCachingEvaluator
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
# Universal (recommended)
|
| 19 |
+
"UniversalSemanticEvaluator",
|
| 20 |
+
"create_universal_evaluator",
|
| 21 |
+
# Base class
|
| 22 |
+
"BaseEvaluator",
|
| 23 |
+
# Task-specific
|
| 24 |
+
"UITreeEvaluator",
|
| 25 |
+
"ScrollElementEvaluator",
|
| 26 |
+
"ValidationEvaluator",
|
| 27 |
+
"IndexCachingEvaluator",
|
| 28 |
+
]
|
src/gepa_optimizer/evaluation/base_evaluator.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base evaluator class for all evaluation strategies.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class BaseEvaluator(ABC):
|
| 12 |
+
"""
|
| 13 |
+
Abstract base class for all evaluation strategies.
|
| 14 |
+
|
| 15 |
+
This enforces a consistent interface while allowing complete customization
|
| 16 |
+
of evaluation logic for any use case.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
|
| 20 |
+
"""
|
| 21 |
+
Initialize evaluator with optional metric weights.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
metric_weights: Optional weights for different metrics.
|
| 25 |
+
If None, subclasses should provide defaults.
|
| 26 |
+
"""
|
| 27 |
+
self.metric_weights = metric_weights or {}
|
| 28 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def evaluate(self, predicted: Any, expected: Any) -> Dict[str, float]:
|
| 32 |
+
"""
|
| 33 |
+
Evaluate predicted output against expected output.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
predicted: The model's predicted output
|
| 37 |
+
expected: The ground truth expected output
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dictionary with metric names as keys and scores as values.
|
| 41 |
+
Must include 'composite_score' key for GEPA integration.
|
| 42 |
+
"""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def validate_weights(self) -> bool:
|
| 46 |
+
"""Validate that metric weights sum to approximately 1.0"""
|
| 47 |
+
if not self.metric_weights:
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
total = sum(self.metric_weights.values())
|
| 51 |
+
return abs(total - 1.0) < 0.01 # Allow small floating point errors
|
src/gepa_optimizer/evaluation/index_caching_evaluator.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Index Caching Evaluator for Index-Based Element Selection Use Case
|
| 3 |
+
|
| 4 |
+
Evaluates predicted index caching results against expected results.
|
| 5 |
+
Compares all 5 fields with equal weight:
|
| 6 |
+
- is_index_based
|
| 7 |
+
- index_value
|
| 8 |
+
- parent_element_id
|
| 9 |
+
- element_id_of_nth_child_of_parent
|
| 10 |
+
- selected_element_is_correct
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Dict, Any, Optional
|
| 14 |
+
import json
|
| 15 |
+
import re
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from .base_evaluator import BaseEvaluator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class IndexCachingEvaluator(BaseEvaluator):
|
| 22 |
+
"""
|
| 23 |
+
Evaluator for index caching use case.
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Compares all 5 fields with equal weight (20% each)
|
| 27 |
+
- Parses JSON from LLM response
|
| 28 |
+
- Handles null values correctly
|
| 29 |
+
- Returns detailed field-by-field comparison
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
|
| 33 |
+
"""
|
| 34 |
+
Initialize index caching evaluator.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
metric_weights: Weights for evaluation metrics
|
| 38 |
+
Default: Equal weight for all 5 fields (0.2 each)
|
| 39 |
+
"""
|
| 40 |
+
# Each field gets 20% weight (5 fields * 0.2 = 1.0)
|
| 41 |
+
default_weights = {
|
| 42 |
+
"is_index_based_match": 0.2,
|
| 43 |
+
"index_value_match": 0.2,
|
| 44 |
+
"parent_element_id_match": 0.2,
|
| 45 |
+
"element_id_of_nth_child_match": 0.2,
|
| 46 |
+
"selected_element_correct_match": 0.2,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
weights = metric_weights or default_weights
|
| 50 |
+
super().__init__(metric_weights=weights)
|
| 51 |
+
|
| 52 |
+
def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
|
| 53 |
+
"""
|
| 54 |
+
Evaluate predicted index caching result against expected result.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
predicted: LLM's output (JSON string with all 5 fields)
|
| 58 |
+
expected: Expected output (JSON string or dict with all 5 fields)
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dictionary with evaluation metrics:
|
| 62 |
+
{
|
| 63 |
+
"is_index_based_match": 1.0 or 0.0,
|
| 64 |
+
"index_value_match": 1.0 or 0.0,
|
| 65 |
+
"parent_element_id_match": 1.0 or 0.0,
|
| 66 |
+
"element_id_of_nth_child_match": 1.0 or 0.0,
|
| 67 |
+
"selected_element_correct_match": 1.0 or 0.0,
|
| 68 |
+
"composite_score": 0.0 to 1.0,
|
| 69 |
+
"predicted_output": str,
|
| 70 |
+
"expected_output": str,
|
| 71 |
+
"field_scores": {...},
|
| 72 |
+
"evaluation_reason": str
|
| 73 |
+
}
|
| 74 |
+
"""
|
| 75 |
+
if not predicted or not expected:
|
| 76 |
+
return {
|
| 77 |
+
"is_index_based_match": 0.0,
|
| 78 |
+
"index_value_match": 0.0,
|
| 79 |
+
"parent_element_id_match": 0.0,
|
| 80 |
+
"element_id_of_nth_child_match": 0.0,
|
| 81 |
+
"selected_element_correct_match": 0.0,
|
| 82 |
+
"composite_score": 0.0,
|
| 83 |
+
"predicted_output": str(predicted).strip() if predicted else "",
|
| 84 |
+
"expected_output": str(expected).strip() if expected else "",
|
| 85 |
+
"field_scores": {},
|
| 86 |
+
"evaluation_reason": "❌ Empty or missing input/output"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Parse expected (could be JSON string or dict)
|
| 90 |
+
try:
|
| 91 |
+
if isinstance(expected, str):
|
| 92 |
+
expected_dict = json.loads(expected)
|
| 93 |
+
else:
|
| 94 |
+
expected_dict = expected
|
| 95 |
+
except (json.JSONDecodeError, TypeError):
|
| 96 |
+
# If expected is already a dict from dataset
|
| 97 |
+
expected_dict = expected if isinstance(expected, dict) else {}
|
| 98 |
+
|
| 99 |
+
# Parse predicted (must be JSON string)
|
| 100 |
+
try:
|
| 101 |
+
predicted_dict = self._parse_json_response(predicted)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
# Log the actual response for debugging
|
| 104 |
+
response_preview = predicted[:200] if predicted else "(empty)"
|
| 105 |
+
self.logger.warning(f"Failed to parse predicted JSON: {e}")
|
| 106 |
+
self.logger.warning(f"Response preview: {response_preview}...")
|
| 107 |
+
predicted_dict = {}
|
| 108 |
+
|
| 109 |
+
# NOTE: "notes" field is present in the output but is NOT used for scoring or reflection
|
| 110 |
+
# It's kept for reference but ignored in evaluation
|
| 111 |
+
|
| 112 |
+
# Compare each field (only the 5 core fields, ignoring "notes")
|
| 113 |
+
field_scores = {}
|
| 114 |
+
field_reasons = []
|
| 115 |
+
|
| 116 |
+
# 1. is_index_based (boolean)
|
| 117 |
+
pred_is_index = predicted_dict.get("is_index_based")
|
| 118 |
+
exp_is_index = expected_dict.get("is_index_based")
|
| 119 |
+
is_index_match = (pred_is_index == exp_is_index) if (pred_is_index is not None and exp_is_index is not None) else False
|
| 120 |
+
field_scores["is_index_based"] = 1.0 if is_index_match else 0.0
|
| 121 |
+
field_reasons.append(f"is_index_based: {pred_is_index} vs {exp_is_index} → {'✅' if is_index_match else '❌'}")
|
| 122 |
+
|
| 123 |
+
# 2. index_value (int or null)
|
| 124 |
+
pred_index_val = predicted_dict.get("index_value")
|
| 125 |
+
exp_index_val = expected_dict.get("index_value")
|
| 126 |
+
# Handle null/None comparison
|
| 127 |
+
index_val_match = (pred_index_val == exp_index_val) or (pred_index_val is None and exp_index_val is None)
|
| 128 |
+
field_scores["index_value"] = 1.0 if index_val_match else 0.0
|
| 129 |
+
field_reasons.append(f"index_value: {pred_index_val} vs {exp_index_val} → {'✅' if index_val_match else '❌'}")
|
| 130 |
+
|
| 131 |
+
# 3. parent_element_id (string or null)
|
| 132 |
+
pred_parent = predicted_dict.get("parent_element_id")
|
| 133 |
+
exp_parent = expected_dict.get("parent_element_id")
|
| 134 |
+
# Handle null/None comparison
|
| 135 |
+
parent_match = (pred_parent == exp_parent) or (pred_parent is None and exp_parent is None)
|
| 136 |
+
field_scores["parent_element_id"] = 1.0 if parent_match else 0.0
|
| 137 |
+
field_reasons.append(f"parent_element_id: {pred_parent} vs {exp_parent} → {'✅' if parent_match else '❌'}")
|
| 138 |
+
|
| 139 |
+
# 4. element_id_of_nth_child_of_parent (string or null)
|
| 140 |
+
pred_element = predicted_dict.get("element_id_of_nth_child_of_parent")
|
| 141 |
+
exp_element = expected_dict.get("element_id_of_nth_child_of_parent")
|
| 142 |
+
# Handle null/None comparison
|
| 143 |
+
element_match = (pred_element == exp_element) or (pred_element is None and exp_element is None)
|
| 144 |
+
field_scores["element_id_of_nth_child_of_parent"] = 1.0 if element_match else 0.0
|
| 145 |
+
field_reasons.append(f"element_id_of_nth_child: {pred_element} vs {exp_element} → {'✅' if element_match else '❌'}")
|
| 146 |
+
|
| 147 |
+
# 5. selected_element_is_correct (boolean)
|
| 148 |
+
pred_selected = predicted_dict.get("selected_element_is_correct")
|
| 149 |
+
exp_selected = expected_dict.get("selected_element_is_correct")
|
| 150 |
+
selected_match = (pred_selected == exp_selected) if (pred_selected is not None and exp_selected is not None) else False
|
| 151 |
+
field_scores["selected_element_is_correct"] = 1.0 if selected_match else 0.0
|
| 152 |
+
field_reasons.append(f"selected_element_is_correct: {pred_selected} vs {exp_selected} → {'✅' if selected_match else '❌'}")
|
| 153 |
+
|
| 154 |
+
# Calculate composite score (weighted average)
|
| 155 |
+
composite_score = (
|
| 156 |
+
field_scores["is_index_based"] * 0.2 +
|
| 157 |
+
field_scores["index_value"] * 0.2 +
|
| 158 |
+
field_scores["parent_element_id"] * 0.2 +
|
| 159 |
+
field_scores["element_id_of_nth_child_of_parent"] * 0.2 +
|
| 160 |
+
field_scores["selected_element_is_correct"] * 0.2
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Build evaluation reason
|
| 164 |
+
all_match = composite_score == 1.0
|
| 165 |
+
reason = "✅ All fields match!" if all_match else f"❌ Partial match ({composite_score:.1%})"
|
| 166 |
+
reason += "\n" + "\n".join(f" {r}" for r in field_reasons)
|
| 167 |
+
|
| 168 |
+
# Log evaluation details
|
| 169 |
+
self.logger.info(f"\n{'─'*70}")
|
| 170 |
+
self.logger.info(f"📊 INDEX CACHING EVALUATION")
|
| 171 |
+
self.logger.info(f"{'─'*70}")
|
| 172 |
+
self.logger.info(f" 🎯 COMPOSITE SCORE: {composite_score:.2f} ({composite_score:.1%})")
|
| 173 |
+
for field, score in field_scores.items():
|
| 174 |
+
status = "✅" if score == 1.0 else "❌"
|
| 175 |
+
self.logger.info(f" {status} {field}: {score:.0f}")
|
| 176 |
+
self.logger.info(f"{'─'*70}\n")
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"is_index_based_match": field_scores["is_index_based"],
|
| 180 |
+
"index_value_match": field_scores["index_value"],
|
| 181 |
+
"parent_element_id_match": field_scores["parent_element_id"],
|
| 182 |
+
"element_id_of_nth_child_match": field_scores["element_id_of_nth_child_of_parent"],
|
| 183 |
+
"selected_element_correct_match": field_scores["selected_element_is_correct"],
|
| 184 |
+
"composite_score": composite_score,
|
| 185 |
+
"predicted_output": predicted,
|
| 186 |
+
"expected_output": json.dumps(expected_dict) if isinstance(expected_dict, dict) else str(expected),
|
| 187 |
+
"predicted_dict": predicted_dict,
|
| 188 |
+
"expected_dict": expected_dict,
|
| 189 |
+
"field_scores": field_scores,
|
| 190 |
+
"evaluation_reason": reason
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def _parse_json_response(self, response: str) -> Dict[str, Any]:
|
| 194 |
+
"""
|
| 195 |
+
Parse JSON from LLM response, handling markdown code blocks and various formats.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
response: LLM response string (may contain markdown)
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Parsed JSON dictionary (empty dict if parsing fails)
|
| 202 |
+
"""
|
| 203 |
+
if not response or not isinstance(response, str):
|
| 204 |
+
return {}
|
| 205 |
+
|
| 206 |
+
response = response.strip()
|
| 207 |
+
|
| 208 |
+
# If response is empty, return empty dict
|
| 209 |
+
if not response:
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
# Strategy 1: Try to extract JSON from markdown code block
|
| 213 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response, re.DOTALL)
|
| 214 |
+
if json_match:
|
| 215 |
+
try:
|
| 216 |
+
json_str = json_match.group(1).strip()
|
| 217 |
+
return json.loads(json_str)
|
| 218 |
+
except json.JSONDecodeError:
|
| 219 |
+
pass
|
| 220 |
+
|
| 221 |
+
# Strategy 2: Find JSON object in response (handle nested braces)
|
| 222 |
+
json_start = response.find('{')
|
| 223 |
+
if json_start != -1:
|
| 224 |
+
# Find matching closing brace
|
| 225 |
+
brace_count = 0
|
| 226 |
+
json_end = json_start
|
| 227 |
+
for i in range(json_start, len(response)):
|
| 228 |
+
if response[i] == '{':
|
| 229 |
+
brace_count += 1
|
| 230 |
+
elif response[i] == '}':
|
| 231 |
+
brace_count -= 1
|
| 232 |
+
if brace_count == 0:
|
| 233 |
+
json_end = i + 1
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
if brace_count == 0:
|
| 237 |
+
json_str = response[json_start:json_end]
|
| 238 |
+
try:
|
| 239 |
+
return json.loads(json_str)
|
| 240 |
+
except json.JSONDecodeError:
|
| 241 |
+
pass
|
| 242 |
+
|
| 243 |
+
# Strategy 3: Try to find any JSON-like structure (more lenient)
|
| 244 |
+
# Look for patterns like {"key": "value"} even if not perfectly formatted
|
| 245 |
+
json_pattern = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
|
| 246 |
+
if json_pattern:
|
| 247 |
+
try:
|
| 248 |
+
return json.loads(json_pattern.group(0))
|
| 249 |
+
except json.JSONDecodeError:
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
# Strategy 4: Try parsing entire response as JSON
|
| 253 |
+
try:
|
| 254 |
+
return json.loads(response)
|
| 255 |
+
except json.JSONDecodeError:
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
# If all strategies fail, return empty dict
|
| 259 |
+
self.logger.debug(f"Could not parse JSON from response: {response[:100]}...")
|
| 260 |
+
return {}
|
| 261 |
+
|
| 262 |
+
def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get summary statistics for a batch of evaluations.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
results: List of evaluation result dictionaries
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Summary statistics including accuracy per field and overall
|
| 271 |
+
"""
|
| 272 |
+
if not results:
|
| 273 |
+
return {
|
| 274 |
+
"total_samples": 0,
|
| 275 |
+
"overall_accuracy": 0.0,
|
| 276 |
+
"field_accuracies": {},
|
| 277 |
+
"perfect_matches": 0
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
total = len(results)
|
| 281 |
+
perfect_matches = sum(1 for r in results if r.get("composite_score", 0.0) == 1.0)
|
| 282 |
+
overall_accuracy = perfect_matches / total if total > 0 else 0.0
|
| 283 |
+
|
| 284 |
+
# Calculate accuracy per field
|
| 285 |
+
field_accuracies = {
|
| 286 |
+
"is_index_based": sum(1 for r in results if r.get("is_index_based_match", 0.0) == 1.0) / total,
|
| 287 |
+
"index_value": sum(1 for r in results if r.get("index_value_match", 0.0) == 1.0) / total,
|
| 288 |
+
"parent_element_id": sum(1 for r in results if r.get("parent_element_id_match", 0.0) == 1.0) / total,
|
| 289 |
+
"element_id_of_nth_child": sum(1 for r in results if r.get("element_id_of_nth_child_match", 0.0) == 1.0) / total,
|
| 290 |
+
"selected_element_is_correct": sum(1 for r in results if r.get("selected_element_correct_match", 0.0) == 1.0) / total,
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"total_samples": total,
|
| 295 |
+
"overall_accuracy": overall_accuracy,
|
| 296 |
+
"field_accuracies": field_accuracies,
|
| 297 |
+
"perfect_matches": perfect_matches,
|
| 298 |
+
"partial_matches": total - perfect_matches
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# Example usage and testing
|
| 303 |
+
if __name__ == "__main__":
|
| 304 |
+
print("🚀 Testing Index Caching Evaluator...")
|
| 305 |
+
|
| 306 |
+
evaluator = IndexCachingEvaluator()
|
| 307 |
+
|
| 308 |
+
# Test cases
|
| 309 |
+
test_cases = [
|
| 310 |
+
# (predicted, expected, should_be_perfect)
|
| 311 |
+
(
|
| 312 |
+
'{"is_index_based": true, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": true}',
|
| 313 |
+
{"is_index_based": True, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": True},
|
| 314 |
+
True
|
| 315 |
+
),
|
| 316 |
+
(
|
| 317 |
+
'{"is_index_based": false, "index_value": null, "parent_element_id": null, "element_id_of_nth_child_of_parent": null, "selected_element_is_correct": true}',
|
| 318 |
+
{"is_index_based": False, "index_value": None, "parent_element_id": None, "element_id_of_nth_child_of_parent": None, "selected_element_is_correct": True},
|
| 319 |
+
True
|
| 320 |
+
),
|
| 321 |
+
(
|
| 322 |
+
'{"is_index_based": true, "index_value": 3, "parent_element_id": null, "element_id_of_nth_child_of_parent": "aaaaaw", "selected_element_is_correct": true}',
|
| 323 |
+
{"is_index_based": True, "index_value": 3, "parent_element_id": None, "element_id_of_nth_child_of_parent": "aaaaaw", "selected_element_is_correct": True},
|
| 324 |
+
True
|
| 325 |
+
),
|
| 326 |
+
(
|
| 327 |
+
'{"is_index_based": true, "index_value": 2, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": true}',
|
| 328 |
+
{"is_index_based": True, "index_value": 1, "parent_element_id": "aaaabf", "element_id_of_nth_child_of_parent": "aaaabg", "selected_element_is_correct": True},
|
| 329 |
+
False # index_value mismatch
|
| 330 |
+
),
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
print("\n📝 Running test cases:")
|
| 334 |
+
print("-" * 80)
|
| 335 |
+
|
| 336 |
+
results = []
|
| 337 |
+
for predicted, expected, should_be_perfect in test_cases:
|
| 338 |
+
result = evaluator.evaluate(predicted, expected)
|
| 339 |
+
is_perfect = result["composite_score"] == 1.0
|
| 340 |
+
|
| 341 |
+
status = "✅" if is_perfect == should_be_perfect else "❌"
|
| 342 |
+
print(f"{status} Test: Perfect match = {is_perfect} (expected {should_be_perfect})")
|
| 343 |
+
print(f" Score: {result['composite_score']:.2f}")
|
| 344 |
+
print()
|
| 345 |
+
|
| 346 |
+
results.append(result)
|
| 347 |
+
|
| 348 |
+
# Summary
|
| 349 |
+
print("\n📊 Summary:")
|
| 350 |
+
summary = evaluator.get_evaluation_summary(results)
|
| 351 |
+
print(f" Total: {summary['total_samples']}")
|
| 352 |
+
print(f" Perfect matches: {summary['perfect_matches']}")
|
| 353 |
+
print(f" Overall accuracy: {summary['overall_accuracy']:.1%}")
|
| 354 |
+
print(f" Field accuracies:")
|
| 355 |
+
for field, acc in summary['field_accuracies'].items():
|
| 356 |
+
print(f" {field}: {acc:.1%}")
|
| 357 |
+
|
src/gepa_optimizer/evaluation/scroll_evaluator.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GENERIC String Match Evaluator
|
| 3 |
+
|
| 4 |
+
Compares predicted output against expected output (simple string comparison).
|
| 5 |
+
NO assumptions about what the output represents (IDs, text, JSON, etc.).
|
| 6 |
+
|
| 7 |
+
Let GEPA discover the correct output format through evolution and feedback!
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, Any
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .base_evaluator import BaseEvaluator
|
| 14 |
+
except ImportError:
|
| 15 |
+
# For standalone testing
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 19 |
+
from gepa_optimizer.evaluation.base_evaluator import BaseEvaluator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ScrollElementEvaluator(BaseEvaluator):
|
| 23 |
+
"""
|
| 24 |
+
GENERIC evaluator - just compares strings!
|
| 25 |
+
|
| 26 |
+
NO assumptions about:
|
| 27 |
+
- Output format (element IDs, text, JSON, etc.)
|
| 28 |
+
- Output structure
|
| 29 |
+
- What the task is
|
| 30 |
+
|
| 31 |
+
GEPA will learn the correct format through feedback and evolution.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, metric_weights: Dict[str, float] = None):
|
| 35 |
+
"""
|
| 36 |
+
Initialize evaluator.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
metric_weights: Weights for evaluation metrics
|
| 40 |
+
Default: {"output_match": 1.0}
|
| 41 |
+
"""
|
| 42 |
+
default_weights = {
|
| 43 |
+
"output_match": 1.0 # Simple string comparison
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
weights = metric_weights or default_weights
|
| 47 |
+
super().__init__(metric_weights=weights)
|
| 48 |
+
|
| 49 |
+
def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
|
| 50 |
+
"""
|
| 51 |
+
Binary evaluation with element ID extraction.
|
| 52 |
+
|
| 53 |
+
Phase 1 Implementation:
|
| 54 |
+
- Extracts element IDs using regex patterns (flexible format support)
|
| 55 |
+
- Uses INTEGER comparison for robustness (prevents "4" vs "14" bugs)
|
| 56 |
+
- Binary scoring: correct element = 1.0, wrong/missing = 0.0
|
| 57 |
+
|
| 58 |
+
Scoring Strategy:
|
| 59 |
+
1. Extract element ID from both predicted and expected outputs
|
| 60 |
+
2. Compare using integer arithmetic (not string comparison)
|
| 61 |
+
3. Return 1.0 if match, 0.0 otherwise (no partial credit)
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
predicted: LLM's output (may include verbose explanation)
|
| 65 |
+
expected: Expected output (may include verbose explanation)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Dictionary with evaluation metrics and extracted element IDs
|
| 69 |
+
"""
|
| 70 |
+
import re
|
| 71 |
+
|
| 72 |
+
if not predicted or not expected:
|
| 73 |
+
return {
|
| 74 |
+
"content_match": 0.0,
|
| 75 |
+
"output_match": 0.0,
|
| 76 |
+
"composite_score": 0.0,
|
| 77 |
+
"predicted_output": str(predicted).strip() if predicted else "",
|
| 78 |
+
"expected_output": str(expected).strip() if expected else "",
|
| 79 |
+
"predicted_element": "None",
|
| 80 |
+
"expected_element": "None",
|
| 81 |
+
"evaluation_reason": "❌ Empty or missing input/output"
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
predicted_str = str(predicted).strip()
|
| 85 |
+
expected_str = str(expected).strip()
|
| 86 |
+
|
| 87 |
+
# 1. Extract element numbers using MULTIPLE strategies (flexible!)
|
| 88 |
+
# Strategy A: "Element: X" or "Element X" (explicit format)
|
| 89 |
+
element_pattern_a = r'element[:\s]+(\d+)'
|
| 90 |
+
|
| 91 |
+
# Strategy B: "element X" or "Element X" anywhere in text
|
| 92 |
+
element_pattern_b = r'\belement\s+(\d+)\b'
|
| 93 |
+
|
| 94 |
+
# Strategy C: Just find ANY number if other strategies fail (last resort)
|
| 95 |
+
number_pattern = r'\b(\d+)\b'
|
| 96 |
+
|
| 97 |
+
# Try to extract from predicted
|
| 98 |
+
pred_match = re.search(element_pattern_a, predicted_str, re.IGNORECASE)
|
| 99 |
+
if not pred_match:
|
| 100 |
+
pred_match = re.search(element_pattern_b, predicted_str, re.IGNORECASE)
|
| 101 |
+
if not pred_match:
|
| 102 |
+
# Last resort: find first number in the text
|
| 103 |
+
pred_match = re.search(number_pattern, predicted_str)
|
| 104 |
+
|
| 105 |
+
# Try to extract from expected
|
| 106 |
+
exp_match = re.search(element_pattern_a, expected_str, re.IGNORECASE)
|
| 107 |
+
if not exp_match:
|
| 108 |
+
exp_match = re.search(element_pattern_b, expected_str, re.IGNORECASE)
|
| 109 |
+
if not exp_match:
|
| 110 |
+
exp_match = re.search(number_pattern, expected_str)
|
| 111 |
+
|
| 112 |
+
# 2. Check if we found element numbers in both
|
| 113 |
+
if not exp_match:
|
| 114 |
+
# Expected doesn't have element pattern - fallback to exact match
|
| 115 |
+
content_score = 1.0 if predicted_str.lower() == expected_str.lower() else 0.0
|
| 116 |
+
elif not pred_match:
|
| 117 |
+
# Predicted doesn't have element number - WRONG
|
| 118 |
+
content_score = 0.0
|
| 119 |
+
else:
|
| 120 |
+
# Both have element pattern - compare using INTEGER comparison
|
| 121 |
+
pred_element = pred_match.group(1)
|
| 122 |
+
exp_element = exp_match.group(1)
|
| 123 |
+
|
| 124 |
+
# 🔥 Phase 1: Use INTEGER comparison for robustness
|
| 125 |
+
# This prevents bugs like "4" != "14" string comparison issues
|
| 126 |
+
try:
|
| 127 |
+
pred_num = int(pred_element)
|
| 128 |
+
exp_num = int(exp_element)
|
| 129 |
+
|
| 130 |
+
# Integer comparison (more robust than string)
|
| 131 |
+
content_score = 1.0 if pred_num == exp_num else 0.0
|
| 132 |
+
|
| 133 |
+
# Log comparison for debugging
|
| 134 |
+
if pred_num != exp_num:
|
| 135 |
+
import logging
|
| 136 |
+
logger = logging.getLogger(__name__)
|
| 137 |
+
logger.debug(f"Element mismatch: predicted={pred_num}, expected={exp_num}")
|
| 138 |
+
|
| 139 |
+
except (ValueError, TypeError) as e:
|
| 140 |
+
# Fallback to string comparison if conversion fails
|
| 141 |
+
import logging
|
| 142 |
+
logger = logging.getLogger(__name__)
|
| 143 |
+
logger.warning(f"Could not convert elements to integers: {e}, using string comparison")
|
| 144 |
+
content_score = 1.0 if pred_element == exp_element else 0.0
|
| 145 |
+
|
| 146 |
+
# 3. Binary score and reason
|
| 147 |
+
if content_score == 1.0:
|
| 148 |
+
composite_score = 1.0
|
| 149 |
+
reason = "✅ Correct! Element number matches"
|
| 150 |
+
else:
|
| 151 |
+
composite_score = 0.0
|
| 152 |
+
if pred_match and exp_match:
|
| 153 |
+
reason = "❌ Wrong element number (predicted different element)"
|
| 154 |
+
else:
|
| 155 |
+
reason = "❌ Missing or invalid element number"
|
| 156 |
+
|
| 157 |
+
pred_element = pred_match.group(1) if pred_match else "None"
|
| 158 |
+
exp_element = exp_match.group(1) if exp_match else "None"
|
| 159 |
+
|
| 160 |
+
# Detailed logging for transparency
|
| 161 |
+
import logging
|
| 162 |
+
logger = logging.getLogger(__name__)
|
| 163 |
+
logger.info(f"\n{'─'*70}")
|
| 164 |
+
logger.info(f"📊 EVALUATION DETAILS")
|
| 165 |
+
logger.info(f"{'─'*70}")
|
| 166 |
+
logger.info(f" Expected: '{expected_str}' (Element: {exp_element})")
|
| 167 |
+
logger.info(f" Predicted: '{predicted_str}' (Element: {pred_element})")
|
| 168 |
+
logger.info(f" {'─'*66}")
|
| 169 |
+
logger.info(f" 🎯 SCORE: {composite_score:.2f} - {reason}")
|
| 170 |
+
logger.info(f"{'─'*70}\n")
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"content_match": content_score,
|
| 174 |
+
"output_match": composite_score, # This is what GEPA uses
|
| 175 |
+
"composite_score": composite_score,
|
| 176 |
+
"predicted_output": predicted_str,
|
| 177 |
+
"expected_output": expected_str,
|
| 178 |
+
"predicted_element": pred_element,
|
| 179 |
+
"expected_element": exp_element,
|
| 180 |
+
"evaluation_reason": reason
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
|
| 184 |
+
"""
|
| 185 |
+
Get summary statistics for a batch of evaluations.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
results: List of evaluation result dictionaries
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Summary statistics
|
| 192 |
+
"""
|
| 193 |
+
if not results:
|
| 194 |
+
return {
|
| 195 |
+
"total_samples": 0,
|
| 196 |
+
"accuracy": 0.0,
|
| 197 |
+
"correct_predictions": 0
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
total = len(results)
|
| 201 |
+
correct = sum(1 for r in results if r.get("output_match", 0.0) == 1.0)
|
| 202 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
"total_samples": total,
|
| 206 |
+
"accuracy": accuracy,
|
| 207 |
+
"correct_predictions": correct,
|
| 208 |
+
"incorrect_predictions": total - correct
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Example usage and testing
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
print("🚀 Testing Scroll Element Evaluator...")
|
| 215 |
+
|
| 216 |
+
evaluator = ScrollElementEvaluator()
|
| 217 |
+
|
| 218 |
+
# Test cases
|
| 219 |
+
test_cases = [
|
| 220 |
+
("4", "4", True),
|
| 221 |
+
("Element: 4", "4", True),
|
| 222 |
+
("Element 4", "4", True),
|
| 223 |
+
("The element to interact with is 4", "4", True),
|
| 224 |
+
("Element ID: 4", "4", True),
|
| 225 |
+
("Click on element 4 to scroll", "4", True),
|
| 226 |
+
("5", "4", False),
|
| 227 |
+
("Element: 5", "4", False),
|
| 228 |
+
("No element found", "4", False),
|
| 229 |
+
("", "4", False),
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
print("\n📝 Running test cases:")
|
| 233 |
+
print("-" * 80)
|
| 234 |
+
|
| 235 |
+
results = []
|
| 236 |
+
for predicted, expected, should_match in test_cases:
|
| 237 |
+
result = evaluator.evaluate(predicted, expected)
|
| 238 |
+
match = result["composite_score"] == 1.0
|
| 239 |
+
|
| 240 |
+
status = "✅" if match == should_match else "❌"
|
| 241 |
+
print(f"{status} Predicted: '{predicted}' | Expected: '{expected}' | Match: {match}")
|
| 242 |
+
|
| 243 |
+
results.append(result)
|
| 244 |
+
|
| 245 |
+
# Summary
|
| 246 |
+
print("\n📊 Summary:")
|
| 247 |
+
summary = evaluator.get_evaluation_summary(results)
|
| 248 |
+
print(f" Total: {summary['total_samples']}")
|
| 249 |
+
print(f" Correct: {summary['correct_predictions']}")
|
| 250 |
+
print(f" Accuracy: {summary['accuracy']:.1%}")
|
| 251 |
+
|
src/gepa_optimizer/evaluation/ui_evaluator.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UI Tree Evaluator for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import difflib
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
from .base_evaluator import BaseEvaluator
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class UITreeEvaluator(BaseEvaluator):
|
| 15 |
+
"""
|
| 16 |
+
Comprehensive evaluator for UI tree extraction quality.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
|
| 20 |
+
"""
|
| 21 |
+
Initializes the UITreeEvaluator with configurable metric weights.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
metric_weights: A dictionary of weights for different metrics.
|
| 25 |
+
If None, default weights will be used.
|
| 26 |
+
"""
|
| 27 |
+
# Set default weights for UI tree evaluation
|
| 28 |
+
default_weights = {
|
| 29 |
+
"element_completeness": 0.3, # How many elements are captured
|
| 30 |
+
"element_type_accuracy": 0.25, # Correct element types (Button, Text, etc.)
|
| 31 |
+
"text_content_accuracy": 0.2, # Text content matches
|
| 32 |
+
"hierarchy_accuracy": 0.15, # Parent-child relationships
|
| 33 |
+
"style_accuracy": 0.1, # Style properties captured
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Use provided weights or defaults
|
| 37 |
+
weights = metric_weights or default_weights
|
| 38 |
+
|
| 39 |
+
# Initialize parent class
|
| 40 |
+
super().__init__(metric_weights=weights)
|
| 41 |
+
|
| 42 |
+
# Normalize weights
|
| 43 |
+
self._normalize_weights()
|
| 44 |
+
|
| 45 |
+
def _normalize_weights(self):
|
| 46 |
+
"""Normalize weights to sum to 1.0"""
|
| 47 |
+
total_weight = sum(self.metric_weights.values())
|
| 48 |
+
if total_weight > 0:
|
| 49 |
+
self.metric_weights = {k: v / total_weight for k, v in self.metric_weights.items()}
|
| 50 |
+
else:
|
| 51 |
+
self.logger.warning("Total metric weight is zero. Scores will be zero.")
|
| 52 |
+
|
| 53 |
+
def evaluate(self, predicted_json: Dict[str, Any], expected_json: Dict[str, Any]) -> Dict[str, float]:
|
| 54 |
+
"""
|
| 55 |
+
Generates a weighted composite score from individual metrics.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
predicted_json: The JSON generated by the LLM.
|
| 59 |
+
expected_json: The ground truth JSON.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
A dictionary of individual metric scores and the composite score.
|
| 63 |
+
"""
|
| 64 |
+
scores = {
|
| 65 |
+
"element_completeness": self.calculate_element_completeness(predicted_json, expected_json),
|
| 66 |
+
"element_type_accuracy": self.calculate_element_type_accuracy(predicted_json, expected_json),
|
| 67 |
+
"text_content_accuracy": self.calculate_text_content_accuracy(predicted_json, expected_json),
|
| 68 |
+
"hierarchy_accuracy": self.calculate_hierarchy_accuracy(predicted_json, expected_json),
|
| 69 |
+
"style_accuracy": self.calculate_style_accuracy(predicted_json, expected_json),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
composite_score = sum(scores[metric] * self.metric_weights.get(metric, 0) for metric in scores)
|
| 73 |
+
scores["composite_score"] = composite_score
|
| 74 |
+
|
| 75 |
+
# Add detailed logging for debugging
|
| 76 |
+
logger.debug(f"Evaluation scores: {scores}")
|
| 77 |
+
logger.debug(f"Composite score: {composite_score:.4f}")
|
| 78 |
+
|
| 79 |
+
# Add small improvement bonus for better prompts (encourage GEPA to accept improvements)
|
| 80 |
+
# This helps GEPA recognize even tiny improvements
|
| 81 |
+
if composite_score > 0.05: # If we have any meaningful content
|
| 82 |
+
composite_score = min(composite_score + 0.001, 1.0) # Small bonus to encourage acceptance
|
| 83 |
+
|
| 84 |
+
return scores
|
| 85 |
+
|
| 86 |
+
def calculate_element_completeness(self, predicted: Dict, expected: Dict) -> float:
|
| 87 |
+
"""
|
| 88 |
+
Calculates how many UI elements are captured in the predicted JSON.
|
| 89 |
+
This is the most important metric for UI tree extraction.
|
| 90 |
+
"""
|
| 91 |
+
def _count_elements(node):
|
| 92 |
+
"""Count total elements in the tree"""
|
| 93 |
+
if not isinstance(node, dict):
|
| 94 |
+
return 0
|
| 95 |
+
count = 1 # Count current node
|
| 96 |
+
for child in node.get("children", []):
|
| 97 |
+
count += _count_elements(child)
|
| 98 |
+
return count
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
predicted_count = _count_elements(predicted)
|
| 102 |
+
expected_count = _count_elements(expected)
|
| 103 |
+
|
| 104 |
+
if expected_count == 0:
|
| 105 |
+
return 1.0 if predicted_count == 0 else 0.0
|
| 106 |
+
|
| 107 |
+
# Score based on how many elements are captured
|
| 108 |
+
completeness_ratio = predicted_count / expected_count
|
| 109 |
+
|
| 110 |
+
# Give bonus for capturing more elements (up to 1.0)
|
| 111 |
+
# Penalize heavily for missing elements
|
| 112 |
+
if completeness_ratio >= 1.0:
|
| 113 |
+
return 1.0 # Perfect or better
|
| 114 |
+
elif completeness_ratio >= 0.8:
|
| 115 |
+
return completeness_ratio # Good coverage
|
| 116 |
+
elif completeness_ratio >= 0.5:
|
| 117 |
+
return completeness_ratio * 0.8 # Moderate coverage with penalty
|
| 118 |
+
else:
|
| 119 |
+
return completeness_ratio * 0.5 # Poor coverage with heavy penalty
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"Error calculating element completeness: {e}")
|
| 123 |
+
return 0.0
|
| 124 |
+
|
| 125 |
+
def calculate_element_type_accuracy(self, predicted: Dict, expected: Dict) -> float:
|
| 126 |
+
"""
|
| 127 |
+
Calculates element type accuracy by comparing the 'type' attribute of corresponding nodes.
|
| 128 |
+
Focuses on common UI element types like Button, Text, Image, etc.
|
| 129 |
+
"""
|
| 130 |
+
def _get_all_types(node):
|
| 131 |
+
if not isinstance(node, dict):
|
| 132 |
+
return []
|
| 133 |
+
types = [node.get("type")]
|
| 134 |
+
for child in node.get("children", []):
|
| 135 |
+
types.extend(_get_all_types(child))
|
| 136 |
+
return [t for t in types if t is not None]
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
predicted_types = _get_all_types(predicted)
|
| 140 |
+
expected_types = _get_all_types(expected)
|
| 141 |
+
|
| 142 |
+
if not expected_types:
|
| 143 |
+
return 1.0 if not predicted_types else 0.5
|
| 144 |
+
|
| 145 |
+
if not predicted_types:
|
| 146 |
+
return 0.0
|
| 147 |
+
|
| 148 |
+
# Count matching types with frequency consideration
|
| 149 |
+
expected_type_counts = {}
|
| 150 |
+
for t in expected_types:
|
| 151 |
+
expected_type_counts[t] = expected_type_counts.get(t, 0) + 1
|
| 152 |
+
|
| 153 |
+
predicted_type_counts = {}
|
| 154 |
+
for t in predicted_types:
|
| 155 |
+
predicted_type_counts[t] = predicted_type_counts.get(t, 0) + 1
|
| 156 |
+
|
| 157 |
+
# Calculate accuracy based on type matches
|
| 158 |
+
total_matches = 0
|
| 159 |
+
for type_name, expected_count in expected_type_counts.items():
|
| 160 |
+
predicted_count = predicted_type_counts.get(type_name, 0)
|
| 161 |
+
# Count matches up to the expected count
|
| 162 |
+
total_matches += min(predicted_count, expected_count)
|
| 163 |
+
|
| 164 |
+
return total_matches / len(expected_types) if expected_types else 0.0
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.warning(f"Error calculating element type accuracy: {e}")
|
| 168 |
+
return 0.0
|
| 169 |
+
|
| 170 |
+
def calculate_hierarchy_accuracy(self, predicted: Dict, expected: Dict) -> float:
|
| 171 |
+
"""
|
| 172 |
+
Calculates hierarchy accuracy by comparing parent-child relationships.
|
| 173 |
+
"""
|
| 174 |
+
def _get_hierarchy_structure(node, parent_type="ROOT"):
|
| 175 |
+
"""Extract hierarchy structure as (parent_type, child_type) pairs"""
|
| 176 |
+
if not isinstance(node, dict):
|
| 177 |
+
return []
|
| 178 |
+
|
| 179 |
+
current_type = node.get("type", "unknown")
|
| 180 |
+
hierarchy = [(parent_type, current_type)]
|
| 181 |
+
|
| 182 |
+
for child in node.get("children", []):
|
| 183 |
+
hierarchy.extend(_get_hierarchy_structure(child, current_type))
|
| 184 |
+
|
| 185 |
+
return hierarchy
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
predicted_hierarchy = _get_hierarchy_structure(predicted)
|
| 189 |
+
expected_hierarchy = _get_hierarchy_structure(expected)
|
| 190 |
+
|
| 191 |
+
if not expected_hierarchy:
|
| 192 |
+
return 1.0 if not predicted_hierarchy else 0.5
|
| 193 |
+
|
| 194 |
+
if not predicted_hierarchy:
|
| 195 |
+
return 0.0
|
| 196 |
+
|
| 197 |
+
# Count matching hierarchy relationships
|
| 198 |
+
expected_hierarchy_set = set(expected_hierarchy)
|
| 199 |
+
predicted_hierarchy_set = set(predicted_hierarchy)
|
| 200 |
+
|
| 201 |
+
matches = len(expected_hierarchy_set.intersection(predicted_hierarchy_set))
|
| 202 |
+
total_expected = len(expected_hierarchy_set)
|
| 203 |
+
|
| 204 |
+
return matches / total_expected if total_expected > 0 else 0.0
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.warning(f"Error calculating hierarchy accuracy: {e}")
|
| 208 |
+
return 0.0
|
| 209 |
+
|
| 210 |
+
def calculate_text_content_accuracy(self, predicted: Dict, expected: Dict) -> float:
|
| 211 |
+
"""
|
| 212 |
+
Calculates text content accuracy by comparing the 'text' attribute of corresponding nodes.
|
| 213 |
+
"""
|
| 214 |
+
def _get_all_texts(node):
|
| 215 |
+
if not isinstance(node, dict):
|
| 216 |
+
return []
|
| 217 |
+
texts = [node.get("text")]
|
| 218 |
+
for child in node.get("children", []):
|
| 219 |
+
texts.extend(_get_all_texts(child))
|
| 220 |
+
return [t for t in texts if t is not None and str(t).strip()]
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
predicted_texts = _get_all_texts(predicted)
|
| 224 |
+
expected_texts = _get_all_texts(expected)
|
| 225 |
+
|
| 226 |
+
if not expected_texts:
|
| 227 |
+
return 1.0 if not predicted_texts else 0.5 # Partial credit if predicted has texts but expected doesn't
|
| 228 |
+
|
| 229 |
+
if not predicted_texts:
|
| 230 |
+
return 0.0 # No predicted texts, so no match
|
| 231 |
+
|
| 232 |
+
total_similarity = 0.0
|
| 233 |
+
for p_text in predicted_texts:
|
| 234 |
+
best_similarity = 0.0
|
| 235 |
+
for e_text in expected_texts:
|
| 236 |
+
similarity = difflib.SequenceMatcher(None, str(p_text).strip(), str(e_text).strip()).ratio()
|
| 237 |
+
best_similarity = max(best_similarity, similarity)
|
| 238 |
+
total_similarity += best_similarity
|
| 239 |
+
|
| 240 |
+
# Average similarity over all predicted texts
|
| 241 |
+
if not predicted_texts and not expected_texts:
|
| 242 |
+
return 1.0
|
| 243 |
+
elif not predicted_texts:
|
| 244 |
+
return 0.0
|
| 245 |
+
else:
|
| 246 |
+
return total_similarity / len(predicted_texts)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.warning(f"Error calculating text content accuracy: {e}")
|
| 249 |
+
return 0.0
|
| 250 |
+
|
| 251 |
+
def calculate_style_accuracy(self, predicted: Dict, expected: Dict) -> float:
|
| 252 |
+
"""
|
| 253 |
+
Calculates style accuracy by comparing style properties.
|
| 254 |
+
"""
|
| 255 |
+
def _get_all_styles(node):
|
| 256 |
+
"""Extract all style properties from the tree"""
|
| 257 |
+
if not isinstance(node, dict):
|
| 258 |
+
return []
|
| 259 |
+
|
| 260 |
+
styles = []
|
| 261 |
+
if "style" in node and isinstance(node["style"], dict):
|
| 262 |
+
styles.append(node["style"])
|
| 263 |
+
|
| 264 |
+
for child in node.get("children", []):
|
| 265 |
+
styles.extend(_get_all_styles(child))
|
| 266 |
+
|
| 267 |
+
return styles
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
predicted_styles = _get_all_styles(predicted)
|
| 271 |
+
expected_styles = _get_all_styles(expected)
|
| 272 |
+
|
| 273 |
+
if not expected_styles:
|
| 274 |
+
return 1.0 if not predicted_styles else 0.5
|
| 275 |
+
|
| 276 |
+
if not predicted_styles:
|
| 277 |
+
return 0.0
|
| 278 |
+
|
| 279 |
+
# Calculate style property overlap
|
| 280 |
+
total_style_properties = 0
|
| 281 |
+
matching_properties = 0
|
| 282 |
+
|
| 283 |
+
for exp_style in expected_styles:
|
| 284 |
+
for prop_name, prop_value in exp_style.items():
|
| 285 |
+
total_style_properties += 1
|
| 286 |
+
|
| 287 |
+
# Find matching property in predicted styles
|
| 288 |
+
for pred_style in predicted_styles:
|
| 289 |
+
if prop_name in pred_style and pred_style[prop_name] == prop_value:
|
| 290 |
+
matching_properties += 1
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
return matching_properties / total_style_properties if total_style_properties > 0 else 0.0
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.warning(f"Error calculating style accuracy: {e}")
|
| 297 |
+
return 0.0
|
src/gepa_optimizer/evaluation/universal_evaluator.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Universal Semantic Evaluator for ANY prompt optimization use case.
|
| 3 |
+
|
| 4 |
+
This evaluator uses LLM-powered semantic analysis to compare predicted vs expected outputs,
|
| 5 |
+
enabling prompt optimization for ANY task without requiring custom evaluator code.
|
| 6 |
+
|
| 7 |
+
Key Features:
|
| 8 |
+
- Semantic understanding (not just string matching)
|
| 9 |
+
- Works with text, JSON, numbers, structured outputs
|
| 10 |
+
- Provides rich feedback for GEPA reflection
|
| 11 |
+
- No task-specific assumptions
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import re
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Dict, Any, Optional, List
|
| 18 |
+
from difflib import SequenceMatcher
|
| 19 |
+
|
| 20 |
+
from .base_evaluator import BaseEvaluator
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class UniversalSemanticEvaluator(BaseEvaluator):
|
| 26 |
+
"""
|
| 27 |
+
Universal evaluator using LLM for semantic comparison.
|
| 28 |
+
|
| 29 |
+
Works for ANY task without hardcoded assumptions:
|
| 30 |
+
- Text outputs: "The answer is 42" vs "42"
|
| 31 |
+
- JSON outputs: {"count": 23} vs {"count": 22}
|
| 32 |
+
- Structured data: Lists, nested objects
|
| 33 |
+
- Multi-modal: Image descriptions, analysis results
|
| 34 |
+
|
| 35 |
+
Evaluation Strategy:
|
| 36 |
+
1. Quick checks (exact match, empty handling)
|
| 37 |
+
2. Structural comparison (for JSON/structured data)
|
| 38 |
+
3. LLM semantic analysis (for meaning understanding)
|
| 39 |
+
4. Combine into composite score with rich feedback
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
llm_client=None,
|
| 45 |
+
use_llm_analysis: bool = True,
|
| 46 |
+
semantic_weight: float = 0.6,
|
| 47 |
+
structural_weight: float = 0.25,
|
| 48 |
+
exact_match_bonus: float = 0.15,
|
| 49 |
+
metric_weights: Optional[Dict[str, float]] = None
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Initialize Universal Semantic Evaluator.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
llm_client: LLM client for semantic analysis (optional, falls back to heuristics)
|
| 56 |
+
use_llm_analysis: Whether to use LLM for semantic comparison
|
| 57 |
+
semantic_weight: Weight for semantic similarity (0.0-1.0)
|
| 58 |
+
structural_weight: Weight for structural similarity (0.0-1.0)
|
| 59 |
+
exact_match_bonus: Bonus weight for exact matches (0.0-1.0)
|
| 60 |
+
metric_weights: Optional custom weights (overrides above)
|
| 61 |
+
"""
|
| 62 |
+
default_weights = metric_weights or {
|
| 63 |
+
"semantic_similarity": semantic_weight,
|
| 64 |
+
"structural_similarity": structural_weight,
|
| 65 |
+
"exact_match": exact_match_bonus
|
| 66 |
+
}
|
| 67 |
+
super().__init__(metric_weights=default_weights)
|
| 68 |
+
|
| 69 |
+
self.llm_client = llm_client
|
| 70 |
+
self.use_llm_analysis = use_llm_analysis and llm_client is not None
|
| 71 |
+
|
| 72 |
+
# Cache for LLM analysis to reduce API calls
|
| 73 |
+
self._analysis_cache: Dict[str, Dict] = {}
|
| 74 |
+
|
| 75 |
+
logger.info(f"🎯 Universal Semantic Evaluator initialized")
|
| 76 |
+
logger.info(f" LLM analysis: {'enabled' if self.use_llm_analysis else 'disabled (using heuristics)'}")
|
| 77 |
+
logger.info(f" Weights: semantic={semantic_weight}, structural={structural_weight}, exact={exact_match_bonus}")
|
| 78 |
+
|
| 79 |
+
def evaluate(self, predicted: Any, expected: Any) -> Dict[str, float]:
|
| 80 |
+
"""
|
| 81 |
+
Evaluate predicted output against expected output using semantic understanding.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
predicted: The model's predicted output (string, dict, or any serializable type)
|
| 85 |
+
expected: The ground truth expected output
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Dictionary with metrics including 'composite_score' (required for GEPA)
|
| 89 |
+
"""
|
| 90 |
+
# Convert to strings for comparison
|
| 91 |
+
predicted_str = self._to_string(predicted)
|
| 92 |
+
expected_str = self._to_string(expected)
|
| 93 |
+
|
| 94 |
+
# Initialize result
|
| 95 |
+
result = {
|
| 96 |
+
"composite_score": 0.0,
|
| 97 |
+
"exact_match": 0.0,
|
| 98 |
+
"semantic_similarity": 0.0,
|
| 99 |
+
"structural_similarity": 0.0,
|
| 100 |
+
"predicted_output": predicted_str[:500], # Truncate for logging
|
| 101 |
+
"expected_output": expected_str[:500],
|
| 102 |
+
"analysis": {},
|
| 103 |
+
"improvement_feedback": ""
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Handle empty/missing outputs
|
| 107 |
+
if not predicted_str or not predicted_str.strip():
|
| 108 |
+
result["improvement_feedback"] = "❌ Output is EMPTY. The prompt must instruct the model to produce output."
|
| 109 |
+
result["analysis"] = {"status": "empty_predicted"}
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
if not expected_str or not expected_str.strip():
|
| 113 |
+
result["improvement_feedback"] = "⚠️ Expected output is empty - cannot evaluate."
|
| 114 |
+
result["analysis"] = {"status": "empty_expected"}
|
| 115 |
+
result["composite_score"] = 0.5 # Neutral score
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
# ─────────────────────────────────────────────────────
|
| 119 |
+
# STEP 1: Exact Match Check (Fast Path)
|
| 120 |
+
# ─────────────────────────────────────────────────────
|
| 121 |
+
normalized_pred = self._normalize(predicted_str)
|
| 122 |
+
normalized_exp = self._normalize(expected_str)
|
| 123 |
+
|
| 124 |
+
if normalized_pred == normalized_exp:
|
| 125 |
+
result["exact_match"] = 1.0
|
| 126 |
+
result["semantic_similarity"] = 1.0
|
| 127 |
+
result["structural_similarity"] = 1.0
|
| 128 |
+
result["composite_score"] = 1.0
|
| 129 |
+
result["improvement_feedback"] = "✅ Perfect match! Output exactly matches expected."
|
| 130 |
+
result["analysis"] = {"status": "exact_match"}
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
# ─────────────────────────────────────────────────────
|
| 134 |
+
# STEP 1.5: FORMAT MISMATCH DETECTION (CRITICAL FIX)
|
| 135 |
+
# ─────────────────────────────────────────────────────
|
| 136 |
+
# 🔥 CRITICAL: Detect when expected is JSON but predicted is narrative text
|
| 137 |
+
# This causes catastrophically low scores and needs explicit handling
|
| 138 |
+
expected_is_json = self._try_parse_json(expected_str) is not None
|
| 139 |
+
predicted_is_json = self._try_parse_json(predicted_str) is not None
|
| 140 |
+
|
| 141 |
+
format_mismatch = expected_is_json and not predicted_is_json
|
| 142 |
+
if format_mismatch:
|
| 143 |
+
# Expected JSON but got narrative - this is a CRITICAL format error
|
| 144 |
+
# Give partial credit for semantic content but penalize heavily for format
|
| 145 |
+
result["analysis"]["format_mismatch"] = True
|
| 146 |
+
result["improvement_feedback"] = (
|
| 147 |
+
"❌ FORMAT ERROR: Expected JSON output but received narrative text. "
|
| 148 |
+
"The prompt MUST enforce JSON output format. "
|
| 149 |
+
"Add explicit instructions like: 'Output ONLY valid JSON, no explanations.' "
|
| 150 |
+
"Consider adding: 'Do NOT write prose or explanations.'"
|
| 151 |
+
)
|
| 152 |
+
# Still evaluate semantic content but cap the score
|
| 153 |
+
# This gives feedback for improving the prompt
|
| 154 |
+
logger.warning(f"⚠️ Format mismatch: expected JSON ({len(expected_str)} chars), got narrative ({len(predicted_str)} chars)")
|
| 155 |
+
|
| 156 |
+
# ─────────────────────────────────────────────────────
|
| 157 |
+
# STEP 2: Structural Comparison (for JSON/structured data)
|
| 158 |
+
# ─────────────────────────────────────────────────────
|
| 159 |
+
structural_result = self._compare_structure(predicted_str, expected_str)
|
| 160 |
+
result["structural_similarity"] = structural_result["score"]
|
| 161 |
+
result["analysis"]["structural"] = structural_result.get("details", {})
|
| 162 |
+
|
| 163 |
+
# ─────────────────────────────────────────────────────
|
| 164 |
+
# STEP 3: Semantic Analysis
|
| 165 |
+
# ─────────────────────────────────────────────────────
|
| 166 |
+
if self.use_llm_analysis:
|
| 167 |
+
semantic_result = self._llm_semantic_analysis(predicted_str, expected_str)
|
| 168 |
+
else:
|
| 169 |
+
semantic_result = self._heuristic_semantic_analysis(predicted_str, expected_str)
|
| 170 |
+
|
| 171 |
+
result["semantic_similarity"] = semantic_result["score"]
|
| 172 |
+
result["analysis"]["semantic"] = semantic_result.get("details", {})
|
| 173 |
+
result["improvement_feedback"] = semantic_result.get("feedback", "")
|
| 174 |
+
|
| 175 |
+
# ─────────────────────────────────────────────────────
|
| 176 |
+
# STEP 4: Compute Composite Score
|
| 177 |
+
# ─────────────────────────────────────────────────────
|
| 178 |
+
weights = self.metric_weights
|
| 179 |
+
composite = (
|
| 180 |
+
result["semantic_similarity"] * weights.get("semantic_similarity", 0.6) +
|
| 181 |
+
result["structural_similarity"] * weights.get("structural_similarity", 0.25) +
|
| 182 |
+
result["exact_match"] * weights.get("exact_match", 0.15)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# 🔥 CRITICAL FIX: Apply format mismatch penalty
|
| 186 |
+
# If expected JSON but got narrative, cap the score to encourage format compliance
|
| 187 |
+
if result.get("analysis", {}).get("format_mismatch"):
|
| 188 |
+
# Cap at 0.3 to indicate "partial semantic match but wrong format"
|
| 189 |
+
# This ensures format-correct outputs always score higher
|
| 190 |
+
composite = min(composite, 0.30)
|
| 191 |
+
logger.debug(f"📊 Format mismatch penalty applied: score capped at {composite:.3f}")
|
| 192 |
+
|
| 193 |
+
result["composite_score"] = min(max(composite, 0.0), 1.0)
|
| 194 |
+
|
| 195 |
+
# Add score breakdown to feedback
|
| 196 |
+
if not result["improvement_feedback"]:
|
| 197 |
+
result["improvement_feedback"] = self._generate_default_feedback(result)
|
| 198 |
+
|
| 199 |
+
# Log evaluation
|
| 200 |
+
logger.debug(f"📊 Evaluation: composite={result['composite_score']:.3f}, "
|
| 201 |
+
f"semantic={result['semantic_similarity']:.3f}, "
|
| 202 |
+
f"structural={result['structural_similarity']:.3f}")
|
| 203 |
+
|
| 204 |
+
# #region agent log
|
| 205 |
+
try:
|
| 206 |
+
import json as _json_debug
|
| 207 |
+
import time as _time_debug
|
| 208 |
+
import os as _os_debug
|
| 209 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 210 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 211 |
+
with open(_debug_log_path, "a") as _f:
|
| 212 |
+
_f.write(_json_debug.dumps({"hypothesisId": "G", "location": "universal_evaluator.py:final_score", "message": "Final evaluation score breakdown", "data": {"composite": result["composite_score"], "semantic": result["semantic_similarity"], "structural": result["structural_similarity"], "exact_match": result["exact_match"], "format_mismatch": result.get("analysis", {}).get("format_mismatch", False), "predicted_preview": predicted_str[:150] if predicted_str else "EMPTY", "expected_preview": expected_str[:150] if expected_str else "EMPTY"}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 213 |
+
except Exception as _e:
|
| 214 |
+
pass # Silent fail for instrumentation
|
| 215 |
+
# #endregion
|
| 216 |
+
|
| 217 |
+
return result
|
| 218 |
+
|
| 219 |
+
def _to_string(self, value: Any) -> str:
|
| 220 |
+
"""Convert any value to string for comparison."""
|
| 221 |
+
if value is None:
|
| 222 |
+
return ""
|
| 223 |
+
if isinstance(value, str):
|
| 224 |
+
return value.strip()
|
| 225 |
+
if isinstance(value, dict):
|
| 226 |
+
try:
|
| 227 |
+
return json.dumps(value, sort_keys=True, indent=2)
|
| 228 |
+
except (TypeError, ValueError):
|
| 229 |
+
return str(value)
|
| 230 |
+
if isinstance(value, (list, tuple)):
|
| 231 |
+
try:
|
| 232 |
+
return json.dumps(list(value), sort_keys=True)
|
| 233 |
+
except (TypeError, ValueError):
|
| 234 |
+
return str(value)
|
| 235 |
+
return str(value).strip()
|
| 236 |
+
|
| 237 |
+
def _normalize(self, text: str) -> str:
|
| 238 |
+
"""Normalize text for comparison (lowercase, whitespace)."""
|
| 239 |
+
# Lowercase and normalize whitespace
|
| 240 |
+
normalized = ' '.join(text.lower().split())
|
| 241 |
+
# Remove common punctuation that doesn't affect meaning
|
| 242 |
+
normalized = re.sub(r'[.,;:!?\'"]+$', '', normalized)
|
| 243 |
+
return normalized
|
| 244 |
+
|
| 245 |
+
def _compare_structure(self, predicted: str, expected: str) -> Dict[str, Any]:
|
| 246 |
+
"""
|
| 247 |
+
Compare structural similarity (especially for JSON/structured outputs).
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Dict with 'score' (0.0-1.0) and 'details'
|
| 251 |
+
"""
|
| 252 |
+
result = {"score": 0.0, "details": {}}
|
| 253 |
+
|
| 254 |
+
# Try to parse as JSON
|
| 255 |
+
pred_json = self._try_parse_json(predicted)
|
| 256 |
+
exp_json = self._try_parse_json(expected)
|
| 257 |
+
|
| 258 |
+
if pred_json is not None and exp_json is not None:
|
| 259 |
+
# Both are valid JSON - do structural comparison
|
| 260 |
+
return self._compare_json_structures(pred_json, exp_json)
|
| 261 |
+
|
| 262 |
+
# Fallback: Compare as text structure
|
| 263 |
+
return self._compare_text_structure(predicted, expected)
|
| 264 |
+
|
| 265 |
+
def _try_parse_json(self, text: str) -> Optional[Any]:
|
| 266 |
+
"""
|
| 267 |
+
Try to parse text as JSON with robust extraction.
|
| 268 |
+
|
| 269 |
+
🔥 FIX: LLMs often wrap JSON in markdown code blocks or add extra text.
|
| 270 |
+
This method now handles multiple formats:
|
| 271 |
+
- Direct JSON
|
| 272 |
+
- ```json ... ``` blocks
|
| 273 |
+
- ``` ... ``` blocks (no language tag)
|
| 274 |
+
- JSON embedded in prose
|
| 275 |
+
- Escaped newlines and quotes
|
| 276 |
+
"""
|
| 277 |
+
if not text or not isinstance(text, str):
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
# 🔥 PREPROCESSING: Clean common LLM output issues
|
| 281 |
+
cleaned = text.strip()
|
| 282 |
+
|
| 283 |
+
# Remove BOM and other invisible characters
|
| 284 |
+
cleaned = cleaned.lstrip('\ufeff\u200b\u200c\u200d')
|
| 285 |
+
|
| 286 |
+
# Strategy 1: Try direct parse (cleanest case)
|
| 287 |
+
try:
|
| 288 |
+
return json.loads(cleaned)
|
| 289 |
+
except json.JSONDecodeError:
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
# Strategy 2: Extract JSON from markdown code block (```json ... ```)
|
| 293 |
+
# More permissive regex that handles optional language tags
|
| 294 |
+
json_match = re.search(r'```(?:json|JSON)?\s*([\{|\[].*?[\}|\]])\s*```', cleaned, re.DOTALL)
|
| 295 |
+
if json_match:
|
| 296 |
+
try:
|
| 297 |
+
return json.loads(json_match.group(1))
|
| 298 |
+
except json.JSONDecodeError:
|
| 299 |
+
pass
|
| 300 |
+
|
| 301 |
+
# Strategy 3: Find JSON using balanced brace matching (handles nested objects)
|
| 302 |
+
def extract_balanced_json(s: str, start_char: str, end_char: str) -> Optional[str]:
|
| 303 |
+
"""Extract JSON with balanced braces/brackets."""
|
| 304 |
+
count = 0
|
| 305 |
+
start_idx = -1
|
| 306 |
+
for i, char in enumerate(s):
|
| 307 |
+
if char == start_char:
|
| 308 |
+
if count == 0:
|
| 309 |
+
start_idx = i
|
| 310 |
+
count += 1
|
| 311 |
+
elif char == end_char:
|
| 312 |
+
count -= 1
|
| 313 |
+
if count == 0 and start_idx >= 0:
|
| 314 |
+
return s[start_idx:i+1]
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
# Try to find JSON object
|
| 318 |
+
json_obj = extract_balanced_json(cleaned, '{', '}')
|
| 319 |
+
if json_obj:
|
| 320 |
+
try:
|
| 321 |
+
return json.loads(json_obj)
|
| 322 |
+
except json.JSONDecodeError:
|
| 323 |
+
# Try to repair common issues
|
| 324 |
+
repaired = self._repair_json(json_obj)
|
| 325 |
+
try:
|
| 326 |
+
return json.loads(repaired)
|
| 327 |
+
except json.JSONDecodeError:
|
| 328 |
+
pass
|
| 329 |
+
|
| 330 |
+
# Try to find JSON array
|
| 331 |
+
json_arr = extract_balanced_json(cleaned, '[', ']')
|
| 332 |
+
if json_arr:
|
| 333 |
+
try:
|
| 334 |
+
return json.loads(json_arr)
|
| 335 |
+
except json.JSONDecodeError:
|
| 336 |
+
repaired = self._repair_json(json_arr)
|
| 337 |
+
try:
|
| 338 |
+
return json.loads(repaired)
|
| 339 |
+
except json.JSONDecodeError:
|
| 340 |
+
pass
|
| 341 |
+
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
def _repair_json(self, json_str: str) -> str:
|
| 345 |
+
"""
|
| 346 |
+
Attempt to repair common JSON issues from LLM output.
|
| 347 |
+
|
| 348 |
+
Fixes:
|
| 349 |
+
- Trailing commas before } or ]
|
| 350 |
+
- Single quotes instead of double quotes
|
| 351 |
+
- Unquoted keys
|
| 352 |
+
- Comments (// and /* */)
|
| 353 |
+
"""
|
| 354 |
+
repaired = json_str
|
| 355 |
+
|
| 356 |
+
# Remove trailing commas
|
| 357 |
+
repaired = re.sub(r',\s*}', '}', repaired)
|
| 358 |
+
repaired = re.sub(r',\s*]', ']', repaired)
|
| 359 |
+
|
| 360 |
+
# Remove single-line comments
|
| 361 |
+
repaired = re.sub(r'//[^\n]*', '', repaired)
|
| 362 |
+
|
| 363 |
+
# Remove multi-line comments
|
| 364 |
+
repaired = re.sub(r'/\*.*?\*/', '', repaired, flags=re.DOTALL)
|
| 365 |
+
|
| 366 |
+
# Replace single quotes with double quotes (but be careful with apostrophes)
|
| 367 |
+
# Only replace when it looks like a JSON delimiter
|
| 368 |
+
def replace_single_quotes(match):
|
| 369 |
+
content = match.group(0)
|
| 370 |
+
# Skip if it looks like an apostrophe in a word
|
| 371 |
+
if re.match(r"'\w+'\s*:", content) or re.match(r":\s*'[^']*'", content):
|
| 372 |
+
return content.replace("'", '"')
|
| 373 |
+
return content
|
| 374 |
+
|
| 375 |
+
# Basic single quote replacement for keys
|
| 376 |
+
repaired = re.sub(r"'([^']+)'\s*:", r'"\1":', repaired)
|
| 377 |
+
|
| 378 |
+
return repaired
|
| 379 |
+
|
| 380 |
+
def _compare_json_structures(self, pred: Any, exp: Any) -> Dict[str, Any]:
|
| 381 |
+
"""Compare two JSON structures."""
|
| 382 |
+
result = {"score": 0.0, "details": {"type": "json", "matches": [], "mismatches": []}}
|
| 383 |
+
|
| 384 |
+
if type(pred) != type(exp):
|
| 385 |
+
result["details"]["mismatches"].append(f"Type mismatch: predicted={type(pred).__name__}, expected={type(exp).__name__}")
|
| 386 |
+
result["score"] = 0.2 # Some credit for being JSON
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
if isinstance(pred, dict) and isinstance(exp, dict):
|
| 390 |
+
return self._compare_dicts(pred, exp)
|
| 391 |
+
elif isinstance(pred, list) and isinstance(exp, list):
|
| 392 |
+
return self._compare_lists(pred, exp)
|
| 393 |
+
else:
|
| 394 |
+
# Primitive types
|
| 395 |
+
if pred == exp:
|
| 396 |
+
result["score"] = 1.0
|
| 397 |
+
result["details"]["matches"].append(f"Values match: {pred}")
|
| 398 |
+
else:
|
| 399 |
+
result["score"] = self._value_similarity(pred, exp)
|
| 400 |
+
result["details"]["mismatches"].append(f"Value mismatch: predicted={pred}, expected={exp}")
|
| 401 |
+
return result
|
| 402 |
+
|
| 403 |
+
def _compare_dicts(self, pred: dict, exp: dict) -> Dict[str, Any]:
|
| 404 |
+
"""
|
| 405 |
+
Compare two dictionaries with CASE-INSENSITIVE key matching.
|
| 406 |
+
|
| 407 |
+
🔥 FIX: LLMs often produce keys like 'Category' when expected is 'category'.
|
| 408 |
+
This method now normalizes keys before comparison for fair scoring.
|
| 409 |
+
"""
|
| 410 |
+
result = {"score": 0.0, "details": {"type": "dict", "matches": [], "mismatches": [], "missing_keys": [], "extra_keys": []}}
|
| 411 |
+
|
| 412 |
+
# 🔥 NORMALIZE: Convert all keys to lowercase for comparison
|
| 413 |
+
# Also handle common variations like underscores vs camelCase
|
| 414 |
+
def normalize_key(key: str) -> str:
|
| 415 |
+
"""Normalize key: lowercase, underscores to nothing, strip spaces."""
|
| 416 |
+
return re.sub(r'[_\s-]', '', str(key).lower())
|
| 417 |
+
|
| 418 |
+
# Build normalized key mappings
|
| 419 |
+
pred_normalized = {normalize_key(k): (k, v) for k, v in pred.items()}
|
| 420 |
+
exp_normalized = {normalize_key(k): (k, v) for k, v in exp.items()}
|
| 421 |
+
|
| 422 |
+
pred_norm_keys = set(pred_normalized.keys())
|
| 423 |
+
exp_norm_keys = set(exp_normalized.keys())
|
| 424 |
+
|
| 425 |
+
# Check for missing/extra keys (using normalized comparison)
|
| 426 |
+
missing_norm = exp_norm_keys - pred_norm_keys
|
| 427 |
+
extra_norm = pred_norm_keys - exp_norm_keys
|
| 428 |
+
common_norm = pred_norm_keys & exp_norm_keys
|
| 429 |
+
|
| 430 |
+
# Convert back to original key names for reporting
|
| 431 |
+
missing = [exp_normalized[k][0] for k in missing_norm]
|
| 432 |
+
extra = [pred_normalized[k][0] for k in extra_norm]
|
| 433 |
+
|
| 434 |
+
result["details"]["missing_keys"] = missing
|
| 435 |
+
result["details"]["extra_keys"] = extra
|
| 436 |
+
|
| 437 |
+
if not exp_norm_keys:
|
| 438 |
+
result["score"] = 1.0 if not pred_norm_keys else 0.5
|
| 439 |
+
return result
|
| 440 |
+
|
| 441 |
+
# Score based on key overlap (normalized)
|
| 442 |
+
key_score = len(common_norm) / len(exp_norm_keys) if exp_norm_keys else 1.0
|
| 443 |
+
|
| 444 |
+
# Score based on value matches
|
| 445 |
+
value_scores = []
|
| 446 |
+
for norm_key in common_norm:
|
| 447 |
+
pred_orig_key, pred_val = pred_normalized[norm_key]
|
| 448 |
+
exp_orig_key, exp_val = exp_normalized[norm_key]
|
| 449 |
+
|
| 450 |
+
if pred_val == exp_val:
|
| 451 |
+
value_scores.append(1.0)
|
| 452 |
+
result["details"]["matches"].append(f"{exp_orig_key}: {exp_val}")
|
| 453 |
+
else:
|
| 454 |
+
sim = self._value_similarity(pred_val, exp_val)
|
| 455 |
+
value_scores.append(sim)
|
| 456 |
+
if sim < 0.8:
|
| 457 |
+
result["details"]["mismatches"].append(f"{exp_orig_key}: predicted={pred_val}, expected={exp_val}")
|
| 458 |
+
|
| 459 |
+
value_score = sum(value_scores) / len(value_scores) if value_scores else 0.0
|
| 460 |
+
|
| 461 |
+
# Combine scores
|
| 462 |
+
result["score"] = 0.3 * key_score + 0.7 * value_score
|
| 463 |
+
|
| 464 |
+
# Penalty for missing keys (reduced from 0.1 to 0.05 per key)
|
| 465 |
+
if missing:
|
| 466 |
+
result["score"] *= (1 - 0.05 * len(missing))
|
| 467 |
+
|
| 468 |
+
result["score"] = max(0.0, min(1.0, result["score"]))
|
| 469 |
+
return result
|
| 470 |
+
|
| 471 |
+
def _compare_lists(self, pred: list, exp: list) -> Dict[str, Any]:
|
| 472 |
+
"""Compare two lists."""
|
| 473 |
+
result = {"score": 0.0, "details": {"type": "list", "length_match": False, "item_matches": 0}}
|
| 474 |
+
|
| 475 |
+
if not exp:
|
| 476 |
+
result["score"] = 1.0 if not pred else 0.5
|
| 477 |
+
return result
|
| 478 |
+
|
| 479 |
+
result["details"]["length_match"] = len(pred) == len(exp)
|
| 480 |
+
|
| 481 |
+
# Compare items (order-sensitive)
|
| 482 |
+
matches = 0
|
| 483 |
+
for i, exp_item in enumerate(exp):
|
| 484 |
+
if i < len(pred):
|
| 485 |
+
if pred[i] == exp_item:
|
| 486 |
+
matches += 1
|
| 487 |
+
else:
|
| 488 |
+
# Check if item exists elsewhere
|
| 489 |
+
if exp_item in pred:
|
| 490 |
+
matches += 0.5 # Partial credit for wrong position
|
| 491 |
+
|
| 492 |
+
result["details"]["item_matches"] = matches
|
| 493 |
+
result["score"] = matches / len(exp)
|
| 494 |
+
|
| 495 |
+
# Penalty for length mismatch
|
| 496 |
+
if len(pred) != len(exp):
|
| 497 |
+
len_ratio = min(len(pred), len(exp)) / max(len(pred), len(exp))
|
| 498 |
+
result["score"] *= (0.7 + 0.3 * len_ratio)
|
| 499 |
+
|
| 500 |
+
return result
|
| 501 |
+
|
| 502 |
+
def _value_similarity(self, pred: Any, exp: Any) -> float:
|
| 503 |
+
"""
|
| 504 |
+
Calculate similarity between two values.
|
| 505 |
+
|
| 506 |
+
🔥 ENHANCED: Now handles:
|
| 507 |
+
- Case-insensitive string comparison
|
| 508 |
+
- Semantic similarity for common variations
|
| 509 |
+
- Underscore/space/dash normalization
|
| 510 |
+
- Numeric comparison with tolerance
|
| 511 |
+
"""
|
| 512 |
+
# Same value (exact match)
|
| 513 |
+
if pred == exp:
|
| 514 |
+
return 1.0
|
| 515 |
+
|
| 516 |
+
# Numeric comparison
|
| 517 |
+
try:
|
| 518 |
+
pred_num = float(pred)
|
| 519 |
+
exp_num = float(exp)
|
| 520 |
+
if exp_num == 0:
|
| 521 |
+
return 1.0 if pred_num == 0 else 0.0
|
| 522 |
+
# Relative error with tolerance
|
| 523 |
+
error = abs(pred_num - exp_num) / abs(exp_num)
|
| 524 |
+
return max(0.0, 1.0 - error)
|
| 525 |
+
except (ValueError, TypeError):
|
| 526 |
+
pass
|
| 527 |
+
|
| 528 |
+
# String comparison with normalization
|
| 529 |
+
pred_str = str(pred).strip()
|
| 530 |
+
exp_str = str(exp).strip()
|
| 531 |
+
|
| 532 |
+
# Case-insensitive exact match
|
| 533 |
+
if pred_str.lower() == exp_str.lower():
|
| 534 |
+
return 0.98 # Slight penalty for case mismatch
|
| 535 |
+
|
| 536 |
+
# Normalize strings (remove underscores, spaces, dashes for comparison)
|
| 537 |
+
def normalize_str(s: str) -> str:
|
| 538 |
+
return re.sub(r'[_\s\-]+', '', s.lower())
|
| 539 |
+
|
| 540 |
+
pred_norm = normalize_str(pred_str)
|
| 541 |
+
exp_norm = normalize_str(exp_str)
|
| 542 |
+
|
| 543 |
+
if pred_norm == exp_norm:
|
| 544 |
+
return 0.95 # Good match despite formatting differences
|
| 545 |
+
|
| 546 |
+
# Check if one contains the other (partial match)
|
| 547 |
+
if pred_norm in exp_norm or exp_norm in pred_norm:
|
| 548 |
+
ratio = min(len(pred_norm), len(exp_norm)) / max(len(pred_norm), len(exp_norm))
|
| 549 |
+
return 0.7 + (0.2 * ratio) # 0.7-0.9 for partial matches
|
| 550 |
+
|
| 551 |
+
# 🔥 SEMANTIC SIMILARITY: Check for common equivalent terms
|
| 552 |
+
semantic_equivalents = {
|
| 553 |
+
# Priority levels
|
| 554 |
+
'low': ['low', 'minor', 'trivial', 'p3', 'p4'],
|
| 555 |
+
'medium': ['medium', 'normal', 'moderate', 'p2'],
|
| 556 |
+
'high': ['high', 'important', 'major', 'p1', 'critical', 'urgent'],
|
| 557 |
+
# Boolean variations
|
| 558 |
+
'true': ['true', 'yes', '1', 'on', 'enabled'],
|
| 559 |
+
'false': ['false', 'no', '0', 'off', 'disabled'],
|
| 560 |
+
# Status variations
|
| 561 |
+
'success': ['success', 'succeeded', 'completed', 'done', 'passed'],
|
| 562 |
+
'failure': ['failure', 'failed', 'error', 'crashed'],
|
| 563 |
+
'pending': ['pending', 'waiting', 'queued', 'in_progress', 'processing'],
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
for canonical, equivalents in semantic_equivalents.items():
|
| 567 |
+
pred_match = any(eq in pred_norm for eq in equivalents)
|
| 568 |
+
exp_match = any(eq in exp_norm for eq in equivalents)
|
| 569 |
+
if pred_match and exp_match:
|
| 570 |
+
return 0.85 # Semantic match
|
| 571 |
+
|
| 572 |
+
# Sequence matching (character-level similarity)
|
| 573 |
+
ratio = SequenceMatcher(None, pred_str.lower(), exp_str.lower()).ratio()
|
| 574 |
+
|
| 575 |
+
# 🔥 WORD-LEVEL SIMILARITY: Check word overlap
|
| 576 |
+
pred_words = set(re.findall(r'\w+', pred_str.lower()))
|
| 577 |
+
exp_words = set(re.findall(r'\w+', exp_str.lower()))
|
| 578 |
+
|
| 579 |
+
if pred_words and exp_words:
|
| 580 |
+
word_overlap = len(pred_words & exp_words) / max(len(pred_words), len(exp_words))
|
| 581 |
+
# Combine character and word similarity
|
| 582 |
+
return max(ratio, word_overlap * 0.9)
|
| 583 |
+
|
| 584 |
+
def _compare_text_structure(self, predicted: str, expected: str) -> Dict[str, Any]:
|
| 585 |
+
"""Compare text structure when not JSON."""
|
| 586 |
+
result = {"score": 0.0, "details": {"type": "text"}}
|
| 587 |
+
|
| 588 |
+
# Word overlap
|
| 589 |
+
pred_words = set(predicted.lower().split())
|
| 590 |
+
exp_words = set(expected.lower().split())
|
| 591 |
+
|
| 592 |
+
if not exp_words:
|
| 593 |
+
result["score"] = 1.0 if not pred_words else 0.5
|
| 594 |
+
return result
|
| 595 |
+
|
| 596 |
+
overlap = len(pred_words & exp_words)
|
| 597 |
+
result["details"]["word_overlap"] = overlap
|
| 598 |
+
result["details"]["expected_words"] = len(exp_words)
|
| 599 |
+
|
| 600 |
+
# Jaccard similarity
|
| 601 |
+
union = len(pred_words | exp_words)
|
| 602 |
+
result["score"] = overlap / union if union > 0 else 0.0
|
| 603 |
+
|
| 604 |
+
return result
|
| 605 |
+
|
| 606 |
+
def _llm_semantic_analysis(self, predicted: str, expected: str) -> Dict[str, Any]:
|
| 607 |
+
"""
|
| 608 |
+
Use LLM for semantic analysis of predicted vs expected.
|
| 609 |
+
|
| 610 |
+
Uses XML-delimited prompt structure to prevent context bleeding
|
| 611 |
+
and Multi-Dimensional Scoring (Semantics vs. Syntax).
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
Dict with 'score' (0.0-1.0), 'details', and 'feedback'
|
| 615 |
+
"""
|
| 616 |
+
# Check cache
|
| 617 |
+
cache_key = f"{hash(predicted)}:{hash(expected)}"
|
| 618 |
+
if cache_key in self._analysis_cache:
|
| 619 |
+
return self._analysis_cache[cache_key]
|
| 620 |
+
|
| 621 |
+
result = {"score": 0.0, "details": {}, "feedback": ""}
|
| 622 |
+
|
| 623 |
+
try:
|
| 624 |
+
# Truncate for token limits but preserve enough context
|
| 625 |
+
expected_truncated = expected[:10000]
|
| 626 |
+
predicted_truncated = predicted[:10000]
|
| 627 |
+
|
| 628 |
+
# OPTIMIZED: Penalty-based scoring with self-verification
|
| 629 |
+
# Starts at 1.0 and deducts for failures - more consistent than subjective scoring
|
| 630 |
+
analysis_prompt = f"""<system_role>
|
| 631 |
+
You are a **Semantic Logic Engine** tasked with grading AI performance.
|
| 632 |
+
You must compare a [PREDICTED] output against a [EXPECTED] truth.
|
| 633 |
+
</system_role>
|
| 634 |
+
|
| 635 |
+
<input_data>
|
| 636 |
+
<expected_output>
|
| 637 |
+
{expected_truncated}
|
| 638 |
+
</expected_output>
|
| 639 |
+
|
| 640 |
+
<predicted_output>
|
| 641 |
+
{predicted_truncated}
|
| 642 |
+
</predicted_output>
|
| 643 |
+
</input_data>
|
| 644 |
+
|
| 645 |
+
<scoring_algorithm>
|
| 646 |
+
Calculate the score based on these STRICT rules. Start with 1.0 and deduct penalties.
|
| 647 |
+
|
| 648 |
+
1. **Information Completeness (Max -0.5)**:
|
| 649 |
+
- If key facts/fields are missing, deduct proportional to importance.
|
| 650 |
+
- If a nested JSON field is missing, deduct 0.1 per field.
|
| 651 |
+
|
| 652 |
+
2. **Accuracy & Hallucination (Max -1.0)**:
|
| 653 |
+
- If factual numbers/IDs are wrong: Score = 0 immediately.
|
| 654 |
+
- If the model invents information NOT in the input: Deduct 0.3.
|
| 655 |
+
|
| 656 |
+
3. **Format Compliance (Max -0.3)**:
|
| 657 |
+
- If JSON is requested but Markdown is returned: Deduct 0.3.
|
| 658 |
+
- If keys are lowercase instead of snake_case: Deduct 0.1.
|
| 659 |
+
|
| 660 |
+
4. **Semantic Equivalence (No Penalty)**:
|
| 661 |
+
- Synonyms are ACCEPTED (e.g., "Purchase" == "Buy").
|
| 662 |
+
- Formatting differences (whitespace) are IGNORED.
|
| 663 |
+
</scoring_algorithm>
|
| 664 |
+
|
| 665 |
+
<self_verification>
|
| 666 |
+
Before finalizing the score, ask: "If I used the predicted output in code expecting the original output, would the code crash?"
|
| 667 |
+
- If YES (Crash) -> Score must be < 0.5.
|
| 668 |
+
- If NO (Safe) -> Score can be high.
|
| 669 |
+
</self_verification>
|
| 670 |
+
|
| 671 |
+
<output_schema>
|
| 672 |
+
Return JSON ONLY:
|
| 673 |
+
{{
|
| 674 |
+
"semantic_similarity": 0.0-1.0,
|
| 675 |
+
"structural_similarity": 0.0-1.0,
|
| 676 |
+
"verdict": "PERFECT" | "ACCEPTABLE" | "FORMAT_ERROR" | "DATA_CORRUPTION",
|
| 677 |
+
"critical_failures": ["List specific failures that caused score < 1.0"],
|
| 678 |
+
"penalty_breakdown": {{"completeness": -0.0, "accuracy": -0.0, "format": -0.0}},
|
| 679 |
+
"fix_directive": "Imperative command to fix the prompt"
|
| 680 |
+
}}
|
| 681 |
+
</output_schema>
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
response = self.llm_client.generate(
|
| 685 |
+
system_prompt="You are a Semantic Logic Engine. Calculate scores using penalty-based deduction from 1.0. Respond only with valid JSON.",
|
| 686 |
+
user_prompt=analysis_prompt,
|
| 687 |
+
image_base64=""
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
content = response.get("content", str(response)) if isinstance(response, dict) else str(response)
|
| 691 |
+
|
| 692 |
+
# Parse JSON response
|
| 693 |
+
analysis = self._extract_json_from_response(content)
|
| 694 |
+
|
| 695 |
+
if analysis:
|
| 696 |
+
# Extract semantic similarity (primary score)
|
| 697 |
+
semantic_sim = float(analysis.get("semantic_similarity", 0.5))
|
| 698 |
+
structural_sim = float(analysis.get("structural_similarity", semantic_sim))
|
| 699 |
+
|
| 700 |
+
# Compute weighted score based on verdict (updated for new schema)
|
| 701 |
+
verdict = analysis.get("verdict", "ACCEPTABLE")
|
| 702 |
+
verdict_multiplier = {
|
| 703 |
+
"PERFECT": 1.0,
|
| 704 |
+
"ACCEPTABLE": 0.85,
|
| 705 |
+
"FORMAT_ERROR": 0.6, # New: was WRONG_FORMAT
|
| 706 |
+
"DATA_CORRUPTION": 0.1, # New: replaces WRONG_CONTENT + HALLUCINATION
|
| 707 |
+
# Legacy support
|
| 708 |
+
"WRONG_FORMAT": 0.6,
|
| 709 |
+
"WRONG_CONTENT": 0.3,
|
| 710 |
+
"HALLUCINATION": 0.1
|
| 711 |
+
}.get(verdict, 0.5)
|
| 712 |
+
|
| 713 |
+
# Final score: weighted combination
|
| 714 |
+
result["score"] = min(1.0, semantic_sim * 0.6 + structural_sim * 0.3 + verdict_multiplier * 0.1)
|
| 715 |
+
|
| 716 |
+
# Extract penalty breakdown if available
|
| 717 |
+
penalty_breakdown = analysis.get("penalty_breakdown", {})
|
| 718 |
+
critical_failures = analysis.get("critical_failures", [])
|
| 719 |
+
|
| 720 |
+
result["details"] = {
|
| 721 |
+
"verdict": verdict,
|
| 722 |
+
"semantic_similarity": semantic_sim,
|
| 723 |
+
"structural_similarity": structural_sim,
|
| 724 |
+
"critical_failures": critical_failures,
|
| 725 |
+
"penalty_breakdown": penalty_breakdown,
|
| 726 |
+
# Legacy field support
|
| 727 |
+
"key_matches": analysis.get("key_matches", []),
|
| 728 |
+
"key_differences": analysis.get("key_differences", critical_failures),
|
| 729 |
+
"value_errors": analysis.get("value_errors", {}),
|
| 730 |
+
"reasoning": analysis.get("reasoning", "")
|
| 731 |
+
}
|
| 732 |
+
result["feedback"] = analysis.get("fix_directive", "")
|
| 733 |
+
else:
|
| 734 |
+
# Fallback if JSON parsing fails
|
| 735 |
+
result = self._heuristic_semantic_analysis(predicted, expected)
|
| 736 |
+
|
| 737 |
+
# Cache result
|
| 738 |
+
self._analysis_cache[cache_key] = result
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
logger.warning(f"LLM semantic analysis failed: {e}, falling back to heuristics")
|
| 742 |
+
result = self._heuristic_semantic_analysis(predicted, expected)
|
| 743 |
+
|
| 744 |
+
return result
|
| 745 |
+
|
| 746 |
+
def _extract_json_from_response(self, content: str) -> Optional[Dict]:
|
| 747 |
+
"""Extract JSON from LLM response."""
|
| 748 |
+
# Try to find JSON in response
|
| 749 |
+
json_match = re.search(r'\{[\s\S]*\}', content)
|
| 750 |
+
if json_match:
|
| 751 |
+
try:
|
| 752 |
+
return json.loads(json_match.group(0))
|
| 753 |
+
except json.JSONDecodeError:
|
| 754 |
+
pass
|
| 755 |
+
return None
|
| 756 |
+
|
| 757 |
+
def _heuristic_semantic_analysis(self, predicted: str, expected: str) -> Dict[str, Any]:
|
| 758 |
+
"""
|
| 759 |
+
Heuristic-based semantic analysis when LLM is not available.
|
| 760 |
+
|
| 761 |
+
Uses multiple signals:
|
| 762 |
+
- Word overlap (Jaccard)
|
| 763 |
+
- Sequence matching (SequenceMatcher)
|
| 764 |
+
- Number extraction and comparison
|
| 765 |
+
- Key phrase matching
|
| 766 |
+
"""
|
| 767 |
+
result = {"score": 0.0, "details": {}, "feedback": ""}
|
| 768 |
+
|
| 769 |
+
pred_lower = predicted.lower()
|
| 770 |
+
exp_lower = expected.lower()
|
| 771 |
+
|
| 772 |
+
# 1. Sequence similarity
|
| 773 |
+
seq_sim = SequenceMatcher(None, pred_lower, exp_lower).ratio()
|
| 774 |
+
|
| 775 |
+
# 2. Word overlap (Jaccard)
|
| 776 |
+
pred_words = set(pred_lower.split())
|
| 777 |
+
exp_words = set(exp_lower.split())
|
| 778 |
+
jaccard = len(pred_words & exp_words) / len(pred_words | exp_words) if (pred_words | exp_words) else 0.0
|
| 779 |
+
|
| 780 |
+
# 3. Number comparison
|
| 781 |
+
pred_nums = re.findall(r'-?\d+\.?\d*', predicted)
|
| 782 |
+
exp_nums = re.findall(r'-?\d+\.?\d*', expected)
|
| 783 |
+
|
| 784 |
+
num_score = 1.0
|
| 785 |
+
num_errors = []
|
| 786 |
+
if exp_nums:
|
| 787 |
+
matches = 0
|
| 788 |
+
for exp_num in exp_nums:
|
| 789 |
+
if exp_num in pred_nums:
|
| 790 |
+
matches += 1
|
| 791 |
+
else:
|
| 792 |
+
# Check for close matches
|
| 793 |
+
try:
|
| 794 |
+
exp_val = float(exp_num)
|
| 795 |
+
for pred_num in pred_nums:
|
| 796 |
+
pred_val = float(pred_num)
|
| 797 |
+
if abs(pred_val - exp_val) <= 1: # Off by 1
|
| 798 |
+
matches += 0.9
|
| 799 |
+
num_errors.append(f"Number close: expected {exp_num}, got {pred_num}")
|
| 800 |
+
break
|
| 801 |
+
else:
|
| 802 |
+
num_errors.append(f"Number missing: expected {exp_num}")
|
| 803 |
+
except ValueError:
|
| 804 |
+
pass
|
| 805 |
+
num_score = matches / len(exp_nums) if exp_nums else 1.0
|
| 806 |
+
|
| 807 |
+
# 4. Key entity extraction (simple approach)
|
| 808 |
+
# Look for capitalized words, quoted strings, etc.
|
| 809 |
+
pred_entities = set(re.findall(r'\b[A-Z][a-z]+\b', predicted))
|
| 810 |
+
exp_entities = set(re.findall(r'\b[A-Z][a-z]+\b', expected))
|
| 811 |
+
entity_overlap = len(pred_entities & exp_entities) / len(exp_entities) if exp_entities else 1.0
|
| 812 |
+
|
| 813 |
+
# Combine scores
|
| 814 |
+
result["score"] = (
|
| 815 |
+
0.3 * seq_sim +
|
| 816 |
+
0.25 * jaccard +
|
| 817 |
+
0.25 * num_score +
|
| 818 |
+
0.2 * entity_overlap
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
result["details"] = {
|
| 822 |
+
"sequence_similarity": seq_sim,
|
| 823 |
+
"word_overlap": jaccard,
|
| 824 |
+
"number_accuracy": num_score,
|
| 825 |
+
"entity_overlap": entity_overlap,
|
| 826 |
+
"number_errors": num_errors
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
# Generate feedback
|
| 830 |
+
feedback_parts = []
|
| 831 |
+
if jaccard < 0.5:
|
| 832 |
+
feedback_parts.append("Low word overlap - output may be missing key terms.")
|
| 833 |
+
if num_errors:
|
| 834 |
+
feedback_parts.append(f"Number issues: {'; '.join(num_errors[:3])}")
|
| 835 |
+
if entity_overlap < 0.5 and exp_entities:
|
| 836 |
+
missing = exp_entities - pred_entities
|
| 837 |
+
feedback_parts.append(f"Missing entities: {', '.join(list(missing)[:3])}")
|
| 838 |
+
|
| 839 |
+
if feedback_parts:
|
| 840 |
+
result["feedback"] = " | ".join(feedback_parts)
|
| 841 |
+
else:
|
| 842 |
+
result["feedback"] = "Output is semantically similar but not exact match."
|
| 843 |
+
|
| 844 |
+
return result
|
| 845 |
+
|
| 846 |
+
def _generate_default_feedback(self, result: Dict) -> str:
|
| 847 |
+
"""Generate default feedback based on scores."""
|
| 848 |
+
score = result["composite_score"]
|
| 849 |
+
semantic = result["semantic_similarity"]
|
| 850 |
+
structural = result["structural_similarity"]
|
| 851 |
+
|
| 852 |
+
if score >= 0.9:
|
| 853 |
+
return "✅ Excellent match! Minor differences only."
|
| 854 |
+
elif score >= 0.7:
|
| 855 |
+
return f"⚠️ Good match (semantic={semantic:.0%}, structural={structural:.0%}). Some differences to address."
|
| 856 |
+
elif score >= 0.5:
|
| 857 |
+
return f"⚠️ Partial match (semantic={semantic:.0%}, structural={structural:.0%}). Significant differences found."
|
| 858 |
+
else:
|
| 859 |
+
return f"❌ Poor match (semantic={semantic:.0%}, structural={structural:.0%}). Major issues to fix."
|
| 860 |
+
|
| 861 |
+
def get_evaluation_summary(self, results: List[Dict]) -> Dict[str, Any]:
|
| 862 |
+
"""
|
| 863 |
+
Get summary statistics for a batch of evaluations.
|
| 864 |
+
|
| 865 |
+
Args:
|
| 866 |
+
results: List of evaluation result dictionaries
|
| 867 |
+
|
| 868 |
+
Returns:
|
| 869 |
+
Summary statistics
|
| 870 |
+
"""
|
| 871 |
+
if not results:
|
| 872 |
+
return {
|
| 873 |
+
"total_samples": 0,
|
| 874 |
+
"accuracy": 0.0,
|
| 875 |
+
"avg_semantic_similarity": 0.0,
|
| 876 |
+
"avg_structural_similarity": 0.0
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
total = len(results)
|
| 880 |
+
scores = [r.get("composite_score", 0.0) for r in results]
|
| 881 |
+
semantic_scores = [r.get("semantic_similarity", 0.0) for r in results]
|
| 882 |
+
structural_scores = [r.get("structural_similarity", 0.0) for r in results]
|
| 883 |
+
|
| 884 |
+
return {
|
| 885 |
+
"total_samples": total,
|
| 886 |
+
"accuracy": sum(1 for s in scores if s >= 0.8) / total,
|
| 887 |
+
"avg_composite_score": sum(scores) / total,
|
| 888 |
+
"avg_semantic_similarity": sum(semantic_scores) / total,
|
| 889 |
+
"avg_structural_similarity": sum(structural_scores) / total,
|
| 890 |
+
"min_score": min(scores),
|
| 891 |
+
"max_score": max(scores)
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
# Convenience function to create evaluator
|
| 896 |
+
def create_universal_evaluator(llm_client=None) -> UniversalSemanticEvaluator:
|
| 897 |
+
"""
|
| 898 |
+
Create a Universal Semantic Evaluator.
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
llm_client: Optional LLM client for semantic analysis.
|
| 902 |
+
If not provided, uses heuristic-based analysis.
|
| 903 |
+
|
| 904 |
+
Returns:
|
| 905 |
+
Configured UniversalSemanticEvaluator instance
|
| 906 |
+
"""
|
| 907 |
+
return UniversalSemanticEvaluator(
|
| 908 |
+
llm_client=llm_client,
|
| 909 |
+
use_llm_analysis=llm_client is not None
|
| 910 |
+
)
|
| 911 |
+
|
src/gepa_optimizer/evaluation/validation_evaluator.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validation Evaluator for UI Validation Use Case
|
| 3 |
+
|
| 4 |
+
Evaluates predicted validation results (true/false) against expected results.
|
| 5 |
+
Extracts reasoning from both predicted and expected outputs for LLM-as-judge feedback.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
import re
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .base_evaluator import BaseEvaluator
|
| 14 |
+
except ImportError:
|
| 15 |
+
# For standalone testing
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 19 |
+
from gepa_optimizer.evaluation.base_evaluator import BaseEvaluator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ValidationEvaluator(BaseEvaluator):
|
| 23 |
+
"""
|
| 24 |
+
Evaluator for validation use case (true/false results).
|
| 25 |
+
|
| 26 |
+
Features:
|
| 27 |
+
- Normalizes boolean formats ("true"/"True"/"1" → True, "false"/"False"/"0" → False)
|
| 28 |
+
- Extracts reasoning from both predicted and expected outputs (REQUIRED for LLM-as-judge)
|
| 29 |
+
- Binary scoring: correct boolean = 1.0, wrong = 0.0
|
| 30 |
+
- Returns reasoning in evaluation results for LLM-as-judge feedback
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
|
| 34 |
+
"""
|
| 35 |
+
Initialize validation evaluator.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
metric_weights: Weights for evaluation metrics
|
| 39 |
+
Default: {"output_match": 1.0}
|
| 40 |
+
"""
|
| 41 |
+
default_weights = {
|
| 42 |
+
"output_match": 1.0 # Binary boolean comparison
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
weights = metric_weights or default_weights
|
| 46 |
+
super().__init__(metric_weights=weights)
|
| 47 |
+
|
| 48 |
+
def evaluate(self, predicted: str, expected: str) -> Dict[str, float]:
|
| 49 |
+
"""
|
| 50 |
+
Evaluate predicted validation result against expected result.
|
| 51 |
+
|
| 52 |
+
Scoring Strategy:
|
| 53 |
+
1. Normalize both predicted and expected to boolean
|
| 54 |
+
2. Compare booleans (exact match required)
|
| 55 |
+
3. Extract reasoning from both (for LLM-as-judge)
|
| 56 |
+
4. Return 1.0 if match, 0.0 otherwise (binary scoring)
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
predicted: LLM's output (may include "true"/"false" + reasoning)
|
| 60 |
+
expected: Expected output (should be "true" or "false", may include reasoning)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Dictionary with evaluation metrics, extracted booleans, and reasoning:
|
| 64 |
+
{
|
| 65 |
+
"output_match": 1.0 or 0.0,
|
| 66 |
+
"composite_score": 1.0 or 0.0,
|
| 67 |
+
"predicted_output": str,
|
| 68 |
+
"expected_output": str,
|
| 69 |
+
"predicted_boolean": True/False,
|
| 70 |
+
"expected_boolean": True/False,
|
| 71 |
+
"predicted_reasoning": str, # REQUIRED for LLM-as-judge
|
| 72 |
+
"expected_reasoning": str, # REQUIRED for LLM-as-judge
|
| 73 |
+
"evaluation_reason": str
|
| 74 |
+
}
|
| 75 |
+
"""
|
| 76 |
+
if not predicted or not expected:
|
| 77 |
+
return {
|
| 78 |
+
"output_match": 0.0,
|
| 79 |
+
"composite_score": 0.0,
|
| 80 |
+
"predicted_output": str(predicted).strip() if predicted else "",
|
| 81 |
+
"expected_output": str(expected).strip() if expected else "",
|
| 82 |
+
"predicted_boolean": None,
|
| 83 |
+
"expected_boolean": None,
|
| 84 |
+
"predicted_reasoning": "",
|
| 85 |
+
"expected_reasoning": "",
|
| 86 |
+
"evaluation_reason": "❌ Empty or missing input/output"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
predicted_str = str(predicted).strip()
|
| 90 |
+
expected_str = str(expected).strip()
|
| 91 |
+
|
| 92 |
+
# 1. Extract boolean from predicted output
|
| 93 |
+
pred_bool = self._normalize_to_bool(predicted_str)
|
| 94 |
+
pred_reasoning = self._extract_reasoning(predicted_str)
|
| 95 |
+
|
| 96 |
+
# 2. Extract boolean from expected output
|
| 97 |
+
exp_bool = self._normalize_to_bool(expected_str)
|
| 98 |
+
exp_reasoning = self._extract_reasoning(expected_str)
|
| 99 |
+
|
| 100 |
+
# 🔥 NEW: Detect output structure for both expected and predicted
|
| 101 |
+
expected_structure = self._detect_output_structure(expected_str)
|
| 102 |
+
predicted_structure = self._detect_output_structure(predicted_str)
|
| 103 |
+
|
| 104 |
+
# Compare structures
|
| 105 |
+
structure_match = (expected_structure['format'] == predicted_structure['format'])
|
| 106 |
+
|
| 107 |
+
# 3. Compare booleans (binary scoring)
|
| 108 |
+
if pred_bool is None or exp_bool is None:
|
| 109 |
+
# Could not extract boolean from one or both
|
| 110 |
+
score = 0.0
|
| 111 |
+
reason = "❌ Could not extract boolean value"
|
| 112 |
+
if pred_bool is None:
|
| 113 |
+
reason += " from predicted output"
|
| 114 |
+
if exp_bool is None:
|
| 115 |
+
reason += " from expected output"
|
| 116 |
+
else:
|
| 117 |
+
# Both booleans extracted successfully - compare
|
| 118 |
+
score = 1.0 if pred_bool == exp_bool else 0.0
|
| 119 |
+
if score == 1.0:
|
| 120 |
+
reason = f"✅ Correct! Result matches (both are {exp_bool})"
|
| 121 |
+
# 🔥 NEW: Add note if structure doesn't match
|
| 122 |
+
if not structure_match:
|
| 123 |
+
reason += f" (but format differs: expected {expected_structure['format']}, got {predicted_structure['format']})"
|
| 124 |
+
else:
|
| 125 |
+
reason = f"❌ Wrong result (predicted: {pred_bool}, expected: {exp_bool})"
|
| 126 |
+
|
| 127 |
+
# 4. Log evaluation details
|
| 128 |
+
self.logger.info(f"\n{'─'*70}")
|
| 129 |
+
self.logger.info(f"📊 VALIDATION EVALUATION")
|
| 130 |
+
self.logger.info(f"{'─'*70}")
|
| 131 |
+
self.logger.info(f" Expected: '{expected_str[:100]}...' → {exp_bool}")
|
| 132 |
+
self.logger.info(f" Predicted: '{predicted_str[:100]}...' → {pred_bool}")
|
| 133 |
+
self.logger.info(f" {'─'*66}")
|
| 134 |
+
self.logger.info(f" 🎯 SCORE: {score:.2f} - {reason}")
|
| 135 |
+
if pred_reasoning:
|
| 136 |
+
self.logger.info(f" 📝 Predicted Reasoning: {pred_reasoning[:150]}...")
|
| 137 |
+
if exp_reasoning:
|
| 138 |
+
self.logger.info(f" 📝 Expected Reasoning: {exp_reasoning[:150]}...")
|
| 139 |
+
# 🔥 NEW: Log structure comparison
|
| 140 |
+
self.logger.info(f" 📐 Expected Format: {expected_structure['format']} (reasoning: {expected_structure['reasoning_quality']})")
|
| 141 |
+
self.logger.info(f" 📐 Predicted Format: {predicted_structure['format']} (reasoning: {predicted_structure['reasoning_quality']})")
|
| 142 |
+
if not structure_match:
|
| 143 |
+
self.logger.warning(f" ⚠️ OUTPUT STRUCTURE MISMATCH!")
|
| 144 |
+
self.logger.info(f"{'─'*70}\n")
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
"output_match": score,
|
| 148 |
+
"composite_score": score, # This is what GEPA uses
|
| 149 |
+
"predicted_output": predicted_str,
|
| 150 |
+
"expected_output": expected_str,
|
| 151 |
+
"predicted_boolean": pred_bool,
|
| 152 |
+
"expected_boolean": exp_bool,
|
| 153 |
+
"predicted_reasoning": pred_reasoning, # REQUIRED for LLM-as-judge
|
| 154 |
+
"expected_reasoning": exp_reasoning, # REQUIRED for LLM-as-judge
|
| 155 |
+
"evaluation_reason": reason,
|
| 156 |
+
# 🔥 NEW: Structure metadata for LLM-as-judge
|
| 157 |
+
"expected_structure": expected_structure,
|
| 158 |
+
"predicted_structure": predicted_structure,
|
| 159 |
+
"output_structure_match": structure_match,
|
| 160 |
+
"expected_has_reasoning": expected_structure['has_reasoning'],
|
| 161 |
+
"predicted_has_reasoning": predicted_structure['has_reasoning'],
|
| 162 |
+
"reasoning_quality_gap": expected_structure['reasoning_quality'] + " → " + predicted_structure['reasoning_quality']
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def _normalize_to_bool(self, value: str) -> Optional[bool]:
|
| 166 |
+
"""
|
| 167 |
+
Normalize various formats to boolean.
|
| 168 |
+
|
| 169 |
+
Handles:
|
| 170 |
+
- "true", "True", "TRUE" → True
|
| 171 |
+
- "false", "False", "FALSE" → False
|
| 172 |
+
- "1", "0" → True, False
|
| 173 |
+
- "yes", "no" → True, False
|
| 174 |
+
- "correct", "incorrect" → True, False
|
| 175 |
+
- JSON: {"result": true} → True
|
| 176 |
+
- Text with boolean: "The result is true because..." → True
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
value: String that may contain a boolean value
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Boolean value or None if cannot be determined
|
| 183 |
+
"""
|
| 184 |
+
if not value:
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
value_lower = value.lower().strip()
|
| 188 |
+
|
| 189 |
+
# Direct boolean strings
|
| 190 |
+
if value_lower in ("true", "1", "yes", "correct", "valid", "pass"):
|
| 191 |
+
return True
|
| 192 |
+
if value_lower in ("false", "0", "no", "incorrect", "invalid", "fail"):
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
# JSON format: {"action": "TRUE"} or {"action": "FALSE"} or {"action": "LOADING"}
|
| 196 |
+
# This handles the production prompt's JSON output format
|
| 197 |
+
# Match both quoted and unquoted values, case-insensitive
|
| 198 |
+
action_match = re.search(r'["\']?action["\']?\s*:\s*["\']?(true|false|loading)["\']?', value_lower)
|
| 199 |
+
if action_match:
|
| 200 |
+
action_value = action_match.group(1).lower()
|
| 201 |
+
if action_value == "true":
|
| 202 |
+
return True
|
| 203 |
+
elif action_value == "false":
|
| 204 |
+
return False
|
| 205 |
+
elif action_value == "loading":
|
| 206 |
+
# Treat LOADING as False for validation purposes (screen not ready)
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
# Also try to parse full JSON structure if present (more robust)
|
| 210 |
+
try:
|
| 211 |
+
import json
|
| 212 |
+
# Try to find and parse JSON object
|
| 213 |
+
json_start = value.find('{')
|
| 214 |
+
if json_start != -1:
|
| 215 |
+
# Try to extract JSON from the response
|
| 216 |
+
for end_idx in range(len(value), json_start, -1):
|
| 217 |
+
try:
|
| 218 |
+
json_str = value[json_start:end_idx]
|
| 219 |
+
data = json.loads(json_str)
|
| 220 |
+
# Check for "action" field (production prompt format)
|
| 221 |
+
if "action" in data:
|
| 222 |
+
action_val = str(data["action"]).upper()
|
| 223 |
+
if action_val == "TRUE":
|
| 224 |
+
return True
|
| 225 |
+
elif action_val == "FALSE":
|
| 226 |
+
return False
|
| 227 |
+
elif action_val == "LOADING":
|
| 228 |
+
return False # Treat as False
|
| 229 |
+
# Check for "result" field (alternative format)
|
| 230 |
+
if "result" in data:
|
| 231 |
+
result_val = data["result"]
|
| 232 |
+
if isinstance(result_val, bool):
|
| 233 |
+
return result_val
|
| 234 |
+
elif isinstance(result_val, str):
|
| 235 |
+
return result_val.lower() in ("true", "1", "yes")
|
| 236 |
+
except (json.JSONDecodeError, KeyError, ValueError):
|
| 237 |
+
continue
|
| 238 |
+
except Exception:
|
| 239 |
+
pass # Fall through to other extraction methods
|
| 240 |
+
|
| 241 |
+
# JSON format: {"result": true} or {"result": false}
|
| 242 |
+
json_match = re.search(r'["\']?result["\']?\s*:\s*(true|false)', value_lower)
|
| 243 |
+
if json_match:
|
| 244 |
+
return json_match.group(1) == "true"
|
| 245 |
+
|
| 246 |
+
# Pattern: "result is true" or "result: true"
|
| 247 |
+
pattern_match = re.search(r'result[:\s]+(true|false)', value_lower)
|
| 248 |
+
if pattern_match:
|
| 249 |
+
return pattern_match.group(1) == "true"
|
| 250 |
+
|
| 251 |
+
# Pattern: "is true" or "is false" (standalone)
|
| 252 |
+
is_match = re.search(r'\b(is|are)\s+(true|false)\b', value_lower)
|
| 253 |
+
if is_match:
|
| 254 |
+
return is_match.group(2) == "true"
|
| 255 |
+
|
| 256 |
+
# Pattern: "true" or "false" as standalone word (not in other words)
|
| 257 |
+
standalone_match = re.search(r'\b(true|false)\b', value_lower)
|
| 258 |
+
if standalone_match:
|
| 259 |
+
return standalone_match.group(1) == "true"
|
| 260 |
+
|
| 261 |
+
# Last resort: check if "true" appears before "false" in text
|
| 262 |
+
true_pos = value_lower.find("true")
|
| 263 |
+
false_pos = value_lower.find("false")
|
| 264 |
+
|
| 265 |
+
if true_pos != -1 and false_pos != -1:
|
| 266 |
+
# Both found - use the one that appears first
|
| 267 |
+
return true_pos < false_pos
|
| 268 |
+
elif true_pos != -1:
|
| 269 |
+
return True
|
| 270 |
+
elif false_pos != -1:
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
# Cannot determine
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
def _detect_output_structure(self, output: str) -> Dict[str, Any]:
|
| 277 |
+
"""
|
| 278 |
+
Dynamically detect the structure/components of the output.
|
| 279 |
+
|
| 280 |
+
This detects:
|
| 281 |
+
- Boolean result presence
|
| 282 |
+
- Reasoning/explanation presence and quality
|
| 283 |
+
- Output format (boolean only, boolean+reasoning, etc.)
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
output: Output string to analyze
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Dictionary with structure information:
|
| 290 |
+
{
|
| 291 |
+
"has_boolean": bool,
|
| 292 |
+
"has_reasoning": bool,
|
| 293 |
+
"reasoning_length": int,
|
| 294 |
+
"reasoning_quality": str, # "missing", "minimal", "adequate", "detailed"
|
| 295 |
+
"format": str # "boolean_only", "boolean_with_reasoning", "unknown"
|
| 296 |
+
}
|
| 297 |
+
"""
|
| 298 |
+
if not output:
|
| 299 |
+
return {
|
| 300 |
+
"has_boolean": False,
|
| 301 |
+
"has_reasoning": False,
|
| 302 |
+
"reasoning_length": 0,
|
| 303 |
+
"reasoning_quality": "missing",
|
| 304 |
+
"format": "empty"
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
output_clean = output.strip()
|
| 308 |
+
|
| 309 |
+
# Detect boolean
|
| 310 |
+
has_boolean = self._normalize_to_bool(output_clean) is not None
|
| 311 |
+
|
| 312 |
+
# Extract reasoning
|
| 313 |
+
reasoning = self._extract_reasoning(output_clean)
|
| 314 |
+
has_reasoning = len(reasoning) > 15 # Minimum 15 chars to count as reasoning
|
| 315 |
+
reasoning_length = len(reasoning)
|
| 316 |
+
|
| 317 |
+
# Classify reasoning quality
|
| 318 |
+
if reasoning_length == 0:
|
| 319 |
+
reasoning_quality = "missing"
|
| 320 |
+
elif reasoning_length < 30:
|
| 321 |
+
reasoning_quality = "minimal" # Just a few words
|
| 322 |
+
elif reasoning_length < 100:
|
| 323 |
+
reasoning_quality = "adequate" # Brief explanation
|
| 324 |
+
else:
|
| 325 |
+
reasoning_quality = "detailed" # Full explanation
|
| 326 |
+
|
| 327 |
+
# Determine format
|
| 328 |
+
if has_boolean and has_reasoning:
|
| 329 |
+
output_format = "boolean_with_reasoning"
|
| 330 |
+
elif has_boolean and not has_reasoning:
|
| 331 |
+
output_format = "boolean_only"
|
| 332 |
+
elif not has_boolean and has_reasoning:
|
| 333 |
+
output_format = "reasoning_only"
|
| 334 |
+
else:
|
| 335 |
+
output_format = "unknown"
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"has_boolean": has_boolean,
|
| 339 |
+
"has_reasoning": has_reasoning,
|
| 340 |
+
"reasoning_length": reasoning_length,
|
| 341 |
+
"reasoning_quality": reasoning_quality,
|
| 342 |
+
"format": output_format
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
def _extract_reasoning(self, output: str) -> str:
|
| 346 |
+
"""
|
| 347 |
+
Extract reasoning/explanation from output string.
|
| 348 |
+
|
| 349 |
+
This is REQUIRED for LLM-as-judge feedback. The reasoning helps
|
| 350 |
+
the judge understand why the result was true/false and compare
|
| 351 |
+
predicted vs expected reasoning.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
output: Full output string that may contain reasoning
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Extracted reasoning text, or empty string if not found
|
| 358 |
+
"""
|
| 359 |
+
if not output:
|
| 360 |
+
return ""
|
| 361 |
+
|
| 362 |
+
# Patterns to find reasoning sections
|
| 363 |
+
reasoning_patterns = [
|
| 364 |
+
r'[Rr]eason[:\s]+(.*?)(?:\n\n|\Z)', # "Reason: ..."
|
| 365 |
+
r'[Ee]xplanation[:\s]+(.*?)(?:\n\n|\Z)', # "Explanation: ..."
|
| 366 |
+
r'[Bb]ecause[:\s]+(.*?)(?:\n\n|\Z)', # "Because: ..."
|
| 367 |
+
r'[Ww]hy[:\s]+(.*?)(?:\n\n|\Z)', # "Why: ..."
|
| 368 |
+
r'[Dd]etails[:\s]+(.*?)(?:\n\n|\Z)', # "Details: ..."
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
# Try each pattern
|
| 372 |
+
for pattern in reasoning_patterns:
|
| 373 |
+
match = re.search(pattern, output, re.DOTALL | re.IGNORECASE)
|
| 374 |
+
if match:
|
| 375 |
+
reasoning = match.group(1).strip()
|
| 376 |
+
if len(reasoning) > 20: # Only return if substantial
|
| 377 |
+
return reasoning
|
| 378 |
+
|
| 379 |
+
# If no explicit reasoning section, check if output has substantial text
|
| 380 |
+
# after boolean (likely contains reasoning)
|
| 381 |
+
bool_match = re.search(r'\b(true|false)\b', output.lower())
|
| 382 |
+
if bool_match:
|
| 383 |
+
# Get text after the boolean
|
| 384 |
+
bool_pos = bool_match.end()
|
| 385 |
+
remaining = output[bool_pos:].strip()
|
| 386 |
+
|
| 387 |
+
# If remaining text is substantial (more than just punctuation), use it
|
| 388 |
+
if len(remaining) > 30:
|
| 389 |
+
# Clean up common prefixes
|
| 390 |
+
remaining = re.sub(r'^[:\s.,;!?-]+', '', remaining)
|
| 391 |
+
if remaining:
|
| 392 |
+
return remaining
|
| 393 |
+
|
| 394 |
+
# If output is long and doesn't start with boolean, might be all reasoning
|
| 395 |
+
if len(output) > 100 and not re.match(r'^\s*(true|false)\s*$', output, re.IGNORECASE):
|
| 396 |
+
# Return first 500 chars as reasoning
|
| 397 |
+
return output[:500].strip()
|
| 398 |
+
|
| 399 |
+
# No reasoning found
|
| 400 |
+
return ""
|
| 401 |
+
|
| 402 |
+
def get_evaluation_summary(self, results: list) -> Dict[str, Any]:
|
| 403 |
+
"""
|
| 404 |
+
Get summary statistics for a batch of evaluations.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
results: List of evaluation result dictionaries
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Summary statistics including accuracy, true/false distribution
|
| 411 |
+
"""
|
| 412 |
+
if not results:
|
| 413 |
+
return {
|
| 414 |
+
"total_samples": 0,
|
| 415 |
+
"accuracy": 0.0,
|
| 416 |
+
"correct_predictions": 0,
|
| 417 |
+
"incorrect_predictions": 0,
|
| 418 |
+
"true_predictions": 0,
|
| 419 |
+
"false_predictions": 0
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
total = len(results)
|
| 423 |
+
correct = sum(1 for r in results if r.get("output_match", 0.0) == 1.0)
|
| 424 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 425 |
+
|
| 426 |
+
# Count true/false predictions
|
| 427 |
+
true_preds = sum(1 for r in results if r.get("predicted_boolean") is True)
|
| 428 |
+
false_preds = sum(1 for r in results if r.get("predicted_boolean") is False)
|
| 429 |
+
|
| 430 |
+
return {
|
| 431 |
+
"total_samples": total,
|
| 432 |
+
"accuracy": accuracy,
|
| 433 |
+
"correct_predictions": correct,
|
| 434 |
+
"incorrect_predictions": total - correct,
|
| 435 |
+
"true_predictions": true_preds,
|
| 436 |
+
"false_predictions": false_preds
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# Example usage and testing
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
print("🚀 Testing Validation Evaluator...")
|
| 443 |
+
|
| 444 |
+
evaluator = ValidationEvaluator()
|
| 445 |
+
|
| 446 |
+
# Test cases
|
| 447 |
+
test_cases = [
|
| 448 |
+
# (predicted, expected, should_match)
|
| 449 |
+
("true", "true", True),
|
| 450 |
+
("false", "false", True),
|
| 451 |
+
("True", "true", True),
|
| 452 |
+
("FALSE", "false", True),
|
| 453 |
+
("1", "true", True),
|
| 454 |
+
("0", "false", True),
|
| 455 |
+
("true", "false", False),
|
| 456 |
+
("false", "true", False),
|
| 457 |
+
("The result is true because the button is visible", "true", True),
|
| 458 |
+
("The result is false because the element is not found", "false", True),
|
| 459 |
+
('{"result": true, "reasoning": "Button is visible"}', "true", True),
|
| 460 |
+
("Result: true\n\nReasoning: The submit button is clearly visible at the bottom of the screen.", "true", True),
|
| 461 |
+
("", "true", False),
|
| 462 |
+
("invalid", "true", False),
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
print("\n📝 Running test cases:")
|
| 466 |
+
print("-" * 80)
|
| 467 |
+
|
| 468 |
+
results = []
|
| 469 |
+
for predicted, expected, should_match in test_cases:
|
| 470 |
+
result = evaluator.evaluate(predicted, expected)
|
| 471 |
+
match = result["composite_score"] == 1.0
|
| 472 |
+
|
| 473 |
+
status = "✅" if match == should_match else "❌"
|
| 474 |
+
pred_bool = result.get("predicted_boolean", "?")
|
| 475 |
+
exp_bool = result.get("expected_boolean", "?")
|
| 476 |
+
pred_reason = result.get("predicted_reasoning", "")[:50]
|
| 477 |
+
|
| 478 |
+
print(f"{status} Predicted: '{predicted[:40]}...' → {pred_bool}")
|
| 479 |
+
print(f" Expected: '{expected}' → {exp_bool}")
|
| 480 |
+
print(f" Match: {match} (should be {should_match})")
|
| 481 |
+
if pred_reason:
|
| 482 |
+
print(f" Reasoning: {pred_reason}...")
|
| 483 |
+
print()
|
| 484 |
+
|
| 485 |
+
results.append(result)
|
| 486 |
+
|
| 487 |
+
# Summary
|
| 488 |
+
print("\n📊 Summary:")
|
| 489 |
+
summary = evaluator.get_evaluation_summary(results)
|
| 490 |
+
print(f" Total: {summary['total_samples']}")
|
| 491 |
+
print(f" Correct: {summary['correct_predictions']}")
|
| 492 |
+
print(f" Accuracy: {summary['accuracy']:.1%}")
|
| 493 |
+
print(f" True predictions: {summary['true_predictions']}")
|
| 494 |
+
print(f" False predictions: {summary['false_predictions']}")
|
| 495 |
+
|
src/gepa_optimizer/infrastructure/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Infrastructure module for cross-cutting concerns.
|
| 3 |
+
|
| 4 |
+
This module contains infrastructure components that are used across
|
| 5 |
+
the entire application, including logging, metrics, and configuration.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .logging import get_logger, configure_logging, LogContext
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"get_logger",
|
| 12 |
+
"configure_logging",
|
| 13 |
+
"LogContext",
|
| 14 |
+
]
|
| 15 |
+
|
src/gepa_optimizer/infrastructure/logging/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Centralized Logging Infrastructure for GEPA Optimizer.
|
| 3 |
+
|
| 4 |
+
This module provides a unified logging system with:
|
| 5 |
+
- Structured logging with context
|
| 6 |
+
- Consistent formatting across all modules
|
| 7 |
+
- Log level configuration
|
| 8 |
+
- Operation tracking with timing
|
| 9 |
+
- Contextual logging for debugging
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
from gepa_optimizer.infrastructure.logging import get_logger, LogContext
|
| 13 |
+
|
| 14 |
+
logger = get_logger(__name__)
|
| 15 |
+
logger.info("Starting optimization", extra={"iteration": 1})
|
| 16 |
+
|
| 17 |
+
with LogContext(logger, "evaluation", sample_id=123):
|
| 18 |
+
logger.info("Evaluating sample")
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from .logger import (
|
| 22 |
+
get_logger,
|
| 23 |
+
configure_logging,
|
| 24 |
+
LogLevel,
|
| 25 |
+
GEPA_LOGGER_NAME,
|
| 26 |
+
)
|
| 27 |
+
from .context import LogContext, log_operation
|
| 28 |
+
from .formatters import GepaFormatter, JsonFormatter
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
# Core logging
|
| 32 |
+
"get_logger",
|
| 33 |
+
"configure_logging",
|
| 34 |
+
"LogLevel",
|
| 35 |
+
"GEPA_LOGGER_NAME",
|
| 36 |
+
# Context management
|
| 37 |
+
"LogContext",
|
| 38 |
+
"log_operation",
|
| 39 |
+
# Formatters
|
| 40 |
+
"GepaFormatter",
|
| 41 |
+
"JsonFormatter",
|
| 42 |
+
]
|
| 43 |
+
|
src/gepa_optimizer/infrastructure/logging/context.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging Context Management.
|
| 3 |
+
|
| 4 |
+
Provides context managers and decorators for:
|
| 5 |
+
- Operation tracking with timing
|
| 6 |
+
- Contextual logging with nested contexts
|
| 7 |
+
- Automatic exception logging
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import time
|
| 12 |
+
import functools
|
| 13 |
+
from contextlib import contextmanager
|
| 14 |
+
from typing import Any, Callable, Dict, Optional, TypeVar, ParamSpec
|
| 15 |
+
|
| 16 |
+
P = ParamSpec('P')
|
| 17 |
+
R = TypeVar('R')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LogContext:
|
| 21 |
+
"""
|
| 22 |
+
Context manager for logging operations with timing and context.
|
| 23 |
+
|
| 24 |
+
Features:
|
| 25 |
+
- Automatic start/end logging
|
| 26 |
+
- Timing measurement
|
| 27 |
+
- Exception capture
|
| 28 |
+
- Nested context support
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
logger = get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
with LogContext(logger, "optimization", iteration=5):
|
| 34 |
+
# ... optimization code ...
|
| 35 |
+
logger.info("Processing sample") # Inherits context
|
| 36 |
+
|
| 37 |
+
# Output:
|
| 38 |
+
# INFO | Starting optimization | iteration=5
|
| 39 |
+
# INFO | Processing sample | iteration=5
|
| 40 |
+
# INFO | Completed optimization | iteration=5 duration_ms=1234
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
logger: logging.Logger,
|
| 46 |
+
operation: str,
|
| 47 |
+
log_start: bool = True,
|
| 48 |
+
log_end: bool = True,
|
| 49 |
+
log_level: int = logging.INFO,
|
| 50 |
+
**context_fields: Any
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Initialize log context.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
logger: Logger instance to use
|
| 57 |
+
operation: Name of the operation being performed
|
| 58 |
+
log_start: Whether to log when entering context
|
| 59 |
+
log_end: Whether to log when exiting context
|
| 60 |
+
log_level: Log level for start/end messages
|
| 61 |
+
**context_fields: Additional fields to include in all logs
|
| 62 |
+
"""
|
| 63 |
+
self.logger = logger
|
| 64 |
+
self.operation = operation
|
| 65 |
+
self.log_start = log_start
|
| 66 |
+
self.log_end = log_end
|
| 67 |
+
self.log_level = log_level
|
| 68 |
+
self.context_fields = context_fields
|
| 69 |
+
self.start_time: Optional[float] = None
|
| 70 |
+
self.exception: Optional[Exception] = None
|
| 71 |
+
|
| 72 |
+
def __enter__(self) -> "LogContext":
|
| 73 |
+
"""Enter the context, logging start if configured."""
|
| 74 |
+
self.start_time = time.perf_counter()
|
| 75 |
+
|
| 76 |
+
if self.log_start:
|
| 77 |
+
self.logger.log(
|
| 78 |
+
self.log_level,
|
| 79 |
+
f"Starting {self.operation}",
|
| 80 |
+
extra=self.context_fields
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
| 86 |
+
"""Exit the context, logging completion or error."""
|
| 87 |
+
duration_ms = (time.perf_counter() - self.start_time) * 1000
|
| 88 |
+
|
| 89 |
+
extra = {
|
| 90 |
+
**self.context_fields,
|
| 91 |
+
"duration_ms": round(duration_ms, 2)
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if exc_type is not None:
|
| 95 |
+
# Log exception
|
| 96 |
+
self.exception = exc_val
|
| 97 |
+
self.logger.error(
|
| 98 |
+
f"Failed {self.operation}: {exc_type.__name__}: {exc_val}",
|
| 99 |
+
extra=extra,
|
| 100 |
+
exc_info=True
|
| 101 |
+
)
|
| 102 |
+
# Don't suppress the exception
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
if self.log_end:
|
| 106 |
+
self.logger.log(
|
| 107 |
+
self.log_level,
|
| 108 |
+
f"Completed {self.operation}",
|
| 109 |
+
extra=extra
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
def log(self, level: int, message: str, **extra_fields: Any) -> None:
|
| 115 |
+
"""Log a message within this context, inheriting context fields."""
|
| 116 |
+
self.logger.log(
|
| 117 |
+
level,
|
| 118 |
+
message,
|
| 119 |
+
extra={**self.context_fields, **extra_fields}
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def info(self, message: str, **extra_fields: Any) -> None:
|
| 123 |
+
"""Log info message within context."""
|
| 124 |
+
self.log(logging.INFO, message, **extra_fields)
|
| 125 |
+
|
| 126 |
+
def debug(self, message: str, **extra_fields: Any) -> None:
|
| 127 |
+
"""Log debug message within context."""
|
| 128 |
+
self.log(logging.DEBUG, message, **extra_fields)
|
| 129 |
+
|
| 130 |
+
def warning(self, message: str, **extra_fields: Any) -> None:
|
| 131 |
+
"""Log warning message within context."""
|
| 132 |
+
self.log(logging.WARNING, message, **extra_fields)
|
| 133 |
+
|
| 134 |
+
def error(self, message: str, **extra_fields: Any) -> None:
|
| 135 |
+
"""Log error message within context."""
|
| 136 |
+
self.log(logging.ERROR, message, **extra_fields)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def log_operation(
|
| 140 |
+
logger: Optional[logging.Logger] = None,
|
| 141 |
+
operation: Optional[str] = None,
|
| 142 |
+
log_args: bool = False,
|
| 143 |
+
log_result: bool = False,
|
| 144 |
+
log_level: int = logging.INFO,
|
| 145 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
| 146 |
+
"""
|
| 147 |
+
Decorator for logging function execution.
|
| 148 |
+
|
| 149 |
+
Automatically logs:
|
| 150 |
+
- Function entry (with arguments if configured)
|
| 151 |
+
- Function exit (with result if configured)
|
| 152 |
+
- Execution duration
|
| 153 |
+
- Exceptions
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
logger: Logger to use (defaults to logger named after module)
|
| 157 |
+
operation: Operation name (defaults to function name)
|
| 158 |
+
log_args: Whether to log function arguments
|
| 159 |
+
log_result: Whether to log function result
|
| 160 |
+
log_level: Log level for messages
|
| 161 |
+
|
| 162 |
+
Example:
|
| 163 |
+
@log_operation(log_args=True)
|
| 164 |
+
def process_batch(batch_id: int, items: List[str]) -> int:
|
| 165 |
+
return len(items)
|
| 166 |
+
|
| 167 |
+
# Output:
|
| 168 |
+
# INFO | Starting process_batch | batch_id=123 items=['a', 'b']
|
| 169 |
+
# INFO | Completed process_batch | duration_ms=45.2 result=2
|
| 170 |
+
"""
|
| 171 |
+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
| 172 |
+
nonlocal logger, operation
|
| 173 |
+
|
| 174 |
+
if logger is None:
|
| 175 |
+
logger = logging.getLogger(func.__module__)
|
| 176 |
+
if operation is None:
|
| 177 |
+
operation = func.__name__
|
| 178 |
+
|
| 179 |
+
@functools.wraps(func)
|
| 180 |
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
| 181 |
+
start_time = time.perf_counter()
|
| 182 |
+
|
| 183 |
+
# Build context fields
|
| 184 |
+
extra: Dict[str, Any] = {}
|
| 185 |
+
if log_args:
|
| 186 |
+
# Include positional args (skip self for methods)
|
| 187 |
+
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
|
| 188 |
+
for i, (name, value) in enumerate(zip(arg_names, args)):
|
| 189 |
+
if name != 'self':
|
| 190 |
+
extra[name] = _safe_repr(value)
|
| 191 |
+
# Include keyword args
|
| 192 |
+
for key, value in kwargs.items():
|
| 193 |
+
extra[key] = _safe_repr(value)
|
| 194 |
+
|
| 195 |
+
logger.log(log_level, f"Starting {operation}", extra=extra)
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
result = func(*args, **kwargs)
|
| 199 |
+
|
| 200 |
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
| 201 |
+
result_extra: Dict[str, Any] = {"duration_ms": round(duration_ms, 2)}
|
| 202 |
+
|
| 203 |
+
if log_result:
|
| 204 |
+
result_extra["result"] = _safe_repr(result)
|
| 205 |
+
|
| 206 |
+
logger.log(log_level, f"Completed {operation}", extra=result_extra)
|
| 207 |
+
|
| 208 |
+
return result
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
| 212 |
+
logger.error(
|
| 213 |
+
f"Failed {operation}: {type(e).__name__}: {e}",
|
| 214 |
+
extra={"duration_ms": round(duration_ms, 2)},
|
| 215 |
+
exc_info=True
|
| 216 |
+
)
|
| 217 |
+
raise
|
| 218 |
+
|
| 219 |
+
return wrapper
|
| 220 |
+
|
| 221 |
+
return decorator
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@contextmanager
|
| 225 |
+
def timed_block(logger: logging.Logger, description: str, log_level: int = logging.DEBUG):
|
| 226 |
+
"""
|
| 227 |
+
Simple context manager for timing a block of code.
|
| 228 |
+
|
| 229 |
+
Less verbose than LogContext, suitable for quick timing measurements.
|
| 230 |
+
|
| 231 |
+
Example:
|
| 232 |
+
with timed_block(logger, "data processing"):
|
| 233 |
+
process_data()
|
| 234 |
+
# Output: DEBUG | data processing completed in 123.45ms
|
| 235 |
+
"""
|
| 236 |
+
start = time.perf_counter()
|
| 237 |
+
try:
|
| 238 |
+
yield
|
| 239 |
+
finally:
|
| 240 |
+
duration_ms = (time.perf_counter() - start) * 1000
|
| 241 |
+
logger.log(log_level, f"{description} completed in {duration_ms:.2f}ms")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _safe_repr(value: Any, max_length: int = 100) -> str:
|
| 245 |
+
"""
|
| 246 |
+
Create a safe string representation of a value for logging.
|
| 247 |
+
|
| 248 |
+
Truncates long strings and handles non-serializable objects.
|
| 249 |
+
"""
|
| 250 |
+
try:
|
| 251 |
+
repr_str = repr(value)
|
| 252 |
+
if len(repr_str) > max_length:
|
| 253 |
+
return repr_str[:max_length] + "..."
|
| 254 |
+
return repr_str
|
| 255 |
+
except Exception:
|
| 256 |
+
return f"<{type(value).__name__}>"
|
| 257 |
+
|
src/gepa_optimizer/infrastructure/logging/formatters.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Log Formatters for GEPA Optimizer.
|
| 3 |
+
|
| 4 |
+
Provides formatters for:
|
| 5 |
+
- Console output with colors and emoji
|
| 6 |
+
- JSON structured logging for production
|
| 7 |
+
- Plain text for file logging
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Any, Dict, Optional
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ANSI color codes for terminal output
|
| 17 |
+
class Colors:
|
| 18 |
+
"""ANSI color codes for terminal coloring."""
|
| 19 |
+
RESET = "\033[0m"
|
| 20 |
+
BOLD = "\033[1m"
|
| 21 |
+
DIM = "\033[2m"
|
| 22 |
+
|
| 23 |
+
# Log level colors
|
| 24 |
+
DEBUG = "\033[36m" # Cyan
|
| 25 |
+
INFO = "\033[32m" # Green
|
| 26 |
+
WARNING = "\033[33m" # Yellow
|
| 27 |
+
ERROR = "\033[31m" # Red
|
| 28 |
+
CRITICAL = "\033[35m" # Magenta
|
| 29 |
+
|
| 30 |
+
# Semantic colors
|
| 31 |
+
TIMESTAMP = "\033[90m" # Gray
|
| 32 |
+
MODULE = "\033[34m" # Blue
|
| 33 |
+
MESSAGE = "\033[0m" # Default
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Emoji prefixes for visual log scanning
|
| 37 |
+
LEVEL_EMOJI = {
|
| 38 |
+
logging.DEBUG: "🔍",
|
| 39 |
+
logging.INFO: "ℹ️ ",
|
| 40 |
+
logging.WARNING: "⚠️ ",
|
| 41 |
+
logging.ERROR: "❌",
|
| 42 |
+
logging.CRITICAL: "🚨",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Level colors mapping
|
| 46 |
+
LEVEL_COLORS = {
|
| 47 |
+
logging.DEBUG: Colors.DEBUG,
|
| 48 |
+
logging.INFO: Colors.INFO,
|
| 49 |
+
logging.WARNING: Colors.WARNING,
|
| 50 |
+
logging.ERROR: Colors.ERROR,
|
| 51 |
+
logging.CRITICAL: Colors.CRITICAL,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class GepaFormatter(logging.Formatter):
|
| 56 |
+
"""
|
| 57 |
+
Custom formatter for GEPA Optimizer logs.
|
| 58 |
+
|
| 59 |
+
Features:
|
| 60 |
+
- Optional color output for console
|
| 61 |
+
- Optional emoji prefixes for visual scanning
|
| 62 |
+
- Structured extra fields support
|
| 63 |
+
- Clean, readable format
|
| 64 |
+
|
| 65 |
+
Example output:
|
| 66 |
+
2024-01-15 10:30:45 | INFO | ℹ️ gepa_optimizer.core.optimizer | Starting optimization iteration=5
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
fmt: Optional[str] = None,
|
| 72 |
+
datefmt: Optional[str] = None,
|
| 73 |
+
use_colors: bool = True,
|
| 74 |
+
include_emoji: bool = True,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Initialize the formatter.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
fmt: Format string (uses default if not provided)
|
| 81 |
+
datefmt: Date format string
|
| 82 |
+
use_colors: Whether to use ANSI colors
|
| 83 |
+
include_emoji: Whether to include emoji prefixes
|
| 84 |
+
"""
|
| 85 |
+
super().__init__(fmt=fmt, datefmt=datefmt)
|
| 86 |
+
self.use_colors = use_colors
|
| 87 |
+
self.include_emoji = include_emoji
|
| 88 |
+
|
| 89 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 90 |
+
"""Format a log record with colors and emoji."""
|
| 91 |
+
# Store original values
|
| 92 |
+
original_msg = record.msg
|
| 93 |
+
original_levelname = record.levelname
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
# Add emoji prefix if enabled
|
| 97 |
+
if self.include_emoji:
|
| 98 |
+
emoji = LEVEL_EMOJI.get(record.levelno, "")
|
| 99 |
+
record.levelname = f"{emoji} {record.levelname}"
|
| 100 |
+
|
| 101 |
+
# Add colors if enabled
|
| 102 |
+
if self.use_colors:
|
| 103 |
+
color = LEVEL_COLORS.get(record.levelno, Colors.RESET)
|
| 104 |
+
record.levelname = f"{color}{record.levelname}{Colors.RESET}"
|
| 105 |
+
record.name = f"{Colors.MODULE}{record.name}{Colors.RESET}"
|
| 106 |
+
|
| 107 |
+
# Format extra fields if present
|
| 108 |
+
extra_str = self._format_extra(record)
|
| 109 |
+
if extra_str:
|
| 110 |
+
record.msg = f"{record.msg} | {extra_str}"
|
| 111 |
+
|
| 112 |
+
# Call parent formatter
|
| 113 |
+
formatted = super().format(record)
|
| 114 |
+
|
| 115 |
+
return formatted
|
| 116 |
+
|
| 117 |
+
finally:
|
| 118 |
+
# Restore original values
|
| 119 |
+
record.msg = original_msg
|
| 120 |
+
record.levelname = original_levelname
|
| 121 |
+
|
| 122 |
+
def _format_extra(self, record: logging.LogRecord) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Format extra fields from the log record.
|
| 125 |
+
|
| 126 |
+
Extra fields are passed via the 'extra' parameter to logging calls:
|
| 127 |
+
logger.info("Message", extra={"key": "value"})
|
| 128 |
+
"""
|
| 129 |
+
# Standard LogRecord attributes to exclude
|
| 130 |
+
standard_attrs = {
|
| 131 |
+
'name', 'msg', 'args', 'created', 'filename', 'funcName',
|
| 132 |
+
'levelname', 'levelno', 'lineno', 'module', 'msecs',
|
| 133 |
+
'pathname', 'process', 'processName', 'relativeCreated',
|
| 134 |
+
'stack_info', 'exc_info', 'exc_text', 'thread', 'threadName',
|
| 135 |
+
'taskName', 'message'
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Collect extra fields
|
| 139 |
+
extra_fields = {
|
| 140 |
+
k: v for k, v in record.__dict__.items()
|
| 141 |
+
if k not in standard_attrs and not k.startswith('_')
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
if not extra_fields:
|
| 145 |
+
return ""
|
| 146 |
+
|
| 147 |
+
# Format as key=value pairs
|
| 148 |
+
parts = []
|
| 149 |
+
for key, value in extra_fields.items():
|
| 150 |
+
if isinstance(value, str):
|
| 151 |
+
parts.append(f"{key}={value}")
|
| 152 |
+
elif isinstance(value, (int, float)):
|
| 153 |
+
parts.append(f"{key}={value}")
|
| 154 |
+
elif isinstance(value, bool):
|
| 155 |
+
parts.append(f"{key}={str(value).lower()}")
|
| 156 |
+
else:
|
| 157 |
+
parts.append(f"{key}={repr(value)}")
|
| 158 |
+
|
| 159 |
+
return " ".join(parts)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class JsonFormatter(logging.Formatter):
|
| 163 |
+
"""
|
| 164 |
+
JSON formatter for structured logging.
|
| 165 |
+
|
| 166 |
+
Outputs each log record as a single JSON line, suitable for:
|
| 167 |
+
- Log aggregation systems (ELK, Splunk)
|
| 168 |
+
- Cloud logging (CloudWatch, Stackdriver)
|
| 169 |
+
- Log parsing and analysis
|
| 170 |
+
|
| 171 |
+
Example output:
|
| 172 |
+
{"timestamp": "2024-01-15T10:30:45.123Z", "level": "INFO", "logger": "gepa_optimizer.core", "message": "Starting optimization", "iteration": 5}
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
include_timestamp: bool = True,
|
| 178 |
+
include_location: bool = False,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Initialize JSON formatter.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
include_timestamp: Include ISO timestamp
|
| 185 |
+
include_location: Include file/line information
|
| 186 |
+
"""
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.include_timestamp = include_timestamp
|
| 189 |
+
self.include_location = include_location
|
| 190 |
+
|
| 191 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 192 |
+
"""Format record as JSON string."""
|
| 193 |
+
log_dict: Dict[str, Any] = {}
|
| 194 |
+
|
| 195 |
+
# Timestamp
|
| 196 |
+
if self.include_timestamp:
|
| 197 |
+
log_dict["timestamp"] = datetime.utcfromtimestamp(
|
| 198 |
+
record.created
|
| 199 |
+
).isoformat() + "Z"
|
| 200 |
+
|
| 201 |
+
# Core fields
|
| 202 |
+
log_dict["level"] = record.levelname
|
| 203 |
+
log_dict["logger"] = record.name
|
| 204 |
+
log_dict["message"] = record.getMessage()
|
| 205 |
+
|
| 206 |
+
# Location info
|
| 207 |
+
if self.include_location:
|
| 208 |
+
log_dict["file"] = record.filename
|
| 209 |
+
log_dict["line"] = record.lineno
|
| 210 |
+
log_dict["function"] = record.funcName
|
| 211 |
+
|
| 212 |
+
# Exception info
|
| 213 |
+
if record.exc_info:
|
| 214 |
+
log_dict["exception"] = self.formatException(record.exc_info)
|
| 215 |
+
|
| 216 |
+
# Extra fields
|
| 217 |
+
standard_attrs = {
|
| 218 |
+
'name', 'msg', 'args', 'created', 'filename', 'funcName',
|
| 219 |
+
'levelname', 'levelno', 'lineno', 'module', 'msecs',
|
| 220 |
+
'pathname', 'process', 'processName', 'relativeCreated',
|
| 221 |
+
'stack_info', 'exc_info', 'exc_text', 'thread', 'threadName',
|
| 222 |
+
'taskName', 'message'
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
for key, value in record.__dict__.items():
|
| 226 |
+
if key not in standard_attrs and not key.startswith('_'):
|
| 227 |
+
try:
|
| 228 |
+
# Ensure value is JSON serializable
|
| 229 |
+
json.dumps(value)
|
| 230 |
+
log_dict[key] = value
|
| 231 |
+
except (TypeError, ValueError):
|
| 232 |
+
log_dict[key] = str(value)
|
| 233 |
+
|
| 234 |
+
return json.dumps(log_dict, default=str)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class CompactFormatter(logging.Formatter):
|
| 238 |
+
"""
|
| 239 |
+
Compact formatter for minimal log output.
|
| 240 |
+
|
| 241 |
+
Useful for:
|
| 242 |
+
- CI/CD pipelines
|
| 243 |
+
- Reduced log verbosity
|
| 244 |
+
- Quick debugging
|
| 245 |
+
|
| 246 |
+
Example output:
|
| 247 |
+
10:30:45 INFO optimizer: Starting optimization
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 251 |
+
"""Format record in compact form."""
|
| 252 |
+
# Short timestamp (time only)
|
| 253 |
+
time_str = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
| 254 |
+
|
| 255 |
+
# Short module name (last part only)
|
| 256 |
+
short_name = record.name.split(".")[-1]
|
| 257 |
+
|
| 258 |
+
return f"{time_str} {record.levelname:5s} {short_name}: {record.getMessage()}"
|
| 259 |
+
|
src/gepa_optimizer/infrastructure/logging/logger.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Logger Factory and Configuration.
|
| 3 |
+
|
| 4 |
+
This module provides the centralized logger factory that should be used
|
| 5 |
+
across all GEPA Optimizer modules. It ensures consistent logging behavior
|
| 6 |
+
and formatting throughout the application.
|
| 7 |
+
|
| 8 |
+
Design Principles:
|
| 9 |
+
- Single source of truth for logger configuration
|
| 10 |
+
- Lazy initialization (loggers created on first use)
|
| 11 |
+
- Thread-safe logger access
|
| 12 |
+
- Configurable log levels per module
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import sys
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from typing import Optional, Dict, Any
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
|
| 21 |
+
from .formatters import GepaFormatter
|
| 22 |
+
|
| 23 |
+
# Root logger name for GEPA Optimizer
|
| 24 |
+
GEPA_LOGGER_NAME = "gepa_optimizer"
|
| 25 |
+
|
| 26 |
+
# Default log format
|
| 27 |
+
DEFAULT_FORMAT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
| 28 |
+
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LogLevel(str, Enum):
|
| 32 |
+
"""Supported log levels with string representation."""
|
| 33 |
+
DEBUG = "DEBUG"
|
| 34 |
+
INFO = "INFO"
|
| 35 |
+
WARNING = "WARNING"
|
| 36 |
+
ERROR = "ERROR"
|
| 37 |
+
CRITICAL = "CRITICAL"
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_string(cls, level: str) -> "LogLevel":
|
| 41 |
+
"""Convert string to LogLevel enum."""
|
| 42 |
+
try:
|
| 43 |
+
return cls(level.upper())
|
| 44 |
+
except ValueError:
|
| 45 |
+
return cls.INFO
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LoggerConfig:
|
| 49 |
+
"""
|
| 50 |
+
Configuration class for GEPA logging.
|
| 51 |
+
|
| 52 |
+
This class holds all logging configuration and can be modified
|
| 53 |
+
before calling configure_logging() to customize behavior.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Default configuration
|
| 57 |
+
level: LogLevel = LogLevel.INFO
|
| 58 |
+
format: str = DEFAULT_FORMAT
|
| 59 |
+
date_format: str = DEFAULT_DATE_FORMAT
|
| 60 |
+
|
| 61 |
+
# Module-specific log levels (for fine-grained control)
|
| 62 |
+
module_levels: Dict[str, LogLevel] = {}
|
| 63 |
+
|
| 64 |
+
# Output configuration
|
| 65 |
+
log_to_console: bool = True
|
| 66 |
+
log_to_file: Optional[str] = None
|
| 67 |
+
|
| 68 |
+
# Formatting options
|
| 69 |
+
use_colors: bool = True
|
| 70 |
+
include_emoji: bool = True # For visual clarity in development
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def reset(cls) -> None:
|
| 74 |
+
"""Reset configuration to defaults."""
|
| 75 |
+
cls.level = LogLevel.INFO
|
| 76 |
+
cls.format = DEFAULT_FORMAT
|
| 77 |
+
cls.date_format = DEFAULT_DATE_FORMAT
|
| 78 |
+
cls.module_levels = {}
|
| 79 |
+
cls.log_to_console = True
|
| 80 |
+
cls.log_to_file = None
|
| 81 |
+
cls.use_colors = True
|
| 82 |
+
cls.include_emoji = True
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Global flag to track if logging is configured
|
| 86 |
+
_logging_configured = False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def configure_logging(
|
| 90 |
+
level: Optional[str] = None,
|
| 91 |
+
log_file: Optional[str] = None,
|
| 92 |
+
use_colors: bool = True,
|
| 93 |
+
include_emoji: bool = True,
|
| 94 |
+
format_string: Optional[str] = None,
|
| 95 |
+
module_levels: Optional[Dict[str, str]] = None,
|
| 96 |
+
) -> None:
|
| 97 |
+
"""
|
| 98 |
+
Configure the GEPA logging system.
|
| 99 |
+
|
| 100 |
+
This should be called once at application startup. Subsequent calls
|
| 101 |
+
will update the configuration.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
level: Global log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 105 |
+
log_file: Optional path to log file
|
| 106 |
+
use_colors: Whether to use colored output in console
|
| 107 |
+
include_emoji: Whether to include emoji prefixes for visual clarity
|
| 108 |
+
format_string: Custom format string (optional)
|
| 109 |
+
module_levels: Dict mapping module names to their specific log levels
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
configure_logging(
|
| 113 |
+
level="DEBUG",
|
| 114 |
+
log_file="optimization.log",
|
| 115 |
+
module_levels={
|
| 116 |
+
"gepa_optimizer.core.optimizer": "INFO",
|
| 117 |
+
"gepa_optimizer.llms": "DEBUG"
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
"""
|
| 121 |
+
global _logging_configured
|
| 122 |
+
|
| 123 |
+
# Update configuration
|
| 124 |
+
if level:
|
| 125 |
+
LoggerConfig.level = LogLevel.from_string(level)
|
| 126 |
+
if log_file:
|
| 127 |
+
LoggerConfig.log_to_file = log_file
|
| 128 |
+
LoggerConfig.use_colors = use_colors
|
| 129 |
+
LoggerConfig.include_emoji = include_emoji
|
| 130 |
+
if format_string:
|
| 131 |
+
LoggerConfig.format = format_string
|
| 132 |
+
if module_levels:
|
| 133 |
+
LoggerConfig.module_levels = {
|
| 134 |
+
k: LogLevel.from_string(v) for k, v in module_levels.items()
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Get or create root GEPA logger
|
| 138 |
+
root_logger = logging.getLogger(GEPA_LOGGER_NAME)
|
| 139 |
+
root_logger.setLevel(getattr(logging, LoggerConfig.level.value))
|
| 140 |
+
|
| 141 |
+
# Remove existing handlers to avoid duplicates
|
| 142 |
+
root_logger.handlers.clear()
|
| 143 |
+
|
| 144 |
+
# Console handler
|
| 145 |
+
if LoggerConfig.log_to_console:
|
| 146 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 147 |
+
console_handler.setLevel(getattr(logging, LoggerConfig.level.value))
|
| 148 |
+
|
| 149 |
+
# Use custom formatter
|
| 150 |
+
formatter = GepaFormatter(
|
| 151 |
+
fmt=LoggerConfig.format,
|
| 152 |
+
datefmt=LoggerConfig.date_format,
|
| 153 |
+
use_colors=use_colors,
|
| 154 |
+
include_emoji=include_emoji,
|
| 155 |
+
)
|
| 156 |
+
console_handler.setFormatter(formatter)
|
| 157 |
+
root_logger.addHandler(console_handler)
|
| 158 |
+
|
| 159 |
+
# File handler (if configured)
|
| 160 |
+
if LoggerConfig.log_to_file:
|
| 161 |
+
file_handler = logging.FileHandler(LoggerConfig.log_to_file)
|
| 162 |
+
file_handler.setLevel(getattr(logging, LoggerConfig.level.value))
|
| 163 |
+
|
| 164 |
+
# File logs don't use colors
|
| 165 |
+
file_formatter = GepaFormatter(
|
| 166 |
+
fmt=LoggerConfig.format,
|
| 167 |
+
datefmt=LoggerConfig.date_format,
|
| 168 |
+
use_colors=False,
|
| 169 |
+
include_emoji=False,
|
| 170 |
+
)
|
| 171 |
+
file_handler.setFormatter(file_formatter)
|
| 172 |
+
root_logger.addHandler(file_handler)
|
| 173 |
+
|
| 174 |
+
# Apply module-specific levels
|
| 175 |
+
for module_name, module_level in LoggerConfig.module_levels.items():
|
| 176 |
+
module_logger = logging.getLogger(module_name)
|
| 177 |
+
module_logger.setLevel(getattr(logging, module_level.value))
|
| 178 |
+
|
| 179 |
+
_logging_configured = True
|
| 180 |
+
|
| 181 |
+
# Log that configuration is complete
|
| 182 |
+
root_logger.debug(
|
| 183 |
+
f"Logging configured: level={LoggerConfig.level.value}, "
|
| 184 |
+
f"file={LoggerConfig.log_to_file}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@lru_cache(maxsize=128)
|
| 189 |
+
def get_logger(name: str) -> logging.Logger:
|
| 190 |
+
"""
|
| 191 |
+
Get a logger instance for the given module name.
|
| 192 |
+
|
| 193 |
+
This is the primary factory function for obtaining loggers.
|
| 194 |
+
All GEPA modules should use this instead of logging.getLogger().
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
name: Module name (typically __name__)
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Configured Logger instance
|
| 201 |
+
|
| 202 |
+
Example:
|
| 203 |
+
from gepa_optimizer.infrastructure.logging import get_logger
|
| 204 |
+
|
| 205 |
+
logger = get_logger(__name__)
|
| 206 |
+
logger.info("Starting process")
|
| 207 |
+
logger.error("Failed to connect", exc_info=True)
|
| 208 |
+
"""
|
| 209 |
+
global _logging_configured
|
| 210 |
+
|
| 211 |
+
# Auto-configure with defaults if not yet configured
|
| 212 |
+
if not _logging_configured:
|
| 213 |
+
configure_logging()
|
| 214 |
+
|
| 215 |
+
# Ensure name is under GEPA namespace for consistent handling
|
| 216 |
+
if not name.startswith(GEPA_LOGGER_NAME) and name != GEPA_LOGGER_NAME:
|
| 217 |
+
# External module - still use our formatting
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
logger = logging.getLogger(name)
|
| 221 |
+
|
| 222 |
+
# Apply module-specific level if configured
|
| 223 |
+
if name in LoggerConfig.module_levels:
|
| 224 |
+
logger.setLevel(getattr(logging, LoggerConfig.module_levels[name].value))
|
| 225 |
+
|
| 226 |
+
return logger
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def set_log_level(level: str, module: Optional[str] = None) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Dynamically change log level at runtime.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
level: New log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 235 |
+
module: Optional module name. If None, changes global level.
|
| 236 |
+
|
| 237 |
+
Example:
|
| 238 |
+
# Enable debug for specific module
|
| 239 |
+
set_log_level("DEBUG", "gepa_optimizer.core.optimizer")
|
| 240 |
+
|
| 241 |
+
# Change global level
|
| 242 |
+
set_log_level("WARNING")
|
| 243 |
+
"""
|
| 244 |
+
log_level = LogLevel.from_string(level)
|
| 245 |
+
|
| 246 |
+
if module:
|
| 247 |
+
# Set level for specific module
|
| 248 |
+
logger = logging.getLogger(module)
|
| 249 |
+
logger.setLevel(getattr(logging, log_level.value))
|
| 250 |
+
LoggerConfig.module_levels[module] = log_level
|
| 251 |
+
else:
|
| 252 |
+
# Set global level
|
| 253 |
+
LoggerConfig.level = log_level
|
| 254 |
+
root_logger = logging.getLogger(GEPA_LOGGER_NAME)
|
| 255 |
+
root_logger.setLevel(getattr(logging, log_level.value))
|
| 256 |
+
|
| 257 |
+
# Update all handlers
|
| 258 |
+
for handler in root_logger.handlers:
|
| 259 |
+
handler.setLevel(getattr(logging, log_level.value))
|
| 260 |
+
|
src/gepa_optimizer/llms/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM module for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .base_llm import BaseLLMClient
|
| 6 |
+
from .vision_llm import VisionLLMClient
|
| 7 |
+
from .batch_llm import BatchLLMClient
|
| 8 |
+
from .llego_enhanced_llm import LLEGOEnhancedLLMClient
|
| 9 |
+
|
| 10 |
+
__all__ = ["BaseLLMClient", "VisionLLMClient", "BatchLLMClient", "LLEGOEnhancedLLMClient"]
|
src/gepa_optimizer/llms/base_llm.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base LLM client class for all LLM providers.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any, Dict, Optional, Union
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class BaseLLMClient(ABC):
|
| 12 |
+
"""
|
| 13 |
+
Abstract base class for all LLM clients.
|
| 14 |
+
|
| 15 |
+
Provides a consistent interface for different LLM providers and models.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, provider: str, model_name: str, **kwargs):
|
| 19 |
+
"""
|
| 20 |
+
Initialize LLM client.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
provider: LLM provider (e.g., 'openai', 'anthropic')
|
| 24 |
+
model_name: Specific model name
|
| 25 |
+
**kwargs: Additional provider-specific parameters
|
| 26 |
+
"""
|
| 27 |
+
self.provider = provider
|
| 28 |
+
self.model_name = model_name
|
| 29 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 30 |
+
|
| 31 |
+
# Store additional configuration
|
| 32 |
+
self.config = kwargs
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> Dict[str, Any]:
|
| 36 |
+
"""
|
| 37 |
+
Generate response from LLM.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
system_prompt: System-level instructions
|
| 41 |
+
user_prompt: User's input prompt
|
| 42 |
+
**kwargs: Additional generation parameters (e.g., image_base64)
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Dictionary with 'content' key containing the generated response
|
| 46 |
+
and additional metadata
|
| 47 |
+
"""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
def get_model_info(self) -> Dict[str, str]:
|
| 51 |
+
"""Get model information for logging and debugging"""
|
| 52 |
+
return {
|
| 53 |
+
'provider': self.provider,
|
| 54 |
+
'model_name': self.model_name,
|
| 55 |
+
'class': self.__class__.__name__
|
| 56 |
+
}
|
src/gepa_optimizer/llms/batch_llm.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch LLM Client for cost-effective processing using Gemini Batch API.
|
| 3 |
+
|
| 4 |
+
This client provides 50% cost savings by using Google's Gemini Batch API
|
| 5 |
+
instead of real-time API calls. Ideal for large-scale prompt optimization
|
| 6 |
+
where latency is acceptable.
|
| 7 |
+
|
| 8 |
+
Features:
|
| 9 |
+
- 50% cost reduction compared to standard API
|
| 10 |
+
- Automatic batching and job management
|
| 11 |
+
- Built-in retry and polling logic
|
| 12 |
+
- Thread-safe operation
|
| 13 |
+
- Comprehensive error handling
|
| 14 |
+
|
| 15 |
+
Author: GEPA Optimizer Team
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import json
|
| 20 |
+
import time
|
| 21 |
+
import logging
|
| 22 |
+
import tempfile
|
| 23 |
+
import io
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 26 |
+
from .base_llm import BaseLLMClient
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from PIL import Image
|
| 30 |
+
PIL_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
PIL_AVAILABLE = False
|
| 33 |
+
Image = None
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from google import genai
|
| 37 |
+
from google.genai import types
|
| 38 |
+
GENAI_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
GENAI_AVAILABLE = False
|
| 41 |
+
genai = None
|
| 42 |
+
types = None
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BatchLLMClient(BaseLLMClient):
|
| 48 |
+
"""
|
| 49 |
+
Batch LLM client that uses Gemini Batch API for cost-effective processing.
|
| 50 |
+
|
| 51 |
+
This client processes multiple requests together in batch jobs, providing:
|
| 52 |
+
- 50% cost savings vs standard API
|
| 53 |
+
- No rate limit impact
|
| 54 |
+
- Automatic job management and polling
|
| 55 |
+
|
| 56 |
+
Usage:
|
| 57 |
+
>>> from gepa_optimizer.llms import BatchLLMClient
|
| 58 |
+
>>>
|
| 59 |
+
>>> client = BatchLLMClient(
|
| 60 |
+
... provider="google",
|
| 61 |
+
... model_name="gemini-2.5-flash",
|
| 62 |
+
... api_key="your-key",
|
| 63 |
+
... batch_size=20,
|
| 64 |
+
... polling_interval=30
|
| 65 |
+
... )
|
| 66 |
+
>>>
|
| 67 |
+
>>> # Use just like VisionLLMClient - adapter handles the rest!
|
| 68 |
+
>>> result = client.generate(
|
| 69 |
+
... system_prompt="You are a helpful assistant",
|
| 70 |
+
... user_prompt="Analyze this image",
|
| 71 |
+
... image_base64="..."
|
| 72 |
+
... )
|
| 73 |
+
|
| 74 |
+
Performance Note:
|
| 75 |
+
Batch processing adds latency (30s+ polling time) but reduces costs by 50%.
|
| 76 |
+
Choose this mode for large-scale optimization where cost > speed.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
provider: str,
|
| 82 |
+
model_name: str,
|
| 83 |
+
api_key: Optional[str] = None,
|
| 84 |
+
batch_size: int = 20,
|
| 85 |
+
polling_interval: int = 30,
|
| 86 |
+
max_polling_time: int = 3600,
|
| 87 |
+
temp_dir: str = ".gepa_batch_temp",
|
| 88 |
+
**kwargs
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Initialize Batch LLM Client.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
provider: Must be "google" or "gemini"
|
| 95 |
+
model_name: Gemini model (e.g., "gemini-2.5-flash", "gemini-1.5-flash")
|
| 96 |
+
api_key: Google API key (defaults to GEMINI_API_KEY env var)
|
| 97 |
+
batch_size: Number of samples to process per batch job (1-100)
|
| 98 |
+
polling_interval: Seconds between job status checks (default: 30)
|
| 99 |
+
max_polling_time: Maximum seconds to wait for job completion (default: 3600)
|
| 100 |
+
temp_dir: Directory for temporary files (default: ".gepa_batch_temp")
|
| 101 |
+
**kwargs: Additional parameters
|
| 102 |
+
|
| 103 |
+
Raises:
|
| 104 |
+
ValueError: If provider is not Google/Gemini
|
| 105 |
+
ImportError: If google-genai is not installed
|
| 106 |
+
"""
|
| 107 |
+
super().__init__(provider=provider, model_name=model_name, **kwargs)
|
| 108 |
+
|
| 109 |
+
# Validate provider
|
| 110 |
+
if provider.lower() not in ["google", "gemini"]:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
f"BatchLLMClient only supports Google/Gemini provider. Got: {provider}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Check dependencies
|
| 116 |
+
if not GENAI_AVAILABLE:
|
| 117 |
+
raise ImportError(
|
| 118 |
+
"google-genai not installed. Install with: pip install google-genai"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Configuration
|
| 122 |
+
self.batch_size = batch_size
|
| 123 |
+
self.polling_interval = polling_interval
|
| 124 |
+
self.max_polling_time = max_polling_time
|
| 125 |
+
self.temp_dir = Path(temp_dir)
|
| 126 |
+
self.temp_dir.mkdir(exist_ok=True)
|
| 127 |
+
|
| 128 |
+
# Initialize Gemini client
|
| 129 |
+
from ..utils.api_keys import APIKeyManager
|
| 130 |
+
self.api_key = api_key or APIKeyManager().get_api_key("google")
|
| 131 |
+
|
| 132 |
+
if not self.api_key:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
"Google API key required. Provide via api_key parameter or "
|
| 135 |
+
"set GEMINI_API_KEY environment variable."
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.client = genai.Client(api_key=self.api_key)
|
| 139 |
+
|
| 140 |
+
logger.info(
|
| 141 |
+
f"✓ BatchLLMClient initialized: {model_name} "
|
| 142 |
+
f"(batch_size={batch_size}, polling={polling_interval}s)"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def generate(
|
| 146 |
+
self,
|
| 147 |
+
system_prompt: str,
|
| 148 |
+
user_prompt: str,
|
| 149 |
+
image_base64: Optional[str] = None,
|
| 150 |
+
**kwargs
|
| 151 |
+
) -> Dict[str, Any]:
|
| 152 |
+
"""
|
| 153 |
+
Generate response using batch API.
|
| 154 |
+
|
| 155 |
+
Note: This method is primarily for compatibility. For batch optimization,
|
| 156 |
+
the adapter will call generate_batch() directly with multiple requests.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
system_prompt: System-level instructions
|
| 160 |
+
user_prompt: User's input prompt
|
| 161 |
+
image_base64: Optional base64 encoded image
|
| 162 |
+
**kwargs: Additional generation parameters
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Dict with 'content' key containing generated text
|
| 166 |
+
"""
|
| 167 |
+
# Single request - process as a batch of 1
|
| 168 |
+
requests = [{
|
| 169 |
+
'system_prompt': system_prompt,
|
| 170 |
+
'user_prompt': user_prompt,
|
| 171 |
+
'image_base64': image_base64
|
| 172 |
+
}]
|
| 173 |
+
|
| 174 |
+
results = self.generate_batch(requests)
|
| 175 |
+
return results[0] if results else {"content": "", "error": "No results"}
|
| 176 |
+
|
| 177 |
+
def generate_batch(
|
| 178 |
+
self,
|
| 179 |
+
requests: List[Dict[str, Any]],
|
| 180 |
+
timeout_override: Optional[int] = None
|
| 181 |
+
) -> List[Dict[str, Any]]:
|
| 182 |
+
"""
|
| 183 |
+
Process multiple requests in a single batch job.
|
| 184 |
+
|
| 185 |
+
This is the main method called by UniversalGepaAdapter during GEPA optimization.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
requests: List of request dicts with keys:
|
| 189 |
+
- system_prompt: System instructions
|
| 190 |
+
- user_prompt: User input
|
| 191 |
+
- image_base64: Optional base64 image
|
| 192 |
+
timeout_override: Override max_polling_time for this batch
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
List of response dicts with 'content' key
|
| 196 |
+
|
| 197 |
+
Raises:
|
| 198 |
+
RuntimeError: If batch job fails
|
| 199 |
+
TimeoutError: If polling exceeds timeout
|
| 200 |
+
"""
|
| 201 |
+
logger.info(f"📦 Processing batch of {len(requests)} requests via Gemini Batch API...")
|
| 202 |
+
|
| 203 |
+
start_time = time.time()
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
# Step 1: Upload images if needed
|
| 207 |
+
file_uris, mime_types = self._upload_images_for_batch(requests)
|
| 208 |
+
|
| 209 |
+
# Step 2: Create JSONL file
|
| 210 |
+
jsonl_path = self._create_batch_jsonl(requests, file_uris, mime_types)
|
| 211 |
+
|
| 212 |
+
# Step 3: Submit batch job
|
| 213 |
+
batch_job_name = self._submit_batch_job(jsonl_path)
|
| 214 |
+
|
| 215 |
+
# Step 4: Wait for completion
|
| 216 |
+
timeout = timeout_override or self.max_polling_time
|
| 217 |
+
self._wait_for_batch_completion(batch_job_name, timeout)
|
| 218 |
+
|
| 219 |
+
# Step 5: Retrieve results
|
| 220 |
+
results = self._retrieve_batch_results(batch_job_name)
|
| 221 |
+
|
| 222 |
+
# Cleanup
|
| 223 |
+
jsonl_path.unlink(missing_ok=True)
|
| 224 |
+
|
| 225 |
+
elapsed_time = time.time() - start_time
|
| 226 |
+
logger.info(
|
| 227 |
+
f"✓ Batch processing complete: {len(results)} results in {elapsed_time:.1f}s "
|
| 228 |
+
f"(~{elapsed_time/len(results):.1f}s per request)"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return results
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
elapsed_time = time.time() - start_time
|
| 235 |
+
logger.error(f"❌ Batch processing failed after {elapsed_time:.1f}s: {e}")
|
| 236 |
+
raise
|
| 237 |
+
|
| 238 |
+
def _upload_images_for_batch(self, requests: List[Dict]) -> Tuple[List[Optional[str]], List[Optional[str]]]:
|
| 239 |
+
"""
|
| 240 |
+
Upload images to Gemini and return file URIs and MIME types.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
requests: List of request dicts
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Tuple of (file_uris, mime_types) - both are lists with None for requests without images
|
| 247 |
+
"""
|
| 248 |
+
file_uris = []
|
| 249 |
+
mime_types = []
|
| 250 |
+
images_to_upload = sum(1 for r in requests if r.get('image_base64'))
|
| 251 |
+
|
| 252 |
+
if images_to_upload > 0:
|
| 253 |
+
logger.info(f" ⬆️ Uploading {images_to_upload} images to Gemini...")
|
| 254 |
+
|
| 255 |
+
for i, request in enumerate(requests):
|
| 256 |
+
image_base64 = request.get('image_base64')
|
| 257 |
+
|
| 258 |
+
if not image_base64:
|
| 259 |
+
file_uris.append(None)
|
| 260 |
+
mime_types.append(None)
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# Decode image data
|
| 265 |
+
import base64
|
| 266 |
+
image_data = base64.b64decode(image_base64)
|
| 267 |
+
|
| 268 |
+
# Detect image format using Pillow
|
| 269 |
+
image_format = None
|
| 270 |
+
if PIL_AVAILABLE:
|
| 271 |
+
try:
|
| 272 |
+
img = Image.open(io.BytesIO(image_data))
|
| 273 |
+
image_format = img.format.lower() if img.format else None
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.warning(f" ⚠️ Could not detect image format: {e}")
|
| 276 |
+
|
| 277 |
+
# Map format to extension and MIME type
|
| 278 |
+
format_map = {
|
| 279 |
+
'jpeg': ('.jpg', 'image/jpeg'),
|
| 280 |
+
'jpg': ('.jpg', 'image/jpeg'),
|
| 281 |
+
'png': ('.png', 'image/png'),
|
| 282 |
+
'gif': ('.gif', 'image/gif'),
|
| 283 |
+
'webp': ('.webp', 'image/webp'),
|
| 284 |
+
'bmp': ('.bmp', 'image/bmp'),
|
| 285 |
+
'tiff': ('.tiff', 'image/tiff'),
|
| 286 |
+
'tif': ('.tiff', 'image/tiff'),
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
# Get extension and MIME type (default to PNG if unknown)
|
| 290 |
+
ext, mime_type = format_map.get(image_format, ('.png', 'image/png'))
|
| 291 |
+
|
| 292 |
+
if image_format and image_format not in format_map:
|
| 293 |
+
logger.warning(f" ⚠️ Unknown image format '{image_format}' for image {i}, defaulting to PNG")
|
| 294 |
+
elif not image_format:
|
| 295 |
+
logger.debug(f" ℹ️ Could not detect format for image {i}, using PNG")
|
| 296 |
+
|
| 297 |
+
# Save to temp file with correct extension
|
| 298 |
+
temp_file = tempfile.NamedTemporaryFile(
|
| 299 |
+
delete=False,
|
| 300 |
+
suffix=ext,
|
| 301 |
+
dir=self.temp_dir
|
| 302 |
+
)
|
| 303 |
+
temp_file.write(image_data)
|
| 304 |
+
temp_file.close()
|
| 305 |
+
|
| 306 |
+
# Upload to Gemini with correct MIME type
|
| 307 |
+
uploaded_file = self.client.files.upload(
|
| 308 |
+
file=temp_file.name,
|
| 309 |
+
config=types.UploadFileConfig(
|
| 310 |
+
display_name=f"batch_image_{i}_{int(time.time())}{ext}",
|
| 311 |
+
mime_type=mime_type
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
logger.debug(f" ✓ Uploaded image {i} as {mime_type}")
|
| 316 |
+
|
| 317 |
+
# Wait for file to be active
|
| 318 |
+
self._wait_for_file_active(uploaded_file)
|
| 319 |
+
file_uris.append(uploaded_file.uri)
|
| 320 |
+
mime_types.append(mime_type)
|
| 321 |
+
|
| 322 |
+
# Cleanup temp file
|
| 323 |
+
Path(temp_file.name).unlink()
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.error(f" ✗ Failed to upload image {i}: {e}")
|
| 327 |
+
file_uris.append(None)
|
| 328 |
+
mime_types.append(None)
|
| 329 |
+
|
| 330 |
+
if images_to_upload > 0:
|
| 331 |
+
successful = sum(1 for uri in file_uris if uri is not None)
|
| 332 |
+
logger.info(f" ✓ Uploaded {successful}/{images_to_upload} images successfully")
|
| 333 |
+
|
| 334 |
+
return file_uris, mime_types
|
| 335 |
+
|
| 336 |
+
def _create_batch_jsonl(
|
| 337 |
+
self,
|
| 338 |
+
requests: List[Dict],
|
| 339 |
+
file_uris: List[Optional[str]],
|
| 340 |
+
mime_types: List[Optional[str]]
|
| 341 |
+
) -> Path:
|
| 342 |
+
"""
|
| 343 |
+
Create JSONL file for batch job.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
requests: List of request dicts
|
| 347 |
+
file_uris: List of uploaded file URIs
|
| 348 |
+
mime_types: List of MIME types for uploaded files
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Path to created JSONL file
|
| 352 |
+
"""
|
| 353 |
+
timestamp = int(time.time())
|
| 354 |
+
jsonl_path = self.temp_dir / f"batch_{timestamp}.jsonl"
|
| 355 |
+
|
| 356 |
+
with open(jsonl_path, 'w', encoding='utf-8') as f:
|
| 357 |
+
for i, (request, file_uri, mime_type) in enumerate(zip(requests, file_uris, mime_types)):
|
| 358 |
+
# Combine system and user prompts
|
| 359 |
+
system_prompt = request.get('system_prompt', '')
|
| 360 |
+
user_prompt = request.get('user_prompt', '')
|
| 361 |
+
full_prompt = f"{system_prompt}\n\n{user_prompt}".strip()
|
| 362 |
+
|
| 363 |
+
# Build request parts
|
| 364 |
+
parts = [{"text": full_prompt}]
|
| 365 |
+
|
| 366 |
+
if file_uri:
|
| 367 |
+
parts.append({
|
| 368 |
+
"file_data": {
|
| 369 |
+
"file_uri": file_uri,
|
| 370 |
+
"mime_type": mime_type or "image/png" # Use actual MIME type
|
| 371 |
+
}
|
| 372 |
+
})
|
| 373 |
+
|
| 374 |
+
# Gemini Batch API format according to official docs
|
| 375 |
+
# Reference: https://ai.google.dev/gemini-api/docs/batch-inference
|
| 376 |
+
# NOTE: The "request" wrapper is REQUIRED for Gemini 2.5 batch API
|
| 377 |
+
batch_request = {
|
| 378 |
+
"custom_id": f"request-{i}",
|
| 379 |
+
"request": {
|
| 380 |
+
"contents": [{
|
| 381 |
+
"role": "user",
|
| 382 |
+
"parts": parts
|
| 383 |
+
}]
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
f.write(json.dumps(batch_request, ensure_ascii=False) + '\n')
|
| 388 |
+
|
| 389 |
+
logger.info(f" 📝 Created JSONL file: {jsonl_path.name} ({len(requests)} requests)")
|
| 390 |
+
return jsonl_path
|
| 391 |
+
|
| 392 |
+
def _submit_batch_job(self, jsonl_path: Path) -> str:
|
| 393 |
+
"""
|
| 394 |
+
Submit batch job to Gemini.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
jsonl_path: Path to JSONL file
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
Batch job name
|
| 401 |
+
"""
|
| 402 |
+
# Upload JSONL file
|
| 403 |
+
# Try multiple methods as the google-genai SDK can be finicky
|
| 404 |
+
try:
|
| 405 |
+
logger.info(f" 📤 Uploading JSONL file: {jsonl_path.name}")
|
| 406 |
+
|
| 407 |
+
# Read and validate file content
|
| 408 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 409 |
+
content = f.read()
|
| 410 |
+
line_count = len(content.strip().split('\n'))
|
| 411 |
+
logger.debug(f" 📄 JSONL: {len(content)} bytes, {line_count} lines")
|
| 412 |
+
|
| 413 |
+
# Validate JSONL format
|
| 414 |
+
for line_num, line in enumerate(content.strip().split('\n'), 1):
|
| 415 |
+
try:
|
| 416 |
+
json.loads(line)
|
| 417 |
+
except json.JSONDecodeError as e:
|
| 418 |
+
logger.error(f" ❌ Invalid JSON at line {line_num}: {e}")
|
| 419 |
+
logger.error(f" Content: {line[:100]}...")
|
| 420 |
+
raise ValueError(f"Invalid JSONL format at line {line_num}") from e
|
| 421 |
+
|
| 422 |
+
# Method 1: Try uploading with Path object
|
| 423 |
+
logger.info(f" 🔄 Upload method 1: Using Path object...")
|
| 424 |
+
try:
|
| 425 |
+
jsonl_file = self.client.files.upload(
|
| 426 |
+
file=jsonl_path,
|
| 427 |
+
config=types.UploadFileConfig(
|
| 428 |
+
display_name=f'gepa-batch-{int(time.time())}',
|
| 429 |
+
mime_type='application/json' # Try application/json instead of application/jsonl
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
logger.info(f" ✓ JSONL file uploaded: {jsonl_file.name}")
|
| 433 |
+
|
| 434 |
+
except Exception as e1:
|
| 435 |
+
logger.warning(f" ⚠️ Method 1 failed: {e1}")
|
| 436 |
+
logger.info(f" 🔄 Upload method 2: Using string path...")
|
| 437 |
+
|
| 438 |
+
# Method 2: Fallback to string path
|
| 439 |
+
try:
|
| 440 |
+
jsonl_file = self.client.files.upload(
|
| 441 |
+
file=str(jsonl_path.absolute()),
|
| 442 |
+
config=types.UploadFileConfig(
|
| 443 |
+
display_name=f'gepa-batch-{int(time.time())}',
|
| 444 |
+
mime_type='application/json'
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
logger.info(f" ✓ JSONL file uploaded (method 2): {jsonl_file.name}")
|
| 448 |
+
except Exception as e2:
|
| 449 |
+
logger.error(f" ❌ Method 2 also failed: {e2}")
|
| 450 |
+
raise e2
|
| 451 |
+
|
| 452 |
+
except KeyError as e:
|
| 453 |
+
logger.error(f"❌ KeyError during JSONL upload: {e}")
|
| 454 |
+
logger.error(f" This suggests the Gemini API response format changed")
|
| 455 |
+
logger.error(f" Try updating google-genai: pip install --upgrade google-genai")
|
| 456 |
+
raise RuntimeError(f"Gemini Batch API response format error: {e}") from e
|
| 457 |
+
except Exception as e:
|
| 458 |
+
logger.error(f"❌ Failed to upload JSONL file: {e}")
|
| 459 |
+
logger.error(f" File path: {jsonl_path}")
|
| 460 |
+
logger.error(f" File exists: {jsonl_path.exists()}")
|
| 461 |
+
logger.error(f" File size: {jsonl_path.stat().st_size if jsonl_path.exists() else 'N/A'} bytes")
|
| 462 |
+
raise RuntimeError(f"Gemini Batch API file upload failed: {e}") from e
|
| 463 |
+
|
| 464 |
+
# Wait for JSONL to be active
|
| 465 |
+
try:
|
| 466 |
+
logger.info(f" ⏳ Waiting for JSONL file to be processed...")
|
| 467 |
+
self._wait_for_file_active(jsonl_file)
|
| 468 |
+
except Exception as e:
|
| 469 |
+
logger.error(f"❌ JSONL file processing failed: {e}")
|
| 470 |
+
raise
|
| 471 |
+
|
| 472 |
+
# Create batch job
|
| 473 |
+
try:
|
| 474 |
+
logger.info(f" 🚀 Creating batch job...")
|
| 475 |
+
batch_job = self.client.batches.create(
|
| 476 |
+
model=self.model_name,
|
| 477 |
+
src=jsonl_file.name,
|
| 478 |
+
config={'display_name': f'gepa-opt-{int(time.time())}'}
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
logger.info(f" ✓ Batch job submitted: {batch_job.name}")
|
| 482 |
+
return batch_job.name
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.error(f"❌ Failed to create batch job: {e}")
|
| 486 |
+
raise RuntimeError(f"Batch job creation failed: {e}") from e
|
| 487 |
+
|
| 488 |
+
def _wait_for_batch_completion(self, job_name: str, timeout: int):
|
| 489 |
+
"""
|
| 490 |
+
Poll batch job until completion.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
job_name: Batch job name
|
| 494 |
+
timeout: Maximum seconds to wait
|
| 495 |
+
|
| 496 |
+
Raises:
|
| 497 |
+
TimeoutError: If polling exceeds timeout
|
| 498 |
+
RuntimeError: If batch job fails
|
| 499 |
+
"""
|
| 500 |
+
logger.info(f" ⏳ Polling for completion (checking every {self.polling_interval}s)...")
|
| 501 |
+
|
| 502 |
+
start_time = time.time()
|
| 503 |
+
poll_count = 0
|
| 504 |
+
|
| 505 |
+
while True:
|
| 506 |
+
elapsed = time.time() - start_time
|
| 507 |
+
|
| 508 |
+
if elapsed > timeout:
|
| 509 |
+
raise TimeoutError(
|
| 510 |
+
f"Batch job timeout after {elapsed:.0f}s "
|
| 511 |
+
f"(max: {timeout}s)"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
batch_job = self.client.batches.get(name=job_name)
|
| 516 |
+
state = batch_job.state.name
|
| 517 |
+
|
| 518 |
+
# Success states
|
| 519 |
+
if state in ['JOB_STATE_SUCCEEDED', 'SUCCEEDED']:
|
| 520 |
+
logger.info(f" ✓ Batch job completed in {elapsed:.0f}s")
|
| 521 |
+
return
|
| 522 |
+
|
| 523 |
+
# Failure states
|
| 524 |
+
if state in ['JOB_STATE_FAILED', 'FAILED']:
|
| 525 |
+
raise RuntimeError(f"Batch job failed with state: {state}")
|
| 526 |
+
|
| 527 |
+
if state in ['JOB_STATE_CANCELLED', 'CANCELLED']:
|
| 528 |
+
raise RuntimeError(f"Batch job was cancelled: {state}")
|
| 529 |
+
|
| 530 |
+
# Still processing
|
| 531 |
+
poll_count += 1
|
| 532 |
+
if poll_count % 5 == 0: # Log every 5 polls
|
| 533 |
+
logger.info(f" ... still processing ({elapsed:.0f}s elapsed, state: {state})")
|
| 534 |
+
|
| 535 |
+
time.sleep(self.polling_interval)
|
| 536 |
+
|
| 537 |
+
except (TimeoutError, RuntimeError):
|
| 538 |
+
raise
|
| 539 |
+
except Exception as e:
|
| 540 |
+
logger.warning(f" ⚠️ Error checking job status: {e}, retrying...")
|
| 541 |
+
time.sleep(5)
|
| 542 |
+
|
| 543 |
+
def _retrieve_batch_results(self, job_name: str) -> List[Dict[str, Any]]:
|
| 544 |
+
"""
|
| 545 |
+
Retrieve and parse batch results.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
job_name: Batch job name
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
List of result dicts
|
| 552 |
+
"""
|
| 553 |
+
batch_job = self.client.batches.get(name=job_name)
|
| 554 |
+
|
| 555 |
+
# Check for inline responses (preferred)
|
| 556 |
+
if hasattr(batch_job.dest, 'inlined_responses') and batch_job.dest.inlined_responses:
|
| 557 |
+
logger.info(f" 📥 Processing inline responses...")
|
| 558 |
+
return self._parse_inline_results(batch_job.dest.inlined_responses)
|
| 559 |
+
|
| 560 |
+
# Download results file (fallback)
|
| 561 |
+
if hasattr(batch_job.dest, 'file_name') and batch_job.dest.file_name:
|
| 562 |
+
logger.info(f" 📥 Downloading results file: {batch_job.dest.file_name}")
|
| 563 |
+
file_data = self.client.files.download(file=batch_job.dest.file_name)
|
| 564 |
+
return self._parse_file_results(file_data)
|
| 565 |
+
|
| 566 |
+
raise RuntimeError("No results available from batch job")
|
| 567 |
+
|
| 568 |
+
def _parse_inline_results(self, inline_responses) -> List[Dict[str, Any]]:
|
| 569 |
+
"""Parse inline batch results."""
|
| 570 |
+
results = []
|
| 571 |
+
|
| 572 |
+
for response_obj in inline_responses:
|
| 573 |
+
if hasattr(response_obj, 'response') and response_obj.response:
|
| 574 |
+
text = self._extract_text_from_response(response_obj.response)
|
| 575 |
+
results.append({
|
| 576 |
+
"content": text,
|
| 577 |
+
"role": "assistant",
|
| 578 |
+
"model": self.model_name,
|
| 579 |
+
"provider": "google"
|
| 580 |
+
})
|
| 581 |
+
else:
|
| 582 |
+
error_msg = str(getattr(response_obj, 'error', 'Unknown error'))
|
| 583 |
+
logger.warning(f" ⚠️ Response error: {error_msg}")
|
| 584 |
+
results.append({
|
| 585 |
+
"content": "",
|
| 586 |
+
"error": error_msg
|
| 587 |
+
})
|
| 588 |
+
|
| 589 |
+
return results
|
| 590 |
+
|
| 591 |
+
def _parse_file_results(self, file_data) -> List[Dict[str, Any]]:
|
| 592 |
+
"""Parse JSONL results file."""
|
| 593 |
+
if isinstance(file_data, bytes):
|
| 594 |
+
jsonl_content = file_data.decode('utf-8')
|
| 595 |
+
else:
|
| 596 |
+
jsonl_content = file_data
|
| 597 |
+
|
| 598 |
+
results = []
|
| 599 |
+
|
| 600 |
+
for line_num, line in enumerate(jsonl_content.strip().split('\n'), 1):
|
| 601 |
+
if not line.strip():
|
| 602 |
+
continue
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
result = json.loads(line)
|
| 606 |
+
|
| 607 |
+
if 'response' in result:
|
| 608 |
+
text = self._extract_text_from_dict(result['response'])
|
| 609 |
+
results.append({
|
| 610 |
+
"content": text,
|
| 611 |
+
"role": "assistant",
|
| 612 |
+
"model": self.model_name,
|
| 613 |
+
"provider": "google"
|
| 614 |
+
})
|
| 615 |
+
else:
|
| 616 |
+
error_msg = result.get('error', 'Unknown error')
|
| 617 |
+
logger.warning(f" ⚠️ Line {line_num} error: {error_msg}")
|
| 618 |
+
results.append({
|
| 619 |
+
"content": "",
|
| 620 |
+
"error": error_msg
|
| 621 |
+
})
|
| 622 |
+
|
| 623 |
+
except json.JSONDecodeError as e:
|
| 624 |
+
logger.error(f" ✗ Line {line_num}: JSON decode error: {e}")
|
| 625 |
+
results.append({"content": "", "error": f"JSON decode error: {e}"})
|
| 626 |
+
|
| 627 |
+
return results
|
| 628 |
+
|
| 629 |
+
def _extract_text_from_response(self, response_obj) -> str:
|
| 630 |
+
"""Extract text from response object."""
|
| 631 |
+
try:
|
| 632 |
+
# Direct text attribute
|
| 633 |
+
if hasattr(response_obj, 'text'):
|
| 634 |
+
return response_obj.text
|
| 635 |
+
|
| 636 |
+
# Navigate through candidates
|
| 637 |
+
if hasattr(response_obj, 'candidates') and response_obj.candidates:
|
| 638 |
+
candidate = response_obj.candidates[0]
|
| 639 |
+
if hasattr(candidate, 'content'):
|
| 640 |
+
content = candidate.content
|
| 641 |
+
if hasattr(content, 'parts') and content.parts:
|
| 642 |
+
part = content.parts[0]
|
| 643 |
+
if hasattr(part, 'text'):
|
| 644 |
+
return part.text
|
| 645 |
+
|
| 646 |
+
# Fallback to string representation
|
| 647 |
+
return str(response_obj)
|
| 648 |
+
|
| 649 |
+
except Exception as e:
|
| 650 |
+
logger.error(f"Error extracting text from response: {e}")
|
| 651 |
+
return ""
|
| 652 |
+
|
| 653 |
+
def _extract_text_from_dict(self, response_dict: Dict) -> str:
|
| 654 |
+
"""Extract text from response dictionary."""
|
| 655 |
+
try:
|
| 656 |
+
# Direct text key
|
| 657 |
+
if 'text' in response_dict:
|
| 658 |
+
return response_dict['text']
|
| 659 |
+
|
| 660 |
+
# Navigate through candidates
|
| 661 |
+
if 'candidates' in response_dict and response_dict['candidates']:
|
| 662 |
+
candidate = response_dict['candidates'][0]
|
| 663 |
+
if 'content' in candidate and 'parts' in candidate['content']:
|
| 664 |
+
parts = candidate['content']['parts']
|
| 665 |
+
if parts and 'text' in parts[0]:
|
| 666 |
+
return parts[0]['text']
|
| 667 |
+
|
| 668 |
+
# Fallback to JSON string
|
| 669 |
+
return json.dumps(response_dict)
|
| 670 |
+
|
| 671 |
+
except Exception as e:
|
| 672 |
+
logger.error(f"Error extracting text from dict: {e}")
|
| 673 |
+
return ""
|
| 674 |
+
|
| 675 |
+
def _wait_for_file_active(self, uploaded_file, timeout: int = 60):
|
| 676 |
+
"""
|
| 677 |
+
Wait for uploaded file to become active.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
uploaded_file: Uploaded file object
|
| 681 |
+
timeout: Maximum seconds to wait
|
| 682 |
+
|
| 683 |
+
Raises:
|
| 684 |
+
TimeoutError: If file processing exceeds timeout
|
| 685 |
+
RuntimeError: If file processing fails
|
| 686 |
+
"""
|
| 687 |
+
start_time = time.time()
|
| 688 |
+
|
| 689 |
+
while uploaded_file.state.name == "PROCESSING":
|
| 690 |
+
if time.time() - start_time > timeout:
|
| 691 |
+
raise TimeoutError(f"File processing timeout: {uploaded_file.name}")
|
| 692 |
+
|
| 693 |
+
time.sleep(1)
|
| 694 |
+
uploaded_file = self.client.files.get(name=uploaded_file.name)
|
| 695 |
+
|
| 696 |
+
if uploaded_file.state.name != "ACTIVE":
|
| 697 |
+
raise RuntimeError(
|
| 698 |
+
f"File processing failed: {uploaded_file.name} "
|
| 699 |
+
f"(state: {uploaded_file.state.name})"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
def get_model_info(self) -> Dict[str, str]:
|
| 703 |
+
"""Get model information for logging and debugging."""
|
| 704 |
+
return {
|
| 705 |
+
'provider': self.provider,
|
| 706 |
+
'model_name': self.model_name,
|
| 707 |
+
'class': self.__class__.__name__,
|
| 708 |
+
'mode': 'batch',
|
| 709 |
+
'batch_size': str(self.batch_size),
|
| 710 |
+
'polling_interval': f'{self.polling_interval}s'
|
| 711 |
+
}
|
| 712 |
+
|
src/gepa_optimizer/llms/llego_enhanced_llm.py
ADDED
|
@@ -0,0 +1,1625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLEGO-Enhanced LLM Client Wrapper
|
| 3 |
+
|
| 4 |
+
This wrapper intercepts LLM calls and uses LLEGO genetic operators
|
| 5 |
+
when generating new prompt candidates during GEPA's reflection phase.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import re
|
| 10 |
+
from typing import Optional, Dict, Any, Callable, List
|
| 11 |
+
from .base_llm import BaseLLMClient
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Fallback system prompt for sequential generation (when JSON parsing fails)
|
| 16 |
+
# Uses Linear Command structure for reliability when complex JSON generation fails
|
| 17 |
+
_FALLBACK_SYSTEM_PROMPT = """You are a Prompt Optimization Engine operating in **SAFE MODE**.
|
| 18 |
+
|
| 19 |
+
<task>
|
| 20 |
+
Rewrite the prompt based on the feedback provided below.
|
| 21 |
+
</task>
|
| 22 |
+
|
| 23 |
+
<output_rules>
|
| 24 |
+
1. Output **ONLY** the new prompt text.
|
| 25 |
+
2. No JSON. No Explanations. No "Here is the prompt".
|
| 26 |
+
3. The prompt must be fully functional and self-contained.
|
| 27 |
+
4. START directly with the prompt content (e.g., "You are a..." or task instructions).
|
| 28 |
+
5. Preserve the core task/domain - only improve HOW it's described.
|
| 29 |
+
</output_rules>
|
| 30 |
+
|
| 31 |
+
<quality_standards>
|
| 32 |
+
- Be specific and concrete (no vague instructions)
|
| 33 |
+
- Use clear, imperative language
|
| 34 |
+
- Include edge case handling if feedback identifies confusion
|
| 35 |
+
- Ensure the prompt is self-contained and unambiguous
|
| 36 |
+
- Add explicit constraints for format/output if needed
|
| 37 |
+
</quality_standards>
|
| 38 |
+
|
| 39 |
+
<forbidden_outputs>
|
| 40 |
+
- Analysis of what went wrong
|
| 41 |
+
- Explanations of your changes
|
| 42 |
+
- Meta-text like "Here's an improved version..."
|
| 43 |
+
- Anything other than the raw prompt text
|
| 44 |
+
</forbidden_outputs>
|
| 45 |
+
|
| 46 |
+
Start of New Prompt:"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LLEGOEnhancedLLMClient(BaseLLMClient):
|
| 50 |
+
"""
|
| 51 |
+
Wrapper around BaseLLMClient that uses LLEGO for candidate generation.
|
| 52 |
+
|
| 53 |
+
This wrapper detects when GEPA is asking for new prompt candidates
|
| 54 |
+
and routes those requests through LLEGO's genetic operators instead
|
| 55 |
+
of standard LLM generation.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
base_llm: BaseLLMClient,
|
| 61 |
+
llego_layer,
|
| 62 |
+
config=None,
|
| 63 |
+
verbose: bool = True
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Initialize LLEGO-enhanced LLM client.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
base_llm: The underlying LLM client (VisionLLMClient, etc.)
|
| 70 |
+
llego_layer: LLEGOIntegrationLayer instance
|
| 71 |
+
config: Optional OptimizationConfig for hybrid mode settings
|
| 72 |
+
verbose: Whether to log LLEGO operations
|
| 73 |
+
"""
|
| 74 |
+
self.base_llm = base_llm
|
| 75 |
+
self.llego = llego_layer
|
| 76 |
+
self.config = config
|
| 77 |
+
self.verbose = verbose
|
| 78 |
+
|
| 79 |
+
# Get log level from config (default to INFO)
|
| 80 |
+
self.log_level = getattr(config, 'log_level', 'INFO') if config else 'INFO'
|
| 81 |
+
|
| 82 |
+
# Track context for detecting reflection calls
|
| 83 |
+
self.reflection_context = {
|
| 84 |
+
'current_prompt': None,
|
| 85 |
+
'feedback': None,
|
| 86 |
+
'in_reflection': False
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Queue for hybrid mode candidates (GEPA will call generate() multiple times)
|
| 90 |
+
self._candidate_queue = []
|
| 91 |
+
self._hybrid_generation_complete = False
|
| 92 |
+
|
| 93 |
+
# 🔥 CRITICAL: Queue for adapter-generated candidates (from make_reflective_dataset)
|
| 94 |
+
# When adapter generates candidates at adapter level, they're stored here
|
| 95 |
+
# GEPA will call generate() for proposals, and we'll return these candidates
|
| 96 |
+
self._adapter_generated_candidates = []
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# 🔥 FORMAT AWARENESS: Store format info from adapter for use in candidate generation
|
| 100 |
+
self._detected_format = None # Will be set by adapter after format detection
|
| 101 |
+
|
| 102 |
+
# FIX #5: Circuit breaker for LLEGO failures
|
| 103 |
+
self._llego_failures = 0
|
| 104 |
+
self._llego_disabled = False
|
| 105 |
+
self._llego_failure_threshold = 3 # Disable after 3 consecutive failures
|
| 106 |
+
|
| 107 |
+
logger.info("🧬 LLEGO-Enhanced LLM Client initialized")
|
| 108 |
+
logger.info(f" Base LLM: {base_llm.__class__.__name__}")
|
| 109 |
+
logger.info(f" LLEGO enabled: {llego_layer is not None}")
|
| 110 |
+
if config and hasattr(config, 'enable_gepa_reflection_with_llego'):
|
| 111 |
+
logger.info(f" Hybrid mode: {config.enable_gepa_reflection_with_llego}")
|
| 112 |
+
logger.debug(f" Log level: {self.log_level}")
|
| 113 |
+
|
| 114 |
+
def _should_log_debug(self) -> bool:
|
| 115 |
+
"""
|
| 116 |
+
Check if DEBUG logging is enabled.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
True if DEBUG level logging is enabled, False otherwise
|
| 120 |
+
"""
|
| 121 |
+
return self.log_level == "DEBUG" or (
|
| 122 |
+
hasattr(logging, 'getLogger') and
|
| 123 |
+
logging.getLogger().isEnabledFor(logging.DEBUG)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def _extract_clean_prompt_from_reflection(self, reflection_output: str) -> str:
|
| 127 |
+
"""
|
| 128 |
+
🛡️ DEFENSIVE FALLBACK: Extract clean prompt if LLM adds analysis despite system prompt instructions.
|
| 129 |
+
|
| 130 |
+
NOTE: The system prompt now explicitly instructs the LLM to output ONLY the prompt text.
|
| 131 |
+
However, this extraction logic serves as a safety net in case the LLM still adds:
|
| 132 |
+
"Based on the performance analysis...
|
| 133 |
+
### Recommendations...
|
| 134 |
+
### Revised Prompt Example:
|
| 135 |
+
[THE ACTUAL PROMPT HERE]
|
| 136 |
+
### Conclusion..."
|
| 137 |
+
|
| 138 |
+
This is now a defensive measure, not the primary mechanism.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
reflection_output: Full reflection output (should be clean prompt, but may contain analysis)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
str: Clean, extracted prompt (or original if extraction fails or not needed)
|
| 145 |
+
"""
|
| 146 |
+
if not reflection_output or not isinstance(reflection_output, str):
|
| 147 |
+
return reflection_output
|
| 148 |
+
|
| 149 |
+
# Pattern 1: Look for "Revised Prompt Example:" or "### Revised Prompt Example:"
|
| 150 |
+
patterns = [
|
| 151 |
+
r'(?:###\s*)?Revised\s+Prompt\s+(?:Example|:)?\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
|
| 152 |
+
r'(?:###\s*)?Revised\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
|
| 153 |
+
r'(?:###\s*)?Optimized\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
|
| 154 |
+
r'(?:###\s*)?New\s+Prompt\s*:\s*\n(.*?)(?:\n###|\n##|\n---|\Z)',
|
| 155 |
+
r'(?:Here\s+is|Here\'s)\s+a?\s*refined?\s+(?:version\s+of\s+)?(?:the\s+)?prompt\s*[:\n](.*?)(?:\n###|\n##|\n---|\Z)',
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
for pattern in patterns:
|
| 159 |
+
match = re.search(pattern, reflection_output, re.IGNORECASE | re.DOTALL)
|
| 160 |
+
if match:
|
| 161 |
+
extracted = match.group(1).strip()
|
| 162 |
+
# Clean up common artifacts
|
| 163 |
+
extracted = re.sub(r'^```(?:plaintext|markdown|text)?\s*\n', '', extracted, flags=re.MULTILINE)
|
| 164 |
+
extracted = re.sub(r'\n```\s*$', '', extracted, flags=re.MULTILINE)
|
| 165 |
+
extracted = extracted.strip()
|
| 166 |
+
|
| 167 |
+
if len(extracted) > 50: # Reasonable minimum length for a prompt
|
| 168 |
+
logger.debug(f"✅ Extracted clean prompt using pattern: {pattern[:50]}...")
|
| 169 |
+
logger.debug(f" Original length: {len(reflection_output)} chars")
|
| 170 |
+
logger.debug(f" Extracted length: {len(extracted)} chars")
|
| 171 |
+
return extracted
|
| 172 |
+
|
| 173 |
+
# Pattern 2: If output starts with a quote or prompt-like structure
|
| 174 |
+
# Look for text that starts with "You are..." and is substantial
|
| 175 |
+
if 'You are' in reflection_output:
|
| 176 |
+
# Find the longest continuous block that starts with "You are"
|
| 177 |
+
prompt_match = re.search(r'(You are[^#]*?)(?:\n###|\n##|###|##|Conclusion|\Z)',
|
| 178 |
+
reflection_output, re.IGNORECASE | re.DOTALL)
|
| 179 |
+
if prompt_match:
|
| 180 |
+
extracted = prompt_match.group(1).strip()
|
| 181 |
+
if len(extracted) > 50:
|
| 182 |
+
logger.debug(f"✅ Extracted prompt starting with 'You are...'")
|
| 183 |
+
return extracted
|
| 184 |
+
|
| 185 |
+
# Pattern 3: If the reflection output is actually just a clean prompt (no analysis)
|
| 186 |
+
# Check if it's relatively short and doesn't contain analysis keywords
|
| 187 |
+
analysis_keywords = ['recommendation', 'suggestion', 'improvement', 'conclusion',
|
| 188 |
+
'optimization', 'analysis', 'feedback']
|
| 189 |
+
if (len(reflection_output) < 2000 and
|
| 190 |
+
not any(keyword in reflection_output.lower() for keyword in analysis_keywords)):
|
| 191 |
+
# Likely a clean prompt, return as-is
|
| 192 |
+
logger.debug(f"✅ Reflection output appears to be a clean prompt (no analysis detected)")
|
| 193 |
+
return reflection_output.strip()
|
| 194 |
+
|
| 195 |
+
# Fallback: Try to extract ANY valid prompt-like text
|
| 196 |
+
# Look for text that might be a prompt even if not perfectly formatted
|
| 197 |
+
if 'You are' in reflection_output:
|
| 198 |
+
# Try to find a substantial block starting with "You are"
|
| 199 |
+
potential_prompt = re.search(
|
| 200 |
+
r'(You are(?:[^\.]|\.(?!\s*(?:Here|This|These|The above)))*?)(?:\n\n|\n###|Conclusion|\Z)',
|
| 201 |
+
reflection_output,
|
| 202 |
+
re.IGNORECASE | re.DOTALL
|
| 203 |
+
)
|
| 204 |
+
if potential_prompt and len(potential_prompt.group(1)) > 100:
|
| 205 |
+
extracted = potential_prompt.group(1).strip()
|
| 206 |
+
logger.warning(f"⚠️ Could not extract clean prompt using standard patterns")
|
| 207 |
+
logger.warning(f" Falling back to 'You are...' block (length: {len(extracted)} chars)")
|
| 208 |
+
logger.warning(f" This may still contain some analysis text")
|
| 209 |
+
return extracted
|
| 210 |
+
|
| 211 |
+
# Final fallback: If still nothing, return original but log strongly
|
| 212 |
+
logger.warning(f"⚠️ Could not extract clean prompt from reflection output")
|
| 213 |
+
logger.warning(f" Output length: {len(reflection_output)} chars")
|
| 214 |
+
logger.warning(f" Output preview: {reflection_output[:200]}...")
|
| 215 |
+
logger.warning(f" ⚠️ WARNING: Returning original output (may contain analysis text or be invalid)")
|
| 216 |
+
logger.warning(f" This candidate may perform poorly - consider improving extraction logic")
|
| 217 |
+
return reflection_output.strip()
|
| 218 |
+
|
| 219 |
+
def _parse_json_variations(self, response_text: str, num_expected: int) -> List[str]:
|
| 220 |
+
"""
|
| 221 |
+
🔥 OPTIMIZED: Parse N prompt variations from JSON format response.
|
| 222 |
+
|
| 223 |
+
Uses robust JSON parsing with multiple fallback strategies.
|
| 224 |
+
|
| 225 |
+
Handles common LLM output issues:
|
| 226 |
+
- Markdown code blocks (```json ... ```)
|
| 227 |
+
- Extra text before/after JSON
|
| 228 |
+
- Trailing commas
|
| 229 |
+
- Comments in JSON
|
| 230 |
+
- Newlines in strings
|
| 231 |
+
"""
|
| 232 |
+
import json
|
| 233 |
+
import re
|
| 234 |
+
|
| 235 |
+
if not response_text or not isinstance(response_text, str):
|
| 236 |
+
raise ValueError("Empty or invalid response text")
|
| 237 |
+
|
| 238 |
+
# 🔥 PREPROCESSING: Clean LLM output
|
| 239 |
+
cleaned = response_text.strip()
|
| 240 |
+
|
| 241 |
+
# Remove BOM and invisible chars
|
| 242 |
+
cleaned = cleaned.lstrip('\ufeff\u200b\u200c\u200d')
|
| 243 |
+
|
| 244 |
+
# Strategy 0: Handle Python dict syntax (single quotes -> double quotes)
|
| 245 |
+
# LLMs sometimes return Python dict syntax {'key': 'value'} instead of JSON {"key": "value"}
|
| 246 |
+
if "'variations'" in cleaned or (cleaned.startswith("{'") or cleaned.startswith("{'variations'")):
|
| 247 |
+
try:
|
| 248 |
+
import ast
|
| 249 |
+
# Try to parse as Python literal (handles single quotes, True/False, None)
|
| 250 |
+
python_dict = ast.literal_eval(cleaned)
|
| 251 |
+
if isinstance(python_dict, dict) and 'variations' in python_dict:
|
| 252 |
+
# Convert to JSON-compatible format
|
| 253 |
+
json_str = json.dumps(python_dict)
|
| 254 |
+
data = json.loads(json_str)
|
| 255 |
+
if 'variations' in data:
|
| 256 |
+
# #region agent log
|
| 257 |
+
import json as _json_debug
|
| 258 |
+
import time as _time_debug
|
| 259 |
+
import os as _os_debug
|
| 260 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 261 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 262 |
+
with open(_debug_log_path, "a") as _f:
|
| 263 |
+
_f.write(_json_debug.dumps({"hypothesisId": "JSON_FIX", "location": "llego_enhanced_llm.py:python_dict_parse", "message": "Successfully parsed Python dict syntax", "data": {"num_expected": num_expected, "parsed_variations": len(data.get('variations', []))}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 264 |
+
# #endregion
|
| 265 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 266 |
+
except (ValueError, SyntaxError, TypeError) as e:
|
| 267 |
+
# If ast.literal_eval fails, try string replacement as fallback
|
| 268 |
+
try:
|
| 269 |
+
# Simple conversion: replace single quotes with double quotes (with escaping)
|
| 270 |
+
# This is a heuristic and may not work for all cases
|
| 271 |
+
converted = cleaned.replace("'", '"')
|
| 272 |
+
data = json.loads(converted)
|
| 273 |
+
if 'variations' in data:
|
| 274 |
+
# #region agent log
|
| 275 |
+
import json as _json_debug
|
| 276 |
+
import time as _time_debug
|
| 277 |
+
import os as _os_debug
|
| 278 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 279 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 280 |
+
with open(_debug_log_path, "a") as _f:
|
| 281 |
+
_f.write(_json_debug.dumps({"hypothesisId": "JSON_FIX", "location": "llego_enhanced_llm.py:python_dict_string_replace", "message": "Parsed Python dict via string replacement", "data": {"num_expected": num_expected, "parsed_variations": len(data.get('variations', []))}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 282 |
+
# #endregion
|
| 283 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 284 |
+
except json.JSONDecodeError:
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
# Strategy 1: Direct JSON parse (cleanest case)
|
| 288 |
+
try:
|
| 289 |
+
data = json.loads(cleaned)
|
| 290 |
+
if 'variations' in data:
|
| 291 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 292 |
+
except json.JSONDecodeError:
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
# Strategy 2: Extract from markdown code block
|
| 296 |
+
# More permissive regex that handles various formats
|
| 297 |
+
code_block_patterns = [
|
| 298 |
+
r'```(?:json|JSON)?\s*(\{[\s\S]*?\})\s*```', # Standard markdown
|
| 299 |
+
r'```\s*(\{[\s\S]*"variations"[\s\S]*\})\s*```', # With "variations" keyword
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
for pattern in code_block_patterns:
|
| 303 |
+
json_match = re.search(pattern, cleaned)
|
| 304 |
+
if json_match:
|
| 305 |
+
json_str = json_match.group(1)
|
| 306 |
+
try:
|
| 307 |
+
data = json.loads(json_str)
|
| 308 |
+
if 'variations' in data:
|
| 309 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 310 |
+
except json.JSONDecodeError:
|
| 311 |
+
# Try repair
|
| 312 |
+
repaired = self._repair_json_string(json_str)
|
| 313 |
+
try:
|
| 314 |
+
data = json.loads(repaired)
|
| 315 |
+
if 'variations' in data:
|
| 316 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 317 |
+
except json.JSONDecodeError:
|
| 318 |
+
pass
|
| 319 |
+
|
| 320 |
+
# Strategy 3: Balanced brace extraction (handles nested objects)
|
| 321 |
+
json_str = self._extract_balanced_json(cleaned)
|
| 322 |
+
if json_str:
|
| 323 |
+
try:
|
| 324 |
+
data = json.loads(json_str)
|
| 325 |
+
if 'variations' in data:
|
| 326 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 327 |
+
except json.JSONDecodeError:
|
| 328 |
+
repaired = self._repair_json_string(json_str)
|
| 329 |
+
try:
|
| 330 |
+
data = json.loads(repaired)
|
| 331 |
+
if 'variations' in data:
|
| 332 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 333 |
+
except json.JSONDecodeError:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
# Strategy 4: Find JSON object with "variations" keyword
|
| 337 |
+
# Use greedy matching to get the full object
|
| 338 |
+
json_match = re.search(r'(\{[\s\S]*"variations"[\s\S]*\})', cleaned)
|
| 339 |
+
if json_match:
|
| 340 |
+
json_str = json_match.group(1)
|
| 341 |
+
# Find the balanced JSON within
|
| 342 |
+
balanced = self._extract_balanced_json(json_str)
|
| 343 |
+
if balanced:
|
| 344 |
+
try:
|
| 345 |
+
data = json.loads(balanced)
|
| 346 |
+
if 'variations' in data:
|
| 347 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 348 |
+
except json.JSONDecodeError:
|
| 349 |
+
repaired = self._repair_json_string(balanced)
|
| 350 |
+
try:
|
| 351 |
+
data = json.loads(repaired)
|
| 352 |
+
if 'variations' in data:
|
| 353 |
+
return self._extract_variations_from_json(data, num_expected)
|
| 354 |
+
except json.JSONDecodeError:
|
| 355 |
+
pass
|
| 356 |
+
|
| 357 |
+
# Strategy 5: Fallback to numbered sections
|
| 358 |
+
logger.warning(f"JSON parsing failed, trying numbered section fallback...")
|
| 359 |
+
try:
|
| 360 |
+
return self._parse_numbered_section_variations(response_text, num_expected)
|
| 361 |
+
except ValueError:
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
# #region agent log
|
| 365 |
+
import json as _json_debug
|
| 366 |
+
import time as _time_debug
|
| 367 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 368 |
+
with open(_debug_log_path, "a") as _f:
|
| 369 |
+
_f.write(_json_debug.dumps({"hypothesisId": "D", "location": "llego_enhanced_llm.py:json_parse_fail", "message": "JSON parsing failed completely", "data": {"num_expected": num_expected, "response_preview": response_text[:500] if response_text else "EMPTY", "response_length": len(response_text) if response_text else 0}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 370 |
+
# #endregion
|
| 371 |
+
|
| 372 |
+
raise ValueError(f"Could not parse {num_expected} variations from response")
|
| 373 |
+
|
| 374 |
+
def _extract_balanced_json(self, text: str) -> Optional[str]:
|
| 375 |
+
"""Extract JSON with balanced braces."""
|
| 376 |
+
brace_count = 0
|
| 377 |
+
start_idx = -1
|
| 378 |
+
in_string = False
|
| 379 |
+
escape_next = False
|
| 380 |
+
|
| 381 |
+
for i, char in enumerate(text):
|
| 382 |
+
# Handle string escaping
|
| 383 |
+
if escape_next:
|
| 384 |
+
escape_next = False
|
| 385 |
+
continue
|
| 386 |
+
if char == '\\' and in_string:
|
| 387 |
+
escape_next = True
|
| 388 |
+
continue
|
| 389 |
+
if char == '"' and not escape_next:
|
| 390 |
+
in_string = not in_string
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
# Skip characters inside strings
|
| 394 |
+
if in_string:
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
if char == '{':
|
| 398 |
+
if brace_count == 0:
|
| 399 |
+
start_idx = i
|
| 400 |
+
brace_count += 1
|
| 401 |
+
elif char == '}':
|
| 402 |
+
brace_count -= 1
|
| 403 |
+
if brace_count == 0 and start_idx >= 0:
|
| 404 |
+
return text[start_idx:i+1]
|
| 405 |
+
|
| 406 |
+
return None
|
| 407 |
+
|
| 408 |
+
def _repair_json_string(self, json_str: str) -> str:
|
| 409 |
+
"""
|
| 410 |
+
Repair common JSON issues from LLM output.
|
| 411 |
+
|
| 412 |
+
Fixes:
|
| 413 |
+
- Trailing commas
|
| 414 |
+
- Comments
|
| 415 |
+
- Unescaped newlines in strings
|
| 416 |
+
"""
|
| 417 |
+
repaired = json_str
|
| 418 |
+
|
| 419 |
+
# Remove trailing commas before } or ]
|
| 420 |
+
repaired = re.sub(r',\s*}', '}', repaired)
|
| 421 |
+
repaired = re.sub(r',\s*]', ']', repaired)
|
| 422 |
+
|
| 423 |
+
# Remove single-line comments
|
| 424 |
+
repaired = re.sub(r'//[^\n]*\n', '\n', repaired)
|
| 425 |
+
|
| 426 |
+
# Remove multi-line comments
|
| 427 |
+
repaired = re.sub(r'/\*[\s\S]*?\*/', '', repaired)
|
| 428 |
+
|
| 429 |
+
return repaired
|
| 430 |
+
|
| 431 |
+
def _extract_variations_from_json(self, data: Dict[str, Any], num_expected: int) -> List[str]:
|
| 432 |
+
"""Extract and validate variations from parsed JSON data."""
|
| 433 |
+
|
| 434 |
+
if not isinstance(data, dict):
|
| 435 |
+
raise ValueError("JSON data is not a dictionary")
|
| 436 |
+
|
| 437 |
+
variations_list = data.get('variations', [])
|
| 438 |
+
if not isinstance(variations_list, list):
|
| 439 |
+
raise ValueError("'variations' field is not a list")
|
| 440 |
+
|
| 441 |
+
# Extract and sort by index
|
| 442 |
+
variations_with_index = []
|
| 443 |
+
for var in variations_list:
|
| 444 |
+
if not isinstance(var, dict):
|
| 445 |
+
continue
|
| 446 |
+
index = var.get('index', 0)
|
| 447 |
+
prompt = var.get('prompt', '')
|
| 448 |
+
if prompt and isinstance(prompt, str):
|
| 449 |
+
variations_with_index.append((index, prompt.strip()))
|
| 450 |
+
|
| 451 |
+
variations_with_index.sort(key=lambda x: x[0])
|
| 452 |
+
variations = [v[1] for v in variations_with_index]
|
| 453 |
+
|
| 454 |
+
# Validate count
|
| 455 |
+
if len(variations) < num_expected:
|
| 456 |
+
logger.warning(f"Only {len(variations)} valid variations found, expected {num_expected}")
|
| 457 |
+
while len(variations) < num_expected:
|
| 458 |
+
variations.append(variations[-1] if variations else "")
|
| 459 |
+
|
| 460 |
+
variations = variations[:num_expected]
|
| 461 |
+
|
| 462 |
+
if not all(v for v in variations):
|
| 463 |
+
raise ValueError(f"Some variations are empty after parsing")
|
| 464 |
+
|
| 465 |
+
return variations
|
| 466 |
+
|
| 467 |
+
def _parse_numbered_section_variations(self, response_text: str, num_expected: int) -> List[str]:
|
| 468 |
+
"""Fallback parser: Extract variations from numbered sections."""
|
| 469 |
+
import re
|
| 470 |
+
|
| 471 |
+
variations = []
|
| 472 |
+
|
| 473 |
+
pattern1 = r'---\s*VARIATION\s+(\d+)\s*---\s*\n(.*?)(?=\n---\s*VARIATION|\Z)'
|
| 474 |
+
matches1 = re.findall(pattern1, response_text, re.DOTALL | re.IGNORECASE)
|
| 475 |
+
|
| 476 |
+
pattern2 = r'Variation\s+(\d+)\s*:?\s*\n(.*?)(?=\nVariation\s+\d+|$)'
|
| 477 |
+
matches2 = re.findall(pattern2, response_text, re.DOTALL | re.IGNORECASE)
|
| 478 |
+
|
| 479 |
+
pattern3 = r'(\d+)\.\s*\n(.*?)(?=\n\d+\.|$)'
|
| 480 |
+
matches3 = re.findall(pattern3, response_text, re.DOTALL)
|
| 481 |
+
|
| 482 |
+
matches = matches1 if len(matches1) >= num_expected else (matches2 if len(matches2) >= num_expected else matches3)
|
| 483 |
+
|
| 484 |
+
if len(matches) >= num_expected:
|
| 485 |
+
matches.sort(key=lambda x: int(x[0]))
|
| 486 |
+
variations = [match[1].strip() for match in matches[:num_expected]]
|
| 487 |
+
|
| 488 |
+
if len(variations) != num_expected:
|
| 489 |
+
raise ValueError(f"Numbered section parsing found {len(variations)} variations, expected {num_expected}")
|
| 490 |
+
|
| 491 |
+
return variations
|
| 492 |
+
|
| 493 |
+
def _is_valid_prompt(self, prompt: str) -> bool:
|
| 494 |
+
"""
|
| 495 |
+
Validate that extracted text is actually a valid system prompt.
|
| 496 |
+
|
| 497 |
+
Uses minimal, conservative filtering: only rejects OBVIOUSLY wrong text.
|
| 498 |
+
Let evaluation decide on quality - false negatives (rejecting good prompts)
|
| 499 |
+
are worse than false positives (accepting bad prompts).
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
prompt: Extracted text to validate
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
True if appears to be a valid prompt, False if obviously wrong
|
| 506 |
+
"""
|
| 507 |
+
if not prompt or not prompt.strip():
|
| 508 |
+
return False
|
| 509 |
+
|
| 510 |
+
prompt_lower = prompt.lower().strip()
|
| 511 |
+
|
| 512 |
+
# STRONG indicators of analysis text (high confidence rejection)
|
| 513 |
+
# These are phrases that almost never appear in actual prompts
|
| 514 |
+
strong_analysis_patterns = [
|
| 515 |
+
'in conclusion',
|
| 516 |
+
'to summarize',
|
| 517 |
+
'based on the analysis',
|
| 518 |
+
'the analysis shows',
|
| 519 |
+
'here are some suggestions',
|
| 520 |
+
'it seems you\'re looking for',
|
| 521 |
+
]
|
| 522 |
+
|
| 523 |
+
# Check first 200 characters for strong patterns
|
| 524 |
+
first_200 = prompt_lower[:200]
|
| 525 |
+
for pattern in strong_analysis_patterns:
|
| 526 |
+
if pattern in first_200:
|
| 527 |
+
if self._should_log_debug():
|
| 528 |
+
logger.debug(f"Rejected prompt: contains analysis pattern '{pattern}'")
|
| 529 |
+
return False
|
| 530 |
+
|
| 531 |
+
# POSITIVE indicators of valid prompt (high confidence acceptance)
|
| 532 |
+
# These are common prompt starters
|
| 533 |
+
valid_starters = [
|
| 534 |
+
'you are',
|
| 535 |
+
'you\'re',
|
| 536 |
+
'your task',
|
| 537 |
+
'your role',
|
| 538 |
+
'analyze',
|
| 539 |
+
'identify',
|
| 540 |
+
'select',
|
| 541 |
+
'determine',
|
| 542 |
+
'given',
|
| 543 |
+
'when',
|
| 544 |
+
]
|
| 545 |
+
|
| 546 |
+
# If starts with valid prompt pattern, accept immediately
|
| 547 |
+
first_100 = prompt_lower[:100]
|
| 548 |
+
if any(first_100.startswith(starter) for starter in valid_starters):
|
| 549 |
+
return True
|
| 550 |
+
|
| 551 |
+
# DEFAULT: Accept everything else and let evaluation decide
|
| 552 |
+
# This is conservative - we'd rather evaluate a bad prompt than reject a good one
|
| 553 |
+
return True
|
| 554 |
+
|
| 555 |
+
def set_reflection_context(
|
| 556 |
+
self,
|
| 557 |
+
current_prompt: Optional[str] = None,
|
| 558 |
+
feedback: Optional[Any] = None,
|
| 559 |
+
in_reflection: bool = False
|
| 560 |
+
):
|
| 561 |
+
"""
|
| 562 |
+
Set context for the next generate() call.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
current_prompt: The prompt being reflected upon
|
| 566 |
+
feedback: Evaluation feedback
|
| 567 |
+
in_reflection: Whether we're in reflection mode
|
| 568 |
+
"""
|
| 569 |
+
self.reflection_context = {
|
| 570 |
+
'current_prompt': current_prompt,
|
| 571 |
+
'feedback': feedback,
|
| 572 |
+
'in_reflection': in_reflection
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
# Reset candidate queue when entering new reflection phase
|
| 576 |
+
if in_reflection:
|
| 577 |
+
self._candidate_queue = []
|
| 578 |
+
self._hybrid_generation_complete = False
|
| 579 |
+
if self._should_log_debug():
|
| 580 |
+
logger.debug("🔄 Entering LLEGO reflection mode (queue reset)")
|
| 581 |
+
else:
|
| 582 |
+
logger.info("🔄 Entering LLEGO reflection mode")
|
| 583 |
+
|
| 584 |
+
def generate(
|
| 585 |
+
self,
|
| 586 |
+
system_prompt: str = "",
|
| 587 |
+
user_prompt: str = "",
|
| 588 |
+
image_base64: str = "",
|
| 589 |
+
**kwargs
|
| 590 |
+
) -> Dict[str, Any]:
|
| 591 |
+
"""
|
| 592 |
+
Generate response, using LLEGO for reflection calls.
|
| 593 |
+
|
| 594 |
+
🔥 CRITICAL: This method intercepts ALL LLM calls. For candidate generation,
|
| 595 |
+
it checks if we have pre-generated candidates from hybrid mode and returns those.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
system_prompt: System prompt
|
| 599 |
+
user_prompt: User prompt
|
| 600 |
+
image_base64: Base64-encoded image (if any)
|
| 601 |
+
**kwargs: Additional arguments
|
| 602 |
+
|
| 603 |
+
Returns:
|
| 604 |
+
Dict with 'content' key containing the generated text
|
| 605 |
+
"""
|
| 606 |
+
# 🔍 DEBUG: Log generate calls (full details at DEBUG level)
|
| 607 |
+
if self._should_log_debug():
|
| 608 |
+
logger.debug(f"🔍 LLEGO Wrapper: generate() called")
|
| 609 |
+
logger.debug(f" system_prompt: '{system_prompt[:100]}...' (truncated)")
|
| 610 |
+
logger.debug(f" user_prompt length: {len(user_prompt)} chars")
|
| 611 |
+
logger.debug(f" in_reflection: {self.reflection_context['in_reflection']}")
|
| 612 |
+
logger.debug(f" has_image: {bool(image_base64)}")
|
| 613 |
+
|
| 614 |
+
# #region agent log
|
| 615 |
+
try:
|
| 616 |
+
import json as _json_debug
|
| 617 |
+
import time as _time_debug
|
| 618 |
+
import os as _os_debug
|
| 619 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 620 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 621 |
+
with open(_debug_log_path, "a") as _f:
|
| 622 |
+
_f.write(_json_debug.dumps({"hypothesisId": "INTERCEPTION", "location": "llego_enhanced_llm.py:generate", "message": "Generate called", "data": {"system_prompt_len": len(system_prompt), "user_prompt_len": len(user_prompt), "has_image": bool(image_base64), "has_candidates": len(getattr(self, '_adapter_generated_candidates', [])), "in_reflection": self.reflection_context.get('in_reflection', False)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 623 |
+
except Exception:
|
| 624 |
+
pass
|
| 625 |
+
# #endregion
|
| 626 |
+
|
| 627 |
+
# 🔥 CRITICAL: Check if we have pre-generated candidates from adapter-level generation
|
| 628 |
+
# This happens when GEPA calls adapter.llm_client to generate candidates
|
| 629 |
+
# We intercept and return our pre-generated candidates instead
|
| 630 |
+
# 🔥 NEW: Select BEST candidate instead of FIFO
|
| 631 |
+
# 🔥 FIX: DON'T intercept evaluation calls (those have images!)
|
| 632 |
+
# Only intercept proposal calls (no images, just asking for new candidate)
|
| 633 |
+
# 🔥 FIX 2: DON'T intercept TEST EVALUATION calls!
|
| 634 |
+
# Test evaluation has no images but uses the OPTIMIZED prompt to execute tasks
|
| 635 |
+
# We detect test evaluation by checking if this is a TASK EXECUTION call (not reflection)
|
| 636 |
+
is_task_execution = (
|
| 637 |
+
# Task execution prompts contain task instructions, not optimization requests
|
| 638 |
+
not any(kw in system_prompt.lower() for kw in ['evolutionary', 'mutation', 'variation', 'optimize', 'improve prompt', 'rewrite', 'generate variations']) and
|
| 639 |
+
# Short prompts are usually task prompts, not optimization prompts
|
| 640 |
+
len(system_prompt) < 1000 and
|
| 641 |
+
# User prompt is the actual task input (short), not feedback (long)
|
| 642 |
+
len(user_prompt) < 2000
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# Log task execution detection for debugging
|
| 646 |
+
if is_task_execution and hasattr(self, '_adapter_generated_candidates') and self._adapter_generated_candidates:
|
| 647 |
+
logger.info(f"🔒 NOT intercepting: Task execution detected (not optimization)")
|
| 648 |
+
logger.debug(f" system_prompt_len={len(system_prompt)}, user_prompt_len={len(user_prompt)}")
|
| 649 |
+
|
| 650 |
+
if hasattr(self, '_adapter_generated_candidates') and self._adapter_generated_candidates and not image_base64 and not is_task_execution:
|
| 651 |
+
# 🔥 BEST-CANDIDATE SELECTION: Find candidate with highest Dpareto score
|
| 652 |
+
# This ensures we use the best candidate for the current iteration
|
| 653 |
+
best_candidate = None
|
| 654 |
+
best_score = -float('inf')
|
| 655 |
+
best_idx = -1
|
| 656 |
+
|
| 657 |
+
# Check if candidates have scores stored
|
| 658 |
+
for idx, cand in enumerate(self._adapter_generated_candidates):
|
| 659 |
+
if isinstance(cand, dict):
|
| 660 |
+
# Try to get score from candidate dict
|
| 661 |
+
score = cand.get('score', -float('inf'))
|
| 662 |
+
|
| 663 |
+
# If score not in dict, try to get from Pareto logger
|
| 664 |
+
if score == -float('inf'):
|
| 665 |
+
from ..utils.pareto_logger import get_pareto_logger
|
| 666 |
+
pareto_log = get_pareto_logger()
|
| 667 |
+
|
| 668 |
+
# Look up score in Pareto front or evaluated candidates
|
| 669 |
+
cand_prompt = cand.get('prompt', '')
|
| 670 |
+
if cand_prompt:
|
| 671 |
+
normalized = cand_prompt.strip().strip('"\'')
|
| 672 |
+
# Check in Pareto front
|
| 673 |
+
for front_cand in pareto_log.pareto_front:
|
| 674 |
+
if front_cand.get('prompt', '').strip().strip('"\'') == normalized:
|
| 675 |
+
score = front_cand.get('score', -float('inf'))
|
| 676 |
+
break
|
| 677 |
+
|
| 678 |
+
# If not in front, check evaluated candidates
|
| 679 |
+
if score == -float('inf'):
|
| 680 |
+
for eval_cand in pareto_log.candidates_evaluated:
|
| 681 |
+
if eval_cand.get('prompt', '').strip().strip('"\'') == normalized:
|
| 682 |
+
score = eval_cand.get('score', -float('inf'))
|
| 683 |
+
break
|
| 684 |
+
|
| 685 |
+
if score > best_score:
|
| 686 |
+
best_score = score
|
| 687 |
+
best_candidate = cand
|
| 688 |
+
best_idx = idx
|
| 689 |
+
|
| 690 |
+
# If no scores found, fall back to FIFO (first candidate)
|
| 691 |
+
if best_candidate is None and self._adapter_generated_candidates:
|
| 692 |
+
best_candidate = self._adapter_generated_candidates[0]
|
| 693 |
+
best_idx = 0
|
| 694 |
+
logger.info(f"⚠️ No scores found for candidates - using FIFO selection")
|
| 695 |
+
|
| 696 |
+
# Remove selected candidate from queue
|
| 697 |
+
if best_idx >= 0:
|
| 698 |
+
self._adapter_generated_candidates.pop(best_idx)
|
| 699 |
+
|
| 700 |
+
# Important event - keep at INFO
|
| 701 |
+
if best_score > -float('inf'):
|
| 702 |
+
logger.info(f"🎯 INTERCEPTING GEPA PROPOSAL CALL - Returning BEST candidate (score: {best_score:.4f})!")
|
| 703 |
+
logger.info(f"🎯 Remaining candidates: {len(self._adapter_generated_candidates)}")
|
| 704 |
+
else:
|
| 705 |
+
logger.info(f"🎯 INTERCEPTING GEPA PROPOSAL CALL - Returning pre-generated candidate!")
|
| 706 |
+
logger.info(f"🎯 Remaining candidates: {len(self._adapter_generated_candidates)}")
|
| 707 |
+
|
| 708 |
+
if isinstance(best_candidate, dict) and 'prompt' in best_candidate:
|
| 709 |
+
prompt = best_candidate['prompt']
|
| 710 |
+
|
| 711 |
+
# Detailed logging only in DEBUG mode
|
| 712 |
+
if self._should_log_debug():
|
| 713 |
+
logger.debug(f"✅ Pre-generated candidate details:")
|
| 714 |
+
logger.debug(f"{'▓'*80}")
|
| 715 |
+
logger.debug(f"{prompt}")
|
| 716 |
+
logger.debug(f"{'▓'*80}")
|
| 717 |
+
else:
|
| 718 |
+
source = best_candidate.get('source', 'unknown')
|
| 719 |
+
score_info = f" (score: {best_score:.4f})" if best_score > -float('inf') else ""
|
| 720 |
+
logger.info(f"✅ Candidate length: {len(prompt)} chars, Source: {source}{score_info}")
|
| 721 |
+
|
| 722 |
+
return {'content': prompt, 'source': best_candidate.get('source', 'adapter_generated')}
|
| 723 |
+
elif isinstance(best_candidate, str):
|
| 724 |
+
if self._should_log_debug():
|
| 725 |
+
logger.debug(f"✅ Pre-generated candidate (string format):")
|
| 726 |
+
logger.debug(f"{'▓'*80}")
|
| 727 |
+
logger.debug(f"{best_candidate}")
|
| 728 |
+
logger.debug(f"{'▓'*80}")
|
| 729 |
+
else:
|
| 730 |
+
logger.info(f"✅ Candidate length: {len(best_candidate)} chars")
|
| 731 |
+
return {'content': best_candidate, 'source': 'adapter_generated'}
|
| 732 |
+
|
| 733 |
+
# 🔥 ENHANCED CALL TYPE DETECTION
|
| 734 |
+
# We need to distinguish between 4 types of calls:
|
| 735 |
+
# 1. Evaluation calls: Image + task command → identify element (pass through)
|
| 736 |
+
# 2. Judge calls: Image + "prompt engineer" → analyze feedback (pass through)
|
| 737 |
+
# 3. Proposal calls: No image + feedback → generate candidate (intercept)
|
| 738 |
+
# 4. JSON batch calls: JSON generation request (pass through)
|
| 739 |
+
|
| 740 |
+
# FIX: DON'T intercept JSON batch generation calls
|
| 741 |
+
is_json_batch_request = (
|
| 742 |
+
'"variations"' in system_prompt or
|
| 743 |
+
'MUST BE VALID JSON' in system_prompt or
|
| 744 |
+
'Output ONLY the JSON object' in system_prompt or
|
| 745 |
+
'```json' in system_prompt.lower()
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# FIX: DON'T intercept LLM-as-Judge calls (they analyze feedback with images)
|
| 749 |
+
is_judge_call = (
|
| 750 |
+
'prompt engineer' in system_prompt.lower() or
|
| 751 |
+
'analyzing mobile ui automation' in system_prompt.lower() or
|
| 752 |
+
'expert prompt engineer' in system_prompt.lower() or
|
| 753 |
+
('analyze' in system_prompt.lower() and 'screenshot with numbered bounding boxes' in system_prompt.lower() and image_base64)
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# Check if this is a reflection call (GEPA asking for new candidate)
|
| 757 |
+
is_reflection_call = (
|
| 758 |
+
self.reflection_context['in_reflection'] or
|
| 759 |
+
self._detect_reflection_call(system_prompt, user_prompt)
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Proposal calls are reflection calls WITHOUT images and NOT judge/JSON calls
|
| 763 |
+
# These are the calls we want to intercept with LLEGO
|
| 764 |
+
is_proposal_call = (
|
| 765 |
+
not is_json_batch_request and # Not a JSON generation request
|
| 766 |
+
not is_judge_call and # Not an LLM-as-Judge analysis
|
| 767 |
+
not image_base64 and # No image = not an evaluation/judge call
|
| 768 |
+
(
|
| 769 |
+
is_reflection_call or
|
| 770 |
+
'improve' in system_prompt.lower() or
|
| 771 |
+
'optimize' in system_prompt.lower() or
|
| 772 |
+
'suggest' in system_prompt.lower() or
|
| 773 |
+
'feedback' in system_prompt.lower() or
|
| 774 |
+
'reflection' in system_prompt.lower()
|
| 775 |
+
) and
|
| 776 |
+
len(user_prompt) > 100 # Proposal calls have substantial feedback
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Detailed call detection logging only in DEBUG mode
|
| 780 |
+
if self._should_log_debug():
|
| 781 |
+
logger.debug(f" is_json_batch_request: {is_json_batch_request}")
|
| 782 |
+
logger.debug(f" is_judge_call: {is_judge_call}")
|
| 783 |
+
logger.debug(f" is_reflection_call: {is_reflection_call}")
|
| 784 |
+
logger.debug(f" is_proposal_call: {is_proposal_call}")
|
| 785 |
+
logger.debug(f" has_image: {bool(image_base64)}")
|
| 786 |
+
logger.debug(f" has_llego: {self.llego is not None}")
|
| 787 |
+
|
| 788 |
+
# Only intercept proposal calls (not judge, not evaluation, not JSON)
|
| 789 |
+
if is_proposal_call and self.llego:
|
| 790 |
+
# FIX #5: Check if LLEGO is disabled due to repeated failures
|
| 791 |
+
if self._llego_disabled:
|
| 792 |
+
logger.warning("⚠️ LLEGO is disabled (circuit breaker), using base LLM")
|
| 793 |
+
return self.base_llm.generate(
|
| 794 |
+
system_prompt=system_prompt,
|
| 795 |
+
user_prompt=user_prompt,
|
| 796 |
+
image_base64=image_base64,
|
| 797 |
+
**kwargs
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
# Important event - keep at INFO
|
| 801 |
+
logger.info("🔥 INTERCEPTING REFLECTION/PROPOSAL CALL FOR CANDIDATE GENERATION")
|
| 802 |
+
return self._llego_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
|
| 803 |
+
else:
|
| 804 |
+
# Standard LLM call (for evaluation, not reflection)
|
| 805 |
+
if self._should_log_debug():
|
| 806 |
+
logger.debug(" → Standard LLM call (evaluation, not reflection)")
|
| 807 |
+
return self.base_llm.generate(
|
| 808 |
+
system_prompt=system_prompt,
|
| 809 |
+
user_prompt=user_prompt,
|
| 810 |
+
image_base64=image_base64,
|
| 811 |
+
**kwargs
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
def _clean_reflection_feedback(self, feedback_text: str, max_length: int = 50000) -> str:
|
| 815 |
+
"""
|
| 816 |
+
Clean reflection feedback by removing base64 images and truncating.
|
| 817 |
+
|
| 818 |
+
🔥 CRITICAL: GEPA's feedback can include massive base64 images (7MB+).
|
| 819 |
+
This function removes them and keeps feedback concise.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
feedback_text: Original feedback (may contain base64)
|
| 823 |
+
max_length: Maximum length after cleaning (default: 50K chars)
|
| 824 |
+
|
| 825 |
+
Returns:
|
| 826 |
+
Cleaned feedback without base64, within size limits
|
| 827 |
+
"""
|
| 828 |
+
if not feedback_text:
|
| 829 |
+
return feedback_text
|
| 830 |
+
|
| 831 |
+
# Step 1: Remove very long base64-like sequences (50K+ chars of alphanumeric)
|
| 832 |
+
base64_pattern = r'[A-Za-z0-9+/=]{5000,}'
|
| 833 |
+
cleaned = re.sub(base64_pattern, '[IMAGE_DATA_REMOVED]', feedback_text)
|
| 834 |
+
|
| 835 |
+
# Step 2: Remove explicit image_base64 references and their values
|
| 836 |
+
cleaned = re.sub(r'image_base64["\']?\s*[:=]\s*["\']?[A-Za-z0-9+/=]+["\']?',
|
| 837 |
+
'image_base64: [REMOVED]', cleaned, flags=re.IGNORECASE)
|
| 838 |
+
|
| 839 |
+
# Step 3: Remove detailed_scores sections that might contain base64
|
| 840 |
+
cleaned = re.sub(r'##\s+detailed_scores[^\n]*\n[^#]*(?:image_base64|base64)[^\n]*(?:\n[^#]*)*',
|
| 841 |
+
'## detailed_scores: [REMOVED_FOR_BREVITY]', cleaned, flags=re.IGNORECASE | re.MULTILINE)
|
| 842 |
+
|
| 843 |
+
# Step 4: Remove any remaining very long strings (likely base64)
|
| 844 |
+
cleaned = re.sub(r'"[A-Za-z0-9+/=]{10000,}"', '[LARGE_DATA_STRING_REMOVED]', cleaned)
|
| 845 |
+
|
| 846 |
+
# Step 5: Truncate if still too long (keep beginning which has most important info)
|
| 847 |
+
if len(cleaned) > max_length:
|
| 848 |
+
truncated_size = len(cleaned) - max_length
|
| 849 |
+
cleaned = cleaned[:max_length] + f"\n\n[TRUNCATED {truncated_size} characters - keeping essential feedback only]"
|
| 850 |
+
logger.warning(f"⚠️ Reflection feedback truncated: {len(feedback_text)} → {len(cleaned)} chars")
|
| 851 |
+
|
| 852 |
+
return cleaned
|
| 853 |
+
|
| 854 |
+
def _detect_reflection_call(self, system_prompt: str, user_prompt: str) -> bool:
|
| 855 |
+
"""
|
| 856 |
+
Heuristic to detect if this is a reflection call from GEPA.
|
| 857 |
+
|
| 858 |
+
GEPA's reflection calls typically contain feedback/error analysis.
|
| 859 |
+
"""
|
| 860 |
+
reflection_keywords = [
|
| 861 |
+
'improve', 'feedback', 'error', 'failure', 'reflection',
|
| 862 |
+
'better prompt', 'modify', 'enhance', 'optimize'
|
| 863 |
+
]
|
| 864 |
+
|
| 865 |
+
combined = (system_prompt + " " + user_prompt).lower()
|
| 866 |
+
return any(keyword in combined for keyword in reflection_keywords)
|
| 867 |
+
|
| 868 |
+
def _llego_generate(
|
| 869 |
+
self,
|
| 870 |
+
system_prompt: str,
|
| 871 |
+
user_prompt: str,
|
| 872 |
+
image_base64: str = "",
|
| 873 |
+
**kwargs
|
| 874 |
+
) -> Dict[str, Any]:
|
| 875 |
+
"""
|
| 876 |
+
Use LLEGO (or Hybrid mode) to generate new prompt candidates.
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
system_prompt: System prompt
|
| 880 |
+
user_prompt: User prompt (contains reflection feedback)
|
| 881 |
+
image_base64: Image data (for reflection, always empty)
|
| 882 |
+
**kwargs: Additional arguments (may contain image_base64, will be removed)
|
| 883 |
+
|
| 884 |
+
Returns:
|
| 885 |
+
Dict with 'content' key containing a new prompt candidate
|
| 886 |
+
"""
|
| 887 |
+
try:
|
| 888 |
+
# 🔥 CRITICAL: Remove image_base64 from kwargs to avoid duplicate argument error
|
| 889 |
+
kwargs.pop('image_base64', None) # Remove if present to avoid conflict
|
| 890 |
+
|
| 891 |
+
# 🔥 HYBRID MODE: Generate from BOTH GEPA reflection AND LLEGO
|
| 892 |
+
if (self.config and
|
| 893 |
+
hasattr(self.config, 'enable_gepa_reflection_with_llego') and
|
| 894 |
+
self.config.enable_gepa_reflection_with_llego):
|
| 895 |
+
|
| 896 |
+
return self._hybrid_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
|
| 897 |
+
|
| 898 |
+
# STANDARD LLEGO MODE (LLEGO only)
|
| 899 |
+
return self._llego_only_generate(system_prompt, user_prompt, image_base64=image_base64, **kwargs)
|
| 900 |
+
|
| 901 |
+
except Exception as e:
|
| 902 |
+
# FIX #5: Circuit breaker - track failures and disable LLEGO if needed
|
| 903 |
+
self._llego_failures += 1
|
| 904 |
+
|
| 905 |
+
logger.error(f"❌ LLEGO generation failed ({self._llego_failures}/{self._llego_failure_threshold}): {e}")
|
| 906 |
+
logger.error("⚠️ Falling back to base LLM")
|
| 907 |
+
|
| 908 |
+
if self._llego_failures >= self._llego_failure_threshold:
|
| 909 |
+
self._llego_disabled = True
|
| 910 |
+
logger.error(f"🚫 LLEGO DISABLED - {self._llego_failures} consecutive failures detected")
|
| 911 |
+
logger.error(" All future requests will use base LLM only")
|
| 912 |
+
|
| 913 |
+
import traceback
|
| 914 |
+
logger.debug(traceback.format_exc())
|
| 915 |
+
|
| 916 |
+
# Fallback to base LLM - ensure image_base64 is not in kwargs
|
| 917 |
+
kwargs.pop('image_base64', None)
|
| 918 |
+
return self.base_llm.generate(
|
| 919 |
+
system_prompt=system_prompt,
|
| 920 |
+
user_prompt=user_prompt,
|
| 921 |
+
image_base64=image_base64,
|
| 922 |
+
**kwargs
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
def _hybrid_generate(
|
| 926 |
+
self,
|
| 927 |
+
system_prompt: str,
|
| 928 |
+
user_prompt: str,
|
| 929 |
+
image_base64: str = "",
|
| 930 |
+
**kwargs
|
| 931 |
+
) -> Dict[str, Any]:
|
| 932 |
+
"""
|
| 933 |
+
🔥 HYBRID MODE: Generate candidates from BOTH GEPA reflection AND LLEGO operators.
|
| 934 |
+
|
| 935 |
+
Smart Compensation Strategy:
|
| 936 |
+
- When crossover can't run (< 2 parents), compensates with extra GEPA reflection
|
| 937 |
+
- GEPA is smarter than mutation (uses semantic understanding of feedback)
|
| 938 |
+
- Crossover only runs when we have 2+ scored parents to combine
|
| 939 |
+
|
| 940 |
+
GEPA will call generate() multiple times. On first call, we generate all candidates
|
| 941 |
+
and queue them. Subsequent calls return from the queue.
|
| 942 |
+
"""
|
| 943 |
+
# If we already generated candidates, return next from queue
|
| 944 |
+
if self._hybrid_generation_complete and self._candidate_queue:
|
| 945 |
+
candidate = self._candidate_queue.pop(0)
|
| 946 |
+
source = candidate.get('source', 'unknown')
|
| 947 |
+
logger.info(f"📦 Returning queued candidate (source: {source}, {len(self._candidate_queue)} remaining)")
|
| 948 |
+
return {'content': candidate['prompt'], 'source': source}
|
| 949 |
+
|
| 950 |
+
# First call: Generate ALL candidates
|
| 951 |
+
from ..utils.clean_logger import get_clean_logger
|
| 952 |
+
clean_log = get_clean_logger()
|
| 953 |
+
|
| 954 |
+
all_candidates = []
|
| 955 |
+
|
| 956 |
+
# ─────────────────────────────────────────────────────
|
| 957 |
+
# PHASE 0: Check if crossover will be possible
|
| 958 |
+
# ─────────────────────────────────────────────────────
|
| 959 |
+
from ..utils.pareto_logger import get_pareto_logger
|
| 960 |
+
pareto_log = get_pareto_logger()
|
| 961 |
+
gepa_pareto_front = pareto_log.pareto_front
|
| 962 |
+
|
| 963 |
+
# Determine if we need to compensate for crossover
|
| 964 |
+
crossover_possible = len(gepa_pareto_front) >= 2
|
| 965 |
+
n_crossover_config = self.config.n_crossover if hasattr(self.config, 'n_crossover') else 2
|
| 966 |
+
crossover_compensation = 0 if crossover_possible else n_crossover_config
|
| 967 |
+
|
| 968 |
+
if not crossover_possible:
|
| 969 |
+
logger.info(f"⚠️ Crossover NOT possible (have {len(gepa_pareto_front)} parents, need 2+)")
|
| 970 |
+
logger.info(f" → Smart compensation: +{crossover_compensation} extra GEPA reflection candidates")
|
| 971 |
+
|
| 972 |
+
# ─────────────────────────────────────────────────────
|
| 973 |
+
# PHASE 1: GEPA REFLECTION (Semantic Understanding)
|
| 974 |
+
# More GEPA = better, it understands WHY things fail
|
| 975 |
+
# ─────────────────────────────────────────────────────
|
| 976 |
+
base_gepa_count = self.config.num_gepa_reflection_candidates if hasattr(self.config, 'num_gepa_reflection_candidates') else 3
|
| 977 |
+
|
| 978 |
+
# 🔥 SMART COMPENSATION: More GEPA when crossover can't run
|
| 979 |
+
num_gepa = base_gepa_count + crossover_compensation
|
| 980 |
+
|
| 981 |
+
logger.info("─" * 80)
|
| 982 |
+
logger.info("PHASE 1: GEPA REFLECTION (Semantic Understanding)")
|
| 983 |
+
if crossover_compensation > 0:
|
| 984 |
+
logger.info(f"Generating {num_gepa} candidates ({base_gepa_count} base + {crossover_compensation} compensation for skipped crossover)")
|
| 985 |
+
else:
|
| 986 |
+
logger.info(f"Generating {num_gepa} candidates")
|
| 987 |
+
logger.info("─" * 80)
|
| 988 |
+
|
| 989 |
+
# 🔥 OPTIMIZED: Single call with JSON format for multiple variations
|
| 990 |
+
try:
|
| 991 |
+
# Clean user_prompt before sending to LLM
|
| 992 |
+
cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
|
| 993 |
+
|
| 994 |
+
# Build diversity requirements based on num_gepa
|
| 995 |
+
diversity_requirements = self._build_diversity_requirements(num_gepa)
|
| 996 |
+
|
| 997 |
+
# 🔥 FORMAT AWARENESS: Get format constraint if available
|
| 998 |
+
format_constraint = ""
|
| 999 |
+
if self._detected_format and self._detected_format.get('format_constraint'):
|
| 1000 |
+
format_constraint = self._detected_format['format_constraint']
|
| 1001 |
+
logger.info(f"📐 Injecting format constraint into candidate generation")
|
| 1002 |
+
# #region agent log
|
| 1003 |
+
import json as _json_debug
|
| 1004 |
+
import time as _time_debug
|
| 1005 |
+
import os as _os_debug
|
| 1006 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 1007 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 1008 |
+
with open(_debug_log_path, "a") as _f:
|
| 1009 |
+
_f.write(_json_debug.dumps({"hypothesisId": "FORMAT_CONSTRAINT", "location": "llego_enhanced_llm.py:format_injection", "message": "Format constraint injected", "data": {"format_type": self._detected_format.get('format_type', 'unknown'), "constraint_length": len(format_constraint), "avg_length": self._detected_format.get('avg_length', 0)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 1010 |
+
# #endregion
|
| 1011 |
+
else:
|
| 1012 |
+
format_constraint = "No specific format detected - ensure output is CONCISE and matches expected examples."
|
| 1013 |
+
# #region agent log
|
| 1014 |
+
import json as _json_debug
|
| 1015 |
+
import time as _time_debug
|
| 1016 |
+
import os as _os_debug
|
| 1017 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 1018 |
+
_os_debug.makedirs(_os_debug.path.dirname(_debug_log_path), exist_ok=True)
|
| 1019 |
+
with open(_debug_log_path, "a") as _f:
|
| 1020 |
+
_f.write(_json_debug.dumps({"hypothesisId": "FORMAT_CONSTRAINT", "location": "llego_enhanced_llm.py:format_injection", "message": "No format constraint available", "data": {"has_detected_format": bool(self._detected_format)}, "timestamp": int(_time_debug.time() * 1000), "sessionId": "debug-session"}) + "\n")
|
| 1021 |
+
# #endregion
|
| 1022 |
+
|
| 1023 |
+
# 🔥 EVOLUTIONARY PROMPT ENGINEER: Forces radically different mutations
|
| 1024 |
+
# Each variation MUST use a distinct genetic strategy to maximize search space
|
| 1025 |
+
optimization_system_prompt = f"""<system_core>
|
| 1026 |
+
You are an **Evolutionary Prompt Engineer**. Your task is to mutate a [FAILING_PROMPT] into a high-performance instruction set using genetic strategies.
|
| 1027 |
+
You must generate {num_gepa} radically different prompt variations based on the [FAILURE_FEEDBACK].
|
| 1028 |
+
</system_core>
|
| 1029 |
+
|
| 1030 |
+
<input_data>
|
| 1031 |
+
<failure_feedback_log>
|
| 1032 |
+
{cleaned_user_prompt}
|
| 1033 |
+
</failure_feedback_log>
|
| 1034 |
+
</input_data>
|
| 1035 |
+
|
| 1036 |
+
<mutation_strategies>
|
| 1037 |
+
You MUST use a different strategy for each variation. Assign strategies in order:
|
| 1038 |
+
|
| 1039 |
+
1. **STRATEGY A: The Strict Auditor (Constraints)**
|
| 1040 |
+
- Focus: Add "Negative Constraints" (e.g., "Do NOT...", "NEVER...", "FORBIDDEN:").
|
| 1041 |
+
- Use strict XML tagging for the output schema.
|
| 1042 |
+
- Goal: Fix hallucinations and formatting errors.
|
| 1043 |
+
|
| 1044 |
+
2. **STRATEGY B: The Reasoning Expert (Chain of Thought)**
|
| 1045 |
+
- Focus: Add a "Reasoning Steps" section.
|
| 1046 |
+
- Instruct the model to "Think step-by-step" before generating the final output.
|
| 1047 |
+
- Goal: Fix logic errors and complex multi-step reasoning failures.
|
| 1048 |
+
|
| 1049 |
+
3. **STRATEGY C: The Few-Shot Teacher (Examples)**
|
| 1050 |
+
- Focus: Generate a *synthetic* example of Input -> Correct Output within the prompt.
|
| 1051 |
+
- Goal: Fix understanding of abstract concepts or strict schema requirements.
|
| 1052 |
+
|
| 1053 |
+
4. **STRATEGY D: The Role-Player (Persona)**
|
| 1054 |
+
- Focus: Change the persona to a hyper-specific expert (e.g., "Senior Data Engineer at Fortune 500" vs "Coder").
|
| 1055 |
+
- Add domain-specific vocabulary and expertise markers.
|
| 1056 |
+
- Goal: Fix domain-specific terminology errors.
|
| 1057 |
+
|
| 1058 |
+
5. **STRATEGY E: The Structure Architect (Format)**
|
| 1059 |
+
- Focus: Add explicit output schema with field-by-field instructions.
|
| 1060 |
+
- Use markdown or XML headers to organize the prompt.
|
| 1061 |
+
- Goal: Fix output structure and field naming errors.
|
| 1062 |
+
</mutation_strategies>
|
| 1063 |
+
|
| 1064 |
+
<output_constraints>
|
| 1065 |
+
1. **Self-Contained**: Each variation must be the FULL prompt text (100-500 words), ready to run.
|
| 1066 |
+
2. **No Meta-Talk**: Do not explain your strategy inside the prompt. Just output the optimized prompt.
|
| 1067 |
+
3. **Preserve Core Task**: Keep the original task/domain - only improve HOW it's described.
|
| 1068 |
+
4. **JSON Output**: Follow the schema below exactly.
|
| 1069 |
+
5. **ENFORCE OUTPUT FORMAT**: The generated prompt MUST instruct the model to output in the EXACT format shown in examples.
|
| 1070 |
+
</output_constraints>
|
| 1071 |
+
|
| 1072 |
+
<critical_output_format_requirement>
|
| 1073 |
+
🚨 THE GENERATED PROMPTS MUST INCLUDE EXPLICIT OUTPUT FORMAT INSTRUCTIONS!
|
| 1074 |
+
Common failure: The model generates explanations/prose instead of the required concise format.
|
| 1075 |
+
|
| 1076 |
+
{format_constraint}
|
| 1077 |
+
|
| 1078 |
+
Your generated prompts MUST include:
|
| 1079 |
+
- Explicit instruction to output ONLY in the required format
|
| 1080 |
+
- "Do NOT explain", "No reasoning", "Output ONLY [format]" constraints
|
| 1081 |
+
- Length constraint to prevent verbose responses
|
| 1082 |
+
</critical_output_format_requirement>
|
| 1083 |
+
|
| 1084 |
+
<response_format>
|
| 1085 |
+
You MUST output ONLY valid JSON. No comments, no explanations, no markdown code blocks.
|
| 1086 |
+
|
| 1087 |
+
Generate exactly {num_gepa} variations in this exact format:
|
| 1088 |
+
|
| 1089 |
+
{{
|
| 1090 |
+
"variations": [
|
| 1091 |
+
{{
|
| 1092 |
+
"index": 1,
|
| 1093 |
+
"strategy": "Strict Auditor",
|
| 1094 |
+
"prompt": "[FULL PROMPT TEXT - Complete, self-contained, ready to use]"
|
| 1095 |
+
}},
|
| 1096 |
+
{{
|
| 1097 |
+
"index": 2,
|
| 1098 |
+
"strategy": "Reasoning Expert",
|
| 1099 |
+
"prompt": "[FULL PROMPT TEXT - Complete, self-contained, ready to use]"
|
| 1100 |
+
}}
|
| 1101 |
+
]
|
| 1102 |
+
}}
|
| 1103 |
+
|
| 1104 |
+
CRITICAL RULES:
|
| 1105 |
+
1. Output ONLY the JSON object - no text before or after
|
| 1106 |
+
2. Do NOT use markdown code blocks (no ```json)
|
| 1107 |
+
3. Do NOT include comments (no // or /* */)
|
| 1108 |
+
4. Ensure all strings are properly escaped
|
| 1109 |
+
5. Generate exactly {num_gepa} variations
|
| 1110 |
+
6. Each variation must have: index (number), strategy (string), prompt (string)
|
| 1111 |
+
</response_format>
|
| 1112 |
+
"""
|
| 1113 |
+
|
| 1114 |
+
# Standard GEPA reflection call
|
| 1115 |
+
call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
|
| 1116 |
+
result = self.base_llm.generate(
|
| 1117 |
+
system_prompt=optimization_system_prompt,
|
| 1118 |
+
user_prompt=cleaned_user_prompt,
|
| 1119 |
+
image_base64=image_base64,
|
| 1120 |
+
**call_kwargs
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
if isinstance(result, dict):
|
| 1124 |
+
response_text = result.get("content", str(result))
|
| 1125 |
+
else:
|
| 1126 |
+
response_text = str(result)
|
| 1127 |
+
|
| 1128 |
+
# Parse JSON variations
|
| 1129 |
+
gepa_variations = self._parse_json_variations(response_text, num_gepa)
|
| 1130 |
+
|
| 1131 |
+
# Add all variations to candidates
|
| 1132 |
+
for idx, variation_prompt in enumerate(gepa_variations, 1):
|
| 1133 |
+
# 🛡️ DEFENSIVE FALLBACK: Extract clean prompt if LLM adds analysis
|
| 1134 |
+
gepa_candidate = self._extract_clean_prompt_from_reflection(variation_prompt)
|
| 1135 |
+
|
| 1136 |
+
# Validate extracted prompt before adding
|
| 1137 |
+
if not self._is_valid_prompt(gepa_candidate):
|
| 1138 |
+
logger.warning(f" ⚠️ Variation {idx} appears invalid, skipping")
|
| 1139 |
+
continue
|
| 1140 |
+
|
| 1141 |
+
# 🔍 DIAGNOSTIC: Log candidate length to help diagnose scoring issues
|
| 1142 |
+
if self._should_log_debug():
|
| 1143 |
+
logger.debug(f" Candidate {idx} length: {len(gepa_candidate)} chars")
|
| 1144 |
+
logger.debug(f" Candidate {idx} preview: {gepa_candidate[:100]}...")
|
| 1145 |
+
|
| 1146 |
+
all_candidates.append({
|
| 1147 |
+
'prompt': gepa_candidate,
|
| 1148 |
+
'source': 'gepa_reflection',
|
| 1149 |
+
'index': idx
|
| 1150 |
+
})
|
| 1151 |
+
|
| 1152 |
+
clean_log.log_gepa_reflection_candidate(idx, gepa_candidate)
|
| 1153 |
+
|
| 1154 |
+
gepa_count = len(all_candidates)
|
| 1155 |
+
logger.info(f"✅ GEPA Reflection: {gepa_count} candidates generated in single optimized call")
|
| 1156 |
+
|
| 1157 |
+
except Exception as e:
|
| 1158 |
+
logger.error(f"❌ Error generating GEPA reflection candidates: {e}")
|
| 1159 |
+
logger.warning(f" Falling back to sequential generation...")
|
| 1160 |
+
import traceback
|
| 1161 |
+
logger.debug(traceback.format_exc())
|
| 1162 |
+
|
| 1163 |
+
# Fallback: Sequential generation (when JSON parsing fails)
|
| 1164 |
+
gepa_count = self._fallback_sequential_gepa_generation(
|
| 1165 |
+
num_gepa, user_prompt, image_base64, kwargs, all_candidates, clean_log
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
if gepa_count > 0:
|
| 1169 |
+
logger.info(f"GEPA Reflection Complete: {gepa_count} candidates")
|
| 1170 |
+
|
| 1171 |
+
# ─────────────────────────────────────────────────────
|
| 1172 |
+
# PHASE 2: LLEGO GENETIC OPERATORS
|
| 1173 |
+
# ─────────────────────────────────────────────────────
|
| 1174 |
+
logger.info("─" * 80)
|
| 1175 |
+
logger.info("PHASE 2: LLEGO GENETIC OPERATORS")
|
| 1176 |
+
logger.info("─" * 80)
|
| 1177 |
+
|
| 1178 |
+
# Extract current prompt from context
|
| 1179 |
+
current_prompt = self.reflection_context.get('current_prompt', '')
|
| 1180 |
+
if not current_prompt:
|
| 1181 |
+
current_prompt = self._extract_prompt_from_feedback(user_prompt)
|
| 1182 |
+
|
| 1183 |
+
if not current_prompt and self.llego.population:
|
| 1184 |
+
current_prompt = self.llego.population[0].prompt
|
| 1185 |
+
logger.info(f" Using population prompt (length: {len(current_prompt)})")
|
| 1186 |
+
|
| 1187 |
+
# Convert GEPA Pareto front to PromptCandidate format (already fetched in Phase 0)
|
| 1188 |
+
pareto_candidates = self.llego._convert_gepa_pareto_to_candidates(gepa_pareto_front)
|
| 1189 |
+
pareto_front = pareto_candidates
|
| 1190 |
+
|
| 1191 |
+
logger.info(f" Pareto front: {len(pareto_front)} candidates with scores")
|
| 1192 |
+
for idx, p in enumerate(pareto_front, 1):
|
| 1193 |
+
notation = p.metadata.get('notation', 'S') if p.metadata else 'S'
|
| 1194 |
+
logger.info(f" {notation}: fitness={p.fitness:.3f}")
|
| 1195 |
+
|
| 1196 |
+
# Create LLM callable for LLEGO genetic operations (crossover/mutation)
|
| 1197 |
+
call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
|
| 1198 |
+
|
| 1199 |
+
# LLEGO genetic prompt with SAFETY LOCKS to prevent task drift
|
| 1200 |
+
# Directed mutations ensure prompts improve without losing core functionality
|
| 1201 |
+
genetic_operator_system_prompt = """<system_role>
|
| 1202 |
+
You are a **Prompt Mutation Engine**. Your input is a [PARENT_PROMPT]. Your output is a [MUTATED_CHILD].
|
| 1203 |
+
</system_role>
|
| 1204 |
+
|
| 1205 |
+
<mutation_directives>
|
| 1206 |
+
Apply ONE of the following micro-mutations to improve the prompt:
|
| 1207 |
+
|
| 1208 |
+
1. **COMPRESS**: Remove fluff words ("please", "ensure to", "kindly"). Make it telegraphic and efficient.
|
| 1209 |
+
2. **INTENSIFY**: Capitalize key constraints (e.g., "must return JSON" -> "**MUST** return **VALID JSON**").
|
| 1210 |
+
3. **STRUCTURIZE**: Add markdown headers or XML tags to organize a messy prompt.
|
| 1211 |
+
4. **CLARIFY**: Expand vague nouns (e.g., "code" -> "production-ready Python code with type hints").
|
| 1212 |
+
5. **CONSTRAIN**: Add negative constraints ("Do NOT include explanations", "NEVER output markdown").
|
| 1213 |
+
</mutation_directives>
|
| 1214 |
+
|
| 1215 |
+
<safety_locks>
|
| 1216 |
+
1. **IMMUTABLE CORE**: You MUST NOT change the core task (e.g., do not change "Extract JSON" to "Write a Summary").
|
| 1217 |
+
2. **NO EXPLANATION**: Output ONLY the new prompt string. No meta-commentary.
|
| 1218 |
+
3. **VALIDITY**: The output must remain a functional system prompt.
|
| 1219 |
+
4. **LENGTH LIMIT**: Keep mutations within 20% of original length (no excessive expansion).
|
| 1220 |
+
</safety_locks>"""
|
| 1221 |
+
|
| 1222 |
+
def llm_callable(genetic_prompt: str) -> str:
|
| 1223 |
+
result = self.base_llm.generate(
|
| 1224 |
+
system_prompt=genetic_operator_system_prompt,
|
| 1225 |
+
user_prompt=genetic_prompt,
|
| 1226 |
+
image_base64="",
|
| 1227 |
+
**call_kwargs
|
| 1228 |
+
)
|
| 1229 |
+
if isinstance(result, dict):
|
| 1230 |
+
return result.get('content', str(result))
|
| 1231 |
+
return str(result)
|
| 1232 |
+
|
| 1233 |
+
# Generate LLEGO offspring (crossover will be skipped if < 2 parents)
|
| 1234 |
+
llego_prompts = self.llego.evolve_generation(
|
| 1235 |
+
llm=llm_callable,
|
| 1236 |
+
pareto_front=pareto_front
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
# Track actual crossover count from LLEGO (it tracks internally now)
|
| 1240 |
+
actual_crossover = getattr(self.llego, '_actual_crossover_count', 0)
|
| 1241 |
+
crossover_skipped = getattr(self.llego, '_crossover_skipped', False)
|
| 1242 |
+
|
| 1243 |
+
crossover_idx = 1
|
| 1244 |
+
mutation_idx = 1
|
| 1245 |
+
|
| 1246 |
+
for i, prompt in enumerate(llego_prompts):
|
| 1247 |
+
if i < actual_crossover:
|
| 1248 |
+
source = 'llego_crossover'
|
| 1249 |
+
clean_log.log_llego_crossover_candidate(crossover_idx, prompt)
|
| 1250 |
+
crossover_idx += 1
|
| 1251 |
+
else:
|
| 1252 |
+
source = 'llego_mutation'
|
| 1253 |
+
clean_log.log_llego_mutation_candidate(mutation_idx, prompt)
|
| 1254 |
+
mutation_idx += 1
|
| 1255 |
+
|
| 1256 |
+
all_candidates.append({
|
| 1257 |
+
'prompt': prompt,
|
| 1258 |
+
'source': source,
|
| 1259 |
+
'index': i + 1
|
| 1260 |
+
})
|
| 1261 |
+
|
| 1262 |
+
mutation_count = len(llego_prompts) - actual_crossover
|
| 1263 |
+
logger.info(f"🧬 LLEGO: {actual_crossover} crossover + {mutation_count} mutation = {len(llego_prompts)} candidates")
|
| 1264 |
+
if crossover_skipped:
|
| 1265 |
+
logger.info(f" (Crossover was skipped - compensated with extra GEPA reflection)")
|
| 1266 |
+
|
| 1267 |
+
# ─────────────────────────────────────────────────────
|
| 1268 |
+
# SUMMARY
|
| 1269 |
+
# ─────────────────────────────────────────────────────
|
| 1270 |
+
total_gepa = len([c for c in all_candidates if c.get('source') == 'gepa_reflection'])
|
| 1271 |
+
total_crossover = len([c for c in all_candidates if c.get('source') == 'llego_crossover'])
|
| 1272 |
+
total_mutation = len([c for c in all_candidates if c.get('source') == 'llego_mutation'])
|
| 1273 |
+
|
| 1274 |
+
logger.info("─" * 80)
|
| 1275 |
+
logger.info("CANDIDATE GENERATION SUMMARY")
|
| 1276 |
+
logger.info("─" * 80)
|
| 1277 |
+
logger.info(f" GEPA Reflection: {total_gepa} candidates (semantic understanding)")
|
| 1278 |
+
logger.info(f" LLEGO Crossover: {total_crossover} candidates (combine best)")
|
| 1279 |
+
logger.info(f" LLEGO Mutation: {total_mutation} candidates (exploration)")
|
| 1280 |
+
logger.info(f" TOTAL: {len(all_candidates)} candidates")
|
| 1281 |
+
if crossover_skipped:
|
| 1282 |
+
logger.info(f" 📝 Note: Crossover skipped (waiting for 2+ scored parents)")
|
| 1283 |
+
logger.info("─" * 80)
|
| 1284 |
+
|
| 1285 |
+
clean_log.log_candidate_generation_summary()
|
| 1286 |
+
|
| 1287 |
+
# Store in queue (skip first one - return it now)
|
| 1288 |
+
self._candidate_queue = all_candidates[1:] if len(all_candidates) > 1 else []
|
| 1289 |
+
self._hybrid_generation_complete = True
|
| 1290 |
+
|
| 1291 |
+
# Return first candidate
|
| 1292 |
+
if all_candidates:
|
| 1293 |
+
first = all_candidates[0]
|
| 1294 |
+
logger.info(f"📤 Returning FIRST candidate (source: {first['source']})")
|
| 1295 |
+
return {'content': first['prompt'], 'source': first['source']}
|
| 1296 |
+
else:
|
| 1297 |
+
logger.error("❌ No candidates generated!")
|
| 1298 |
+
return {'content': '', 'source': 'error'}
|
| 1299 |
+
|
| 1300 |
+
def _llego_only_generate(
|
| 1301 |
+
self,
|
| 1302 |
+
system_prompt: str,
|
| 1303 |
+
user_prompt: str,
|
| 1304 |
+
image_base64: str = "",
|
| 1305 |
+
**kwargs
|
| 1306 |
+
) -> Dict[str, Any]:
|
| 1307 |
+
"""
|
| 1308 |
+
STANDARD LLEGO MODE: Generate candidates using only LLEGO operators.
|
| 1309 |
+
"""
|
| 1310 |
+
# 🔥 CRITICAL: Remove image_base64 from kwargs to avoid duplicate argument error
|
| 1311 |
+
kwargs.pop('image_base64', None)
|
| 1312 |
+
|
| 1313 |
+
# 🔥 FIX: Clean user_prompt if it contains feedback (might have base64)
|
| 1314 |
+
cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
|
| 1315 |
+
|
| 1316 |
+
# Extract current prompt from context or user_prompt
|
| 1317 |
+
current_prompt = self.reflection_context.get('current_prompt', '')
|
| 1318 |
+
|
| 1319 |
+
if not current_prompt:
|
| 1320 |
+
# Try to extract from cleaned user_prompt
|
| 1321 |
+
current_prompt = self._extract_prompt_from_feedback(cleaned_user_prompt)
|
| 1322 |
+
|
| 1323 |
+
logger.info(f"🧬 LLEGO: Evolving prompt...")
|
| 1324 |
+
if self._should_log_debug():
|
| 1325 |
+
logger.debug(f" Current prompt: '{current_prompt[:100]}...' (length: {len(current_prompt)} chars)")
|
| 1326 |
+
else:
|
| 1327 |
+
logger.info(f" Prompt length: {len(current_prompt)} chars")
|
| 1328 |
+
|
| 1329 |
+
# 🔥 FIX 2: Get Pareto front from GEPA (not LLEGO population)
|
| 1330 |
+
# This ensures LLEGO operators use true non-dominated solutions
|
| 1331 |
+
from ..utils.pareto_logger import get_pareto_logger
|
| 1332 |
+
pareto_log = get_pareto_logger()
|
| 1333 |
+
gepa_pareto_front = pareto_log.pareto_front
|
| 1334 |
+
|
| 1335 |
+
# Convert GEPA Pareto front to PromptCandidate format
|
| 1336 |
+
pareto_candidates = self.llego._convert_gepa_pareto_to_candidates(gepa_pareto_front)
|
| 1337 |
+
pareto_front = pareto_candidates
|
| 1338 |
+
|
| 1339 |
+
logger.info(f" Using GEPA Pareto front (size: {len(gepa_pareto_front)})")
|
| 1340 |
+
logger.info(f" Converted to {len(pareto_front)} PromptCandidate objects")
|
| 1341 |
+
|
| 1342 |
+
# Create LLM callable for LLEGO genetic operations
|
| 1343 |
+
# Uses Genetic Mutation Engine prompt for micro-mutations
|
| 1344 |
+
call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
|
| 1345 |
+
|
| 1346 |
+
genetic_system_prompt = """You are a **Genetic Mutation Engine** for Text Prompts.
|
| 1347 |
+
|
| 1348 |
+
<task>
|
| 1349 |
+
Apply a specific micro-mutation to the provided prompt to increase its clarity, strictness, or effectiveness.
|
| 1350 |
+
</task>
|
| 1351 |
+
|
| 1352 |
+
<mutation_types>
|
| 1353 |
+
1. **Compress**: Shorten verbose instructions without losing meaning.
|
| 1354 |
+
2. **Expand**: Add detail to vague nouns (e.g., "code" -> "production-ready Python 3.10 code").
|
| 1355 |
+
3. **Emphasize**: Highlight CRITICAL constraints using caps, bold, or explicit markers.
|
| 1356 |
+
4. **Constrain**: Add explicit boundaries (what NOT to do, format rules, length limits).
|
| 1357 |
+
5. **Exemplify**: Add a brief example if the task is ambiguous.
|
| 1358 |
+
</mutation_types>
|
| 1359 |
+
|
| 1360 |
+
<output_rules>
|
| 1361 |
+
1. Output ONLY the mutated prompt text.
|
| 1362 |
+
2. Do NOT change the core intent or task domain.
|
| 1363 |
+
3. Do NOT add explanations or meta-commentary.
|
| 1364 |
+
4. Apply ONE primary mutation type while preserving all existing strengths.
|
| 1365 |
+
</output_rules>"""
|
| 1366 |
+
|
| 1367 |
+
def llm_callable(prompt: str) -> str:
|
| 1368 |
+
# Clean prompt before sending (might contain base64 if from feedback)
|
| 1369 |
+
cleaned_prompt = self._clean_reflection_feedback(prompt)
|
| 1370 |
+
result = self.base_llm.generate(
|
| 1371 |
+
system_prompt=genetic_system_prompt,
|
| 1372 |
+
user_prompt=cleaned_prompt,
|
| 1373 |
+
image_base64="", # Always empty for LLEGO genetic operations
|
| 1374 |
+
**call_kwargs
|
| 1375 |
+
)
|
| 1376 |
+
if isinstance(result, dict):
|
| 1377 |
+
return result.get('content', str(result))
|
| 1378 |
+
return str(result)
|
| 1379 |
+
|
| 1380 |
+
# Generate offspring using LLEGO
|
| 1381 |
+
new_prompts = self.llego.evolve_generation(
|
| 1382 |
+
llm=llm_callable,
|
| 1383 |
+
pareto_front=pareto_front
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
if new_prompts:
|
| 1387 |
+
new_prompt = new_prompts[0]
|
| 1388 |
+
logger.info(f"✅ LLEGO generated new candidate (length: {len(new_prompt)} chars)")
|
| 1389 |
+
|
| 1390 |
+
if self._should_log_debug():
|
| 1391 |
+
logger.debug(f" Full prompt:")
|
| 1392 |
+
logger.debug(f" '{new_prompt}'")
|
| 1393 |
+
|
| 1394 |
+
return {
|
| 1395 |
+
'content': new_prompt,
|
| 1396 |
+
'source': 'llego',
|
| 1397 |
+
'num_candidates': len(new_prompts)
|
| 1398 |
+
}
|
| 1399 |
+
else:
|
| 1400 |
+
logger.warning("⚠️ LLEGO returned no candidates, falling back to base LLM")
|
| 1401 |
+
return self.base_llm.generate(
|
| 1402 |
+
system_prompt=system_prompt,
|
| 1403 |
+
user_prompt=user_prompt,
|
| 1404 |
+
image_base64="",
|
| 1405 |
+
**kwargs
|
| 1406 |
+
)
|
| 1407 |
+
|
| 1408 |
+
def _build_diversity_requirements(self, num_gepa: int) -> str:
|
| 1409 |
+
"""
|
| 1410 |
+
Build diversity requirements using research-backed Prompt Design Patterns.
|
| 1411 |
+
|
| 1412 |
+
These are proven strategies from prompt engineering literature:
|
| 1413 |
+
- Chain-of-Thought (CoT)
|
| 1414 |
+
- Few-Shot Learning
|
| 1415 |
+
- Negative Constraints
|
| 1416 |
+
- Persona Pattern
|
| 1417 |
+
|
| 1418 |
+
Args:
|
| 1419 |
+
num_gepa: Number of GEPA variations to generate
|
| 1420 |
+
|
| 1421 |
+
Returns:
|
| 1422 |
+
String with diversity requirements for the optimization prompt
|
| 1423 |
+
"""
|
| 1424 |
+
# Research-backed Prompt Design Patterns that solve specific classes of problems
|
| 1425 |
+
strategies = [
|
| 1426 |
+
"""
|
| 1427 |
+
<variation_1>
|
| 1428 |
+
**STRATEGY: COGNITIVE DECOMPOSITION (Chain-of-Thought)**
|
| 1429 |
+
- **Goal**: Fixes logic/reasoning errors.
|
| 1430 |
+
- **Action**: Add a thinking process section that forces step-by-step reasoning.
|
| 1431 |
+
- **Implementation**: Include instructions like "First analyze..., then identify..., finally conclude..."
|
| 1432 |
+
- **Pattern**: Force the model to "Plan before executing".
|
| 1433 |
+
</variation_1>
|
| 1434 |
+
""",
|
| 1435 |
+
|
| 1436 |
+
"""
|
| 1437 |
+
<variation_2>
|
| 1438 |
+
**STRATEGY: FEW-SHOT SIMULATION (In-Context Learning)**
|
| 1439 |
+
- **Goal**: Fixes formatting/syntax errors and output structure issues.
|
| 1440 |
+
- **Action**: Invent 1-2 realistic "Input -> Output" examples that mirror the expected format.
|
| 1441 |
+
- **Implementation**: Add "Example: Given [input], respond with: [expected output format]"
|
| 1442 |
+
- **Pattern**: Show, don't just tell. Demonstrate the gold standard.
|
| 1443 |
+
</variation_2>
|
| 1444 |
+
""",
|
| 1445 |
+
|
| 1446 |
+
"""
|
| 1447 |
+
<variation_3>
|
| 1448 |
+
**STRATEGY: SEMANTIC CONSTRAINING (Negative Constraints)**
|
| 1449 |
+
- **Goal**: Fixes hallucinations, verbosity, and off-topic responses.
|
| 1450 |
+
- **Action**: Add explicit forbidden actions and boundaries.
|
| 1451 |
+
- **Implementation**: Include "Do NOT explain your reasoning", "Do NOT add preambles", "Do NOT include information not asked for"
|
| 1452 |
+
- **Pattern**: Define the walls, not just the path.
|
| 1453 |
+
</variation_3>
|
| 1454 |
+
""",
|
| 1455 |
+
|
| 1456 |
+
"""
|
| 1457 |
+
<variation_4>
|
| 1458 |
+
**STRATEGY: PERSONA & ROLE HARDENING**
|
| 1459 |
+
- **Goal**: Fixes tone, domain knowledge gaps, and inconsistent behavior.
|
| 1460 |
+
- **Action**: Define a hyper-specific expert role with clear responsibilities.
|
| 1461 |
+
- **Implementation**: Instead of "You are a helpful assistant", use "You are a Senior Data Analyst with 10 years of experience in [domain]"
|
| 1462 |
+
- **Pattern**: Adopt the mental model and rigorous standards of a real expert.
|
| 1463 |
+
</variation_4>
|
| 1464 |
+
""",
|
| 1465 |
+
|
| 1466 |
+
"""
|
| 1467 |
+
<variation_5>
|
| 1468 |
+
**STRATEGY: OUTPUT SCHEMA ENFORCEMENT**
|
| 1469 |
+
- **Goal**: Fixes structural and format compliance issues.
|
| 1470 |
+
- **Action**: Define an explicit output schema with field names and types.
|
| 1471 |
+
- **Implementation**: Include "Your response MUST follow this exact format: {field1: type, field2: type}"
|
| 1472 |
+
- **Pattern**: Leave no ambiguity about what the output should look like.
|
| 1473 |
+
</variation_5>
|
| 1474 |
+
""",
|
| 1475 |
+
|
| 1476 |
+
"""
|
| 1477 |
+
<variation_6>
|
| 1478 |
+
**STRATEGY: SELF-VERIFICATION LOOP**
|
| 1479 |
+
- **Goal**: Fixes errors that could be caught by double-checking.
|
| 1480 |
+
- **Action**: Add instructions for the model to verify its own output.
|
| 1481 |
+
- **Implementation**: Include "Before responding, verify: 1) Does this match the required format? 2) Did I include all requested information?"
|
| 1482 |
+
- **Pattern**: Build in quality control before submission.
|
| 1483 |
+
</variation_6>
|
| 1484 |
+
""",
|
| 1485 |
+
|
| 1486 |
+
"""
|
| 1487 |
+
<variation_7>
|
| 1488 |
+
**STRATEGY: TASK DECOMPOSITION**
|
| 1489 |
+
- **Goal**: Fixes complex tasks that overwhelm the model.
|
| 1490 |
+
- **Action**: Break the task into numbered sub-tasks.
|
| 1491 |
+
- **Implementation**: "Step 1: [subtask]. Step 2: [subtask]. Step 3: Combine results."
|
| 1492 |
+
- **Pattern**: Divide and conquer complexity.
|
| 1493 |
+
</variation_7>
|
| 1494 |
+
"""
|
| 1495 |
+
]
|
| 1496 |
+
|
| 1497 |
+
# Select strategies based on num_gepa
|
| 1498 |
+
selected = strategies[:min(num_gepa, len(strategies))]
|
| 1499 |
+
|
| 1500 |
+
requirements = "<required_strategies>\n"
|
| 1501 |
+
requirements += "Each variation MUST use a DIFFERENT strategy from the list below:\n"
|
| 1502 |
+
requirements += "\n".join(selected)
|
| 1503 |
+
requirements += "\n</required_strategies>"
|
| 1504 |
+
|
| 1505 |
+
requirements += """
|
| 1506 |
+
|
| 1507 |
+
<strategy_application_rules>
|
| 1508 |
+
1. Each variation must apply its assigned strategy comprehensively.
|
| 1509 |
+
2. Each variation must ALSO address ALL issues mentioned in the feedback.
|
| 1510 |
+
3. The strategies are not mutually exclusive - but the PRIMARY focus of each variation should be its assigned strategy.
|
| 1511 |
+
4. Do not just add a single line - transform the prompt structure according to the strategy.
|
| 1512 |
+
</strategy_application_rules>
|
| 1513 |
+
"""
|
| 1514 |
+
|
| 1515 |
+
return requirements
|
| 1516 |
+
|
| 1517 |
+
def _fallback_sequential_gepa_generation(
|
| 1518 |
+
self,
|
| 1519 |
+
num_gepa: int,
|
| 1520 |
+
user_prompt: str,
|
| 1521 |
+
image_base64: str,
|
| 1522 |
+
kwargs: dict,
|
| 1523 |
+
all_candidates: list,
|
| 1524 |
+
clean_log
|
| 1525 |
+
) -> int:
|
| 1526 |
+
"""
|
| 1527 |
+
Fallback to sequential generation when JSON parsing fails.
|
| 1528 |
+
|
| 1529 |
+
Args:
|
| 1530 |
+
num_gepa: Number of candidates to generate
|
| 1531 |
+
user_prompt: The feedback/context
|
| 1532 |
+
image_base64: Image data (if any)
|
| 1533 |
+
kwargs: Additional kwargs
|
| 1534 |
+
all_candidates: List to append candidates to
|
| 1535 |
+
clean_log: Logger for clean output
|
| 1536 |
+
|
| 1537 |
+
Returns:
|
| 1538 |
+
Number of candidates generated
|
| 1539 |
+
"""
|
| 1540 |
+
generated_count = 0
|
| 1541 |
+
|
| 1542 |
+
for i in range(num_gepa):
|
| 1543 |
+
logger.debug(f"Generating Reflection Candidate #{i+1}/{num_gepa} (fallback mode)...")
|
| 1544 |
+
try:
|
| 1545 |
+
cleaned_user_prompt = self._clean_reflection_feedback(user_prompt)
|
| 1546 |
+
|
| 1547 |
+
# Use research-backed strategy for each variation
|
| 1548 |
+
strategy_prompts = [
|
| 1549 |
+
"<optimization_rule>\nApply CHAIN-OF-THOUGHT: Add step-by-step reasoning instructions. Force the model to 'think before answering'.\n</optimization_rule>",
|
| 1550 |
+
"<optimization_rule>\nApply FEW-SHOT LEARNING: Add 1-2 concrete input/output examples within the prompt. Show, don't just tell.\n</optimization_rule>",
|
| 1551 |
+
"<optimization_rule>\nApply NEGATIVE CONSTRAINTS: Add explicit 'Do NOT' rules. Define what the model must avoid.\n</optimization_rule>",
|
| 1552 |
+
"<optimization_rule>\nApply PERSONA HARDENING: Define a specific expert role with clear responsibilities and standards.\n</optimization_rule>",
|
| 1553 |
+
"<optimization_rule>\nApply OUTPUT SCHEMA: Define the exact output format with field names and types. Leave no ambiguity.\n</optimization_rule>",
|
| 1554 |
+
]
|
| 1555 |
+
|
| 1556 |
+
strategy = strategy_prompts[i % len(strategy_prompts)]
|
| 1557 |
+
|
| 1558 |
+
fallback_prompt = f"""You are a Prompt Optimization Engine in **SAFE MODE**.
|
| 1559 |
+
|
| 1560 |
+
{strategy}
|
| 1561 |
+
|
| 1562 |
+
{_FALLBACK_SYSTEM_PROMPT}"""
|
| 1563 |
+
|
| 1564 |
+
call_kwargs = {k: v for k, v in kwargs.items() if k != 'image_base64'}
|
| 1565 |
+
result = self.base_llm.generate(
|
| 1566 |
+
system_prompt=fallback_prompt,
|
| 1567 |
+
user_prompt=cleaned_user_prompt,
|
| 1568 |
+
image_base64=image_base64,
|
| 1569 |
+
**call_kwargs
|
| 1570 |
+
)
|
| 1571 |
+
|
| 1572 |
+
if isinstance(result, dict):
|
| 1573 |
+
gepa_candidate_raw = result.get("content", str(result))
|
| 1574 |
+
else:
|
| 1575 |
+
gepa_candidate_raw = str(result)
|
| 1576 |
+
|
| 1577 |
+
gepa_candidate = self._extract_clean_prompt_from_reflection(gepa_candidate_raw)
|
| 1578 |
+
|
| 1579 |
+
if not self._is_valid_prompt(gepa_candidate):
|
| 1580 |
+
logger.warning(f" ⚠️ Fallback candidate #{i+1} appears invalid, skipping")
|
| 1581 |
+
continue
|
| 1582 |
+
|
| 1583 |
+
all_candidates.append({
|
| 1584 |
+
'prompt': gepa_candidate,
|
| 1585 |
+
'source': 'gepa_reflection',
|
| 1586 |
+
'index': i + 1
|
| 1587 |
+
})
|
| 1588 |
+
|
| 1589 |
+
clean_log.log_gepa_reflection_candidate(i + 1, gepa_candidate)
|
| 1590 |
+
generated_count += 1
|
| 1591 |
+
|
| 1592 |
+
except Exception as fallback_error:
|
| 1593 |
+
logger.error(f"❌ Error in fallback generation #{i+1}: {fallback_error}")
|
| 1594 |
+
|
| 1595 |
+
return generated_count
|
| 1596 |
+
|
| 1597 |
+
def _extract_prompt_from_feedback(self, user_prompt: str) -> str:
|
| 1598 |
+
"""
|
| 1599 |
+
Try to extract the current prompt from GEPA's reflection feedback.
|
| 1600 |
+
|
| 1601 |
+
Args:
|
| 1602 |
+
user_prompt: The feedback text from GEPA
|
| 1603 |
+
|
| 1604 |
+
Returns:
|
| 1605 |
+
Extracted prompt or empty string
|
| 1606 |
+
"""
|
| 1607 |
+
# Look for common patterns in GEPA's feedback
|
| 1608 |
+
if "current prompt:" in user_prompt.lower():
|
| 1609 |
+
lines = user_prompt.split('\n')
|
| 1610 |
+
for i, line in enumerate(lines):
|
| 1611 |
+
if "current prompt:" in line.lower():
|
| 1612 |
+
# Return the next line(s) as the prompt
|
| 1613 |
+
return '\n'.join(lines[i+1:i+10])
|
| 1614 |
+
|
| 1615 |
+
return ""
|
| 1616 |
+
|
| 1617 |
+
# Forward other methods to base LLM
|
| 1618 |
+
def get_model_info(self) -> str:
|
| 1619 |
+
"""Get model information."""
|
| 1620 |
+
return f"LLEGO({self.base_llm.get_model_info()})"
|
| 1621 |
+
|
| 1622 |
+
def __getattr__(self, name):
|
| 1623 |
+
"""Forward unknown attributes to base LLM."""
|
| 1624 |
+
return getattr(self.base_llm, name)
|
| 1625 |
+
|
src/gepa_optimizer/llms/vision_llm.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vision LLM Client for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from enum import Enum
|
| 9 |
+
import requests
|
| 10 |
+
from typing import Dict, Optional, Any, TYPE_CHECKING, Union
|
| 11 |
+
|
| 12 |
+
# Assuming APIKeyManager is available from utils
|
| 13 |
+
from ..utils.api_keys import APIKeyManager
|
| 14 |
+
|
| 15 |
+
# Import ModelConfig only for type checking to avoid circular imports
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from ..models.config import ModelConfig
|
| 18 |
+
|
| 19 |
+
from .base_llm import BaseLLMClient
|
| 20 |
+
|
| 21 |
+
class ProviderType(str, Enum):
|
| 22 |
+
OPENAI = "openai"
|
| 23 |
+
ANTHROPIC = "anthropic"
|
| 24 |
+
HUGGINGFACE = "huggingface"
|
| 25 |
+
VLLM = "vllm"
|
| 26 |
+
GOOGLE = "google"
|
| 27 |
+
GEMINI = "gemini"
|
| 28 |
+
|
| 29 |
+
class ErrorType(str, Enum):
|
| 30 |
+
API_ERROR = "api_error"
|
| 31 |
+
VALIDATION_ERROR = "validation_error"
|
| 32 |
+
NETWORK_ERROR = "network_error"
|
| 33 |
+
RATE_LIMIT = "rate_limit"
|
| 34 |
+
TIMEOUT = "timeout"
|
| 35 |
+
|
| 36 |
+
class GepaLLMError(Exception):
|
| 37 |
+
"""Base exception for GEPA LLM related errors"""
|
| 38 |
+
def __init__(self, message: str, error_type: ErrorType, status_code: Optional[int] = None):
|
| 39 |
+
self.message = message
|
| 40 |
+
self.error_type = error_type
|
| 41 |
+
self.status_code = status_code
|
| 42 |
+
super().__init__(self.message)
|
| 43 |
+
|
| 44 |
+
def __str__(self):
|
| 45 |
+
if self.status_code:
|
| 46 |
+
return f"{self.error_type.value} (HTTP {self.status_code}): {self.message}"
|
| 47 |
+
return f"{self.error_type.value}: {self.message}"
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
|
| 51 |
+
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
|
| 52 |
+
|
| 53 |
+
class VisionLLMClient(BaseLLMClient):
|
| 54 |
+
"""
|
| 55 |
+
A client for interacting with multi-modal Vision LLMs (e.g., OpenAI GPT-4 Vision).
|
| 56 |
+
|
| 57 |
+
Example:
|
| 58 |
+
```python
|
| 59 |
+
# Basic usage
|
| 60 |
+
client = VisionLLMClient(
|
| 61 |
+
provider="openai",
|
| 62 |
+
model_name="gpt-4-vision-preview",
|
| 63 |
+
temperature=0.7,
|
| 64 |
+
max_tokens=2048
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# With custom configuration
|
| 68 |
+
config = ModelConfig(
|
| 69 |
+
provider="openai",
|
| 70 |
+
model_name="gpt-4-vision-preview",
|
| 71 |
+
temperature=0.5,
|
| 72 |
+
max_tokens=1024
|
| 73 |
+
)
|
| 74 |
+
client = VisionLLMClient.from_config(config)
|
| 75 |
+
```
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
provider: Union[str, ProviderType],
|
| 81 |
+
model_name: str,
|
| 82 |
+
api_key: Optional[str] = None,
|
| 83 |
+
base_url: Optional[str] = None,
|
| 84 |
+
temperature: float = 0.7,
|
| 85 |
+
max_tokens: int = 2048,
|
| 86 |
+
top_p: float = 1.0,
|
| 87 |
+
frequency_penalty: float = 0.0,
|
| 88 |
+
presence_penalty: float = 0.0,
|
| 89 |
+
timeout: int = 120, # Increase to 2 minutes for large prompts
|
| 90 |
+
max_retries: int = 3
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Initializes the VisionLLMClient with model configuration.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
provider: The provider of the model (e.g., 'openai', 'anthropic')
|
| 97 |
+
model_name: The name of the multi-modal LLM model to use (e.g., "gpt-4-vision-preview").
|
| 98 |
+
api_key: Optional API key. If not provided, it will be fetched from APIKeyManager.
|
| 99 |
+
base_url: Optional base URL for the API endpoint.
|
| 100 |
+
temperature: Controls randomness in the response generation.
|
| 101 |
+
max_tokens: Maximum number of tokens to generate.
|
| 102 |
+
top_p: Controls diversity via nucleus sampling.
|
| 103 |
+
frequency_penalty: Penalizes repeated tokens.
|
| 104 |
+
presence_penalty: Penalizes new tokens based on their presence in the text so far.
|
| 105 |
+
"""
|
| 106 |
+
# Initialize parent class
|
| 107 |
+
super().__init__(provider=str(provider), model_name=model_name, **{
|
| 108 |
+
'api_key': api_key,
|
| 109 |
+
'base_url': base_url,
|
| 110 |
+
'temperature': temperature,
|
| 111 |
+
'max_tokens': max_tokens,
|
| 112 |
+
'top_p': top_p,
|
| 113 |
+
'frequency_penalty': frequency_penalty,
|
| 114 |
+
'presence_penalty': presence_penalty,
|
| 115 |
+
'timeout': timeout,
|
| 116 |
+
'max_retries': max_retries
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
# Initialize the actual client
|
| 120 |
+
self._initialize_client(provider, model_name, api_key, base_url, temperature,
|
| 121 |
+
max_tokens, top_p, frequency_penalty, presence_penalty,
|
| 122 |
+
timeout, max_retries)
|
| 123 |
+
|
| 124 |
+
def _initialize_client(self, provider, model_name, api_key, base_url, temperature,
|
| 125 |
+
max_tokens, top_p, frequency_penalty, presence_penalty,
|
| 126 |
+
timeout, max_retries):
|
| 127 |
+
"""Initialize the actual client (existing logic)"""
|
| 128 |
+
# Input validation
|
| 129 |
+
try:
|
| 130 |
+
self.provider = ProviderType(provider.lower())
|
| 131 |
+
except ValueError:
|
| 132 |
+
raise GepaLLMError(
|
| 133 |
+
f"Unsupported provider: {provider}. "
|
| 134 |
+
f"Supported providers: {[p.value for p in ProviderType]}",
|
| 135 |
+
ErrorType.VALIDATION_ERROR
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if not model_name:
|
| 139 |
+
raise GepaLLMError("model_name cannot be empty", ErrorType.VALIDATION_ERROR)
|
| 140 |
+
|
| 141 |
+
if not isinstance(temperature, (int, float)) or not 0 <= temperature <= 2:
|
| 142 |
+
raise GepaLLMError(
|
| 143 |
+
f"temperature must be between 0 and 2, got {temperature}",
|
| 144 |
+
ErrorType.VALIDATION_ERROR
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if not isinstance(max_tokens, int) or max_tokens <= 0:
|
| 148 |
+
raise GepaLLMError(
|
| 149 |
+
f"max_tokens must be a positive integer, got {max_tokens}",
|
| 150 |
+
ErrorType.VALIDATION_ERROR
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Initialize API key
|
| 154 |
+
try:
|
| 155 |
+
self.api_key = api_key or APIKeyManager().get_api_key(self.provider.value)
|
| 156 |
+
if not self.api_key:
|
| 157 |
+
raise GepaLLMError(
|
| 158 |
+
f"No API key found for provider: {self.provider}",
|
| 159 |
+
ErrorType.VALIDATION_ERROR
|
| 160 |
+
)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
raise GepaLLMError(
|
| 163 |
+
f"Failed to initialize API key: {str(e)}",
|
| 164 |
+
ErrorType.API_ERROR
|
| 165 |
+
) from e
|
| 166 |
+
|
| 167 |
+
self.model_name = model_name
|
| 168 |
+
self.base_url = base_url or OPENAI_API_URL
|
| 169 |
+
self.temperature = temperature
|
| 170 |
+
self.max_tokens = max_tokens
|
| 171 |
+
self.top_p = top_p
|
| 172 |
+
self.frequency_penalty = frequency_penalty
|
| 173 |
+
self.presence_penalty = presence_penalty
|
| 174 |
+
self.timeout = timeout
|
| 175 |
+
self.max_retries = max_retries
|
| 176 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 177 |
+
|
| 178 |
+
# Configure session with retry
|
| 179 |
+
self.session = requests.Session()
|
| 180 |
+
retry_strategy = requests.adapters.Retry(
|
| 181 |
+
total=max_retries,
|
| 182 |
+
backoff_factor=1,
|
| 183 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
| 184 |
+
allowed_methods=["POST"]
|
| 185 |
+
)
|
| 186 |
+
adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy)
|
| 187 |
+
self.session.mount("https://", adapter)
|
| 188 |
+
self.session.mount("http://", adapter)
|
| 189 |
+
|
| 190 |
+
# No hardcoded model restrictions - user can specify any model name
|
| 191 |
+
# The API provider will validate if the model exists and supports vision
|
| 192 |
+
|
| 193 |
+
def _get_api_key(self) -> Optional[str]:
|
| 194 |
+
"""Get API key based on provider"""
|
| 195 |
+
if self.provider == 'openai':
|
| 196 |
+
return APIKeyManager().get_api_key('openai')
|
| 197 |
+
elif self.provider == 'anthropic':
|
| 198 |
+
return APIKeyManager().get_api_key('anthropic')
|
| 199 |
+
elif self.provider in ['google', 'gemini']:
|
| 200 |
+
return APIKeyManager().get_api_key('google')
|
| 201 |
+
# Add other providers as needed
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def from_config(cls, config: 'ModelConfig') -> 'VisionLLMClient':
|
| 206 |
+
"""Create a VisionLLMClient from a ModelConfig object.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
config: ModelConfig instance with provider and model settings
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Configured VisionLLMClient instance
|
| 213 |
+
|
| 214 |
+
Example:
|
| 215 |
+
```python
|
| 216 |
+
config = ModelConfig(
|
| 217 |
+
provider="openai",
|
| 218 |
+
model_name="gpt-4-vision-preview",
|
| 219 |
+
temperature=0.7
|
| 220 |
+
)
|
| 221 |
+
client = VisionLLMClient.from_config(config)
|
| 222 |
+
```
|
| 223 |
+
"""
|
| 224 |
+
return cls(
|
| 225 |
+
provider=config.provider,
|
| 226 |
+
model_name=config.model_name,
|
| 227 |
+
api_key=config.api_key,
|
| 228 |
+
base_url=config.base_url,
|
| 229 |
+
temperature=config.temperature,
|
| 230 |
+
max_tokens=config.max_tokens,
|
| 231 |
+
top_p=config.top_p,
|
| 232 |
+
frequency_penalty=config.frequency_penalty,
|
| 233 |
+
presence_penalty=config.presence_penalty
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
@classmethod
|
| 237 |
+
def from_model_string(cls, model_string: str, **kwargs) -> 'VisionLLMClient':
|
| 238 |
+
"""Create a VisionLLMClient from a model string like "provider/model-name".
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
model_string: Model identifier in format "provider/model-name" or just "model-name"
|
| 242 |
+
Examples: "google/gemini-2.0-flash", "openai/gpt-4o", "gemini-1.5-pro"
|
| 243 |
+
**kwargs: Additional configuration options (temperature, max_tokens, etc.)
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Configured VisionLLMClient instance
|
| 247 |
+
|
| 248 |
+
Example:
|
| 249 |
+
```python
|
| 250 |
+
# With provider
|
| 251 |
+
client = VisionLLMClient.from_model_string("google/gemini-2.0-flash")
|
| 252 |
+
|
| 253 |
+
# Without provider (defaults to openai)
|
| 254 |
+
client = VisionLLMClient.from_model_string("gpt-4o")
|
| 255 |
+
|
| 256 |
+
# With additional options
|
| 257 |
+
client = VisionLLMClient.from_model_string(
|
| 258 |
+
"google/gemini-2.0-flash",
|
| 259 |
+
temperature=0.5,
|
| 260 |
+
max_tokens=4096
|
| 261 |
+
)
|
| 262 |
+
```
|
| 263 |
+
"""
|
| 264 |
+
import os
|
| 265 |
+
|
| 266 |
+
# Parse "provider/model-name" format
|
| 267 |
+
if "/" in model_string:
|
| 268 |
+
provider, model_name = model_string.split("/", 1)
|
| 269 |
+
else:
|
| 270 |
+
# Default to openai if no provider specified
|
| 271 |
+
provider = "openai"
|
| 272 |
+
model_name = model_string
|
| 273 |
+
|
| 274 |
+
# Normalize provider names
|
| 275 |
+
provider = provider.lower()
|
| 276 |
+
if provider == "gemini":
|
| 277 |
+
provider = "google"
|
| 278 |
+
|
| 279 |
+
# Get API key from environment if not provided
|
| 280 |
+
api_key = kwargs.pop('api_key', None)
|
| 281 |
+
if not api_key:
|
| 282 |
+
env_var_map = {
|
| 283 |
+
"openai": "OPENAI_API_KEY",
|
| 284 |
+
"anthropic": "ANTHROPIC_API_KEY",
|
| 285 |
+
"google": "GOOGLE_API_KEY",
|
| 286 |
+
}
|
| 287 |
+
env_var = env_var_map.get(provider, f"{provider.upper()}_API_KEY")
|
| 288 |
+
api_key = os.getenv(env_var)
|
| 289 |
+
|
| 290 |
+
return cls(
|
| 291 |
+
provider=provider,
|
| 292 |
+
model_name=model_name,
|
| 293 |
+
api_key=api_key,
|
| 294 |
+
**kwargs
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def generate(
|
| 298 |
+
self,
|
| 299 |
+
system_prompt: str,
|
| 300 |
+
user_prompt: str,
|
| 301 |
+
image_base64: Optional[str] = None,
|
| 302 |
+
**generation_kwargs
|
| 303 |
+
) -> Dict[str, Any]:
|
| 304 |
+
"""
|
| 305 |
+
Generates a response from the Vision LLM.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
system_prompt: The system-level instructions for the LLM.
|
| 309 |
+
user_prompt: The user's query or task.
|
| 310 |
+
image_base64: Optional Base64 encoded image string.
|
| 311 |
+
**generation_kwargs: Additional model-specific generation parameters
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
A dictionary containing the generated response and metadata.
|
| 315 |
+
|
| 316 |
+
Raises:
|
| 317 |
+
GepaLLMError: If there's an error during generation
|
| 318 |
+
|
| 319 |
+
Example:
|
| 320 |
+
```python
|
| 321 |
+
response = client.generate(
|
| 322 |
+
system_prompt="You are a helpful assistant.",
|
| 323 |
+
user_prompt="What's in this image?",
|
| 324 |
+
image_base64="base64_encoded_image"
|
| 325 |
+
)
|
| 326 |
+
```
|
| 327 |
+
"""
|
| 328 |
+
if not system_prompt or not user_prompt:
|
| 329 |
+
raise GepaLLMError(
|
| 330 |
+
"system_prompt and user_prompt are required",
|
| 331 |
+
ErrorType.VALIDATION_ERROR
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
if self.provider == ProviderType.OPENAI:
|
| 336 |
+
return self._generate_openai(system_prompt, user_prompt, image_base64, **generation_kwargs)
|
| 337 |
+
elif self.provider in [ProviderType.GOOGLE, ProviderType.GEMINI]:
|
| 338 |
+
return self._generate_google(system_prompt, user_prompt, image_base64, **generation_kwargs)
|
| 339 |
+
else:
|
| 340 |
+
raise GepaLLMError(
|
| 341 |
+
f"Provider {self.provider} is not yet supported",
|
| 342 |
+
ErrorType.VALIDATION_ERROR
|
| 343 |
+
)
|
| 344 |
+
except requests.exceptions.RequestException as e:
|
| 345 |
+
self.logger.error(f"Network error during generation: {str(e)}")
|
| 346 |
+
raise GepaLLMError(
|
| 347 |
+
f"Network error: {str(e)}",
|
| 348 |
+
ErrorType.NETWORK_ERROR,
|
| 349 |
+
getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
| 350 |
+
) from e
|
| 351 |
+
except GepaLLMError:
|
| 352 |
+
raise
|
| 353 |
+
except Exception as e:
|
| 354 |
+
self.logger.error(f"Unexpected error during generation: {str(e)}")
|
| 355 |
+
raise GepaLLMError(
|
| 356 |
+
f"Generation failed: {str(e)}",
|
| 357 |
+
ErrorType.API_ERROR
|
| 358 |
+
) from e
|
| 359 |
+
|
| 360 |
+
def _generate_openai(
|
| 361 |
+
self,
|
| 362 |
+
system_prompt: str,
|
| 363 |
+
user_prompt: str,
|
| 364 |
+
image_base64: Optional[str] = None,
|
| 365 |
+
**generation_kwargs
|
| 366 |
+
) -> Dict[str, Any]:
|
| 367 |
+
"""
|
| 368 |
+
Generate response using OpenAI's API with configured parameters.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
system_prompt: System instructions for the model
|
| 372 |
+
user_prompt: User's input prompt
|
| 373 |
+
image_base64: Optional base64 encoded image
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Dictionary containing the API response
|
| 377 |
+
|
| 378 |
+
Raises:
|
| 379 |
+
GepaDependencyError: If API call fails
|
| 380 |
+
"""
|
| 381 |
+
headers = {
|
| 382 |
+
"Content-Type": "application/json",
|
| 383 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 384 |
+
"User-Agent": "GepaOptimizer/1.0 (Python)"
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
messages = [
|
| 388 |
+
{"role": "system", "content": system_prompt},
|
| 389 |
+
{
|
| 390 |
+
"role": "user",
|
| 391 |
+
"content": [
|
| 392 |
+
{"type": "text", "text": user_prompt}
|
| 393 |
+
]
|
| 394 |
+
}
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
if image_base64:
|
| 398 |
+
# #region agent log
|
| 399 |
+
import json as _json_debug
|
| 400 |
+
import time as _time_debug
|
| 401 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 402 |
+
try:
|
| 403 |
+
with open(_debug_log_path, "a") as _f:
|
| 404 |
+
_f.write(_json_debug.dumps({
|
| 405 |
+
"id": f"log_{int(_time_debug.time() * 1000)}",
|
| 406 |
+
"timestamp": int(_time_debug.time() * 1000),
|
| 407 |
+
"location": "vision_llm.py:_generate_openai",
|
| 408 |
+
"message": "Image base64 BEFORE processing",
|
| 409 |
+
"data": {
|
| 410 |
+
"image_base64_length": len(image_base64) if image_base64 else 0,
|
| 411 |
+
"has_data_uri_prefix": image_base64.startswith("data:image") if image_base64 else False,
|
| 412 |
+
"prefix": image_base64[:50] if image_base64 and len(image_base64) > 50 else image_base64,
|
| 413 |
+
"is_none": image_base64 is None,
|
| 414 |
+
"is_empty": image_base64 == "" if image_base64 else True
|
| 415 |
+
},
|
| 416 |
+
"sessionId": "debug-session",
|
| 417 |
+
"runId": "run1",
|
| 418 |
+
"hypothesisId": "A,C,D"
|
| 419 |
+
}) + "\n")
|
| 420 |
+
except Exception:
|
| 421 |
+
pass
|
| 422 |
+
# #endregion
|
| 423 |
+
|
| 424 |
+
# Detect and extract image format
|
| 425 |
+
detected_format = "jpeg" # Default fallback
|
| 426 |
+
clean_base64 = image_base64
|
| 427 |
+
|
| 428 |
+
# Extract format from data URI prefix if present
|
| 429 |
+
if image_base64.startswith("data:image"):
|
| 430 |
+
# Parse format from prefix: data:image/png;base64,...
|
| 431 |
+
if "," in image_base64:
|
| 432 |
+
prefix_part = image_base64.split(",", 1)[0]
|
| 433 |
+
clean_base64 = image_base64.split(",", 1)[1]
|
| 434 |
+
# Extract format from "data:image/PNG;base64" or "data:image/png"
|
| 435 |
+
if "/" in prefix_part and ";" in prefix_part:
|
| 436 |
+
detected_format = prefix_part.split("/")[1].split(";")[0].lower()
|
| 437 |
+
elif "/" in prefix_part:
|
| 438 |
+
detected_format = prefix_part.split("/")[1].lower()
|
| 439 |
+
else:
|
| 440 |
+
# Fallback: try to extract format
|
| 441 |
+
if "/" in image_base64:
|
| 442 |
+
detected_format = image_base64.split("/")[1].split(";")[0].lower() if ";" in image_base64 else "jpeg"
|
| 443 |
+
clean_base64 = image_base64.replace("data:image/", "").replace(";base64", "")
|
| 444 |
+
|
| 445 |
+
# If no format detected from prefix, try to detect from image data
|
| 446 |
+
if detected_format == "jpeg" or not detected_format:
|
| 447 |
+
try:
|
| 448 |
+
import base64 as b64
|
| 449 |
+
from PIL import Image
|
| 450 |
+
import io
|
| 451 |
+
image_data = b64.b64decode(clean_base64)
|
| 452 |
+
img = Image.open(io.BytesIO(image_data))
|
| 453 |
+
if img.format:
|
| 454 |
+
detected_format = img.format.lower()
|
| 455 |
+
# Normalize format names
|
| 456 |
+
if detected_format in ["jpg", "jpeg"]:
|
| 457 |
+
detected_format = "jpeg"
|
| 458 |
+
except Exception:
|
| 459 |
+
# If detection fails, keep default
|
| 460 |
+
pass
|
| 461 |
+
|
| 462 |
+
# Normalize format for data URI (OpenAI accepts: jpeg, png, gif, webp)
|
| 463 |
+
format_map = {
|
| 464 |
+
"jpg": "jpeg",
|
| 465 |
+
"jpeg": "jpeg",
|
| 466 |
+
"png": "png",
|
| 467 |
+
"gif": "gif",
|
| 468 |
+
"webp": "webp",
|
| 469 |
+
"bmp": "png", # Convert BMP to PNG (OpenAI doesn't support BMP)
|
| 470 |
+
"tiff": "png", # Convert TIFF to PNG
|
| 471 |
+
"tif": "png"
|
| 472 |
+
}
|
| 473 |
+
final_format = format_map.get(detected_format, "jpeg")
|
| 474 |
+
|
| 475 |
+
final_url = f"data:image/{final_format};base64,{clean_base64}"
|
| 476 |
+
|
| 477 |
+
# #region agent log
|
| 478 |
+
try:
|
| 479 |
+
with open(_debug_log_path, "a") as _f:
|
| 480 |
+
_f.write(_json_debug.dumps({
|
| 481 |
+
"id": f"log_{int(_time_debug.time() * 1000)}",
|
| 482 |
+
"timestamp": int(_time_debug.time() * 1000),
|
| 483 |
+
"location": "vision_llm.py:_generate_openai",
|
| 484 |
+
"message": "Image URL AFTER processing",
|
| 485 |
+
"data": {
|
| 486 |
+
"detected_format": detected_format,
|
| 487 |
+
"final_format": final_format,
|
| 488 |
+
"clean_base64_length": len(clean_base64),
|
| 489 |
+
"final_url_length": len(final_url),
|
| 490 |
+
"final_url_prefix": final_url[:60]
|
| 491 |
+
},
|
| 492 |
+
"sessionId": "debug-session",
|
| 493 |
+
"runId": "run1",
|
| 494 |
+
"hypothesisId": "A,B"
|
| 495 |
+
}) + "\n")
|
| 496 |
+
except Exception:
|
| 497 |
+
pass
|
| 498 |
+
# #endregion
|
| 499 |
+
|
| 500 |
+
messages[1]["content"].append({
|
| 501 |
+
"type": "image_url",
|
| 502 |
+
"image_url": {
|
| 503 |
+
"url": final_url
|
| 504 |
+
}
|
| 505 |
+
})
|
| 506 |
+
|
| 507 |
+
payload = {
|
| 508 |
+
"model": self.model_name,
|
| 509 |
+
"messages": messages,
|
| 510 |
+
# "temperature": self.temperature,
|
| 511 |
+
# "max_tokens": self.max_tokens,
|
| 512 |
+
"top_p": self.top_p,
|
| 513 |
+
"frequency_penalty": self.frequency_penalty,
|
| 514 |
+
"presence_penalty": self.presence_penalty
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
self.logger.debug(f"Sending request to {self.base_url} with model {self.model_name}")
|
| 518 |
+
|
| 519 |
+
try:
|
| 520 |
+
self.logger.debug(f"Sending request to {self.model_name}")
|
| 521 |
+
|
| 522 |
+
# Make the API request with retry
|
| 523 |
+
response = self.session.post(
|
| 524 |
+
self.base_url,
|
| 525 |
+
headers=headers,
|
| 526 |
+
json=payload,
|
| 527 |
+
timeout=300
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# Handle rate limiting
|
| 531 |
+
if response.status_code == 429:
|
| 532 |
+
retry_after = int(response.headers.get('Retry-After', 5))
|
| 533 |
+
self.logger.warning(f"Rate limited. Retrying after {retry_after} seconds...")
|
| 534 |
+
time.sleep(retry_after)
|
| 535 |
+
return self._generate_openai(system_prompt, user_prompt, image_base64, **generation_kwargs)
|
| 536 |
+
|
| 537 |
+
response.raise_for_status()
|
| 538 |
+
|
| 539 |
+
result = response.json()
|
| 540 |
+
self.logger.debug(f"Received response from {self.model_name}")
|
| 541 |
+
|
| 542 |
+
# Extract and validate the response
|
| 543 |
+
try:
|
| 544 |
+
message = result["choices"][0]["message"]
|
| 545 |
+
llm_response_content = message["content"]
|
| 546 |
+
|
| 547 |
+
# Log token usage if available
|
| 548 |
+
if "usage" in result:
|
| 549 |
+
usage = result["usage"]
|
| 550 |
+
self.logger.info(
|
| 551 |
+
f"Tokens used - Prompt: {usage.get('prompt_tokens', 'N/A')}, "
|
| 552 |
+
f"Completion: {usage.get('completion_tokens', 'N/A')}, "
|
| 553 |
+
f"Total: {usage.get('total_tokens', 'N/A')}"
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Try to parse JSON if the response looks like JSON
|
| 557 |
+
if isinstance(llm_response_content, str) and (
|
| 558 |
+
llm_response_content.startswith('{') or
|
| 559 |
+
llm_response_content.startswith('[')
|
| 560 |
+
):
|
| 561 |
+
try:
|
| 562 |
+
return json.loads(llm_response_content)
|
| 563 |
+
except json.JSONDecodeError:
|
| 564 |
+
pass
|
| 565 |
+
|
| 566 |
+
# Default response format
|
| 567 |
+
return {
|
| 568 |
+
"content": llm_response_content,
|
| 569 |
+
"role": message.get("role", "assistant"),
|
| 570 |
+
"model": self.model_name,
|
| 571 |
+
"provider": self.provider.value
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
except (KeyError, IndexError) as e:
|
| 575 |
+
self.logger.error(f"Unexpected response format: {result}")
|
| 576 |
+
raise GepaLLMError(
|
| 577 |
+
f"Unexpected response format from {self.provider} API",
|
| 578 |
+
ErrorType.API_ERROR,
|
| 579 |
+
response.status_code
|
| 580 |
+
) from e
|
| 581 |
+
|
| 582 |
+
except requests.exceptions.HTTPError as e:
|
| 583 |
+
status_code = e.response.status_code if hasattr(e, 'response') else None
|
| 584 |
+
error_msg = f"HTTP error {status_code} from {self.provider} API"
|
| 585 |
+
|
| 586 |
+
try:
|
| 587 |
+
error_data = e.response.json()
|
| 588 |
+
error_msg = error_data.get('error', {}).get('message', error_msg)
|
| 589 |
+
except Exception:
|
| 590 |
+
error_msg = str(e)
|
| 591 |
+
|
| 592 |
+
self.logger.error(f"{error_msg}: {error_data if 'error_data' in locals() else str(e)}")
|
| 593 |
+
raise GepaLLMError(
|
| 594 |
+
error_msg,
|
| 595 |
+
ErrorType.RATE_LIMIT if status_code == 429 else ErrorType.API_ERROR,
|
| 596 |
+
status_code
|
| 597 |
+
) from e
|
| 598 |
+
|
| 599 |
+
except requests.exceptions.Timeout:
|
| 600 |
+
self.logger.error(f"Request to {self.provider} API timed out after {self.timeout} seconds")
|
| 601 |
+
raise GepaLLMError(
|
| 602 |
+
f"Request timed out after {self.timeout} seconds",
|
| 603 |
+
ErrorType.TIMEOUT
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
except requests.exceptions.RequestException as e:
|
| 607 |
+
self.logger.error(f"Network error: {str(e)}")
|
| 608 |
+
raise GepaLLMError(
|
| 609 |
+
f"Network error: {str(e)}",
|
| 610 |
+
ErrorType.NETWORK_ERROR
|
| 611 |
+
) from e
|
| 612 |
+
|
| 613 |
+
except Exception as e:
|
| 614 |
+
self.logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
| 615 |
+
raise GepaLLMError(
|
| 616 |
+
f"Unexpected error: {str(e)}",
|
| 617 |
+
ErrorType.API_ERROR
|
| 618 |
+
) from e
|
| 619 |
+
|
| 620 |
+
def _generate_google(
|
| 621 |
+
self,
|
| 622 |
+
system_prompt: str,
|
| 623 |
+
user_prompt: str,
|
| 624 |
+
image_base64: Optional[str] = None,
|
| 625 |
+
**generation_kwargs
|
| 626 |
+
) -> Dict[str, Any]:
|
| 627 |
+
"""
|
| 628 |
+
Generate response using Google Gemini API with configured parameters.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
system_prompt: System instructions for the model
|
| 632 |
+
user_prompt: User's input prompt
|
| 633 |
+
image_base64: Optional base64 encoded image
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
Dictionary containing the API response
|
| 637 |
+
|
| 638 |
+
Raises:
|
| 639 |
+
GepaLLMError: If API call fails
|
| 640 |
+
"""
|
| 641 |
+
try:
|
| 642 |
+
import google.generativeai as genai
|
| 643 |
+
import base64
|
| 644 |
+
from PIL import Image
|
| 645 |
+
import io
|
| 646 |
+
except ImportError as e:
|
| 647 |
+
raise GepaLLMError(
|
| 648 |
+
f"Required dependencies for Google Gemini not installed: {str(e)}. "
|
| 649 |
+
f"Please install: pip install google-generativeai Pillow",
|
| 650 |
+
ErrorType.VALIDATION_ERROR
|
| 651 |
+
) from e
|
| 652 |
+
|
| 653 |
+
# Configure Gemini
|
| 654 |
+
genai.configure(api_key=self.api_key)
|
| 655 |
+
|
| 656 |
+
# Use the model name directly as specified by the user
|
| 657 |
+
# No hardcoded mappings or restrictions - fully configurable
|
| 658 |
+
# The Gemini API will validate if the model exists
|
| 659 |
+
gemini_model_name = self.model_name
|
| 660 |
+
|
| 661 |
+
try:
|
| 662 |
+
model = genai.GenerativeModel(gemini_model_name)
|
| 663 |
+
except Exception as e:
|
| 664 |
+
raise GepaLLMError(
|
| 665 |
+
f"Failed to initialize Gemini model {gemini_model_name}: {str(e)}",
|
| 666 |
+
ErrorType.API_ERROR
|
| 667 |
+
) from e
|
| 668 |
+
|
| 669 |
+
# Prepare content
|
| 670 |
+
content_parts = []
|
| 671 |
+
|
| 672 |
+
# Add system prompt and user prompt
|
| 673 |
+
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
| 674 |
+
content_parts.append(full_prompt)
|
| 675 |
+
|
| 676 |
+
# Add image if provided
|
| 677 |
+
if image_base64:
|
| 678 |
+
# #region agent log
|
| 679 |
+
import json as _json_debug
|
| 680 |
+
import time as _time_debug
|
| 681 |
+
_debug_log_path = "/Users/suhas/Desktop/Projects/Prompt-Optimizer/.cursor/debug.log"
|
| 682 |
+
try:
|
| 683 |
+
with open(_debug_log_path, "a") as _f:
|
| 684 |
+
_f.write(_json_debug.dumps({
|
| 685 |
+
"id": f"log_{int(_time_debug.time() * 1000)}",
|
| 686 |
+
"timestamp": int(_time_debug.time() * 1000),
|
| 687 |
+
"location": "vision_llm.py:_generate_google",
|
| 688 |
+
"message": "Image base64 BEFORE processing (Google)",
|
| 689 |
+
"data": {
|
| 690 |
+
"image_base64_length": len(image_base64) if image_base64 else 0,
|
| 691 |
+
"has_data_uri_prefix": image_base64.startswith("data:image") if image_base64 else False,
|
| 692 |
+
"prefix": image_base64[:50] if image_base64 and len(image_base64) > 50 else image_base64,
|
| 693 |
+
"is_none": image_base64 is None,
|
| 694 |
+
"is_empty": image_base64 == "" if image_base64 else True
|
| 695 |
+
},
|
| 696 |
+
"sessionId": "debug-session",
|
| 697 |
+
"runId": "run1",
|
| 698 |
+
"hypothesisId": "A,C,D"
|
| 699 |
+
}) + "\n")
|
| 700 |
+
except Exception:
|
| 701 |
+
pass
|
| 702 |
+
# #endregion
|
| 703 |
+
|
| 704 |
+
try:
|
| 705 |
+
# Strip data URI prefix if present (hypothesis A fix)
|
| 706 |
+
clean_base64 = image_base64
|
| 707 |
+
if image_base64.startswith("data:image"):
|
| 708 |
+
# Extract just the base64 part after the comma
|
| 709 |
+
if "," in image_base64:
|
| 710 |
+
clean_base64 = image_base64.split(",", 1)[1]
|
| 711 |
+
else:
|
| 712 |
+
clean_base64 = image_base64.replace("data:image/", "").replace(";base64", "")
|
| 713 |
+
|
| 714 |
+
# Decode base64 image
|
| 715 |
+
image_data = base64.b64decode(clean_base64)
|
| 716 |
+
image = Image.open(io.BytesIO(image_data))
|
| 717 |
+
content_parts.append(image)
|
| 718 |
+
self.logger.debug(f"Added image to Gemini request")
|
| 719 |
+
except Exception as e:
|
| 720 |
+
self.logger.warning(f"Failed to process image for Gemini: {str(e)}")
|
| 721 |
+
# Continue without image rather than failing
|
| 722 |
+
|
| 723 |
+
self.logger.debug(f"Sending request to Gemini model {gemini_model_name}")
|
| 724 |
+
|
| 725 |
+
try:
|
| 726 |
+
# Generate response with retry logic
|
| 727 |
+
max_retries = 3
|
| 728 |
+
for attempt in range(max_retries):
|
| 729 |
+
try:
|
| 730 |
+
# Configure generation parameters
|
| 731 |
+
generation_config = genai.types.GenerationConfig(
|
| 732 |
+
temperature=self.temperature,
|
| 733 |
+
max_output_tokens=self.max_tokens,
|
| 734 |
+
top_p=self.top_p,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
response = model.generate_content(
|
| 738 |
+
content_parts,
|
| 739 |
+
generation_config=generation_config
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# Check if response was blocked
|
| 743 |
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
| 744 |
+
raise GepaLLMError(
|
| 745 |
+
f"Gemini blocked the prompt: {response.prompt_feedback.block_reason}",
|
| 746 |
+
ErrorType.VALIDATION_ERROR
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
# Check if response was blocked
|
| 750 |
+
if not response.text:
|
| 751 |
+
if response.candidates and response.candidates[0].finish_reason:
|
| 752 |
+
finish_reason = response.candidates[0].finish_reason
|
| 753 |
+
if finish_reason == genai.types.FinishReason.SAFETY:
|
| 754 |
+
raise GepaLLMError(
|
| 755 |
+
"Gemini response blocked due to safety concerns",
|
| 756 |
+
ErrorType.VALIDATION_ERROR
|
| 757 |
+
)
|
| 758 |
+
elif finish_reason == genai.types.FinishReason.RECITATION:
|
| 759 |
+
raise GepaLLMError(
|
| 760 |
+
"Gemini response blocked due to recitation concerns",
|
| 761 |
+
ErrorType.VALIDATION_ERROR
|
| 762 |
+
)
|
| 763 |
+
raise GepaLLMError(
|
| 764 |
+
"Gemini returned empty response",
|
| 765 |
+
ErrorType.API_ERROR
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
self.logger.debug(f"Received response from Gemini model {gemini_model_name}")
|
| 769 |
+
|
| 770 |
+
# Log usage information if available
|
| 771 |
+
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
| 772 |
+
usage = response.usage_metadata
|
| 773 |
+
self.logger.info(
|
| 774 |
+
f"Tokens used - Prompt: {usage.prompt_token_count}, "
|
| 775 |
+
f"Completion: {usage.candidates_token_count}, "
|
| 776 |
+
f"Total: {usage.total_token_count}"
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Try to parse JSON if the response looks like JSON
|
| 780 |
+
response_text = response.text
|
| 781 |
+
if isinstance(response_text, str) and (
|
| 782 |
+
response_text.startswith('{') or
|
| 783 |
+
response_text.startswith('[')
|
| 784 |
+
):
|
| 785 |
+
try:
|
| 786 |
+
return json.loads(response_text)
|
| 787 |
+
except json.JSONDecodeError:
|
| 788 |
+
pass
|
| 789 |
+
|
| 790 |
+
# Default response format
|
| 791 |
+
return {
|
| 792 |
+
"content": response_text,
|
| 793 |
+
"role": "assistant",
|
| 794 |
+
"model": gemini_model_name,
|
| 795 |
+
"provider": "google"
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
except Exception as e:
|
| 799 |
+
if attempt < max_retries - 1:
|
| 800 |
+
self.logger.warning(f"Gemini API attempt {attempt + 1} failed: {str(e)}. Retrying...")
|
| 801 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
| 802 |
+
continue
|
| 803 |
+
else:
|
| 804 |
+
raise
|
| 805 |
+
|
| 806 |
+
except GepaLLMError:
|
| 807 |
+
raise
|
| 808 |
+
except Exception as e:
|
| 809 |
+
self.logger.error(f"Unexpected error with Gemini API: {str(e)}")
|
| 810 |
+
raise GepaLLMError(
|
| 811 |
+
f"Gemini API error: {str(e)}",
|
| 812 |
+
ErrorType.API_ERROR
|
| 813 |
+
) from e
|
src/gepa_optimizer/models/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Models module for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .config import ModelConfig, OptimizationConfig
|
| 6 |
+
from .dataset import DatasetItem
|
| 7 |
+
from .result import OptimizationResult, OptimizedResult
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ModelConfig",
|
| 11 |
+
"OptimizationConfig",
|
| 12 |
+
"DatasetItem",
|
| 13 |
+
"OptimizationResult",
|
| 14 |
+
"OptimizedResult"
|
| 15 |
+
]
|
src/gepa_optimizer/models/config.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration models for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import List, Optional, Dict, Any, Union, Tuple
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModelConfig:
|
| 11 |
+
"""Configuration for any LLM provider"""
|
| 12 |
+
provider: str # Required: "openai", "anthropic", "huggingface", "vllm", etc.
|
| 13 |
+
model_name: str # Required: actual model name
|
| 14 |
+
api_key: str # Required: API key for the provider
|
| 15 |
+
base_url: Optional[str] = None # Optional: custom endpoint URL
|
| 16 |
+
temperature: float = 0.7
|
| 17 |
+
max_tokens: int = 2048
|
| 18 |
+
top_p: float = 1.0
|
| 19 |
+
frequency_penalty: float = 0.0
|
| 20 |
+
presence_penalty: float = 0.0
|
| 21 |
+
|
| 22 |
+
def __post_init__(self):
|
| 23 |
+
"""Validate required fields after initialization"""
|
| 24 |
+
if not self.provider:
|
| 25 |
+
raise ValueError("Provider is required (e.g., 'openai', 'anthropic', 'huggingface')")
|
| 26 |
+
if not self.model_name:
|
| 27 |
+
raise ValueError("Model name is required (e.g., 'gpt-4', 'claude-3-opus')")
|
| 28 |
+
if not self.api_key:
|
| 29 |
+
raise ValueError(f"API key is required for {self.provider} provider")
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_string(cls, model_string: str) -> 'ModelConfig':
|
| 33 |
+
"""Create ModelConfig from string like 'openai/gpt-4' or 'gpt-4'"""
|
| 34 |
+
if "/" in model_string:
|
| 35 |
+
provider, model_name = model_string.split("/", 1)
|
| 36 |
+
else:
|
| 37 |
+
# Default to OpenAI if no provider specified
|
| 38 |
+
provider = "openai"
|
| 39 |
+
model_name = model_string
|
| 40 |
+
|
| 41 |
+
# Get API key from environment
|
| 42 |
+
api_key = cls._get_api_key_for_provider(provider)
|
| 43 |
+
if not api_key:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"No API key found for {provider}. Please set {provider.upper()}_API_KEY environment variable"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return cls(
|
| 49 |
+
provider=provider,
|
| 50 |
+
model_name=model_name,
|
| 51 |
+
api_key=api_key
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def from_dict(cls, config_dict: dict) -> 'ModelConfig':
|
| 56 |
+
"""Create ModelConfig from dictionary"""
|
| 57 |
+
return cls(**config_dict)
|
| 58 |
+
|
| 59 |
+
def to_dict(self) -> dict:
|
| 60 |
+
"""Convert ModelConfig to dictionary"""
|
| 61 |
+
return {
|
| 62 |
+
'provider': self.provider,
|
| 63 |
+
'model_name': self.model_name,
|
| 64 |
+
'api_key': self.api_key,
|
| 65 |
+
'base_url': self.base_url,
|
| 66 |
+
'temperature': self.temperature,
|
| 67 |
+
'max_tokens': self.max_tokens,
|
| 68 |
+
'top_p': self.top_p,
|
| 69 |
+
'frequency_penalty': self.frequency_penalty,
|
| 70 |
+
'presence_penalty': self.presence_penalty
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def _get_api_key_for_provider(provider: str) -> Optional[str]:
|
| 75 |
+
"""Get API key for provider from environment variables"""
|
| 76 |
+
env_var_map = {
|
| 77 |
+
"openai": "OPENAI_API_KEY",
|
| 78 |
+
"anthropic": "ANTHROPIC_API_KEY",
|
| 79 |
+
"huggingface": "HUGGINGFACE_API_KEY",
|
| 80 |
+
"cohere": "COHERE_API_KEY",
|
| 81 |
+
"ai21": "AI21_API_KEY",
|
| 82 |
+
"together": "TOGETHER_API_KEY",
|
| 83 |
+
"replicate": "REPLICATE_API_TOKEN",
|
| 84 |
+
"groq": "GROQ_API_KEY",
|
| 85 |
+
"ollama": "OLLAMA_API_KEY"
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
env_var = env_var_map.get(provider.lower())
|
| 89 |
+
if env_var:
|
| 90 |
+
return os.getenv(env_var)
|
| 91 |
+
|
| 92 |
+
# Fallback: try generic pattern
|
| 93 |
+
return os.getenv(f"{provider.upper()}_API_KEY")
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class DataSplitConfig:
|
| 97 |
+
"""Configuration for dataset splitting into train/val/test sets
|
| 98 |
+
|
| 99 |
+
🔥 ADAPTIVE SPLITTING: Automatically adjusts ratios based on dataset size for optimal results.
|
| 100 |
+
- Small datasets (< 15): Prioritizes validation set (70/25/5) for reliable candidate ranking
|
| 101 |
+
- Medium datasets (15-50): Balanced split (60/20/20)
|
| 102 |
+
- Large datasets (50+): More training data (70/15/15)
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# Split ratios (must sum to 1.0) - used as defaults, but adaptive strategy overrides for small datasets
|
| 106 |
+
train_ratio: float = 0.6 # 60% for training (Dfeedback - reflection examples)
|
| 107 |
+
val_ratio: float = 0.2 # 20% for validation (Dpareto - Pareto selection)
|
| 108 |
+
test_ratio: float = 0.2 # 20% for test (held-out final evaluation)
|
| 109 |
+
|
| 110 |
+
# Minimum samples per split
|
| 111 |
+
min_train_samples: int = 3
|
| 112 |
+
min_val_samples: int = 3 # 🔥 INCREASED from 2 to 3 for more reliable validation scores
|
| 113 |
+
min_test_samples: int = 1 # 🔥 REDUCED from 2 to 1 (test set less critical, only used once)
|
| 114 |
+
|
| 115 |
+
# Strategy for handling small datasets
|
| 116 |
+
small_dataset_strategy: str = 'adaptive' # 🔥 DEFAULT: 'adaptive', 'duplicate_val', 'no_test', 'error'
|
| 117 |
+
|
| 118 |
+
def __post_init__(self):
|
| 119 |
+
"""Validate split configuration"""
|
| 120 |
+
total = self.train_ratio + self.val_ratio + self.test_ratio
|
| 121 |
+
if not (0.99 <= total <= 1.01): # Allow small floating point errors
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"Split ratios must sum to 1.0, got {total:.3f} "
|
| 124 |
+
f"(train={self.train_ratio}, val={self.val_ratio}, test={self.test_ratio})"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if self.train_ratio <= 0 or self.val_ratio <= 0 or self.test_ratio < 0:
|
| 128 |
+
raise ValueError("Split ratios must be positive (test_ratio can be 0 to disable)")
|
| 129 |
+
|
| 130 |
+
if self.small_dataset_strategy not in {'adaptive', 'duplicate_val', 'no_test', 'error'}:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"Invalid small_dataset_strategy: {self.small_dataset_strategy}. "
|
| 133 |
+
f"Must be 'adaptive', 'duplicate_val', 'no_test', or 'error'"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def get_adaptive_ratios(self, dataset_size: int) -> Tuple[float, float, float]:
|
| 137 |
+
"""
|
| 138 |
+
🔥 NEW: Get adaptive split ratios based on dataset size.
|
| 139 |
+
|
| 140 |
+
For prompt optimization:
|
| 141 |
+
- Small datasets (< 15): Prioritize validation (70/25/5) for reliable candidate ranking
|
| 142 |
+
- Medium (15-50): Balanced (60/20/20)
|
| 143 |
+
- Large (50+): More training (70/15/15)
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
dataset_size: Total number of samples in dataset
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Tuple of (train_ratio, val_ratio, test_ratio)
|
| 150 |
+
"""
|
| 151 |
+
if dataset_size < 15:
|
| 152 |
+
# Small dataset: Prioritize validation for reliable candidate ranking
|
| 153 |
+
# Validation set is CRITICAL - used for every candidate evaluation
|
| 154 |
+
return (0.70, 0.25, 0.05) # 70% train, 25% val, 5% test
|
| 155 |
+
elif dataset_size < 50:
|
| 156 |
+
# Medium dataset: Balanced split
|
| 157 |
+
return (0.60, 0.20, 0.20) # 60% train, 20% val, 20% test
|
| 158 |
+
else:
|
| 159 |
+
# Large dataset: More training data, can reduce validation/test
|
| 160 |
+
return (0.70, 0.15, 0.15) # 70% train, 15% val, 15% test
|
| 161 |
+
|
| 162 |
+
def get_split_indices(self, dataset_size: int) -> Tuple[int, int, int, int]:
|
| 163 |
+
"""
|
| 164 |
+
Calculate split indices for a dataset with adaptive ratios.
|
| 165 |
+
|
| 166 |
+
🔥 ADAPTIVE SPLITTING: Automatically adjusts ratios based on dataset size.
|
| 167 |
+
This ensures optimal allocation:
|
| 168 |
+
- Small datasets: More validation samples for reliable ranking
|
| 169 |
+
- Medium datasets: Balanced split
|
| 170 |
+
- Large datasets: More training data
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
dataset_size: Total number of samples in dataset
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tuple of (train_end, val_end, test_end, dataset_size) indices
|
| 177 |
+
|
| 178 |
+
Raises:
|
| 179 |
+
ValueError: If dataset is too small for configured splits
|
| 180 |
+
"""
|
| 181 |
+
# 🔥 NEW: Use adaptive ratios if strategy is 'adaptive'
|
| 182 |
+
if self.small_dataset_strategy == 'adaptive':
|
| 183 |
+
train_ratio, val_ratio, test_ratio = self.get_adaptive_ratios(dataset_size)
|
| 184 |
+
else:
|
| 185 |
+
train_ratio, val_ratio, test_ratio = self.train_ratio, self.val_ratio, self.test_ratio
|
| 186 |
+
|
| 187 |
+
if dataset_size < self.min_train_samples + self.min_val_samples:
|
| 188 |
+
if self.small_dataset_strategy == 'error':
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"Dataset too small ({dataset_size} samples). "
|
| 191 |
+
f"Need at least {self.min_train_samples + self.min_val_samples} samples."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Calculate ideal split points with adaptive ratios
|
| 195 |
+
train_end = max(self.min_train_samples, int(dataset_size * train_ratio))
|
| 196 |
+
val_end = train_end + max(self.min_val_samples, int(dataset_size * val_ratio))
|
| 197 |
+
|
| 198 |
+
# Adjust for small datasets
|
| 199 |
+
if val_end >= dataset_size:
|
| 200 |
+
if self.small_dataset_strategy in {'adaptive', 'duplicate_val'}:
|
| 201 |
+
# Ensure minimum validation samples, use remainder for test
|
| 202 |
+
val_end = min(dataset_size, train_end + self.min_val_samples)
|
| 203 |
+
test_end = dataset_size
|
| 204 |
+
elif self.small_dataset_strategy == 'no_test':
|
| 205 |
+
# No test set for small datasets
|
| 206 |
+
val_end = dataset_size
|
| 207 |
+
test_end = dataset_size
|
| 208 |
+
else: # error
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"Dataset too small ({dataset_size} samples) for train/val/test split. "
|
| 211 |
+
f"Need at least {self.min_train_samples + self.min_val_samples + self.min_test_samples} samples."
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
test_end = dataset_size
|
| 215 |
+
|
| 216 |
+
return train_end, val_end, test_end, dataset_size
|
| 217 |
+
|
| 218 |
+
@dataclass
|
| 219 |
+
class OptimizationConfig:
|
| 220 |
+
"""Configuration class for GEPA optimization process"""
|
| 221 |
+
|
| 222 |
+
# Core models - REQUIRED by user
|
| 223 |
+
model: Union[str, ModelConfig] # No default - user must specify
|
| 224 |
+
reflection_model: Union[str, ModelConfig] # No default - user must specify
|
| 225 |
+
|
| 226 |
+
# Optimization parameters - REQUIRED by user
|
| 227 |
+
max_iterations: int # No default - user decides their budget
|
| 228 |
+
max_metric_calls: int # No default - user sets their budget
|
| 229 |
+
batch_size: int # No default - user decides based on memory
|
| 230 |
+
|
| 231 |
+
# Dataset splitting configuration
|
| 232 |
+
data_split: DataSplitConfig = field(default_factory=DataSplitConfig)
|
| 233 |
+
|
| 234 |
+
# Reflection settings (separate from evaluation batch_size)
|
| 235 |
+
reflection_examples: int = 3 # Number of examples for each reflection (small!)
|
| 236 |
+
|
| 237 |
+
# Optional optimization settings with sensible fallbacks
|
| 238 |
+
early_stopping: bool = True
|
| 239 |
+
learning_rate: float = 0.01
|
| 240 |
+
|
| 241 |
+
# Multi-objective optimization
|
| 242 |
+
multi_objective: bool = False
|
| 243 |
+
objectives: List[str] = field(default_factory=lambda: ["accuracy"])
|
| 244 |
+
|
| 245 |
+
# Advanced settings
|
| 246 |
+
custom_metrics: Optional[Dict[str, Any]] = None
|
| 247 |
+
use_cache: bool = True
|
| 248 |
+
parallel_evaluation: bool = False
|
| 249 |
+
|
| 250 |
+
# Backwards compatibility (deprecated)
|
| 251 |
+
train_split_ratio: Optional[float] = None # Use data_split instead
|
| 252 |
+
min_dataset_size: int = 2
|
| 253 |
+
|
| 254 |
+
# Cost and budget - user controlled
|
| 255 |
+
max_cost_usd: Optional[float] = None
|
| 256 |
+
timeout_seconds: Optional[int] = None
|
| 257 |
+
|
| 258 |
+
# GEPA-specific optimization parameters (based on actual GEPA library)
|
| 259 |
+
candidate_selection_strategy: str = 'pareto' # Use Pareto selection strategy
|
| 260 |
+
skip_perfect_score: bool = False # Don't skip perfect scores (set to True for early stopping)
|
| 261 |
+
reflection_minibatch_size: Optional[int] = None # Will use reflection_examples if None
|
| 262 |
+
perfect_score: float = 1.0 # Perfect score threshold
|
| 263 |
+
module_selector: str = 'round_robin' # Component selection strategy
|
| 264 |
+
verbose: bool = True # Enable detailed GEPA logging
|
| 265 |
+
|
| 266 |
+
# Test set evaluation
|
| 267 |
+
evaluate_on_test: bool = True # Evaluate final prompt on held-out test set
|
| 268 |
+
|
| 269 |
+
# 🆕 LLEGO Genetic Operator Parameters (Optional - for faster convergence)
|
| 270 |
+
# Based on ICLR 2025 paper: "Decision Tree Induction Through LLMs via Semantically-Aware Evolution"
|
| 271 |
+
# Optimized for small datasets (6-10 samples)
|
| 272 |
+
use_llego_operators: bool = False # Enable LLEGO genetic operators
|
| 273 |
+
|
| 274 |
+
# 🔥 HYBRID MODE: Combine GEPA Reflection + LLEGO Operators
|
| 275 |
+
# When both enabled, candidates are generated from BOTH sources for maximum diversity
|
| 276 |
+
enable_gepa_reflection_with_llego: bool = False # Enable hybrid GEPA+LLEGO mode
|
| 277 |
+
num_gepa_reflection_candidates: int = 3 # Number of GEPA reflection candidates per iteration (default: 3 for better exploration, range: 2-5)
|
| 278 |
+
|
| 279 |
+
# Fitness-guided crossover parameters (FIX #3: Conservative alpha)
|
| 280 |
+
alpha: float = 0.05 # FIX #3: Fitness extrapolation (0.05 = 5% above best parent, realistic for prompt optimization)
|
| 281 |
+
n_crossover: int = 2 # Number of offspring from crossover per iteration
|
| 282 |
+
|
| 283 |
+
# Diversity-guided mutation parameters
|
| 284 |
+
tau: float = 8.0 # Diversity temperature (8.0 = moderate diversity, balanced exploration/exploitation)
|
| 285 |
+
nu: int = 3 # Parent arity (3 parents optimal for small populations ~6 samples)
|
| 286 |
+
n_mutation: int = 2 # Number of offspring from mutation per iteration (total 4 offspring with crossover)
|
| 287 |
+
|
| 288 |
+
# Population management (for genetic operators)
|
| 289 |
+
population_size: int = 8 # Size of prompt population (small but diverse for 6-sample dataset)
|
| 290 |
+
|
| 291 |
+
# 🆕 LLM-as-Judge configuration (Phase 2)
|
| 292 |
+
use_llm_as_judge: bool = True # Enable LLM-as-Judge feedback for detailed, actionable analysis
|
| 293 |
+
llm_as_judge_threshold: float = 0.8 # Use LLM-as-Judge for scores below this threshold
|
| 294 |
+
llm_as_judge_model: Optional[ModelConfig] = None # Optional: use different model (defaults to reflection_model)
|
| 295 |
+
|
| 296 |
+
# 🆕 Logging configuration (Phase 3)
|
| 297 |
+
log_level: str = "INFO" # Logging level: "DEBUG", "INFO", "WARNING", "ERROR"
|
| 298 |
+
|
| 299 |
+
def __post_init__(self):
|
| 300 |
+
"""Validate and process configuration after initialization"""
|
| 301 |
+
# Handle backwards compatibility for train_split_ratio
|
| 302 |
+
if self.train_split_ratio is not None and self.train_split_ratio != 0.8:
|
| 303 |
+
import warnings
|
| 304 |
+
warnings.warn(
|
| 305 |
+
"train_split_ratio is deprecated. Use data_split=DataSplitConfig(...) instead. "
|
| 306 |
+
"Converting to 3-way split with your ratio.",
|
| 307 |
+
DeprecationWarning,
|
| 308 |
+
stacklevel=2
|
| 309 |
+
)
|
| 310 |
+
# Convert 2-way split to 3-way: use train_ratio, split remainder between val/test
|
| 311 |
+
remainder = 1.0 - self.train_split_ratio
|
| 312 |
+
self.data_split = DataSplitConfig(
|
| 313 |
+
train_ratio=self.train_split_ratio,
|
| 314 |
+
val_ratio=remainder * 0.5,
|
| 315 |
+
test_ratio=remainder * 0.5
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Convert string models to ModelConfig objects
|
| 319 |
+
self.model = self._parse_model_config(self.model, "model")
|
| 320 |
+
self.reflection_model = self._parse_model_config(self.reflection_model, "reflection_model")
|
| 321 |
+
|
| 322 |
+
# Set reflection_minibatch_size default
|
| 323 |
+
if self.reflection_minibatch_size is None:
|
| 324 |
+
self.reflection_minibatch_size = self.reflection_examples
|
| 325 |
+
|
| 326 |
+
# Validate required parameters
|
| 327 |
+
self._validate_required_params()
|
| 328 |
+
|
| 329 |
+
# Validate ranges
|
| 330 |
+
self._validate_ranges()
|
| 331 |
+
|
| 332 |
+
def _parse_model_config(self, model: Union[str, ModelConfig], field_name: str) -> ModelConfig:
|
| 333 |
+
"""Parse string model specification into ModelConfig"""
|
| 334 |
+
if isinstance(model, ModelConfig):
|
| 335 |
+
return model
|
| 336 |
+
|
| 337 |
+
if isinstance(model, str):
|
| 338 |
+
# Parse "provider/model-name" format
|
| 339 |
+
if "/" in model:
|
| 340 |
+
provider, model_name = model.split("/", 1)
|
| 341 |
+
else:
|
| 342 |
+
# Default to openai if no provider specified
|
| 343 |
+
provider = "openai"
|
| 344 |
+
model_name = model
|
| 345 |
+
|
| 346 |
+
# Try to get API key from environment
|
| 347 |
+
api_key = self._get_api_key_for_provider(provider)
|
| 348 |
+
if not api_key:
|
| 349 |
+
raise ValueError(
|
| 350 |
+
f"No API key found for {provider}. Please set environment variable "
|
| 351 |
+
f"or provide ModelConfig with api_key for {field_name}"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
return ModelConfig(
|
| 355 |
+
provider=provider,
|
| 356 |
+
model_name=model_name,
|
| 357 |
+
api_key=api_key
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
raise ValueError(f"{field_name} must be either a string or ModelConfig object")
|
| 361 |
+
|
| 362 |
+
def _get_api_key_for_provider(self, provider: str) -> Optional[str]:
|
| 363 |
+
"""Get API key for provider from environment variables"""
|
| 364 |
+
return ModelConfig._get_api_key_for_provider(provider)
|
| 365 |
+
|
| 366 |
+
def _validate_required_params(self):
|
| 367 |
+
"""Validate that all required parameters are provided"""
|
| 368 |
+
required_fields = {
|
| 369 |
+
"max_iterations": self.max_iterations,
|
| 370 |
+
"max_metric_calls": self.max_metric_calls,
|
| 371 |
+
"batch_size": self.batch_size,
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
for field_name, value in required_fields.items():
|
| 375 |
+
if value is None:
|
| 376 |
+
raise ValueError(f"{field_name} is required and must be specified by user")
|
| 377 |
+
|
| 378 |
+
def _validate_ranges(self):
|
| 379 |
+
"""Validate parameter ranges"""
|
| 380 |
+
if self.max_iterations <= 0:
|
| 381 |
+
raise ValueError("max_iterations must be positive")
|
| 382 |
+
|
| 383 |
+
if self.max_metric_calls <= 0:
|
| 384 |
+
raise ValueError("max_metric_calls must be positive")
|
| 385 |
+
|
| 386 |
+
if self.batch_size <= 0:
|
| 387 |
+
raise ValueError("batch_size must be positive")
|
| 388 |
+
|
| 389 |
+
if self.reflection_examples <= 0 or self.reflection_examples > 10:
|
| 390 |
+
raise ValueError("reflection_examples must be between 1 and 10 (recommended: 2-5)")
|
| 391 |
+
|
| 392 |
+
if self.reflection_minibatch_size <= 0:
|
| 393 |
+
raise ValueError("reflection_minibatch_size must be positive")
|
| 394 |
+
|
| 395 |
+
if hasattr(self.model, 'max_tokens') and self.model.max_tokens <= 0:
|
| 396 |
+
raise ValueError("model.max_tokens must be a positive integer")
|
| 397 |
+
|
| 398 |
+
# Validate hybrid mode parameters
|
| 399 |
+
if self.enable_gepa_reflection_with_llego and not self.use_llego_operators:
|
| 400 |
+
raise ValueError("enable_gepa_reflection_with_llego requires use_llego_operators=True")
|
| 401 |
+
|
| 402 |
+
if self.num_gepa_reflection_candidates <= 0 or self.num_gepa_reflection_candidates > 5:
|
| 403 |
+
raise ValueError("num_gepa_reflection_candidates must be between 1 and 5 (recommended: 3 for balanced exploration)")
|
| 404 |
+
|
| 405 |
+
# Validate log_level
|
| 406 |
+
valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR"]
|
| 407 |
+
if self.log_level.upper() not in valid_log_levels:
|
| 408 |
+
raise ValueError(f"log_level must be one of {valid_log_levels}, got: {self.log_level}")
|
| 409 |
+
|
| 410 |
+
def validate_api_connectivity(self) -> Dict[str, bool]:
|
| 411 |
+
"""Test API connectivity for both models"""
|
| 412 |
+
results = {}
|
| 413 |
+
|
| 414 |
+
for model_name, model_config in [("model", self.model), ("reflection_model", self.reflection_model)]:
|
| 415 |
+
try:
|
| 416 |
+
# This would be implemented to actually test the API
|
| 417 |
+
# For now, just check if we have the required info
|
| 418 |
+
if model_config.api_key and model_config.provider and model_config.model_name:
|
| 419 |
+
results[model_name] = True
|
| 420 |
+
else:
|
| 421 |
+
results[model_name] = False
|
| 422 |
+
except Exception:
|
| 423 |
+
results[model_name] = False
|
| 424 |
+
|
| 425 |
+
return results
|
| 426 |
+
|
| 427 |
+
def get_estimated_cost(self) -> Dict[str, Any]:
|
| 428 |
+
"""Estimate cost based on configuration"""
|
| 429 |
+
# This would calculate estimated costs based on:
|
| 430 |
+
# - max_metric_calls
|
| 431 |
+
# - model pricing
|
| 432 |
+
# - expected tokens per call
|
| 433 |
+
return {
|
| 434 |
+
"max_calls": self.max_metric_calls,
|
| 435 |
+
"estimated_cost_range": "To be calculated based on provider pricing",
|
| 436 |
+
"cost_factors": {
|
| 437 |
+
"model_calls": self.max_metric_calls,
|
| 438 |
+
"reflection_calls": self.max_iterations,
|
| 439 |
+
"batch_size": self.batch_size
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
@classmethod
|
| 444 |
+
def create_example_config(cls, provider: str = "openai") -> str:
|
| 445 |
+
"""Generate example configuration code for users"""
|
| 446 |
+
examples = {
|
| 447 |
+
"openai": '''
|
| 448 |
+
# Example OpenAI Configuration
|
| 449 |
+
config = OptimizationConfig(
|
| 450 |
+
model="openai/gpt-4-turbo", # or ModelConfig(...)
|
| 451 |
+
reflection_model="openai/gpt-4-turbo",
|
| 452 |
+
max_iterations=50, # Your choice based on budget
|
| 453 |
+
max_metric_calls=300, # Your choice based on budget
|
| 454 |
+
batch_size=8, # Your choice based on memory
|
| 455 |
+
early_stopping=True,
|
| 456 |
+
learning_rate=0.01
|
| 457 |
+
)
|
| 458 |
+
''',
|
| 459 |
+
"anthropic": '''
|
| 460 |
+
# Example Anthropic Configuration
|
| 461 |
+
config = OptimizationConfig(
|
| 462 |
+
model=ModelConfig(
|
| 463 |
+
provider="anthropic",
|
| 464 |
+
model_name="claude-3-opus-20240229",
|
| 465 |
+
api_key="your-anthropic-key",
|
| 466 |
+
temperature=0.7
|
| 467 |
+
),
|
| 468 |
+
reflection_model="anthropic/claude-3-sonnet-20240229",
|
| 469 |
+
max_iterations=30,
|
| 470 |
+
max_metric_calls=200,
|
| 471 |
+
batch_size=4
|
| 472 |
+
)
|
| 473 |
+
''',
|
| 474 |
+
"mixed": '''
|
| 475 |
+
# Example Mixed Providers Configuration
|
| 476 |
+
config = OptimizationConfig(
|
| 477 |
+
model="openai/gpt-4-turbo", # Main model
|
| 478 |
+
reflection_model="anthropic/claude-3-opus", # Reflection model
|
| 479 |
+
max_iterations=25,
|
| 480 |
+
max_metric_calls=250,
|
| 481 |
+
batch_size=6,
|
| 482 |
+
max_cost_usd=100.0, # Budget limit
|
| 483 |
+
timeout_seconds=3600 # 1 hour limit
|
| 484 |
+
)
|
| 485 |
+
'''
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
return examples.get(provider, examples["openai"])
|
src/gepa_optimizer/models/dataset.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset models for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
import uuid
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class DatasetItem:
|
| 11 |
+
"""Single item in a dataset"""
|
| 12 |
+
|
| 13 |
+
# Identifiers
|
| 14 |
+
item_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 15 |
+
|
| 16 |
+
# Core data
|
| 17 |
+
input_data: Any = ""
|
| 18 |
+
expected_output: Optional[str] = None
|
| 19 |
+
image_base64: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
# Metadata
|
| 22 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
tags: List[str] = field(default_factory=list)
|
| 24 |
+
|
| 25 |
+
# File references
|
| 26 |
+
file_paths: List[str] = field(default_factory=list)
|
| 27 |
+
|
| 28 |
+
# Quality indicators
|
| 29 |
+
quality_score: float = 1.0
|
| 30 |
+
is_validated: bool = False
|
| 31 |
+
validation_notes: List[str] = field(default_factory=list)
|
| 32 |
+
|
| 33 |
+
def __post_init__(self):
|
| 34 |
+
"""Validate item after initialization"""
|
| 35 |
+
if self.quality_score < 0 or self.quality_score > 1:
|
| 36 |
+
raise ValueError("quality_score must be between 0 and 1")
|
| 37 |
+
|
| 38 |
+
def add_tag(self, tag: str):
|
| 39 |
+
"""Add a tag to this item"""
|
| 40 |
+
if tag not in self.tags:
|
| 41 |
+
self.tags.append(tag)
|
| 42 |
+
|
| 43 |
+
def mark_validated(self, notes: Optional[List[str]] = None):
|
| 44 |
+
"""Mark item as validated"""
|
| 45 |
+
self.is_validated = True
|
| 46 |
+
if notes:
|
| 47 |
+
self.validation_notes.extend(notes)
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ProcessedDataset:
|
| 51 |
+
"""Dataset after processing for GEPA optimization"""
|
| 52 |
+
|
| 53 |
+
# Identifiers
|
| 54 |
+
dataset_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 55 |
+
name: str = "Untitled Dataset"
|
| 56 |
+
|
| 57 |
+
# Data
|
| 58 |
+
items: List[DatasetItem] = field(default_factory=list)
|
| 59 |
+
train_split: List[DatasetItem] = field(default_factory=list)
|
| 60 |
+
val_split: List[DatasetItem] = field(default_factory=list)
|
| 61 |
+
|
| 62 |
+
# Metadata
|
| 63 |
+
source_info: Dict[str, Any] = field(default_factory=dict)
|
| 64 |
+
processing_stats: Dict[str, Any] = field(default_factory=dict)
|
| 65 |
+
|
| 66 |
+
# Quality metrics
|
| 67 |
+
total_items: int = 0
|
| 68 |
+
validated_items: int = 0
|
| 69 |
+
avg_quality_score: float = 0.0
|
| 70 |
+
|
| 71 |
+
def __post_init__(self):
|
| 72 |
+
"""Calculate derived fields"""
|
| 73 |
+
self.total_items = len(self.items)
|
| 74 |
+
|
| 75 |
+
if self.items:
|
| 76 |
+
self.validated_items = sum(1 for item in self.items if item.is_validated)
|
| 77 |
+
self.avg_quality_score = sum(item.quality_score for item in self.items) / len(self.items)
|
| 78 |
+
|
| 79 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 80 |
+
"""Get dataset statistics"""
|
| 81 |
+
return {
|
| 82 |
+
'total_items': self.total_items,
|
| 83 |
+
'validated_items': self.validated_items,
|
| 84 |
+
'validation_rate': self.validated_items / self.total_items if self.total_items > 0 else 0,
|
| 85 |
+
'avg_quality_score': self.avg_quality_score,
|
| 86 |
+
'train_size': len(self.train_split),
|
| 87 |
+
'val_size': len(self.val_split),
|
| 88 |
+
'has_expected_outputs': sum(1 for item in self.items if item.expected_output),
|
| 89 |
+
}
|
src/gepa_optimizer/models/result.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Result models for GEPA Optimizer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, Any, Optional, List
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class OptimizationResult:
|
| 12 |
+
"""Complete optimization result with all metadata"""
|
| 13 |
+
|
| 14 |
+
# Identifiers
|
| 15 |
+
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 16 |
+
|
| 17 |
+
# Core results
|
| 18 |
+
original_prompt: str = ""
|
| 19 |
+
optimized_prompt: str = ""
|
| 20 |
+
|
| 21 |
+
# Performance metrics
|
| 22 |
+
improvement_data: Dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
baseline_metrics: Dict[str, float] = field(default_factory=dict)
|
| 24 |
+
final_metrics: Dict[str, float] = field(default_factory=dict)
|
| 25 |
+
|
| 26 |
+
# Process metadata
|
| 27 |
+
optimization_time: float = 0.0
|
| 28 |
+
dataset_size: int = 0
|
| 29 |
+
total_iterations: int = 0
|
| 30 |
+
|
| 31 |
+
# Status and error handling
|
| 32 |
+
status: str = "pending" # pending, running, completed, failed
|
| 33 |
+
error_message: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
# Timestamps
|
| 36 |
+
created_at: datetime = field(default_factory=datetime.now)
|
| 37 |
+
completed_at: Optional[datetime] = None
|
| 38 |
+
|
| 39 |
+
# Reflection history
|
| 40 |
+
reflection_history: List[Dict[str, Any]] = field(default_factory=list)
|
| 41 |
+
|
| 42 |
+
# Cost and resource usage
|
| 43 |
+
estimated_cost: Optional[float] = None
|
| 44 |
+
api_calls_made: int = 0
|
| 45 |
+
|
| 46 |
+
def mark_completed(self):
|
| 47 |
+
"""Mark optimization as completed"""
|
| 48 |
+
self.status = "completed"
|
| 49 |
+
self.completed_at = datetime.now()
|
| 50 |
+
|
| 51 |
+
def mark_failed(self, error: str):
|
| 52 |
+
"""Mark optimization as failed"""
|
| 53 |
+
self.status = "failed"
|
| 54 |
+
self.error_message = error
|
| 55 |
+
self.completed_at = datetime.now()
|
| 56 |
+
|
| 57 |
+
class OptimizedResult:
|
| 58 |
+
"""
|
| 59 |
+
User-facing result class that provides clean interface
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self,
|
| 63 |
+
original_prompt: str = "",
|
| 64 |
+
optimized_prompt: str = "",
|
| 65 |
+
improvement_data: Dict[str, Any] = None,
|
| 66 |
+
optimization_time: float = 0.0,
|
| 67 |
+
dataset_size: int = 0,
|
| 68 |
+
total_iterations: int = 0,
|
| 69 |
+
status: str = "pending",
|
| 70 |
+
error_message: Optional[str] = None,
|
| 71 |
+
detailed_result: Optional[OptimizationResult] = None,
|
| 72 |
+
session_id: Optional[str] = None):
|
| 73 |
+
"""
|
| 74 |
+
Initialize OptimizedResult with individual parameters
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
original_prompt: Original seed prompt
|
| 78 |
+
optimized_prompt: Optimized prompt
|
| 79 |
+
improvement_data: Performance improvement data
|
| 80 |
+
optimization_time: Time taken for optimization
|
| 81 |
+
dataset_size: Size of dataset used
|
| 82 |
+
total_iterations: Number of optimization iterations
|
| 83 |
+
status: Optimization status
|
| 84 |
+
error_message: Error message if failed
|
| 85 |
+
detailed_result: Optional detailed OptimizationResult
|
| 86 |
+
session_id: Optional session ID
|
| 87 |
+
"""
|
| 88 |
+
if improvement_data is None:
|
| 89 |
+
improvement_data = {}
|
| 90 |
+
|
| 91 |
+
# Create internal OptimizationResult
|
| 92 |
+
self._result = OptimizationResult(
|
| 93 |
+
session_id=session_id or str(uuid.uuid4()),
|
| 94 |
+
original_prompt=original_prompt,
|
| 95 |
+
optimized_prompt=optimized_prompt,
|
| 96 |
+
improvement_data=improvement_data,
|
| 97 |
+
optimization_time=optimization_time,
|
| 98 |
+
dataset_size=dataset_size,
|
| 99 |
+
total_iterations=total_iterations,
|
| 100 |
+
status=status,
|
| 101 |
+
error_message=error_message
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# If detailed_result is provided, use it instead
|
| 105 |
+
if detailed_result is not None:
|
| 106 |
+
self._result = detailed_result
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def prompt(self) -> str:
|
| 110 |
+
"""The optimized prompt ready for production use"""
|
| 111 |
+
return self._result.optimized_prompt
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def original_prompt(self) -> str:
|
| 115 |
+
"""The original seed prompt for reference"""
|
| 116 |
+
return self._result.original_prompt
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def session_id(self) -> str:
|
| 120 |
+
"""Unique session identifier"""
|
| 121 |
+
return self._result.session_id
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def improvement_data(self) -> Dict[str, Any]:
|
| 125 |
+
"""Performance improvement data"""
|
| 126 |
+
return self._result.improvement_data
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def status(self) -> str:
|
| 130 |
+
"""Optimization status"""
|
| 131 |
+
return self._result.status
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def error_message(self) -> Optional[str]:
|
| 135 |
+
"""Error message if optimization failed"""
|
| 136 |
+
return self._result.error_message
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def is_successful(self) -> bool:
|
| 140 |
+
"""Whether optimization completed successfully"""
|
| 141 |
+
return (
|
| 142 |
+
self._result.status == "completed" and
|
| 143 |
+
self._result.error_message is None
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def optimization_time(self) -> float:
|
| 148 |
+
"""Time taken for optimization in seconds"""
|
| 149 |
+
return self._result.optimization_time
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def dataset_size(self) -> int:
|
| 153 |
+
"""Size of dataset used for optimization"""
|
| 154 |
+
return self._result.dataset_size
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def total_iterations(self) -> int:
|
| 158 |
+
"""Total optimization iterations performed"""
|
| 159 |
+
return self._result.total_iterations
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def estimated_cost(self) -> Optional[float]:
|
| 163 |
+
"""Estimated cost in USD"""
|
| 164 |
+
return self._result.estimated_cost
|
| 165 |
+
|
| 166 |
+
def get_improvement_summary(self) -> Dict[str, Any]:
|
| 167 |
+
"""Get summary of improvements made"""
|
| 168 |
+
summary = {
|
| 169 |
+
'has_improvement': bool(self._result.improvement_data),
|
| 170 |
+
'optimization_time': self.optimization_time,
|
| 171 |
+
'iterations': self.total_iterations,
|
| 172 |
+
'dataset_size': self.dataset_size
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Add improvement percentage if available
|
| 176 |
+
if 'improvement_percent' in self._result.improvement_data:
|
| 177 |
+
summary['improvement_percent'] = self._result.improvement_data['improvement_percent']
|
| 178 |
+
|
| 179 |
+
return summary
|
| 180 |
+
|
| 181 |
+
def get_reflection_summary(self) -> Dict[str, Any]:
|
| 182 |
+
"""Get summary of reflection process"""
|
| 183 |
+
if not self._result.reflection_history:
|
| 184 |
+
return {'total_reflections': 0}
|
| 185 |
+
|
| 186 |
+
return {
|
| 187 |
+
'total_reflections': len(self._result.reflection_history),
|
| 188 |
+
'reflection_points': [
|
| 189 |
+
r.get('summary', 'No summary')
|
| 190 |
+
for r in self._result.reflection_history[:3] # First 3
|
| 191 |
+
]
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def get_detailed_result(self) -> OptimizationResult:
|
| 195 |
+
"""Get the full detailed result for advanced users"""
|
| 196 |
+
return self._result
|
| 197 |
+
|
| 198 |
+
def __str__(self) -> str:
|
| 199 |
+
"""String representation"""
|
| 200 |
+
status_emoji = "✅" if self.is_successful else "❌" if self.status == "failed" else "⏳"
|
| 201 |
+
return f"OptimizedResult({status_emoji} {self.status}, time={self.optimization_time:.2f}s)"
|
| 202 |
+
|
| 203 |
+
def __repr__(self) -> str:
|
| 204 |
+
return self.__str__()
|
src/gepa_optimizer/operators/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLEGO Genetic Operators for GEPA.
|
| 3 |
+
|
| 4 |
+
This module provides genetic operators for prompt optimization:
|
| 5 |
+
- FitnessGuidedCrossover: Combines high-performing prompts
|
| 6 |
+
- DiversityGuidedMutation: Explores diverse variations
|
| 7 |
+
- LLEGOIntegrationLayer: Manages the genetic algorithm workflow
|
| 8 |
+
|
| 9 |
+
Based on: Decision Tree Induction Through LLMs via Semantically-Aware Evolution (ICLR 2025)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# Base interfaces (SOLID: Interface Segregation)
|
| 13 |
+
from .base_operator import (
|
| 14 |
+
BaseGeneticOperator,
|
| 15 |
+
BaseCrossoverOperator,
|
| 16 |
+
BaseMutationOperator,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Data models
|
| 20 |
+
from .models import (
|
| 21 |
+
PromptCandidate,
|
| 22 |
+
PromptMetadata,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Concrete operators (SOLID: Single Responsibility)
|
| 26 |
+
from .crossover import FitnessGuidedCrossover
|
| 27 |
+
from .mutation import DiversityGuidedMutation
|
| 28 |
+
|
| 29 |
+
# Integration layer
|
| 30 |
+
from .llego_operators import LLEGOIntegrationLayer
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
# Base interfaces
|
| 34 |
+
'BaseGeneticOperator',
|
| 35 |
+
'BaseCrossoverOperator',
|
| 36 |
+
'BaseMutationOperator',
|
| 37 |
+
# Data models
|
| 38 |
+
'PromptCandidate',
|
| 39 |
+
'PromptMetadata',
|
| 40 |
+
# Operators
|
| 41 |
+
'FitnessGuidedCrossover',
|
| 42 |
+
'DiversityGuidedMutation',
|
| 43 |
+
# Integration
|
| 44 |
+
'LLEGOIntegrationLayer',
|
| 45 |
+
]
|
src/gepa_optimizer/operators/base_operator.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base Genetic Operator Interface.
|
| 3 |
+
|
| 4 |
+
Defines the abstract interface for all genetic operators following
|
| 5 |
+
the Interface Segregation Principle (ISP) of SOLID.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import List, Callable
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseGeneticOperator(ABC):
|
| 16 |
+
"""
|
| 17 |
+
Abstract base class for genetic operators.
|
| 18 |
+
|
| 19 |
+
All genetic operators (crossover, mutation, etc.) should inherit from this
|
| 20 |
+
class and implement the __call__ method.
|
| 21 |
+
|
| 22 |
+
Design Principles:
|
| 23 |
+
- Single Responsibility: Each operator does one thing
|
| 24 |
+
- Open/Closed: Extend via inheritance, don't modify
|
| 25 |
+
- Liskov Substitution: Any operator works where base is expected
|
| 26 |
+
- Interface Segregation: Minimal required interface
|
| 27 |
+
- Dependency Inversion: Depend on abstractions (LLM callable)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def __call__(self, *args, **kwargs) -> str:
|
| 32 |
+
"""
|
| 33 |
+
Execute the genetic operation.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
str: New prompt generated by the operation
|
| 37 |
+
"""
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
@abstractmethod
|
| 41 |
+
def _build_prompt(self, *args, **kwargs) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Build the LLM prompt for this operation.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
str: Prompt to send to the LLM
|
| 47 |
+
"""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class BaseCrossoverOperator(BaseGeneticOperator):
|
| 52 |
+
"""
|
| 53 |
+
Abstract base class for crossover operators.
|
| 54 |
+
|
| 55 |
+
Crossover combines multiple parent prompts to create offspring
|
| 56 |
+
that inherit good traits from both parents.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def __call__(
|
| 61 |
+
self,
|
| 62 |
+
parents: List, # List[PromptCandidate]
|
| 63 |
+
target_fitness: float,
|
| 64 |
+
llm: Callable[[str], str]
|
| 65 |
+
) -> str:
|
| 66 |
+
"""
|
| 67 |
+
Combine parent prompts to create offspring.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
parents: List of parent PromptCandidate objects
|
| 71 |
+
target_fitness: Desired fitness for offspring
|
| 72 |
+
llm: Language model callable
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
str: Offspring prompt
|
| 76 |
+
"""
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class BaseMutationOperator(BaseGeneticOperator):
|
| 81 |
+
"""
|
| 82 |
+
Abstract base class for mutation operators.
|
| 83 |
+
|
| 84 |
+
Mutation creates variations of a parent prompt to explore
|
| 85 |
+
new regions of the search space.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@abstractmethod
|
| 89 |
+
def __call__(
|
| 90 |
+
self,
|
| 91 |
+
parent, # PromptCandidate
|
| 92 |
+
population: List, # List[PromptCandidate]
|
| 93 |
+
llm: Callable[[str], str]
|
| 94 |
+
) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Mutate a parent prompt to create a variation.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
parent: Parent PromptCandidate to mutate
|
| 100 |
+
population: Current population for diversity guidance
|
| 101 |
+
llm: Language model callable
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
str: Mutated prompt
|
| 105 |
+
"""
|
| 106 |
+
pass
|
| 107 |
+
|
src/gepa_optimizer/operators/crossover.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fitness-Guided Crossover Operator.
|
| 3 |
+
|
| 4 |
+
Adapts LLEGO's fitness-guided crossover for text prompts.
|
| 5 |
+
Based on: Decision Tree Induction Through LLMs via Semantically-Aware Evolution (ICLR 2025)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Callable, TYPE_CHECKING
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from .base_operator import BaseCrossoverOperator
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from .models import PromptCandidate
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FitnessGuidedCrossover(BaseCrossoverOperator):
|
| 20 |
+
"""
|
| 21 |
+
Fitness-guided crossover for text prompts.
|
| 22 |
+
|
| 23 |
+
Combines high-performing parent prompts to generate offspring
|
| 24 |
+
that target specific fitness levels using LLM semantic understanding.
|
| 25 |
+
|
| 26 |
+
From LLEGO paper:
|
| 27 |
+
"Fitness-guided crossover exploits high-performing regions of the search space
|
| 28 |
+
by combining parent trees targeting a desired fitness level f* = f_max + α(f_max - f_min)"
|
| 29 |
+
|
| 30 |
+
Reference: https://github.com/nicolashuynh/LLEGO
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, alpha: float = 0.1):
|
| 34 |
+
"""
|
| 35 |
+
Initialize crossover operator.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
alpha: Fitness extrapolation parameter.
|
| 39 |
+
Higher α = target higher fitness than parents.
|
| 40 |
+
Default 0.1 from LLEGO paper (target 10% above best parent).
|
| 41 |
+
"""
|
| 42 |
+
self.alpha = alpha
|
| 43 |
+
logger.debug(f"FitnessGuidedCrossover initialized with α={alpha}")
|
| 44 |
+
|
| 45 |
+
def __call__(
|
| 46 |
+
self,
|
| 47 |
+
parents: List["PromptCandidate"],
|
| 48 |
+
target_fitness: float,
|
| 49 |
+
llm: Callable[[str], str]
|
| 50 |
+
) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Combine parent prompts targeting specific fitness.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
parents: List of PromptCandidate objects (2+ parents)
|
| 56 |
+
target_fitness: Desired fitness for offspring
|
| 57 |
+
llm: Language model callable
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
str: Offspring prompt
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
ValueError: If fewer than 2 parents provided
|
| 64 |
+
"""
|
| 65 |
+
if len(parents) < 2:
|
| 66 |
+
raise ValueError("Crossover requires at least 2 parents")
|
| 67 |
+
|
| 68 |
+
# Sort parents by fitness (best first)
|
| 69 |
+
sorted_parents = sorted(parents, key=lambda p: p.fitness, reverse=True)
|
| 70 |
+
|
| 71 |
+
logger.debug(f"Crossover: {len(parents)} parents, target fitness={target_fitness:.3f}")
|
| 72 |
+
|
| 73 |
+
# Build crossover prompt and call LLM
|
| 74 |
+
crossover_prompt = self._build_prompt(sorted_parents, target_fitness)
|
| 75 |
+
new_prompt = llm(crossover_prompt)
|
| 76 |
+
|
| 77 |
+
return new_prompt
|
| 78 |
+
|
| 79 |
+
def _build_prompt(
|
| 80 |
+
self,
|
| 81 |
+
parents: List["PromptCandidate"],
|
| 82 |
+
target_fitness: float
|
| 83 |
+
) -> str:
|
| 84 |
+
"""
|
| 85 |
+
Build LLM prompt for crossover operation.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
parents: Sorted list of parent candidates (best first)
|
| 89 |
+
target_fitness: Target fitness for offspring
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
str: Prompt for LLM
|
| 93 |
+
"""
|
| 94 |
+
# Truncate parents to prevent safety filter issues
|
| 95 |
+
MAX_PARENT_LENGTH = 350
|
| 96 |
+
|
| 97 |
+
# Build parent descriptions (limit to top 2)
|
| 98 |
+
parent_descriptions = []
|
| 99 |
+
for i, parent in enumerate(parents[:2]):
|
| 100 |
+
truncated = parent.prompt[:MAX_PARENT_LENGTH]
|
| 101 |
+
if len(parent.prompt) > MAX_PARENT_LENGTH:
|
| 102 |
+
truncated += "..."
|
| 103 |
+
parent_descriptions.append(
|
| 104 |
+
f"P{i+1} (f={parent.fitness:.2f}): {truncated}\n"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
prompt = f"""Combine these prompts into ONE improved version (target fitness: {target_fitness:.2f}).
|
| 108 |
+
|
| 109 |
+
{' '.join(parent_descriptions)}
|
| 110 |
+
Instructions:
|
| 111 |
+
1. Merge the best rules/principles from both parents
|
| 112 |
+
2. Organize logic clearly (e.g., "For X tasks: do Y", "If Z: then A")
|
| 113 |
+
3. Add structure to handle different cases systematically
|
| 114 |
+
4. Keep output format (Element: X, Description:, Reason:)
|
| 115 |
+
5. Max 600 chars
|
| 116 |
+
|
| 117 |
+
Output ONLY the combined prompt:"""
|
| 118 |
+
|
| 119 |
+
return prompt
|
| 120 |
+
|