File size: 5,557 Bytes
1581d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Population sampler for querying multiple persona variants

Handles parallel querying with rate limiting and progress tracking.
"""

import time
from typing import List, Optional, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed

from ..personas.models import Persona
from ..llm.anthropic_client import AnthropicClient
from ..llm.prompt_builder import PromptBuilder
from ..context.models import EnvironmentalContext
from .variant_generator import generate_variants, VariationLevel


class PopulationResponse:
    """Response from a single persona variant"""

    def __init__(
        self,
        variant_id: int,
        question: str,
        response: str,
        persona: Persona
    ):
        self.variant_id = variant_id
        self.question = question
        self.response = response
        self.persona = persona


class PopulationSampler:
    """Sample responses from a population of persona variants"""

    def __init__(
        self,
        llm_client: Optional[AnthropicClient] = None,
        max_workers: int = 5,  # Parallel requests (be mindful of API rate limits)
        requests_per_minute: int = 50,  # Rate limiting
    ):
        """
        Initialize population sampler

        Args:
            llm_client: LLM client to use (creates default if None)
            max_workers: Max parallel API requests
            requests_per_minute: Rate limit for API calls
        """
        self.llm_client = llm_client or AnthropicClient()
        self.prompt_builder = PromptBuilder()
        self.max_workers = max_workers
        self.requests_per_minute = requests_per_minute
        self.min_delay = 60.0 / requests_per_minute  # Seconds between requests

    def query_population(
        self,
        base_persona: Persona,
        question: str,
        population_size: int = 100,
        variation_level: VariationLevel = VariationLevel.MODERATE,
        context: Optional[EnvironmentalContext] = None,
        progress_callback: Optional[Callable[[int, int], None]] = None,
    ) -> List[PopulationResponse]:
        """
        Query a population of persona variants with a question

        Args:
            base_persona: Base persona to create variants from
            question: Question to ask all variants
            population_size: Number of variants to generate
            variation_level: How much variation to apply
            context: Optional environmental context
            progress_callback: Optional callback for progress updates (current, total)

        Returns:
            List of PopulationResponse objects
        """
        # Generate population
        print(f"Generating population of {population_size} variants...")
        variants = generate_variants(
            base_persona,
            population_size,
            variation_level
        )

        # Query all variants
        print(f"Querying {len(variants)} personas...")
        responses = self._query_variants(
            variants,
            question,
            context,
            progress_callback
        )

        return responses

    def _query_variants(
        self,
        variants: List[Persona],
        question: str,
        context: Optional[EnvironmentalContext],
        progress_callback: Optional[Callable[[int, int], None]],
    ) -> List[PopulationResponse]:
        """Query all variants with rate limiting and progress tracking"""
        responses = []
        total = len(variants)
        completed = 0

        # Track timing for rate limiting
        last_request_time = 0

        def query_single(idx: int, persona: Persona) -> PopulationResponse:
            """Query a single persona variant"""
            nonlocal last_request_time

            # Rate limiting
            current_time = time.time()
            time_since_last = current_time - last_request_time
            if time_since_last < self.min_delay:
                time.sleep(self.min_delay - time_since_last)

            last_request_time = time.time()

            # Build prompt
            system_prompt = self.prompt_builder.build_persona_system_prompt(
                persona=persona,
                context=context
            )

            # Query LLM
            response_text = self.llm_client.generate_response(
                system_prompt=system_prompt,
                user_message=question,
            )

            return PopulationResponse(
                variant_id=idx,
                question=question,
                response=response_text,
                persona=persona
            )

        # Use ThreadPoolExecutor for parallel queries
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all queries
            future_to_idx = {
                executor.submit(query_single, idx, persona): idx
                for idx, persona in enumerate(variants)
            }

            # Process completed queries
            for future in as_completed(future_to_idx):
                try:
                    response = future.result()
                    responses.append(response)
                    completed += 1

                    # Progress callback
                    if progress_callback:
                        progress_callback(completed, total)

                except Exception as e:
                    print(f"Error querying variant: {e}")
                    completed += 1

        # Sort by variant_id to maintain order
        responses.sort(key=lambda r: r.variant_id)

        return responses