Spaces:
Sleeping
Sleeping
File size: 2,935 Bytes
cacd4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""
Dataset models for GEPA Optimizer
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import uuid
@dataclass
class DatasetItem:
"""Single item in a dataset"""
# Identifiers
item_id: str = field(default_factory=lambda: str(uuid.uuid4()))
# Core data
input_data: Any = ""
expected_output: Optional[str] = None
image_base64: Optional[str] = None
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
# File references
file_paths: List[str] = field(default_factory=list)
# Quality indicators
quality_score: float = 1.0
is_validated: bool = False
validation_notes: List[str] = field(default_factory=list)
def __post_init__(self):
"""Validate item after initialization"""
if self.quality_score < 0 or self.quality_score > 1:
raise ValueError("quality_score must be between 0 and 1")
def add_tag(self, tag: str):
"""Add a tag to this item"""
if tag not in self.tags:
self.tags.append(tag)
def mark_validated(self, notes: Optional[List[str]] = None):
"""Mark item as validated"""
self.is_validated = True
if notes:
self.validation_notes.extend(notes)
@dataclass
class ProcessedDataset:
"""Dataset after processing for GEPA optimization"""
# Identifiers
dataset_id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: str = "Untitled Dataset"
# Data
items: List[DatasetItem] = field(default_factory=list)
train_split: List[DatasetItem] = field(default_factory=list)
val_split: List[DatasetItem] = field(default_factory=list)
# Metadata
source_info: Dict[str, Any] = field(default_factory=dict)
processing_stats: Dict[str, Any] = field(default_factory=dict)
# Quality metrics
total_items: int = 0
validated_items: int = 0
avg_quality_score: float = 0.0
def __post_init__(self):
"""Calculate derived fields"""
self.total_items = len(self.items)
if self.items:
self.validated_items = sum(1 for item in self.items if item.is_validated)
self.avg_quality_score = sum(item.quality_score for item in self.items) / len(self.items)
def get_stats(self) -> Dict[str, Any]:
"""Get dataset statistics"""
return {
'total_items': self.total_items,
'validated_items': self.validated_items,
'validation_rate': self.validated_items / self.total_items if self.total_items > 0 else 0,
'avg_quality_score': self.avg_quality_score,
'train_size': len(self.train_split),
'val_size': len(self.val_split),
'has_expected_outputs': sum(1 for item in self.items if item.expected_output),
}
|