File size: 936 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Protocol
@dataclass
class TrainingSample:
prompt: str
target: str
metadata: dict[str, Any]
class DataProvider(Protocol):
def load(
self,
split: str,
max_samples: int | None = None,
cache_dir: str | None = None,
) -> list[TrainingSample]:
...
class RewardFunction(Protocol):
def __call__(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict[str, Any]],
) -> list[float]:
...
class ObjectiveModule(Protocol):
name: str
def reward_names(self) -> list[str]:
...
def extra_reward(
self,
prompts: list[str],
completions: list[str],
references: list[str],
metadata: list[dict[str, Any]],
) -> list[float]:
...
|