Suhasdev's picture
Deploy Universal Prompt Optimizer to HF Spaces (clean)
cacd4d0
"""
Dataset models for GEPA Optimizer
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import uuid
@dataclass
class DatasetItem:
"""Single item in a dataset"""
# Identifiers
item_id: str = field(default_factory=lambda: str(uuid.uuid4()))
# Core data
input_data: Any = ""
expected_output: Optional[str] = None
image_base64: Optional[str] = None
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
# File references
file_paths: List[str] = field(default_factory=list)
# Quality indicators
quality_score: float = 1.0
is_validated: bool = False
validation_notes: List[str] = field(default_factory=list)
def __post_init__(self):
"""Validate item after initialization"""
if self.quality_score < 0 or self.quality_score > 1:
raise ValueError("quality_score must be between 0 and 1")
def add_tag(self, tag: str):
"""Add a tag to this item"""
if tag not in self.tags:
self.tags.append(tag)
def mark_validated(self, notes: Optional[List[str]] = None):
"""Mark item as validated"""
self.is_validated = True
if notes:
self.validation_notes.extend(notes)
@dataclass
class ProcessedDataset:
"""Dataset after processing for GEPA optimization"""
# Identifiers
dataset_id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: str = "Untitled Dataset"
# Data
items: List[DatasetItem] = field(default_factory=list)
train_split: List[DatasetItem] = field(default_factory=list)
val_split: List[DatasetItem] = field(default_factory=list)
# Metadata
source_info: Dict[str, Any] = field(default_factory=dict)
processing_stats: Dict[str, Any] = field(default_factory=dict)
# Quality metrics
total_items: int = 0
validated_items: int = 0
avg_quality_score: float = 0.0
def __post_init__(self):
"""Calculate derived fields"""
self.total_items = len(self.items)
if self.items:
self.validated_items = sum(1 for item in self.items if item.is_validated)
self.avg_quality_score = sum(item.quality_score for item in self.items) / len(self.items)
def get_stats(self) -> Dict[str, Any]:
"""Get dataset statistics"""
return {
'total_items': self.total_items,
'validated_items': self.validated_items,
'validation_rate': self.validated_items / self.total_items if self.total_items > 0 else 0,
'avg_quality_score': self.avg_quality_score,
'train_size': len(self.train_split),
'val_size': len(self.val_split),
'has_expected_outputs': sum(1 for item in self.items if item.expected_output),
}