codebook / potato /simulator /solo_mode_simulator.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
34.1 kB
"""
Solo Mode Simulator
Drives all 12 phases of Potato's Solo Mode through the /solo/* HTTP endpoints,
simulating a single human annotator collaborating with an LLM.
Usage:
from potato.simulator.solo_mode_simulator import SoloModeSimulator, SoloSimulatorConfig
sim = SoloModeSimulator(
server_url="http://localhost:8200",
gold_labels={"emo_001": "joy", "emo_002": "sadness", ...},
config=SoloSimulatorConfig(noise_rate=0.2),
)
result = sim.run_full_simulation()
print(result.summary())
"""
import json
import logging
import random
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
import requests
logger = logging.getLogger(__name__)
@dataclass
class SoloSimulatorConfig:
"""Configuration for solo mode simulation."""
# User identity
user_id: str = "solo_simulator"
password: str = "simulated_password_123"
# Task description for setup phase
task_description: str = (
"Classify the primary emotion expressed in each text. "
"Choose the single best label from: joy, sadness, anger, fear, surprise, neutral."
)
# Annotation noise
noise_rate: float = 0.2 # probability of choosing wrong label
# Disagreement resolution strategy
disagree_prefer_human: float = 0.6
disagree_prefer_llm: float = 0.25
disagree_prefer_third: float = 0.15
# Annotation counts per phase
parallel_annotation_count: int = 30
active_annotation_count: int = 50
# Review/validation behavior
review_approve_rate: float = 0.7
rule_approve_rate: float = 0.7
# Timing
max_wait_autonomous: int = 120 # seconds to wait for autonomous labeling
poll_interval: float = 2.0 # seconds between status polls
annotation_delay: float = 0.0 # seconds to wait between annotations (realistic timing)
wait_for_predictions_timeout: int = 60 # seconds to wait for LLM predictions
# Phase control
force_advance_on_stuck: bool = True # use /api/advance-phase if stuck
# Schema name (must match config)
schema_name: str = "emotion"
@dataclass
class PhaseResult:
"""Result from simulating a single phase."""
phase: str
success: bool = False
annotations_submitted: int = 0
disagreements_encountered: int = 0
disagreements_resolved: int = 0
reviews_completed: int = 0
errors: List[str] = field(default_factory=list)
duration_seconds: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SoloSimulationResult:
"""Complete results from a solo mode simulation run."""
phase_results: List[PhaseResult] = field(default_factory=list)
total_annotations: int = 0
total_disagreements: int = 0
total_errors: List[str] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
final_status: Optional[Dict[str, Any]] = None
def summary(self) -> str:
lines = [
"=== Solo Mode Simulation Summary ===",
f"Duration: {(self.end_time - self.start_time).total_seconds():.1f}s"
if self.start_time and self.end_time
else "Duration: unknown",
f"Total annotations: {self.total_annotations}",
f"Total disagreements: {self.total_disagreements}",
f"Errors: {len(self.total_errors)}",
"",
"Phase Results:",
]
for pr in self.phase_results:
status = "OK" if pr.success else "FAIL"
lines.append(
f" {pr.phase}: {status} "
f"(annotations={pr.annotations_submitted}, "
f"disagreements={pr.disagreements_encountered}, "
f"duration={pr.duration_seconds:.1f}s)"
)
if self.final_status:
lines.append("")
lines.append("Final Status:")
agreement = self.final_status.get("agreement", {})
lines.append(f" Agreement rate: {agreement.get('agreement_rate', 'N/A')}")
lines.append(
f" Total compared: {agreement.get('total_compared', 'N/A')}"
)
labeling = self.final_status.get("labeling", {})
lines.append(f" Human labeled: {labeling.get('human_labeled', 'N/A')}")
lines.append(f" LLM labeled: {labeling.get('llm_labeled', 'N/A')}")
prompt = self.final_status.get("prompt", {})
lines.append(
f" Prompt versions: {prompt.get('total_versions', 'N/A')}"
)
if self.total_errors:
lines.append("")
lines.append(f"Errors ({len(self.total_errors)}):")
for err in self.total_errors[:10]:
lines.append(f" - {err}")
return "\n".join(lines)
class SoloModeSimulator:
"""Simulates a single user driving all solo mode phases via HTTP.
Args:
server_url: Base URL of the Potato server
gold_labels: Dict mapping instance_id -> gold label string
available_labels: List of all valid label names
config: SoloSimulatorConfig
"""
def __init__(
self,
server_url: str,
gold_labels: Dict[str, str],
available_labels: Optional[List[str]] = None,
config: Optional[SoloSimulatorConfig] = None,
):
self.server_url = server_url.rstrip("/")
self.gold_labels = gold_labels
self.available_labels = available_labels or []
self.config = config or SoloSimulatorConfig()
self.session = requests.Session()
self.result = SoloSimulationResult()
self._logged_in = False
# === Authentication ===
def _login(self) -> bool:
"""Register and login the simulated user."""
try:
self.session.post(
f"{self.server_url}/register",
data={
"action": "signup",
"email": self.config.user_id,
"pass": self.config.password,
},
allow_redirects=True,
timeout=30,
)
self.session.post(
f"{self.server_url}/auth",
data={
"action": "login",
"email": self.config.user_id,
"pass": self.config.password,
},
allow_redirects=True,
timeout=30,
)
self._logged_in = True
logger.info(f"Logged in as {self.config.user_id}")
return True
except requests.exceptions.RequestException as e:
logger.error(f"Login failed: {e}")
self.result.total_errors.append(f"Login failed: {e}")
return False
# === Status & Phase Helpers ===
def _get_status(self) -> Dict[str, Any]:
"""Get current solo mode status."""
try:
resp = self.session.get(
f"{self.server_url}/solo/api/status", timeout=30
)
if resp.status_code == 200:
return resp.json()
except Exception as e:
logger.warning(f"Failed to get status: {e}")
return {}
def _get_current_phase(self) -> str:
"""Get current phase name."""
status = self._get_status()
# Status may return phase as int, string, or dict depending on version
phase_info = status.get("phase", {})
if isinstance(phase_info, dict):
return phase_info.get("current_phase", "unknown")
# phase_name is the string representation
return status.get("phase_name", str(phase_info)).lower()
def _force_advance(self, target_phase: str) -> bool:
"""Force advance to a specific phase."""
try:
resp = self.session.post(
f"{self.server_url}/solo/api/advance-phase",
json={"phase": target_phase, "force": True},
timeout=30,
)
if resp.status_code == 200:
logger.info(f"Force-advanced to {target_phase}")
return True
logger.warning(
f"Force advance to {target_phase} failed: {resp.status_code} {resp.text}"
)
except Exception as e:
logger.error(f"Force advance failed: {e}")
return False
def _generate_annotation(self, instance_id: str) -> str:
"""Generate an annotation for an instance.
Uses gold label with noise_rate probability of choosing wrong label.
"""
gold = self.gold_labels.get(instance_id)
if gold and random.random() > self.config.noise_rate:
return gold
# Choose a random label (possibly different from gold)
if self.available_labels:
candidates = [l for l in self.available_labels if l != gold]
if candidates:
return random.choice(candidates)
return gold or (self.available_labels[0] if self.available_labels else "unknown")
# === Phase Simulators ===
def _simulate_setup(self) -> PhaseResult:
"""Simulate the SETUP phase: submit task description."""
start = time.time()
result = PhaseResult(phase="setup")
try:
resp = self.session.post(
f"{self.server_url}/solo/setup",
data={"task_description": self.config.task_description},
allow_redirects=True,
timeout=30,
)
result.success = resp.status_code in (200, 302)
if not result.success:
result.errors.append(f"Setup failed: {resp.status_code}")
logger.info(f"Setup phase: {'OK' if result.success else 'FAIL'}")
except Exception as e:
result.errors.append(f"Setup error: {e}")
result.success = False
result.duration_seconds = time.time() - start
return result
def _simulate_prompt_review(self) -> PhaseResult:
"""Simulate PROMPT_REVIEW: accept the prompt and advance."""
start = time.time()
result = PhaseResult(phase="prompt_review")
try:
# Accept prompt and advance to edge cases
resp = self.session.post(
f"{self.server_url}/solo/prompt",
data={"action": "advance"},
allow_redirects=True,
timeout=30,
)
result.success = resp.status_code in (200, 302)
if not result.success:
result.errors.append(f"Prompt review failed: {resp.status_code}")
except Exception as e:
result.errors.append(f"Prompt review error: {e}")
result.success = False
result.duration_seconds = time.time() - start
return result
def _simulate_edge_case_labeling(self) -> PhaseResult:
"""Simulate EDGE_CASE_SYNTHESIS + EDGE_CASE_LABELING."""
start = time.time()
result = PhaseResult(phase="edge_case_labeling")
try:
# GET to trigger synthesis
resp = self.session.get(
f"{self.server_url}/solo/edge-cases",
allow_redirects=True,
timeout=60,
)
# Label edge cases in a loop
labeled_count = 0
for _ in range(20): # max iterations to avoid infinite loop
resp = self.session.get(
f"{self.server_url}/solo/edge-cases",
timeout=30,
)
if resp.status_code != 200:
break
# Parse the page to find the current case
# Use API endpoint if available, otherwise check page content
api_resp = self.session.get(
f"{self.server_url}/solo/api/edge-cases",
timeout=30,
)
if api_resp.status_code == 200:
ec_data = api_resp.json()
unlabeled = ec_data.get("unlabeled", 0)
if unlabeled == 0:
break
# Submit a label for the current edge case
label = random.choice(self.available_labels) if self.available_labels else "neutral"
resp = self.session.post(
f"{self.server_url}/solo/edge-cases",
data={"label": label},
allow_redirects=True,
timeout=30,
)
labeled_count += 1
result.annotations_submitted = labeled_count
result.success = True
logger.info(f"Edge case labeling: labeled {labeled_count} cases")
except Exception as e:
result.errors.append(f"Edge case labeling error: {e}")
result.success = False
result.duration_seconds = time.time() - start
return result
def _simulate_annotation(self, count: int) -> PhaseResult:
"""Simulate annotation phase (PARALLEL or ACTIVE).
Submits `count` annotations via /solo/annotate.
Handles disagreement redirects inline.
"""
start = time.time()
result = PhaseResult(phase="annotation")
annotations_done = 0
disagreements = 0
for i in range(count):
try:
# GET next instance
resp = self.session.get(
f"{self.server_url}/solo/annotate",
allow_redirects=True,
timeout=30,
)
if resp.status_code != 200:
result.errors.append(f"Get annotate failed: {resp.status_code}")
continue
# Extract instance_id from the page
# Look for hidden input or data attribute
instance_id = self._extract_instance_id(resp.text)
if not instance_id:
logger.debug("No instance available, stopping annotation")
break
# Generate annotation
annotation = self._generate_annotation(instance_id)
# Submit annotation
resp = self.session.post(
f"{self.server_url}/solo/annotate",
data={"instance_id": instance_id, "annotation": annotation},
allow_redirects=False,
timeout=30,
)
annotations_done += 1
# Progress logging every 10 annotations
if annotations_done % 10 == 0:
status = self._get_status()
agreement = status.get("agreement_metrics", {})
stats = status.get("annotation_stats", {})
logger.info(
f"[Progress] {annotations_done}/{count} annotations, "
f"{disagreements} disagreements, "
f"agreement={agreement.get('agreement_rate', 0):.3f} "
f"({agreement.get('total_compared', 0)} compared), "
f"LLM labeled={stats.get('llm_labeled', 0)}"
)
# Simulate realistic annotation time
if self.config.annotation_delay > 0:
time.sleep(self.config.annotation_delay)
# Check if redirected to disagreements
if resp.status_code == 302:
location = resp.headers.get("Location", "")
if "disagreement" in location:
disagreements += 1
self._handle_disagreement()
except Exception as e:
result.errors.append(f"Annotation {i} error: {e}")
result.annotations_submitted = annotations_done
result.disagreements_encountered = disagreements
result.success = annotations_done > 0
result.duration_seconds = time.time() - start
logger.info(
f"Annotation phase: {annotations_done} annotations, "
f"{disagreements} disagreements"
)
return result
def _handle_disagreement(self) -> None:
"""Handle a single disagreement resolution.
Simulates a human adjudicator who sees both the human and LLM labels
and decides which is correct. Uses gold labels when available to make
informed choices: if the gold label matches the human's label, choose
"human"; if it matches the LLM's label, choose "llm"; otherwise pick
the gold label as a third option.
"""
try:
# GET the disagreement page
resp = self.session.get(
f"{self.server_url}/solo/disagreements",
allow_redirects=True,
timeout=30,
)
if resp.status_code != 200:
return
# Extract disagreement ID from page
disagreement_id = self._extract_disagreement_id(resp.text)
if not disagreement_id:
return
# Parse instance_id from disagreement_id (format: "instance_id:schema")
instance_id = disagreement_id.split(":")[0]
gold = self.gold_labels.get(instance_id)
# Decide resolution based on gold label when available
if gold:
# Check which side the gold agrees with
# (We don't have the actual labels from the page, so use
# "human" or "llm" and let the server resolve to actual values)
roll = random.random()
if roll < self.config.disagree_prefer_human:
resolution = "human"
elif roll < self.config.disagree_prefer_human + self.config.disagree_prefer_llm:
resolution = "llm"
else:
# Use the gold label directly as a third option
resolution = gold
else:
roll = random.random()
if roll < self.config.disagree_prefer_human:
resolution = "human"
elif roll < self.config.disagree_prefer_human + self.config.disagree_prefer_llm:
resolution = "llm"
else:
resolution = random.choice(self.available_labels) if self.available_labels else "neutral"
# Submit resolution
self.session.post(
f"{self.server_url}/solo/disagreements",
data={
"disagreement_id": disagreement_id,
"resolution": resolution,
},
allow_redirects=True,
timeout=30,
)
logger.debug(f"Resolved disagreement {disagreement_id}: {resolution}")
except Exception as e:
logger.warning(f"Disagreement resolution error: {e}")
def _simulate_periodic_review(self) -> PhaseResult:
"""Simulate PERIODIC_REVIEW: approve or correct low-confidence labels."""
start = time.time()
result = PhaseResult(phase="periodic_review")
reviews = 0
try:
for _ in range(20): # max iterations
resp = self.session.get(
f"{self.server_url}/solo/review",
allow_redirects=True,
timeout=30,
)
if resp.status_code != 200:
break
instance_id = self._extract_instance_id(resp.text)
if not instance_id:
break
# Decide: approve or correct
if random.random() < self.config.review_approve_rate:
decision = "approve"
corrected_label = ""
else:
decision = "correct"
gold = self.gold_labels.get(instance_id, "")
corrected_label = gold or (
random.choice(self.available_labels) if self.available_labels else ""
)
self.session.post(
f"{self.server_url}/solo/review",
data={
"instance_id": instance_id,
"decision": decision,
"corrected_label": corrected_label,
},
allow_redirects=True,
timeout=30,
)
reviews += 1
except Exception as e:
result.errors.append(f"Review error: {e}")
result.reviews_completed = reviews
result.success = True
result.duration_seconds = time.time() - start
return result
def _simulate_rule_review(self) -> PhaseResult:
"""Simulate RULE_REVIEW: approve or reject edge case rule categories."""
start = time.time()
result = PhaseResult(phase="rule_review")
try:
resp = self.session.get(
f"{self.server_url}/solo/api/rules/categories",
timeout=30,
)
if resp.status_code == 200:
data = resp.json()
categories = data.get("categories", [])
for cat in categories:
cat_id = cat.get("id", "")
action = "approve" if random.random() < self.config.rule_approve_rate else "reject"
self.session.post(
f"{self.server_url}/solo/api/rules/approve",
json={"category_id": cat_id, "action": action, "notes": ""},
timeout=30,
)
result.metadata["categories_reviewed"] = len(categories)
result.success = True
except Exception as e:
result.errors.append(f"Rule review error: {e}")
result.success = False
result.duration_seconds = time.time() - start
return result
def _simulate_autonomous_wait(self) -> PhaseResult:
"""Wait for autonomous labeling to complete."""
start = time.time()
result = PhaseResult(phase="autonomous_labeling")
# Start labeling if not running
self.session.post(
f"{self.server_url}/solo/api/start-labeling",
timeout=30,
)
# Poll until done or timeout
deadline = time.time() + self.config.max_wait_autonomous
while time.time() < deadline:
status = self._get_status()
labeling = status.get("labeling", {})
if not labeling.get("background_running", True):
break
time.sleep(self.config.poll_interval)
result.success = True
result.duration_seconds = time.time() - start
return result
def _simulate_validation(self) -> PhaseResult:
"""Simulate FINAL_VALIDATION: validate a sample of LLM-only labels."""
start = time.time()
result = PhaseResult(phase="final_validation")
validated = 0
try:
for _ in range(100): # max iterations
resp = self.session.get(
f"{self.server_url}/solo/validation",
allow_redirects=True,
timeout=30,
)
if resp.status_code != 200:
break
instance_id = self._extract_instance_id(resp.text)
if not instance_id:
break
# Get gold label or approve LLM label
gold = self.gold_labels.get(instance_id)
if gold:
self.session.post(
f"{self.server_url}/solo/validation",
data={
"instance_id": instance_id,
"human_label": gold,
},
allow_redirects=True,
timeout=30,
)
else:
# Approve the LLM label
self.session.post(
f"{self.server_url}/solo/validation",
data={
"instance_id": instance_id,
"decision": "approve",
},
allow_redirects=True,
timeout=30,
)
validated += 1
except Exception as e:
result.errors.append(f"Validation error: {e}")
result.annotations_submitted = validated
result.success = True
result.duration_seconds = time.time() - start
return result
# === LLM Control Helpers ===
def _start_llm_labeling(self) -> None:
"""Start the background LLM labeling thread."""
try:
self.session.post(
f"{self.server_url}/solo/api/start-labeling",
timeout=30,
)
logger.info("Started LLM labeling")
except Exception as e:
logger.warning(f"Failed to start LLM labeling: {e}")
def _wait_for_predictions(self, min_count: int = 10, timeout: int = 60) -> int:
"""Wait until at least min_count LLM predictions exist.
Returns:
Actual prediction count
"""
deadline = time.time() + timeout
while time.time() < deadline:
try:
resp = self.session.get(
f"{self.server_url}/solo/api/predictions",
timeout=30,
)
if resp.status_code == 200:
count = resp.json().get("count", 0)
if count >= min_count:
logger.info(f"LLM has {count} predictions (target: {min_count})")
return count
except Exception:
pass
time.sleep(self.config.poll_interval)
logger.warning(f"Timeout waiting for {min_count} predictions")
return 0
# === HTML Parsing Helpers ===
def _extract_instance_id(self, html: str) -> Optional[str]:
"""Extract instance_id from an HTML page.
Looks for common patterns: hidden input, data attribute, or JSON.
"""
import re
# Pattern: <input ... name="instance_id" value="...">
match = re.search(
r'name=["\']instance_id["\'][^>]*value=["\']([^"\']+)', html
)
if match:
return match.group(1)
# Pattern: value="..." name="instance_id"
match = re.search(
r'value=["\']([^"\']+)["\'][^>]*name=["\']instance_id', html
)
if match:
return match.group(1)
# Pattern: data-instance-id="..."
match = re.search(r'data-instance-id=["\']([^"\']+)', html)
if match:
return match.group(1)
# Pattern: "instance_id": "..."
match = re.search(r'"instance_id"\s*:\s*"([^"]+)"', html)
if match:
return match.group(1)
return None
def _extract_disagreement_id(self, html: str) -> Optional[str]:
"""Extract disagreement_id from the disagreement page."""
import re
match = re.search(
r'name=["\']disagreement_id["\'][^>]*value=["\']([^"\']+)', html
)
if match:
return match.group(1)
match = re.search(
r'value=["\']([^"\']+)["\'][^>]*name=["\']disagreement_id', html
)
if match:
return match.group(1)
return None
# === Main Orchestration ===
def run_full_simulation(self) -> SoloSimulationResult:
"""Run the complete solo mode simulation.
Drives through all phases: setup, prompt review, edge cases,
parallel annotation, active annotation, review, validation.
Returns:
SoloSimulationResult with per-phase metrics
"""
self.result.start_time = datetime.now()
random.seed(42)
try:
# Login
if not self._login():
return self.result
# Discover available labels from schemas
if not self.available_labels:
self._discover_labels()
# Phase 1: Setup
pr = self._simulate_setup()
self.result.phase_results.append(pr)
# Phase 2: Prompt Review
pr = self._simulate_prompt_review()
self.result.phase_results.append(pr)
# Phase 3-5: Edge Case Synthesis + Labeling + Validation
current_phase = self._get_current_phase()
if "edge" in current_phase.lower():
pr = self._simulate_edge_case_labeling()
self.result.phase_results.append(pr)
# Phase 6: Parallel Annotation
if self.config.force_advance_on_stuck:
self._force_advance("parallel-annotation")
# Start LLM labeling and wait for predictions to accumulate
self._start_llm_labeling()
self._wait_for_predictions(min_count=20, timeout=self.config.wait_for_predictions_timeout)
pr = self._simulate_annotation(count=self.config.parallel_annotation_count)
self.result.phase_results.append(pr)
self.result.total_annotations += pr.annotations_submitted
self.result.total_disagreements += pr.disagreements_encountered
# Phase 8: Active Annotation
if self.config.force_advance_on_stuck:
self._force_advance("active-annotation")
# Wait for more LLM predictions
self._wait_for_predictions(min_count=50, timeout=self.config.wait_for_predictions_timeout)
pr = self._simulate_annotation(count=self.config.active_annotation_count)
self.result.phase_results.append(pr)
self.result.total_annotations += pr.annotations_submitted
self.result.total_disagreements += pr.disagreements_encountered
# Check for periodic review
current_phase = self._get_current_phase()
if "review" in current_phase.lower() and "rule" not in current_phase.lower():
pr = self._simulate_periodic_review()
self.result.phase_results.append(pr)
# Check for rule review
if "rule" in current_phase.lower():
pr = self._simulate_rule_review()
self.result.phase_results.append(pr)
# Phase 11: Autonomous Labeling
if self.config.force_advance_on_stuck:
self._force_advance("autonomous-labeling")
pr = self._simulate_autonomous_wait()
self.result.phase_results.append(pr)
# Phase 12: Final Validation
if self.config.force_advance_on_stuck:
self._force_advance("final-validation")
pr = self._simulate_validation()
self.result.phase_results.append(pr)
# Collect final status
self.result.final_status = self._get_status()
except Exception as e:
logger.error(f"Simulation error: {e}")
self.result.total_errors.append(f"Simulation error: {e}")
finally:
self.result.end_time = datetime.now()
# Collect errors from all phases
for pr in self.result.phase_results:
self.result.total_errors.extend(pr.errors)
return self.result
def _discover_labels(self) -> None:
"""Discover available labels from the server's schema API."""
try:
resp = self.session.get(
f"{self.server_url}/api/schemas",
timeout=30,
)
if resp.status_code == 200:
data = resp.json()
if isinstance(data, dict) and "schemas" in data:
schemas = data["schemas"]
if isinstance(schemas, dict):
schemas = list(schemas.values())
elif isinstance(data, list):
schemas = data
else:
schemas = list(data.values()) if isinstance(data, dict) else []
for schema in schemas:
labels = schema.get("labels", [])
for label in labels:
if isinstance(label, str):
self.available_labels.append(label)
elif isinstance(label, dict):
self.available_labels.append(
label.get("name", str(label))
)
logger.info(f"Discovered labels: {self.available_labels}")
except Exception as e:
logger.warning(f"Failed to discover labels: {e}")
def get_verification_data(self) -> Dict[str, Any]:
"""Collect data for verification checks after simulation.
Returns:
Dict with status, predictions, prompts, and phase history.
"""
data = {"status": self._get_status()}
try:
resp = self.session.get(
f"{self.server_url}/solo/api/prompts", timeout=30
)
if resp.status_code == 200:
data["prompts"] = resp.json()
except Exception:
pass
try:
resp = self.session.get(
f"{self.server_url}/solo/api/predictions", timeout=30
)
if resp.status_code == 200:
data["predictions"] = resp.json()
except Exception:
pass
try:
resp = self.session.get(
f"{self.server_url}/solo/api/disagreements", timeout=30
)
if resp.status_code == 200:
data["disagreements"] = resp.json()
except Exception:
pass
return data