| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Protocol | |
| class TrainingSample: | |
| prompt: str | |
| target: str | |
| metadata: dict[str, Any] | |
| class DataProvider(Protocol): | |
| def load( | |
| self, | |
| split: str, | |
| max_samples: int | None = None, | |
| cache_dir: str | None = None, | |
| ) -> list[TrainingSample]: | |
| ... | |
| class RewardFunction(Protocol): | |
| def __call__( | |
| self, | |
| prompts: list[str], | |
| completions: list[str], | |
| references: list[str], | |
| metadata: list[dict[str, Any]], | |
| ) -> list[float]: | |
| ... | |
| class ObjectiveModule(Protocol): | |
| name: str | |
| def reward_names(self) -> list[str]: | |
| ... | |
| def extra_reward( | |
| self, | |
| prompts: list[str], | |
| completions: list[str], | |
| references: list[str], | |
| metadata: list[dict[str, Any]], | |
| ) -> list[float]: | |
| ... | |