File size: 7,410 Bytes
685d968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import json
import random
import time
import sys
import asyncio
import os
from typing import List, Dict, Any
from synthetic_data.pipeline import SyntheticDataPipeline
from synthetic_data.validate import validate_synthetic_data
from synthetic_data.clean_data import clean_datum

CATEGORY_DISTRIBUTION = {
    "none": 0.15,
    "user.interaction_preferences": 0.12,
    "user.session_history": 0.10,
    "company.brand_core": 0.10,
    
    "company.strategic_signatures": 0.07,
    "company.knowledge_artifacts": 0.07,
    "user.communication_style": 0.07,
    "user.strategic_approach": 0.07,
    "user.workflow_patterns": 0.07,
    
    "company.tools_config": 0.05,
    "company.performance_context": 0.05,
    "company.business_priorities": 0.04,
    "user.role_context": 0.04
}

async def generate_single_item(pipeline: SyntheticDataPipeline, category: str, item_num: int) -> Dict[str, Any]:
    """Generate a single conversation item asynchronously."""
    print(f"  Starting item {item_num} (Target: {category})...")
    
    # Determine distractor
    categories = list(CATEGORY_DISTRIBUTION.keys())
    distractor = None
    if random.random() < 0.30 and category != "none":
        possible_distractors = [c for c in categories if c != category and c != "none"]
        if possible_distractors:
            distractor = random.choice(possible_distractors)
    
    persistence = _get_persistence_for_category(category)
    turns = random.randint(4, 10)
    
    # Generate scenario (synchronous call wrapped in executor)
    loop = asyncio.get_event_loop()
    scenario = await loop.run_in_executor(
        None,
        pipeline.generate_scenario_spec,
        category,
        distractor,
        persistence,
        "neutral",
        turns,
        ""
    )
    
    if not scenario:
        print(f"  Failed item {item_num}: scenario generation failed")
        return None
    
    # Generate conversation
    conversation = await loop.run_in_executor(
        None,
        pipeline.generate_conversation,
        scenario,
        turns,
        category
    )
    
    if conversation:
        # Clean the item immediately
        cleaned_conversation = clean_datum(conversation)
        print(f"  Completed item {item_num}: {cleaned_conversation.get('scenario_id', 'Unknown')}")
        return cleaned_conversation
    else:
        print(f"  Failed item {item_num}: conversation generation failed")
        return None

async def generate_batch_concurrent(pipeline: SyntheticDataPipeline, batch_size: int, batch_num: int) -> List[Dict[str, Any]]:
    """Generate a full batch of items concurrently, retrying until batch is full."""
    print(f"\n=== Processing Batch {batch_num} (Concurrent) ===")
    
    categories = list(CATEGORY_DISTRIBUTION.keys())
    weights = list(CATEGORY_DISTRIBUTION.values())
    
    batch_data = []
    items_needed = batch_size
    
    while items_needed > 0:
        # Select categories for this chunk of work
        batch_categories = random.choices(categories, weights=weights, k=items_needed)
        
        # Create tasks for needed items
        tasks = [
            generate_single_item(pipeline, category, len(batch_data) + i + 1)
            for i, category in enumerate(batch_categories)
        ]
        
        print(f"  Launch {items_needed} concurrent tasks...")
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Collect successes
        success_count = 0
        for result in results:
            if isinstance(result, Exception):
                print(f"    Task exception: {result}")
            elif result is not None:
                batch_data.append(result)
                success_count += 1
        
        items_needed = batch_size - len(batch_data)
        if items_needed > 0:
            print(f"  Batch incomplete ({len(batch_data)}/{batch_size}). Retrying {items_needed} items in 5s...")
            await asyncio.sleep(5)
    
    print(f"Batch {batch_num} complete: {len(batch_data)}/{batch_size} items generated")
    return batch_data

async def run_pipeline_batches_async(total_items: int = 100, batch_size: int = 10):
    """Run the full pipeline with concurrent batch processing."""
    pipeline = SyntheticDataPipeline(max_retries=5)
    
    all_data = []
    num_batches = max(1, total_items // batch_size)
    
    print(f"Starting CONCURRENT generation of {total_items} items in {num_batches} batches...")
    print(f"Batch size: {batch_size} items (generated in parallel)")
    
    for batch_num in range(1, num_batches + 1):
        # Check if batch already exists
        batch_filename = f"synthetic_data/batch_{batch_num:02d}.jsonl"
        if os.path.exists(batch_filename):
            print(f"Batch {batch_num} already exists ({batch_filename}). Skipping generation...")
            # Load existing data to include in final output
            try:
                with open(batch_filename, 'r') as f:
                    for line in f:
                        if line.strip():
                            all_data.append(json.loads(line))
                print(f"Loaded {len(all_data)} items so far.")
                continue
            except Exception as e:
                print(f"Error reading existing batch {batch_num}: {e}. Regenerating...")
        
        # Generate entire batch concurrently
        batch_data = await generate_batch_concurrent(pipeline, batch_size, batch_num)
        
        # Save batch as JSONL
        with open(batch_filename, "w") as f:
            for item in batch_data:
                f.write(json.dumps(item) + "\n")
        print(f"Saved batch to {batch_filename}")
        
        # Validate batch
        print("Validating batch...")
        metrics = validate_synthetic_data(batch_filename)
        print(json.dumps(metrics, indent=2))
        
        all_data.extend(batch_data)
        
        # Wait 5 seconds before next batch
        if batch_num < num_batches:
            print("Waiting 5 seconds before next batch...")
            await asyncio.sleep(5)
    
    # Save all data
    output_file = f"synthetic_data/all_generated_data_{total_items}.jsonl"
    with open(output_file, "w") as f:
        for item in all_data:
            f.write(json.dumps(item) + "\n")
    
    print(f"\n{'='*60}")
    print(f"COMPLETED: {len(all_data)} items generated")
    print(f"Full dataset saved to {output_file}")
    print(f"{'='*60}")

def _get_persistence_for_category(category: str) -> str:
    """Map category to its expected persistence level."""
    if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category:
        return "long"
    elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category:
        return "medium"
    elif "business_priorities" in category or "session_history" in category:
        return "short"
    elif "performance_context" in category:
        return "rolling"
    elif "interaction_preferences" in category:
        return "evolving"
    elif "none" in category:
        return "short"
    return "medium"

if __name__ == "__main__":
    total = int(sys.argv[1]) if len(sys.argv) > 1 else 100
    batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10
    
    asyncio.run(run_pipeline_batches_async(total, batch))