Spaces:
Paused
Paused
| """ | |
| Simulator manager for orchestrating multiple simulated users. | |
| This module provides the SimulatorManager class that manages multiple | |
| SimulatedUser instances, handling parallel execution and result aggregation. | |
| """ | |
| import json | |
| import logging | |
| import random | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Dict, List, Any, Optional | |
| from .config import ( | |
| SimulatorConfig, | |
| UserConfig, | |
| CompetenceLevel, | |
| AnnotationStrategyType, | |
| ) | |
| from .user_simulator import SimulatedUser, UserSimulationResult | |
| from .reporting import SimulationReporter | |
| logger = logging.getLogger(__name__) | |
| class SimulatorManager: | |
| """Orchestrates multiple simulated users. | |
| The SimulatorManager handles: | |
| - Generating user configurations based on competence distribution | |
| - Running simulations in parallel or sequentially | |
| - Aggregating results across all users | |
| - Exporting results via SimulationReporter | |
| """ | |
| def __init__( | |
| self, | |
| config: SimulatorConfig, | |
| server_url: str, | |
| gold_standards: Optional[Dict[str, Dict[str, Any]]] = None, | |
| ): | |
| """Initialize simulator manager. | |
| Args: | |
| config: Simulator configuration | |
| server_url: Base URL of the Potato server | |
| gold_standards: Optional gold standard answers keyed by instance_id | |
| """ | |
| self.config = config | |
| self.server_url = server_url.rstrip("/") | |
| self.gold_standards = gold_standards or {} | |
| # Load gold standards from file if specified | |
| if config.gold_standard_file and not gold_standards: | |
| self.gold_standards = self._load_gold_standards(config.gold_standard_file) | |
| # Generate user configs if not provided | |
| self.user_configs = self._generate_user_configs() | |
| # Results tracking | |
| self.results: Dict[str, UserSimulationResult] = {} | |
| self.reporter = SimulationReporter(config.output_dir) | |
| def _load_gold_standards(self, filepath: str) -> Dict[str, Dict[str, Any]]: | |
| """Load gold standards from JSON file. | |
| Expected format: | |
| [ | |
| {"id": "instance_001", "label_field": "value", ...}, | |
| ... | |
| ] | |
| Args: | |
| filepath: Path to JSON file | |
| Returns: | |
| Gold standards dict keyed by instance ID | |
| """ | |
| try: | |
| with open(filepath, "r") as f: | |
| items = json.load(f) | |
| gold_standards = {} | |
| for item in items: | |
| item_id = item.pop("id", None) | |
| if item_id: | |
| gold_standards[item_id] = item | |
| logger.info(f"Loaded {len(gold_standards)} gold standards from {filepath}") | |
| return gold_standards | |
| except Exception as e: | |
| logger.warning(f"Failed to load gold standards from {filepath}: {e}") | |
| return {} | |
| def _generate_user_configs(self) -> List[UserConfig]: | |
| """Generate user configurations based on competence distribution. | |
| If explicit user configs are provided, uses those. | |
| Otherwise, generates based on user_count and competence_distribution. | |
| Returns: | |
| List of UserConfig instances | |
| """ | |
| if self.config.users: | |
| return self.config.users | |
| users = [] | |
| # Get competence distribution | |
| competence_levels = list(self.config.competence_distribution.keys()) | |
| competence_weights = list(self.config.competence_distribution.values()) | |
| # Normalize weights | |
| total_weight = sum(competence_weights) | |
| if total_weight > 0: | |
| competence_weights = [w / total_weight for w in competence_weights] | |
| for i in range(self.config.user_count): | |
| # Select competence level based on distribution | |
| competence_str = random.choices( | |
| competence_levels, weights=competence_weights, k=1 | |
| )[0] | |
| try: | |
| competence = CompetenceLevel(competence_str) | |
| except ValueError: | |
| competence = CompetenceLevel.AVERAGE | |
| users.append( | |
| UserConfig( | |
| user_id=f"sim_user_{i:04d}", | |
| competence=competence, | |
| strategy=self.config.strategy, | |
| timing=self.config.timing, | |
| llm_config=self.config.llm_config, | |
| biased_config=self.config.biased_config, | |
| agent_config=self.config.agent_config, | |
| ) | |
| ) | |
| logger.info(f"Generated {len(users)} user configurations") | |
| return users | |
| def run_single_user( | |
| self, user_config: UserConfig, max_annotations: Optional[int] = None | |
| ) -> UserSimulationResult: | |
| """Run simulation for a single user. | |
| Args: | |
| user_config: Configuration for the user | |
| max_annotations: Maximum annotations for this user | |
| Returns: | |
| UserSimulationResult with tracking data | |
| """ | |
| user = SimulatedUser( | |
| user_config=user_config, | |
| server_url=self.server_url, | |
| gold_standards=self.gold_standards, | |
| simulate_wait=self.config.simulate_wait, | |
| attention_check_fail_rate=self.config.attention_check_fail_rate, | |
| respond_fast_rate=self.config.respond_fast_rate, | |
| interactive_config=self.config.interactive, | |
| ) | |
| result = user.run_simulation(max_annotations) | |
| self.results[user_config.user_id] = result | |
| return result | |
| def run_parallel( | |
| self, max_annotations_per_user: Optional[int] = None | |
| ) -> Dict[str, UserSimulationResult]: | |
| """Run simulation for all users in parallel. | |
| Args: | |
| max_annotations_per_user: Maximum annotations per user | |
| Returns: | |
| Dict mapping user_id to UserSimulationResult | |
| """ | |
| logger.info( | |
| f"Starting parallel simulation with {len(self.user_configs)} users " | |
| f"({self.config.parallel_users} concurrent)" | |
| ) | |
| with ThreadPoolExecutor(max_workers=self.config.parallel_users) as executor: | |
| futures = {} | |
| for i, user_config in enumerate(self.user_configs): | |
| # Stagger user starts | |
| if i > 0 and self.config.delay_between_users > 0: | |
| time.sleep(self.config.delay_between_users) | |
| future = executor.submit( | |
| self.run_single_user, user_config, max_annotations_per_user | |
| ) | |
| futures[future] = user_config.user_id | |
| # Wait for completion | |
| completed = 0 | |
| for future in as_completed(futures): | |
| user_id = futures[future] | |
| completed += 1 | |
| try: | |
| result = future.result() | |
| logger.info( | |
| f"[{completed}/{len(futures)}] User {user_id} completed: " | |
| f"{len(result.annotations)} annotations" | |
| ) | |
| except Exception as e: | |
| logger.error(f"User {user_id} failed: {e}") | |
| logger.info(f"Parallel simulation completed: {len(self.results)} users") | |
| return self.results | |
| def run_sequential( | |
| self, max_annotations_per_user: Optional[int] = None | |
| ) -> Dict[str, UserSimulationResult]: | |
| """Run simulation for all users sequentially. | |
| Args: | |
| max_annotations_per_user: Maximum annotations per user | |
| Returns: | |
| Dict mapping user_id to UserSimulationResult | |
| """ | |
| logger.info( | |
| f"Starting sequential simulation with {len(self.user_configs)} users" | |
| ) | |
| for i, user_config in enumerate(self.user_configs): | |
| result = self.run_single_user(user_config, max_annotations_per_user) | |
| logger.info( | |
| f"[{i+1}/{len(self.user_configs)}] User {user_config.user_id} " | |
| f"completed: {len(result.annotations)} annotations" | |
| ) | |
| logger.info(f"Sequential simulation completed: {len(self.results)} users") | |
| return self.results | |
| def get_summary(self) -> Dict[str, Any]: | |
| """Get summary statistics for all users. | |
| Returns: | |
| Summary dictionary with aggregate statistics | |
| """ | |
| if not self.results: | |
| return {"error": "No results available"} | |
| total_annotations = sum(len(r.annotations) for r in self.results.values()) | |
| total_time = sum(r.total_time for r in self.results.values()) | |
| total_attention_passed = sum( | |
| r.attention_checks_passed for r in self.results.values() | |
| ) | |
| total_attention_failed = sum( | |
| r.attention_checks_failed for r in self.results.values() | |
| ) | |
| total_gold_correct = sum( | |
| r.gold_standard_correct for r in self.results.values() | |
| ) | |
| total_gold_incorrect = sum( | |
| r.gold_standard_incorrect for r in self.results.values() | |
| ) | |
| blocked_users = sum(1 for r in self.results.values() if r.was_blocked) | |
| users_with_errors = sum(1 for r in self.results.values() if r.errors) | |
| # Calculate response time statistics | |
| all_response_times = [ | |
| record.response_time | |
| for result in self.results.values() | |
| for record in result.annotations | |
| ] | |
| response_time_stats = {} | |
| if all_response_times: | |
| response_time_stats = { | |
| "min": min(all_response_times), | |
| "max": max(all_response_times), | |
| "mean": sum(all_response_times) / len(all_response_times), | |
| } | |
| # Competence level distribution in results | |
| competence_distribution = {} | |
| for user_id in self.results: | |
| for config in self.user_configs: | |
| if config.user_id == user_id: | |
| level = config.competence.value | |
| competence_distribution[level] = ( | |
| competence_distribution.get(level, 0) + 1 | |
| ) | |
| break | |
| return { | |
| "user_count": len(self.results), | |
| "total_annotations": total_annotations, | |
| "total_time_seconds": total_time, | |
| "average_annotations_per_user": ( | |
| total_annotations / len(self.results) if self.results else 0 | |
| ), | |
| "average_time_per_user": ( | |
| total_time / len(self.results) if self.results else 0 | |
| ), | |
| "attention_checks": { | |
| "passed": total_attention_passed, | |
| "failed": total_attention_failed, | |
| "pass_rate": ( | |
| total_attention_passed | |
| / (total_attention_passed + total_attention_failed) | |
| if (total_attention_passed + total_attention_failed) > 0 | |
| else None | |
| ), | |
| }, | |
| "gold_standards": { | |
| "correct": total_gold_correct, | |
| "incorrect": total_gold_incorrect, | |
| "accuracy": ( | |
| total_gold_correct / (total_gold_correct + total_gold_incorrect) | |
| if (total_gold_correct + total_gold_incorrect) > 0 | |
| else None | |
| ), | |
| }, | |
| "blocked_users": blocked_users, | |
| "users_with_errors": users_with_errors, | |
| "response_time_stats": response_time_stats, | |
| "competence_distribution": competence_distribution, | |
| "per_user": { | |
| user_id: { | |
| "annotations": len(r.annotations), | |
| "total_time": r.total_time, | |
| "attention_passed": r.attention_checks_passed, | |
| "attention_failed": r.attention_checks_failed, | |
| "gold_correct": r.gold_standard_correct, | |
| "gold_incorrect": r.gold_standard_incorrect, | |
| "was_blocked": r.was_blocked, | |
| "errors": len(r.errors), | |
| } | |
| for user_id, r in self.results.items() | |
| }, | |
| } | |
| def export_results(self) -> str: | |
| """Export all results using the reporter. | |
| Returns: | |
| Path to the output directory | |
| """ | |
| self.reporter.export_results(self.results, self.get_summary()) | |
| return self.config.output_dir | |
| def print_summary(self) -> None: | |
| """Print a summary of results to stdout.""" | |
| summary = self.get_summary() | |
| print("\n" + "=" * 60) | |
| print("SIMULATION SUMMARY") | |
| print("=" * 60) | |
| print(f"\nUsers: {summary['user_count']}") | |
| print(f"Total annotations: {summary['total_annotations']}") | |
| print(f"Total time: {summary['total_time_seconds']:.1f}s") | |
| print( | |
| f"Avg annotations/user: {summary['average_annotations_per_user']:.1f}" | |
| ) | |
| print(f"Avg time/user: {summary['average_time_per_user']:.1f}s") | |
| ac = summary["attention_checks"] | |
| if ac["passed"] or ac["failed"]: | |
| print(f"\nAttention Checks:") | |
| print(f" Passed: {ac['passed']}") | |
| print(f" Failed: {ac['failed']}") | |
| if ac["pass_rate"] is not None: | |
| print(f" Pass rate: {ac['pass_rate']:.1%}") | |
| gs = summary["gold_standards"] | |
| if gs["correct"] or gs["incorrect"]: | |
| print(f"\nGold Standards:") | |
| print(f" Correct: {gs['correct']}") | |
| print(f" Incorrect: {gs['incorrect']}") | |
| if gs["accuracy"] is not None: | |
| print(f" Accuracy: {gs['accuracy']:.1%}") | |
| if summary["blocked_users"]: | |
| print(f"\nBlocked users: {summary['blocked_users']}") | |
| if summary["users_with_errors"]: | |
| print(f"Users with errors: {summary['users_with_errors']}") | |
| if summary["competence_distribution"]: | |
| print(f"\nCompetence distribution:") | |
| for level, count in summary["competence_distribution"].items(): | |
| print(f" {level}: {count}") | |
| print("\n" + "=" * 60) | |