davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
9.93 kB
"""
Command-line interface for the user simulator.
Usage:
python -m potato.simulator --server http://localhost:8000 --users 10
python -m potato.simulator --config simulator-config.yaml --server http://localhost:8000
"""
import argparse
import logging
import sys
import os
from .config import (
SimulatorConfig,
TimingConfig,
LLMStrategyConfig,
BiasedStrategyConfig,
AnnotationStrategyType,
)
from .simulator_manager import SimulatorManager
def setup_logging(verbose: bool = False) -> None:
"""Configure logging for the CLI.
Args:
verbose: If True, enable debug logging
"""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
# Suppress noisy loggers
if not verbose:
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments.
Returns:
Parsed arguments namespace
"""
parser = argparse.ArgumentParser(
description="User Simulator for Potato Annotation Platform",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic random simulation
python -m potato.simulator --server http://localhost:8000 --users 10
# With configuration file
python -m potato.simulator --config simulator.yaml --server http://localhost:8000
# LLM-powered simulation with Ollama
python -m potato.simulator --server http://localhost:8000 --users 5 \\
--strategy llm --llm-endpoint ollama --llm-model llama3.2
# Biased simulation
python -m potato.simulator --server http://localhost:8000 --users 20 \\
--strategy biased --bias-weights positive=0.6,negative=0.3,neutral=0.1
# Fast scalability test
python -m potato.simulator --server http://localhost:8000 --users 100 \\
--parallel 20 --max-annotations 5 --fast-mode
""",
)
# Required arguments
parser.add_argument(
"--server",
"-s",
required=True,
help="Potato server URL (e.g., http://localhost:8000)",
)
# Configuration file (alternative to CLI args)
parser.add_argument(
"--config",
"-c",
help="Path to YAML configuration file",
)
# User configuration
parser.add_argument(
"--users",
"-u",
type=int,
default=10,
help="Number of simulated users (default: 10)",
)
parser.add_argument(
"--competence",
help="Competence distribution as comma-separated key=value pairs "
"(e.g., good=0.5,average=0.3,poor=0.2)",
)
# Strategy configuration
parser.add_argument(
"--strategy",
choices=["random", "biased", "llm", "pattern", "gold_standard"],
default="random",
help="Annotation strategy (default: random)",
)
# LLM configuration
parser.add_argument(
"--llm-endpoint",
choices=["openai", "anthropic", "ollama", "gemini", "huggingface", "vllm"],
help="LLM endpoint type (for --strategy llm)",
)
parser.add_argument(
"--llm-model",
help="LLM model name (for --strategy llm)",
)
parser.add_argument(
"--llm-api-key",
help="LLM API key (or set via environment variable)",
)
parser.add_argument(
"--llm-base-url",
help="LLM base URL (for local endpoints like Ollama)",
)
# Biased strategy configuration
parser.add_argument(
"--bias-weights",
help="Label bias weights as comma-separated key=value pairs "
"(e.g., positive=0.6,negative=0.3,neutral=0.1)",
)
# Execution configuration
parser.add_argument(
"--parallel",
"-p",
type=int,
default=5,
help="Maximum concurrent users (default: 5)",
)
parser.add_argument(
"--max-annotations",
"-m",
type=int,
help="Maximum annotations per user (default: unlimited)",
)
parser.add_argument(
"--sequential",
action="store_true",
help="Run users sequentially instead of in parallel",
)
# Timing configuration
parser.add_argument(
"--fast-mode",
action="store_true",
help="Disable waiting between annotations (for testing)",
)
parser.add_argument(
"--timing-min",
type=float,
default=2.0,
help="Minimum annotation time in seconds (default: 2.0)",
)
parser.add_argument(
"--timing-max",
type=float,
default=30.0,
help="Maximum annotation time in seconds (default: 30.0)",
)
# Quality control testing
parser.add_argument(
"--attention-fail-rate",
type=float,
default=0.0,
help="Rate at which to fail attention checks (0-1, default: 0)",
)
parser.add_argument(
"--fast-response-rate",
type=float,
default=0.0,
help="Rate of suspiciously fast responses (0-1, default: 0)",
)
# Gold standards
parser.add_argument(
"--gold-file",
help="Path to JSON file with gold standard answers",
)
# Output configuration
parser.add_argument(
"--output-dir",
"-o",
default="simulator_output",
help="Output directory for results (default: simulator_output)",
)
parser.add_argument(
"--no-export",
action="store_true",
help="Don't export results to files",
)
# Other options
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging",
)
return parser.parse_args()
def parse_key_value_pairs(s: str) -> dict:
"""Parse comma-separated key=value pairs.
Args:
s: String like "key1=val1,key2=val2"
Returns:
Dictionary of parsed pairs
"""
result = {}
if not s:
return result
for pair in s.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
# Try to convert to float
try:
result[key.strip()] = float(value.strip())
except ValueError:
result[key.strip()] = value.strip()
return result
def build_config_from_args(args: argparse.Namespace) -> SimulatorConfig:
"""Build SimulatorConfig from CLI arguments.
Args:
args: Parsed arguments
Returns:
SimulatorConfig instance
"""
# If config file provided, use it as base
if args.config:
config = SimulatorConfig.from_yaml(args.config)
else:
config = SimulatorConfig()
# Override with CLI arguments
config.user_count = args.users
config.parallel_users = args.parallel
config.simulate_wait = not args.fast_mode
config.attention_check_fail_rate = args.attention_fail_rate
config.respond_fast_rate = args.fast_response_rate
config.output_dir = args.output_dir
# Parse competence distribution
if args.competence:
config.competence_distribution = parse_key_value_pairs(args.competence)
# Parse strategy
try:
config.strategy = AnnotationStrategyType(args.strategy)
except ValueError:
config.strategy = AnnotationStrategyType.RANDOM
# LLM configuration
if args.strategy == "llm" and args.llm_endpoint:
api_key = args.llm_api_key
if not api_key:
# Try common environment variables
env_vars = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"huggingface": "HF_TOKEN",
"gemini": "GOOGLE_API_KEY",
}
env_var = env_vars.get(args.llm_endpoint)
if env_var:
api_key = os.environ.get(env_var)
config.llm_config = LLMStrategyConfig(
endpoint_type=args.llm_endpoint,
model=args.llm_model,
api_key=api_key,
base_url=args.llm_base_url,
)
# Biased configuration
if args.strategy == "biased" and args.bias_weights:
config.biased_config = BiasedStrategyConfig(
label_weights=parse_key_value_pairs(args.bias_weights)
)
# Timing configuration
config.timing = TimingConfig(
annotation_time_min=args.timing_min,
annotation_time_max=args.timing_max,
)
# Gold standards file
if args.gold_file:
config.gold_standard_file = args.gold_file
return config
def main() -> int:
"""Main entry point for CLI.
Returns:
Exit code (0 for success, 1 for error)
"""
args = parse_args()
setup_logging(args.verbose)
logger = logging.getLogger(__name__)
try:
# Build configuration
config = build_config_from_args(args)
logger.info(f"Starting simulator with {config.user_count} users")
logger.info(f"Server: {args.server}")
logger.info(f"Strategy: {config.strategy.value}")
# Create manager
manager = SimulatorManager(config, args.server)
# Run simulation
if args.sequential:
results = manager.run_sequential(args.max_annotations)
else:
results = manager.run_parallel(args.max_annotations)
# Print summary
manager.print_summary()
# Export results
if not args.no_export:
manager.export_results()
return 0
except KeyboardInterrupt:
logger.info("Simulation interrupted by user")
return 1
except Exception as e:
logger.error(f"Simulation failed: {e}")
if args.verbose:
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())