""" Configuration classes for the user simulator. This module defines all configuration dataclasses used to configure the simulator behavior, including user competence, timing, and strategies. """ from dataclasses import dataclass, field from typing import Dict, List, Optional, Any, Literal, Union from enum import Enum import os import yaml class CompetenceLevel(Enum): """Competence levels for simulated users. Each level defines a range of accuracy for the simulated annotator: - PERFECT: 100% accuracy (always matches gold standard) - GOOD: 80-90% accuracy - AVERAGE: 60-70% accuracy - POOR: 40-50% accuracy - RANDOM: Random selection (~1/N accuracy for N labels) - ADVERSARIAL: Intentionally wrong (avoids gold standard) """ PERFECT = "perfect" GOOD = "good" AVERAGE = "average" POOR = "poor" RANDOM = "random" ADVERSARIAL = "adversarial" class AnnotationStrategyType(Enum): """Annotation generation strategies. - RANDOM: Uniform random selection from available labels - BIASED: Weighted random selection based on label preferences - LLM: Use an LLM to generate annotations based on text content - PATTERN: Consistent per-user patterns for testing specific behaviors - GOLD_STANDARD: Use gold answer when available, random otherwise - AGENT: Vision-capable LLM that reads structured / multi-modal instance content (dialogue traces, spreadsheets, image fields) and emits a single batched annotation covering every schema for the instance. """ RANDOM = "random" BIASED = "biased" LLM = "llm" PATTERN = "pattern" GOLD_STANDARD = "gold_standard" AGENT = "agent" @dataclass class TimingConfig: """Configuration for annotation timing behavior. Attributes: annotation_time_min: Minimum time per annotation in seconds annotation_time_max: Maximum time per annotation in seconds annotation_time_mean: Mean time for normal distribution annotation_time_std: Standard deviation for normal distribution distribution: Timing distribution model (uniform, normal, exponential) fast_response_threshold: Threshold for flagging suspiciously fast responses session_duration_max: Maximum session duration in minutes (optional) """ annotation_time_min: float = 2.0 annotation_time_max: float = 30.0 annotation_time_mean: float = 10.0 annotation_time_std: float = 5.0 distribution: Literal["uniform", "normal", "exponential"] = "normal" fast_response_threshold: float = 1.0 session_duration_max: Optional[float] = None @dataclass class LLMStrategyConfig: """Configuration for LLM-based annotation strategy. Uses the existing potato.ai endpoint infrastructure. Attributes: endpoint_type: LLM provider (openai, anthropic, ollama, etc.) model: Model name/identifier api_key: API key for cloud providers (can use env var reference) base_url: Base URL for local providers like Ollama temperature: Temperature for generation (0-2) max_tokens: Maximum tokens in response add_noise: Whether to occasionally add noise to LLM outputs noise_rate: Probability of adding noise (0-1) """ endpoint_type: str = "openai" model: Optional[str] = None # Uses provider default if None api_key: Optional[str] = None base_url: Optional[str] = None temperature: float = 0.1 max_tokens: int = 100 add_noise: bool = True noise_rate: float = 0.05 @dataclass class InteractiveConfig: """Configuration for driving live ``interactive_chat`` sessions. When enabled, the simulator runs a multi-turn chat against the server's ``/agent_chat/*`` routes before annotating each instance whose display contains an ``interactive_chat`` field. The simulator plays the user role; the server-side ``agent_proxy`` plays the agent (echo, OpenAI, HTTP, etc. -- whatever the annotation config specifies). Attributes: enabled: Whether to attempt an interactive session per instance. endpoint_type: AI endpoint used to generate the user persona's messages. Defaults to ``ollama`` (text only -- the persona usually doesn't need vision). model: Persona model name. Defaults to provider default. api_key: Optional API key (env-var refs supported). base_url: Optional endpoint base URL. temperature: Sampling temperature for persona messages. max_tokens: Per-message token cap. max_turns: Hard upper bound on turn count per session. persona_system_prompt: System prompt that defines the user persona. Should encourage natural multi-turn behavior and a clear ``DONE`` signal when the task is complete. done_marker: Substring (case-insensitive) the persona emits when it considers the task complete. The runner finishes the session immediately when seen. first_message_template: Template applied to the persona's first message. ``{task}`` is replaced with the task description. If None, the persona generates the first message from scratch. """ enabled: bool = False endpoint_type: str = "ollama" model: Optional[str] = None api_key: Optional[str] = None base_url: Optional[str] = None temperature: float = 0.7 max_tokens: int = 200 max_turns: int = 6 persona_system_prompt: str = ( "You are a curious end-user testing an AI assistant. " "Send concise, natural messages that drive the assistant to " "complete the task. When the assistant has fully completed the " "task, respond with a short acknowledgement and the literal " "marker [DONE]." ) done_marker: str = "[DONE]" first_message_template: Optional[str] = ( "Please help me with this task: {task}" ) @dataclass class AgentStrategyConfig: """Configuration for the agent (vision-LLM) annotation strategy. Drives a vision-capable LLM that consumes structured / multi-modal instance content (dialogue arrays, spreadsheets, image fields) and produces a batched annotation over every schema for the instance. Attributes: endpoint_type: AI endpoint (default ``ollama_vision``). Any vision endpoint registered with ``AIEndpointFactory`` works (``anthropic_vision``, ``openai_vision``, etc.). model: Model identifier (e.g. ``gemma3:4b``, ``llava:latest``, ``llama3.2-vision``). Defaults to provider default. api_key: Cloud-provider API key (env-var refs supported, e.g. ``${ANTHROPIC_API_KEY}``). base_url: Custom endpoint URL (Ollama: ``http://localhost:11434``). temperature: Sampling temperature. max_tokens: Cap on response tokens. max_image_dim: Resize images so the longest edge is at most this many pixels before sending. ``None`` keeps the original. max_image_count: Skip image attachment past this many images per instance (some models cap at 1–4). include_dialogue_text: Render dialogue arrays as ``: `` lines in the prompt. include_spreadsheet: Render spreadsheet/table fields as plain text. max_dialogue_chars: Truncate long dialogue payloads to this many characters in the prompt to fit the model's context window. cache_per_instance: When True (default), one LLM call per instance answers all schemas; subsequent ``generate_annotation`` calls for the same instance return cached results. add_noise: Probability of falling back to a random annotation per schema (mirrors ``LLMStrategyConfig`` so existing competence modeling still applies). noise_rate: Probability used for noise injection (0–1). """ endpoint_type: str = "ollama_vision" model: Optional[str] = None api_key: Optional[str] = None base_url: Optional[str] = None temperature: float = 0.1 max_tokens: int = 800 max_image_dim: Optional[int] = 1024 max_image_count: int = 4 include_dialogue_text: bool = True include_spreadsheet: bool = True max_dialogue_chars: int = 12000 cache_per_instance: bool = True add_noise: bool = False noise_rate: float = 0.0 @dataclass class BiasedStrategyConfig: """Configuration for biased annotation strategy. Attributes: label_weights: Dictionary mapping label names to selection weights. Higher weights mean higher probability of selection. Example: {"positive": 0.6, "negative": 0.3, "neutral": 0.1} """ label_weights: Dict[str, float] = field(default_factory=dict) @dataclass class PatternStrategyConfig: """Configuration for pattern-based annotation strategy. Allows defining specific behavior patterns per user. Attributes: patterns: Dictionary mapping user_id to behavior configuration. Each pattern can specify: - preferred_label: Label this user tends to select - bias_strength: How strongly they prefer it (0-1) - keywords: Text patterns that trigger specific labels """ patterns: Dict[str, Dict[str, Any]] = field(default_factory=dict) @dataclass class UserConfig: """Configuration for a single simulated user. Attributes: user_id: Unique identifier for this user competence: Competence level determining accuracy strategy: Annotation strategy type timing: Timing configuration for this user llm_config: LLM configuration if strategy is LLM biased_config: Bias configuration if strategy is BIASED pattern_config: Pattern configuration if strategy is PATTERN max_annotations: Maximum annotations for this user (optional) """ user_id: str competence: CompetenceLevel = CompetenceLevel.AVERAGE strategy: AnnotationStrategyType = AnnotationStrategyType.RANDOM timing: TimingConfig = field(default_factory=TimingConfig) llm_config: Optional[LLMStrategyConfig] = None biased_config: Optional[BiasedStrategyConfig] = None pattern_config: Optional[PatternStrategyConfig] = None agent_config: Optional[AgentStrategyConfig] = None max_annotations: Optional[int] = None @dataclass class SimulatorConfig: """Master configuration for the user simulator. Attributes: user_count: Number of simulated users to create competence_distribution: Distribution of competence levels (keys are competence level names, values are proportions) users: Explicit list of user configurations (overrides user_count) timing: Global timing configuration (can be overridden per-user) strategy: Default annotation strategy llm_config: LLM configuration for LLM strategy biased_config: Bias configuration for biased strategy gold_standard_file: Path to JSON file with gold standard labels parallel_users: Maximum concurrent users delay_between_users: Delay between starting users (seconds) attention_check_fail_rate: Rate at which users fail attention checks respond_fast_rate: Rate of suspiciously fast responses simulate_wait: Whether to actually wait between annotations output_dir: Directory for output files export_format: Output format (json, csv, jsonl) """ # User configuration user_count: int = 10 competence_distribution: Dict[str, float] = field( default_factory=lambda: {"good": 0.5, "average": 0.3, "poor": 0.2} ) users: List[UserConfig] = field(default_factory=list) # Global timing configuration timing: TimingConfig = field(default_factory=TimingConfig) # Strategy configuration - default to random strategy: AnnotationStrategyType = AnnotationStrategyType.RANDOM llm_config: Optional[LLMStrategyConfig] = None biased_config: Optional[BiasedStrategyConfig] = None agent_config: Optional[AgentStrategyConfig] = None interactive: Optional[InteractiveConfig] = None # Gold standard data for competence-based accuracy gold_standard_file: Optional[str] = None # Execution configuration parallel_users: int = 5 delay_between_users: float = 0.5 # Quality control testing options attention_check_fail_rate: float = 0.0 respond_fast_rate: float = 0.0 # Whether to actually wait (set False for fast testing) simulate_wait: bool = False # Output configuration output_dir: str = "simulator_output" export_format: Literal["json", "csv", "jsonl"] = "json" @classmethod def from_yaml(cls, yaml_path: str) -> "SimulatorConfig": """Load configuration from YAML file. Args: yaml_path: Path to YAML configuration file Returns: SimulatorConfig instance """ with open(yaml_path, "r") as f: data = yaml.safe_load(f) return cls._parse_config(data) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SimulatorConfig": """Load configuration from dictionary. Args: data: Configuration dictionary Returns: SimulatorConfig instance """ return cls._parse_config(data) @classmethod def _parse_config(cls, data: Dict[str, Any]) -> "SimulatorConfig": """Parse configuration from dictionary. Args: data: Raw configuration dictionary Returns: SimulatorConfig instance """ # Handle nested 'simulator' key if present if "simulator" in data: data = data["simulator"] # Parse timing config timing = TimingConfig() if "timing" in data: timing_data = data["timing"] if "annotation_time" in timing_data: at = timing_data["annotation_time"] timing = TimingConfig( annotation_time_min=at.get("min", 2.0), annotation_time_max=at.get("max", 30.0), annotation_time_mean=at.get("mean", 10.0), annotation_time_std=at.get("std", 5.0), distribution=at.get("distribution", "normal"), fast_response_threshold=timing_data.get( "fast_response_threshold", 1.0 ), session_duration_max=timing_data.get("session_duration_max"), ) else: timing = TimingConfig( annotation_time_min=timing_data.get("annotation_time_min", 2.0), annotation_time_max=timing_data.get("annotation_time_max", 30.0), annotation_time_mean=timing_data.get("annotation_time_mean", 10.0), annotation_time_std=timing_data.get("annotation_time_std", 5.0), distribution=timing_data.get("distribution", "normal"), fast_response_threshold=timing_data.get( "fast_response_threshold", 1.0 ), session_duration_max=timing_data.get("session_duration_max"), ) # Parse LLM config llm_config = None if "llm_config" in data: llm_data = data["llm_config"] # Handle environment variable references api_key = llm_data.get("api_key") if api_key and api_key.startswith("${") and api_key.endswith("}"): env_var = api_key[2:-1] api_key = os.environ.get(env_var) llm_config = LLMStrategyConfig( endpoint_type=llm_data.get("endpoint_type", "openai"), model=llm_data.get("model"), api_key=api_key, base_url=llm_data.get("base_url"), temperature=llm_data.get("temperature", 0.1), max_tokens=llm_data.get("max_tokens", 100), add_noise=llm_data.get("add_noise", True), noise_rate=llm_data.get("noise_rate", 0.05), ) # Parse biased config biased_config = None if "biased_config" in data: biased_config = BiasedStrategyConfig( label_weights=data["biased_config"].get("label_weights", {}) ) # Parse interactive (chat-driving) config interactive_config = None if "interactive" in data: ic = data["interactive"] api_key = ic.get("api_key") if api_key and api_key.startswith("${") and api_key.endswith("}"): api_key = os.environ.get(api_key[2:-1]) kwargs = { "enabled": ic.get("enabled", True), "endpoint_type": ic.get("endpoint_type", "ollama"), "model": ic.get("model"), "api_key": api_key, "base_url": ic.get("base_url"), "temperature": ic.get("temperature", 0.7), "max_tokens": ic.get("max_tokens", 200), "max_turns": ic.get("max_turns", 6), "done_marker": ic.get("done_marker", "[DONE]"), } if "persona_system_prompt" in ic: kwargs["persona_system_prompt"] = ic["persona_system_prompt"] if "first_message_template" in ic: kwargs["first_message_template"] = ic["first_message_template"] interactive_config = InteractiveConfig(**kwargs) # Parse agent (vision-LLM) config agent_config = None if "agent_config" in data: ad = data["agent_config"] api_key = ad.get("api_key") if api_key and api_key.startswith("${") and api_key.endswith("}"): api_key = os.environ.get(api_key[2:-1]) agent_config = AgentStrategyConfig( endpoint_type=ad.get("endpoint_type", "ollama_vision"), model=ad.get("model"), api_key=api_key, base_url=ad.get("base_url"), temperature=ad.get("temperature", 0.1), max_tokens=ad.get("max_tokens", 800), max_image_dim=ad.get("max_image_dim", 1024), max_image_count=ad.get("max_image_count", 4), include_dialogue_text=ad.get("include_dialogue_text", True), include_spreadsheet=ad.get("include_spreadsheet", True), max_dialogue_chars=ad.get("max_dialogue_chars", 12000), cache_per_instance=ad.get("cache_per_instance", True), add_noise=ad.get("add_noise", False), noise_rate=ad.get("noise_rate", 0.0), ) # Parse strategy strategy_str = data.get("strategy", "random") try: strategy = AnnotationStrategyType(strategy_str) except ValueError: strategy = AnnotationStrategyType.RANDOM # Parse users section users_data = data.get("users", {}) user_count = users_data.get("count", data.get("user_count", 10)) competence_dist = users_data.get( "competence_distribution", data.get( "competence_distribution", {"good": 0.5, "average": 0.3, "poor": 0.2} ), ) # Parse execution config execution = data.get("execution", {}) parallel_users = execution.get("parallel_users", data.get("parallel_users", 5)) delay_between = execution.get( "delay_between_users", data.get("delay_between_users", 0.5) ) max_annotations = execution.get("max_annotations_per_user") # Parse QC config qc_config = data.get("quality_control", {}) attention_fail_rate = qc_config.get( "attention_check_fail_rate", data.get("attention_check_fail_rate", 0.0) ) respond_fast_rate = qc_config.get( "respond_fast_rate", data.get("respond_fast_rate", 0.0) ) # Parse output config output_config = data.get("output", {}) output_dir = output_config.get("dir", data.get("output_dir", "simulator_output")) export_format = output_config.get( "format", data.get("export_format", "json") ) return cls( user_count=user_count, competence_distribution=competence_dist, timing=timing, strategy=strategy, llm_config=llm_config, biased_config=biased_config, agent_config=agent_config, interactive=interactive_config, gold_standard_file=data.get("gold_standard_file"), parallel_users=parallel_users, delay_between_users=delay_between, attention_check_fail_rate=attention_fail_rate, respond_fast_rate=respond_fast_rate, simulate_wait=data.get("simulate_wait", False), output_dir=output_dir, export_format=export_format, ) def to_dict(self) -> Dict[str, Any]: """Convert configuration to dictionary. Returns: Configuration as dictionary """ return { "user_count": self.user_count, "competence_distribution": self.competence_distribution, "timing": { "annotation_time_min": self.timing.annotation_time_min, "annotation_time_max": self.timing.annotation_time_max, "annotation_time_mean": self.timing.annotation_time_mean, "annotation_time_std": self.timing.annotation_time_std, "distribution": self.timing.distribution, "fast_response_threshold": self.timing.fast_response_threshold, "session_duration_max": self.timing.session_duration_max, }, "strategy": self.strategy.value, "llm_config": ( { "endpoint_type": self.llm_config.endpoint_type, "model": self.llm_config.model, "temperature": self.llm_config.temperature, "max_tokens": self.llm_config.max_tokens, "add_noise": self.llm_config.add_noise, "noise_rate": self.llm_config.noise_rate, } if self.llm_config else None ), "biased_config": ( {"label_weights": self.biased_config.label_weights} if self.biased_config else None ), "gold_standard_file": self.gold_standard_file, "parallel_users": self.parallel_users, "delay_between_users": self.delay_between_users, "attention_check_fail_rate": self.attention_check_fail_rate, "respond_fast_rate": self.respond_fast_rate, "simulate_wait": self.simulate_wait, "output_dir": self.output_dir, "export_format": self.export_format, }