preference-lab / models.py
Sibam
refactor: apply production readiness recommendations including dataset caching, XSS protection, pure schemas, and JSON decoding logic.
5ee1380
Raw
History Blame Contribute Delete
5.12 kB
"""
Data models for PreferenceLab Environment.
PreferenceLab simulates the RLHF preference data collection pipeline.
An AI agent learns to judge which LLM responses are better using
gold-standard labels from real preference datasets as ground truth.
Tasks:
Task 1 (Easy): Pairwise ranking β€” pick A or B
Task 2 (Medium): Multi-axis Likert scoring
Task 3 (Hard): Transitive consistency chain ranking
"""
from typing import Literal, Optional
from pydantic import Field
from openenv.core.env_server.types import Action, Observation
# ── Actions ───────────────────────────────────────────────────
class PairwiseAction(Action):
"""Action for Task 1: choose which response is better."""
choice: Literal["A", "B", "tie", "skip"] = Field(
...,
description="Choose 'A' if response A is better, 'B' if response B is better, "
"'tie' if they are equal, or 'skip' if undecidable.",
)
justification: Optional[str] = Field(
default=None,
description="Optional reasoning for the choice (not used for grading).",
)
confidence: float = Field(
default=0.8, ge=0.0, le=1.0,
description="Annotator confidence (0.0-1.0)"
)
class LikertAction(Action):
"""Action for Task 2: score a response on 4 quality axes (1-5 each)."""
helpfulness: int = Field(
..., ge=1, le=5,
description="How helpful is the response? 1=not helpful, 5=extremely helpful.",
)
honesty: int = Field(
..., ge=1, le=5,
description="How honest and accurate is the response? 1=misleading, 5=fully honest.",
)
harmlessness: int = Field(
..., ge=1, le=5,
description="How harmless is the response? 1=harmful, 5=completely safe.",
)
instruction_following: int = Field(
..., ge=1, le=5,
description="How well does the response follow the instruction? 1=ignores it, 5=follows perfectly.",
)
class ConsistencyAction(Action):
"""Action for Task 3: rank 4 responses (A, B, C, D) from best to worst."""
ranking: list[Literal["A", "B", "C", "D"]] = Field(
...,
min_length=4,
max_length=4,
description="List of 4 response IDs ordered best to worst, e.g. ['B', 'A', 'D', 'C'].",
)
# ── Observations ──────────────────────────────────────────────
class PairwiseObservation(Observation):
"""Observation for Task 1: a prompt with two candidate responses."""
task_id: str = Field(..., description="Unique task identifier.")
task_type: Literal["pairwise"] = Field(default="pairwise")
prompt: str = Field(..., description="The user prompt / instruction.")
response_a: str = Field(..., description="Candidate response A.")
response_b: str = Field(..., description="Candidate response B.")
reward: float = Field(default=0.0, description="Reward signal from last step.")
done: bool = Field(default=False, description="Whether the episode is complete.")
step_count: int = Field(default=0, description="Current step within the episode.")
info: dict = Field(default_factory=dict, description="Extra debug info.")
class LikertObservation(Observation):
"""Observation for Task 2: a prompt + single response to score on multiple axes."""
task_id: str = Field(..., description="Unique task identifier.")
task_type: Literal["likert"] = Field(default="likert")
prompt: str = Field(..., description="The user prompt / instruction.")
response: str = Field(..., description="The response to evaluate.")
rubric: str = Field(..., description="Scoring rubric to guide evaluation.")
reward: float = Field(default=0.0, description="Reward signal from last step.")
done: bool = Field(default=False, description="Whether the episode is complete.")
step_count: int = Field(default=0, description="Current step within the episode.")
info: dict = Field(default_factory=dict, description="Extra debug info.")
class ConsistencyObservation(Observation):
"""Observation for Task 3: a prompt + 4 responses to rank transitively."""
task_id: str = Field(..., description="Unique task identifier.")
task_type: Literal["consistency"] = Field(default="consistency")
prompt: str = Field(..., description="The user prompt / instruction.")
response_a: str = Field(..., description="Candidate response A.")
response_b: str = Field(..., description="Candidate response B.")
response_c: str = Field(..., description="Candidate response C.")
response_d: str = Field(..., description="Candidate response D.")
reward: float = Field(default=0.0, description="Reward signal from last step.")
done: bool = Field(default=False, description="Whether the episode is complete.")
step_count: int = Field(default=0, description="Current step within the episode.")
info: dict = Field(default_factory=dict, description="Extra debug info.")