Universal-prompt-Optimizer / src /gepa_optimizer /data /validation_dataset_loader.py
Suhasdev's picture
Deploy Universal Prompt Optimizer to HF Spaces (clean)
cacd4d0
"""
Validation Dataset Loader for UI Validation Use Case
Loads validation datapoints from SQLite database and converts to GEPA-compatible format.
Supports filtering by data_type (trainset/valset/testset) and confirmed status.
"""
import os
import sqlite3
import base64
import logging
from typing import List, Dict, Any, Optional, Literal
from pathlib import Path
logger = logging.getLogger(__name__)
class ValidationDatasetLoader:
"""
Loads validation dataset from SQLite database.
Database schema:
- validation_data: id, image_id, command, result (0/1), reasoning, data_type, confirmed, created_at
- images: image_id, mime, bytes (BLOB), created_at
Converts to GEPA format:
- input: command text (seed prompt will be provided in test script)
- output: "true" or "false" (converted from 0/1)
- image_base64: base64 encoded image (TOP LEVEL for UniversalConverter)
- metadata: All original fields plus converted values
Note: The seed prompt is NOT stored in database - it will be provided in the test script.
The input field contains just the command, and the image is at top level.
"""
def __init__(
self,
db_path: Optional[str] = None,
confirmed_only: bool = True
):
"""
Initialize validation dataset loader.
Args:
db_path: Path to SQLite database file.
Default: "./validation_data.db" or from VD_DB_PATH env var
confirmed_only: If True, only load datapoints where confirmed=1.
Default: True (only manually reviewed data)
Raises:
FileNotFoundError: If database file doesn't exist
sqlite3.Error: If database connection fails
"""
# Get database path from env or use default
if db_path is None:
db_path = os.getenv("VD_DB_PATH", "./validation_data.db")
self.db_path = Path(db_path).resolve()
if not self.db_path.exists():
raise FileNotFoundError(
f"Database file not found: {self.db_path}\n"
f"Make sure validation_data_ui_server_async.py has been run at least once to create the database."
)
self.confirmed_only = confirmed_only
def load_dataset(
self,
data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
confirmed_only: Optional[bool] = None
) -> List[Dict[str, Any]]:
"""
Load dataset from database and convert to GEPA format.
Args:
data_type: Filter by data_type. If None, loads all types.
Options: "trainset", "valset", "testset"
confirmed_only: Override instance default. If True, only load confirmed datapoints.
If None, uses instance default (self.confirmed_only)
Returns:
List of dataset items in GEPA format:
[
{
"input": "Validate Submit button is visible", # Command only (seed prompt in test script)
"output": "true", # or "false" (converted from 0/1)
"image_base64": "<base64_encoded_image>", # TOP LEVEL (image + command together)
"metadata": {
"id": 1,
"image_id": "abc123...",
"command": "Validate Submit button is visible",
"result": True, # Boolean
"result_int": 1, # Original 0/1
"reasoning": "Detailed explanation...",
"data_type": "trainset",
"confirmed": True,
"created_at": "2024-01-01 12:00:00"
}
},
...
]
Note: Seed prompt is provided separately in test script, not in database.
Raises:
sqlite3.Error: If database query fails
ValueError: If no datapoints found matching criteria
"""
# Use provided confirmed_only or instance default
use_confirmed = confirmed_only if confirmed_only is not None else self.confirmed_only
conn = sqlite3.connect(str(self.db_path))
conn.row_factory = sqlite3.Row # Access columns by name
dataset = []
try:
# Build query with filters
query = """
SELECT
v.id,
v.image_id,
v.command,
v.result,
v.reasoning,
v.data_type,
v.confirmed,
v.created_at,
i.mime,
i.bytes
FROM validation_data v
INNER JOIN images i ON v.image_id = i.image_id
WHERE 1=1
"""
params = []
# Add filters
if use_confirmed:
query += " AND v.confirmed = 1"
if data_type:
query += " AND v.data_type = ?"
params.append(data_type)
query += " ORDER BY v.id ASC"
# Execute query
cursor = conn.execute(query, params)
rows = cursor.fetchall()
if not rows:
filter_msg = []
if use_confirmed:
filter_msg.append("confirmed=1")
if data_type:
filter_msg.append(f"data_type='{data_type}'")
filter_str = " with filters: " + ", ".join(filter_msg) if filter_msg else ""
raise ValueError(
f"No datapoints found{filter_str} in database: {self.db_path}\n"
f"Make sure you have generated and saved datapoints using the validation UI."
)
# Convert rows to GEPA format
for row in rows:
# Convert 0/1 to "true"/"false" string for GEPA
result_str = "true" if row["result"] == 1 else "false"
# Encode image bytes to base64
image_base64 = base64.b64encode(row["bytes"]).decode("utf-8")
# Create GEPA format item
# Input: command (seed prompt will be provided in test script)
# Image: separate at top level (image_base64)
# Output: "true" or "false" (converted from 0/1)
dataset_item = {
"input": row["command"], # Just the command - seed prompt will be in test script
"output": result_str, # "true" or "false" (string)
"image_base64": image_base64, # TOP LEVEL for UniversalConverter (image + command together)
"metadata": {
"id": row["id"],
"image_id": row["image_id"],
"command": row["command"], # Keep original for reference
"result": bool(row["result"]), # Boolean for reference
"result_int": row["result"], # Original 0/1 for reference
"reasoning": row["reasoning"],
"data_type": row["data_type"],
"confirmed": bool(row["confirmed"]),
"created_at": row["created_at"],
"mime": row["mime"],
}
}
dataset.append(dataset_item)
# Log summary
data_type_str = f" ({data_type})" if data_type else ""
confirmed_str = " (confirmed only)" if use_confirmed else " (all)"
logger.info(f"Loaded {len(dataset)} validation datapoints{data_type_str}{confirmed_str}")
return dataset
finally:
conn.close()
def load_split_dataset(
self,
confirmed_only: Optional[bool] = None
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Load dataset split by data_type (trainset/valset/testset).
Convenience method that loads all three splits at once.
Args:
confirmed_only: Override instance default. If True, only load confirmed datapoints.
Returns:
Tuple of (train_set, val_set, test_set) in GEPA format
Example:
loader = ValidationDatasetLoader(db_path="./validation_data.db")
train, val, test = loader.load_split_dataset()
"""
train_set = self.load_dataset(data_type="trainset", confirmed_only=confirmed_only)
val_set = self.load_dataset(data_type="valset", confirmed_only=confirmed_only)
test_set = self.load_dataset(data_type="testset", confirmed_only=confirmed_only)
logger.info(f"Dataset Split Summary: Training={len(train_set)}, Validation={len(val_set)}, Test={len(test_set)}, Total={len(train_set) + len(val_set) + len(test_set)}")
return train_set, val_set, test_set
def get_dataset_stats(self) -> Dict[str, Any]:
"""
Get statistics about the dataset in the database.
Returns:
Dictionary with dataset statistics:
{
"total": 100,
"confirmed": 95,
"unconfirmed": 5,
"by_data_type": {
"trainset": 70,
"valset": 15,
"testset": 15
},
"by_result": {
"true": 50,
"false": 50
}
}
"""
conn = sqlite3.connect(str(self.db_path))
conn.row_factory = sqlite3.Row
try:
stats = {}
# Total counts
total = conn.execute("SELECT COUNT(*) FROM validation_data").fetchone()[0]
confirmed = conn.execute("SELECT COUNT(*) FROM validation_data WHERE confirmed = 1").fetchone()[0]
stats["total"] = total
stats["confirmed"] = confirmed
stats["unconfirmed"] = total - confirmed
# By data_type
data_type_rows = conn.execute("""
SELECT data_type, COUNT(*) as count
FROM validation_data
GROUP BY data_type
""").fetchall()
stats["by_data_type"] = {row["data_type"]: row["count"] for row in data_type_rows}
# By result (true/false)
result_rows = conn.execute("""
SELECT result, COUNT(*) as count
FROM validation_data
GROUP BY result
""").fetchall()
stats["by_result"] = {
"true": sum(row["count"] for row in result_rows if row["result"] == 1),
"false": sum(row["count"] for row in result_rows if row["result"] == 0)
}
return stats
finally:
conn.close()
def load_validation_dataset(
db_path: Optional[str] = None,
data_type: Optional[Literal["trainset", "valset", "testset"]] = None,
confirmed_only: bool = True
) -> List[Dict[str, Any]]:
"""
Convenience function to load validation dataset.
Args:
db_path: Path to SQLite database file. Default: "./validation_data.db"
data_type: Filter by data_type. If None, loads all types.
confirmed_only: If True, only load confirmed datapoints.
Returns:
List of dataset items in GEPA format
Example:
# Load all confirmed training data
train_data = load_validation_dataset(data_type="trainset", confirmed_only=True)
# Load all confirmed data
all_data = load_validation_dataset(confirmed_only=True)
"""
loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
return loader.load_dataset(data_type=data_type, confirmed_only=confirmed_only)
def load_validation_split(
db_path: Optional[str] = None,
confirmed_only: bool = True
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Convenience function to load validation dataset split by data_type.
Args:
db_path: Path to SQLite database file. Default: "./validation_data.db"
confirmed_only: If True, only load confirmed datapoints.
Returns:
Tuple of (train_set, val_set, test_set) in GEPA format
Example:
train, val, test = load_validation_split(confirmed_only=True)
"""
loader = ValidationDatasetLoader(db_path=db_path, confirmed_only=confirmed_only)
return loader.load_split_dataset(confirmed_only=confirmed_only)
# Example usage and testing
if __name__ == "__main__":
print("🚀 Testing Validation Dataset Loader...")
try:
loader = ValidationDatasetLoader()
# Get stats
print("\n📊 Dataset Statistics:")
stats = loader.get_dataset_stats()
print(f" Total: {stats['total']}")
print(f" Confirmed: {stats['confirmed']}")
print(f" Unconfirmed: {stats['unconfirmed']}")
print(f" By data_type: {stats['by_data_type']}")
print(f" By result: {stats['by_result']}")
# Load split dataset
print("\n📦 Loading split dataset...")
train, val, test = loader.load_split_dataset()
# Show sample
if train:
sample = train[0]
print(f"\n📝 Sample Training Item:")
print(f" Input: {sample['input']}")
print(f" Output: {sample['output']}")
print(f" Image ID: {sample['metadata']['image_id'][:8]}...")
print(f" Data Type: {sample['metadata']['data_type']}")
print(f" Result: {sample['metadata']['result']} (int: {sample['metadata']['result_int']})")
except FileNotFoundError as e:
print(f"❌ {e}")
print("\n💡 Make sure validation_data_ui_server_async.py has been run to create the database.")
except ValueError as e:
print(f"❌ {e}")
print("\n💡 Generate and save some datapoints using the validation UI first.")