""" Population sampler for querying multiple persona variants Handles parallel querying with rate limiting and progress tracking. """ import time from typing import List, Optional, Callable from concurrent.futures import ThreadPoolExecutor, as_completed from ..personas.models import Persona from ..llm.anthropic_client import AnthropicClient from ..llm.prompt_builder import PromptBuilder from ..context.models import EnvironmentalContext from .variant_generator import generate_variants, VariationLevel class PopulationResponse: """Response from a single persona variant""" def __init__( self, variant_id: int, question: str, response: str, persona: Persona ): self.variant_id = variant_id self.question = question self.response = response self.persona = persona class PopulationSampler: """Sample responses from a population of persona variants""" def __init__( self, llm_client: Optional[AnthropicClient] = None, max_workers: int = 5, # Parallel requests (be mindful of API rate limits) requests_per_minute: int = 50, # Rate limiting ): """ Initialize population sampler Args: llm_client: LLM client to use (creates default if None) max_workers: Max parallel API requests requests_per_minute: Rate limit for API calls """ self.llm_client = llm_client or AnthropicClient() self.prompt_builder = PromptBuilder() self.max_workers = max_workers self.requests_per_minute = requests_per_minute self.min_delay = 60.0 / requests_per_minute # Seconds between requests def query_population( self, base_persona: Persona, question: str, population_size: int = 100, variation_level: VariationLevel = VariationLevel.MODERATE, context: Optional[EnvironmentalContext] = None, progress_callback: Optional[Callable[[int, int], None]] = None, ) -> List[PopulationResponse]: """ Query a population of persona variants with a question Args: base_persona: Base persona to create variants from question: Question to ask all variants population_size: Number of variants to generate variation_level: How much variation to apply context: Optional environmental context progress_callback: Optional callback for progress updates (current, total) Returns: List of PopulationResponse objects """ # Generate population print(f"Generating population of {population_size} variants...") variants = generate_variants( base_persona, population_size, variation_level ) # Query all variants print(f"Querying {len(variants)} personas...") responses = self._query_variants( variants, question, context, progress_callback ) return responses def _query_variants( self, variants: List[Persona], question: str, context: Optional[EnvironmentalContext], progress_callback: Optional[Callable[[int, int], None]], ) -> List[PopulationResponse]: """Query all variants with rate limiting and progress tracking""" responses = [] total = len(variants) completed = 0 # Track timing for rate limiting last_request_time = 0 def query_single(idx: int, persona: Persona) -> PopulationResponse: """Query a single persona variant""" nonlocal last_request_time # Rate limiting current_time = time.time() time_since_last = current_time - last_request_time if time_since_last < self.min_delay: time.sleep(self.min_delay - time_since_last) last_request_time = time.time() # Build prompt system_prompt = self.prompt_builder.build_persona_system_prompt( persona=persona, context=context ) # Query LLM response_text = self.llm_client.generate_response( system_prompt=system_prompt, user_message=question, ) return PopulationResponse( variant_id=idx, question=question, response=response_text, persona=persona ) # Use ThreadPoolExecutor for parallel queries with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit all queries future_to_idx = { executor.submit(query_single, idx, persona): idx for idx, persona in enumerate(variants) } # Process completed queries for future in as_completed(future_to_idx): try: response = future.result() responses.append(response) completed += 1 # Progress callback if progress_callback: progress_callback(completed, total) except Exception as e: print(f"Error querying variant: {e}") completed += 1 # Sort by variant_id to maintain order responses.sort(key=lambda r: r.variant_id) return responses