Spaces:
Sleeping
Sleeping
| """ | |
| Population sampler for querying multiple persona variants | |
| Handles parallel querying with rate limiting and progress tracking. | |
| """ | |
| import time | |
| from typing import List, Optional, Callable | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from ..personas.models import Persona | |
| from ..llm.anthropic_client import AnthropicClient | |
| from ..llm.prompt_builder import PromptBuilder | |
| from ..context.models import EnvironmentalContext | |
| from .variant_generator import generate_variants, VariationLevel | |
| class PopulationResponse: | |
| """Response from a single persona variant""" | |
| def __init__( | |
| self, | |
| variant_id: int, | |
| question: str, | |
| response: str, | |
| persona: Persona | |
| ): | |
| self.variant_id = variant_id | |
| self.question = question | |
| self.response = response | |
| self.persona = persona | |
| class PopulationSampler: | |
| """Sample responses from a population of persona variants""" | |
| def __init__( | |
| self, | |
| llm_client: Optional[AnthropicClient] = None, | |
| max_workers: int = 5, # Parallel requests (be mindful of API rate limits) | |
| requests_per_minute: int = 50, # Rate limiting | |
| ): | |
| """ | |
| Initialize population sampler | |
| Args: | |
| llm_client: LLM client to use (creates default if None) | |
| max_workers: Max parallel API requests | |
| requests_per_minute: Rate limit for API calls | |
| """ | |
| self.llm_client = llm_client or AnthropicClient() | |
| self.prompt_builder = PromptBuilder() | |
| self.max_workers = max_workers | |
| self.requests_per_minute = requests_per_minute | |
| self.min_delay = 60.0 / requests_per_minute # Seconds between requests | |
| def query_population( | |
| self, | |
| base_persona: Persona, | |
| question: str, | |
| population_size: int = 100, | |
| variation_level: VariationLevel = VariationLevel.MODERATE, | |
| context: Optional[EnvironmentalContext] = None, | |
| progress_callback: Optional[Callable[[int, int], None]] = None, | |
| ) -> List[PopulationResponse]: | |
| """ | |
| Query a population of persona variants with a question | |
| Args: | |
| base_persona: Base persona to create variants from | |
| question: Question to ask all variants | |
| population_size: Number of variants to generate | |
| variation_level: How much variation to apply | |
| context: Optional environmental context | |
| progress_callback: Optional callback for progress updates (current, total) | |
| Returns: | |
| List of PopulationResponse objects | |
| """ | |
| # Generate population | |
| print(f"Generating population of {population_size} variants...") | |
| variants = generate_variants( | |
| base_persona, | |
| population_size, | |
| variation_level | |
| ) | |
| # Query all variants | |
| print(f"Querying {len(variants)} personas...") | |
| responses = self._query_variants( | |
| variants, | |
| question, | |
| context, | |
| progress_callback | |
| ) | |
| return responses | |
| def _query_variants( | |
| self, | |
| variants: List[Persona], | |
| question: str, | |
| context: Optional[EnvironmentalContext], | |
| progress_callback: Optional[Callable[[int, int], None]], | |
| ) -> List[PopulationResponse]: | |
| """Query all variants with rate limiting and progress tracking""" | |
| responses = [] | |
| total = len(variants) | |
| completed = 0 | |
| # Track timing for rate limiting | |
| last_request_time = 0 | |
| def query_single(idx: int, persona: Persona) -> PopulationResponse: | |
| """Query a single persona variant""" | |
| nonlocal last_request_time | |
| # Rate limiting | |
| current_time = time.time() | |
| time_since_last = current_time - last_request_time | |
| if time_since_last < self.min_delay: | |
| time.sleep(self.min_delay - time_since_last) | |
| last_request_time = time.time() | |
| # Build prompt | |
| system_prompt = self.prompt_builder.build_persona_system_prompt( | |
| persona=persona, | |
| context=context | |
| ) | |
| # Query LLM | |
| response_text = self.llm_client.generate_response( | |
| system_prompt=system_prompt, | |
| user_message=question, | |
| ) | |
| return PopulationResponse( | |
| variant_id=idx, | |
| question=question, | |
| response=response_text, | |
| persona=persona | |
| ) | |
| # Use ThreadPoolExecutor for parallel queries | |
| with ThreadPoolExecutor(max_workers=self.max_workers) as executor: | |
| # Submit all queries | |
| future_to_idx = { | |
| executor.submit(query_single, idx, persona): idx | |
| for idx, persona in enumerate(variants) | |
| } | |
| # Process completed queries | |
| for future in as_completed(future_to_idx): | |
| try: | |
| response = future.result() | |
| responses.append(response) | |
| completed += 1 | |
| # Progress callback | |
| if progress_callback: | |
| progress_callback(completed, total) | |
| except Exception as e: | |
| print(f"Error querying variant: {e}") | |
| completed += 1 | |
| # Sort by variant_id to maintain order | |
| responses.sort(key=lambda r: r.variant_id) | |
| return responses | |