Spaces:
Paused
Paused
| """ | |
| Core user simulator class. | |
| This module provides the SimulatedUser class that simulates a single | |
| annotator interacting with the Potato annotation platform via its API. | |
| """ | |
| import logging | |
| import random | |
| from typing import Dict, List, Any, Optional | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| import requests | |
| from .config import ( | |
| UserConfig, | |
| SimulatorConfig, | |
| CompetenceLevel, | |
| AnnotationStrategyType, | |
| TimingConfig, | |
| InteractiveConfig, | |
| ) | |
| from .competence_profiles import CompetenceProfile, create_competence_profile | |
| from .annotation_strategies import AnnotationStrategy, create_strategy | |
| from .timing_models import TimingModel, NoWaitTimingModel | |
| from .interactive_runner import InteractiveSessionRunner | |
| logger = logging.getLogger(__name__) | |
| class AnnotationRecord: | |
| """Record of a single annotation submission. | |
| Attributes: | |
| instance_id: ID of the annotated instance | |
| schema_name: Name of the annotation schema | |
| annotation: The annotation data submitted | |
| response_time: Time taken to annotate (seconds) | |
| timestamp: When the annotation was submitted | |
| was_attention_check: Whether this was an attention check item | |
| attention_check_passed: Result of attention check (if applicable) | |
| was_gold_standard: Whether this was a gold standard item | |
| gold_standard_correct: Whether gold standard was answered correctly | |
| """ | |
| instance_id: str | |
| schema_name: str | |
| annotation: Dict[str, Any] | |
| response_time: float | |
| timestamp: datetime | |
| was_attention_check: bool = False | |
| attention_check_passed: Optional[bool] = None | |
| was_gold_standard: bool = False | |
| gold_standard_correct: Optional[bool] = None | |
| class UserSimulationResult: | |
| """Results from a user simulation session. | |
| Attributes: | |
| user_id: ID of the simulated user | |
| annotations: List of annotation records | |
| total_time: Total simulation time in seconds | |
| attention_checks_passed: Number of passed attention checks | |
| attention_checks_failed: Number of failed attention checks | |
| gold_standard_correct: Number of correct gold standard answers | |
| gold_standard_incorrect: Number of incorrect gold standard answers | |
| errors: List of error messages encountered | |
| start_time: When simulation started | |
| end_time: When simulation ended | |
| was_blocked: Whether user was blocked by quality control | |
| """ | |
| user_id: str | |
| annotations: List[AnnotationRecord] = field(default_factory=list) | |
| total_time: float = 0.0 | |
| attention_checks_passed: int = 0 | |
| attention_checks_failed: int = 0 | |
| gold_standard_correct: int = 0 | |
| gold_standard_incorrect: int = 0 | |
| errors: List[str] = field(default_factory=list) | |
| start_time: Optional[datetime] = None | |
| end_time: Optional[datetime] = None | |
| was_blocked: bool = False | |
| class SimulatedUser: | |
| """Simulates a single user annotating items via the Potato API. | |
| The SimulatedUser handles: | |
| - Authentication (login/registration) | |
| - Fetching annotation items | |
| - Generating annotations based on strategy | |
| - Submitting annotations | |
| - Navigating between items | |
| - Tracking quality control results | |
| """ | |
| def __init__( | |
| self, | |
| user_config: UserConfig, | |
| server_url: str, | |
| gold_standards: Optional[Dict[str, Dict[str, Any]]] = None, | |
| simulate_wait: bool = False, | |
| attention_check_fail_rate: float = 0.0, | |
| respond_fast_rate: float = 0.0, | |
| interactive_config: Optional[InteractiveConfig] = None, | |
| ): | |
| """Initialize simulated user. | |
| Args: | |
| user_config: Configuration for this user | |
| server_url: Base URL of the Potato server | |
| gold_standards: Optional gold standard answers keyed by instance_id | |
| simulate_wait: Whether to actually wait between annotations | |
| attention_check_fail_rate: Rate at which to fail attention checks | |
| respond_fast_rate: Rate of suspiciously fast responses | |
| """ | |
| self.config = user_config | |
| self.server_url = server_url.rstrip("/") | |
| self.gold_standards = gold_standards or {} | |
| self.attention_check_fail_rate = attention_check_fail_rate | |
| self.respond_fast_rate = respond_fast_rate | |
| # Initialize components | |
| self.competence = create_competence_profile(user_config.competence) | |
| self.strategy = self._create_strategy() | |
| # Create timing model based on simulate_wait setting | |
| if simulate_wait: | |
| self.timing = TimingModel(user_config.timing) | |
| else: | |
| self.timing = NoWaitTimingModel(user_config.timing) | |
| # Session and state | |
| self.session = requests.Session() | |
| self.logged_in = False | |
| self.current_instance_id: Optional[str] = None | |
| self.schemas: List[Dict[str, Any]] = [] | |
| # Optional interactive_chat driver | |
| self.interactive_runner: Optional[InteractiveSessionRunner] = None | |
| if interactive_config and interactive_config.enabled: | |
| self.interactive_runner = InteractiveSessionRunner( | |
| interactive_config, server_url | |
| ) | |
| # Results tracking | |
| self.result = UserSimulationResult(user_id=user_config.user_id) | |
| def _create_strategy(self) -> AnnotationStrategy: | |
| """Create the annotation strategy for this user. | |
| Returns: | |
| AnnotationStrategy instance | |
| """ | |
| return create_strategy( | |
| strategy_type=self.config.strategy, | |
| llm_config=self.config.llm_config, | |
| biased_config=self.config.biased_config, | |
| pattern_config=self.config.pattern_config, | |
| agent_config=self.config.agent_config, | |
| user_id=self.config.user_id, | |
| ) | |
| def login(self) -> bool: | |
| """Login or register the simulated user. | |
| Attempts to login first, then registers if login fails. | |
| Returns: | |
| True if authentication successful | |
| """ | |
| password = "simulated_password_123" | |
| try: | |
| # Try to register first (in case user doesn't exist) | |
| response = self.session.post( | |
| f"{self.server_url}/register", | |
| data={ | |
| "action": "signup", | |
| "email": self.config.user_id, | |
| "pass": password, | |
| }, | |
| allow_redirects=True, | |
| timeout=30, | |
| ) | |
| # Now try to login | |
| response = self.session.post( | |
| f"{self.server_url}/auth", | |
| data={ | |
| "action": "login", | |
| "email": self.config.user_id, | |
| "pass": password, | |
| }, | |
| allow_redirects=True, | |
| timeout=30, | |
| ) | |
| # Check if we're logged in by trying to access annotate page | |
| check_response = self.session.get( | |
| f"{self.server_url}/annotate", | |
| allow_redirects=False, | |
| timeout=30, | |
| ) | |
| # If we get redirected to login, auth failed | |
| if check_response.status_code == 302: | |
| location = check_response.headers.get("Location", "") | |
| if "auth" in location or "login" in location: | |
| logger.warning(f"Login failed for {self.config.user_id}") | |
| self.result.errors.append("Login failed - redirected to auth") | |
| return False | |
| self.logged_in = True | |
| logger.debug(f"User {self.config.user_id} logged in successfully") | |
| return True | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Login failed for {self.config.user_id}: {e}") | |
| self.result.errors.append(f"Login failed: {e}") | |
| return False | |
| def get_current_instance(self) -> Optional[Dict[str, Any]]: | |
| """Get the current instance to annotate. | |
| Returns: | |
| Instance data dict or None if unavailable | |
| """ | |
| try: | |
| response = self.session.get( | |
| f"{self.server_url}/api/current_instance", | |
| timeout=30, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| self.current_instance_id = data.get("instance_id") | |
| # Get the actual text content | |
| if self.current_instance_id: | |
| text_response = self.session.get( | |
| f"{self.server_url}/api/spans/{self.current_instance_id}", | |
| timeout=30, | |
| ) | |
| if text_response.status_code == 200: | |
| text_data = text_response.json() | |
| data["text"] = text_data.get("text", "") | |
| return data | |
| elif response.status_code == 404: | |
| logger.info(f"No more instances for {self.config.user_id}") | |
| return None | |
| else: | |
| logger.warning( | |
| f"Failed to get instance: {response.status_code} - {response.text}" | |
| ) | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to get current instance: {e}") | |
| self.result.errors.append(f"Get instance failed: {e}") | |
| return None | |
| def get_schemas(self) -> List[Dict[str, Any]]: | |
| """Get annotation schemas from the server. | |
| Returns: | |
| List of schema definitions | |
| """ | |
| try: | |
| response = self.session.get( | |
| f"{self.server_url}/api/schemas", | |
| timeout=30, | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Handle both list and dict formats | |
| if isinstance(data, dict): | |
| if "schemas" in data: | |
| self.schemas = ( | |
| list(data["schemas"].values()) | |
| if isinstance(data["schemas"], dict) | |
| else data["schemas"] | |
| ) | |
| else: | |
| self.schemas = list(data.values()) | |
| else: | |
| self.schemas = data | |
| return self.schemas | |
| logger.warning(f"Failed to get schemas: {response.status_code}") | |
| return [] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to get schemas: {e}") | |
| self.result.errors.append(f"Get schemas failed: {e}") | |
| return [] | |
| def generate_annotations(self, instance: Dict[str, Any]) -> Dict[str, Any]: | |
| """Generate annotations for all schemas. | |
| Args: | |
| instance: Instance data including text | |
| Returns: | |
| Combined annotation dictionary for all schemas | |
| """ | |
| instance_id = instance.get("instance_id") | |
| gold_answer = self.gold_standards.get(instance_id) | |
| # Attach the full schema set so batching strategies (e.g. AgentSimulatorStrategy) | |
| # can build a single multi-schema prompt per instance. Other strategies | |
| # ignore the extra key. | |
| instance = dict(instance) | |
| instance["__all_schemas__"] = self.schemas | |
| all_annotations = {} | |
| for schema in self.schemas: | |
| schema_name = schema.get("name") | |
| schema_gold = None | |
| if gold_answer: | |
| schema_gold = {schema_name: gold_answer.get(schema_name)} | |
| annotation = self.strategy.generate_annotation( | |
| instance, schema, self.competence, schema_gold | |
| ) | |
| all_annotations.update(annotation) | |
| return all_annotations | |
| def submit_annotation( | |
| self, | |
| instance_id: str, | |
| annotations: Dict[str, Any], | |
| response_time: float, | |
| ) -> bool: | |
| """Submit annotations for an instance. | |
| Args: | |
| instance_id: ID of the instance | |
| annotations: Annotation data to submit | |
| response_time: Time taken to annotate | |
| Returns: | |
| True if submission successful | |
| """ | |
| try: | |
| payload = { | |
| "instance_id": instance_id, | |
| "annotations": annotations, | |
| "span_annotations": [], | |
| "client_timestamp": datetime.now().isoformat(), | |
| } | |
| response = self.session.post( | |
| f"{self.server_url}/updateinstance", | |
| json=payload, | |
| timeout=30, | |
| ) | |
| if response.status_code == 200: | |
| result_data = response.json() | |
| # Create annotation record | |
| record = AnnotationRecord( | |
| instance_id=instance_id, | |
| schema_name=",".join(annotations.keys()), | |
| annotation=annotations, | |
| response_time=response_time, | |
| timestamp=datetime.now(), | |
| ) | |
| # Check for quality control results | |
| if "qc_result" in result_data: | |
| qc_result = result_data["qc_result"] | |
| qc_type = qc_result.get("type") | |
| if qc_type == "attention_check": | |
| record.was_attention_check = True | |
| record.attention_check_passed = qc_result.get("passed", False) | |
| if record.attention_check_passed: | |
| self.result.attention_checks_passed += 1 | |
| else: | |
| self.result.attention_checks_failed += 1 | |
| elif qc_type == "gold_standard": | |
| record.was_gold_standard = True | |
| record.gold_standard_correct = qc_result.get("correct", False) | |
| if record.gold_standard_correct: | |
| self.result.gold_standard_correct += 1 | |
| else: | |
| self.result.gold_standard_incorrect += 1 | |
| # Check for blocking | |
| if result_data.get("status") == "blocked": | |
| self.result.was_blocked = True | |
| logger.info(f"User {self.config.user_id} was blocked") | |
| self.result.annotations.append(record) | |
| return True | |
| else: | |
| logger.warning( | |
| f"Annotation submission failed: {response.status_code} - {response.text}" | |
| ) | |
| self.result.errors.append(f"Submit failed: {response.status_code}") | |
| return False | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Submit annotation failed: {e}") | |
| self.result.errors.append(f"Submit failed: {e}") | |
| return False | |
| def navigate_next(self) -> bool: | |
| """Navigate to the next instance. | |
| Returns: | |
| True if navigation successful | |
| """ | |
| try: | |
| # POST to /annotate with action=next_instance | |
| response = self.session.post( | |
| f"{self.server_url}/annotate", | |
| data={"action": "next_instance"}, | |
| timeout=30, | |
| ) | |
| return response.status_code in [200, 302] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Navigate next failed: {e}") | |
| self.result.errors.append(f"Navigate failed: {e}") | |
| return False | |
| def run_simulation( | |
| self, max_annotations: Optional[int] = None | |
| ) -> UserSimulationResult: | |
| """Run the full simulation for this user. | |
| Args: | |
| max_annotations: Maximum number of annotations (optional) | |
| Returns: | |
| UserSimulationResult with all tracking data | |
| """ | |
| self.result.start_time = datetime.now() | |
| max_ann = max_annotations if max_annotations is not None else self.config.max_annotations | |
| annotation_count = 0 | |
| try: | |
| # Login | |
| if not self.login(): | |
| logger.warning(f"User {self.config.user_id} failed to login") | |
| return self.result | |
| # Get schemas | |
| if not self.get_schemas(): | |
| logger.warning(f"User {self.config.user_id} failed to get schemas") | |
| self.result.errors.append("Failed to get schemas") | |
| # Main annotation loop | |
| while True: | |
| # Check if blocked | |
| if self.result.was_blocked: | |
| logger.info(f"User {self.config.user_id} is blocked, stopping") | |
| break | |
| # Check annotation limit | |
| if max_ann is not None and annotation_count >= max_ann: | |
| logger.debug( | |
| f"User {self.config.user_id} reached annotation limit ({max_ann})" | |
| ) | |
| break | |
| # Get current instance | |
| instance = self.get_current_instance() | |
| if not instance or not instance.get("instance_id"): | |
| logger.info(f"No more instances for {self.config.user_id}") | |
| break | |
| # If an interactive_chat session is configured, drive the | |
| # chat first so the conversation field is populated before | |
| # the rating strategy reads it. | |
| if self.interactive_runner is not None: | |
| instance_data = instance.get("data") or {} | |
| task_text = ( | |
| instance_data.get("task_description") | |
| or instance_data.get("text") | |
| or instance.get("text", "") | |
| ) | |
| chat_result = self.interactive_runner.run( | |
| self.session, | |
| instance.get("instance_id"), | |
| task_text, | |
| ) | |
| if chat_result.error: | |
| self.result.errors.append( | |
| f"interactive: {chat_result.error}" | |
| ) | |
| # Re-fetch the instance so its data reflects the chat | |
| refreshed = self.get_current_instance() | |
| if refreshed and refreshed.get("instance_id") == instance.get("instance_id"): | |
| instance = refreshed | |
| else: | |
| # Server moved on; fall back to using the in-memory | |
| # conversation we just collected. | |
| instance.setdefault("data", {}) | |
| instance["data"]["conversation"] = chat_result.conversation | |
| # Generate timing | |
| response_time = self.timing.get_response_time(self.respond_fast_rate) | |
| # Wait if configured (NoWaitTimingModel skips actual waiting) | |
| self.timing.wait(response_time) | |
| # Generate annotations | |
| annotations = self.generate_annotations(instance) | |
| # Submit | |
| if self.submit_annotation( | |
| instance.get("instance_id"), | |
| annotations, | |
| response_time, | |
| ): | |
| annotation_count += 1 | |
| logger.debug( | |
| f"User {self.config.user_id} annotated {annotation_count} items" | |
| ) | |
| # Navigate to next | |
| if not self.navigate_next(): | |
| logger.debug(f"User {self.config.user_id} navigation failed") | |
| break | |
| except Exception as e: | |
| logger.error(f"Simulation error for {self.config.user_id}: {e}") | |
| self.result.errors.append(f"Simulation error: {e}") | |
| finally: | |
| self.result.end_time = datetime.now() | |
| self.result.total_time = ( | |
| self.result.end_time - self.result.start_time | |
| ).total_seconds() | |
| logger.info( | |
| f"User {self.config.user_id} completed: " | |
| f"{len(self.result.annotations)} annotations in {self.result.total_time:.1f}s" | |
| ) | |
| return self.result | |