File size: 7,410 Bytes
685d968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import json
import random
import time
import sys
import asyncio
import os
from typing import List, Dict, Any
from synthetic_data.pipeline import SyntheticDataPipeline
from synthetic_data.validate import validate_synthetic_data
from synthetic_data.clean_data import clean_datum
CATEGORY_DISTRIBUTION = {
"none": 0.15,
"user.interaction_preferences": 0.12,
"user.session_history": 0.10,
"company.brand_core": 0.10,
"company.strategic_signatures": 0.07,
"company.knowledge_artifacts": 0.07,
"user.communication_style": 0.07,
"user.strategic_approach": 0.07,
"user.workflow_patterns": 0.07,
"company.tools_config": 0.05,
"company.performance_context": 0.05,
"company.business_priorities": 0.04,
"user.role_context": 0.04
}
async def generate_single_item(pipeline: SyntheticDataPipeline, category: str, item_num: int) -> Dict[str, Any]:
"""Generate a single conversation item asynchronously."""
print(f" Starting item {item_num} (Target: {category})...")
# Determine distractor
categories = list(CATEGORY_DISTRIBUTION.keys())
distractor = None
if random.random() < 0.30 and category != "none":
possible_distractors = [c for c in categories if c != category and c != "none"]
if possible_distractors:
distractor = random.choice(possible_distractors)
persistence = _get_persistence_for_category(category)
turns = random.randint(4, 10)
# Generate scenario (synchronous call wrapped in executor)
loop = asyncio.get_event_loop()
scenario = await loop.run_in_executor(
None,
pipeline.generate_scenario_spec,
category,
distractor,
persistence,
"neutral",
turns,
""
)
if not scenario:
print(f" Failed item {item_num}: scenario generation failed")
return None
# Generate conversation
conversation = await loop.run_in_executor(
None,
pipeline.generate_conversation,
scenario,
turns,
category
)
if conversation:
# Clean the item immediately
cleaned_conversation = clean_datum(conversation)
print(f" Completed item {item_num}: {cleaned_conversation.get('scenario_id', 'Unknown')}")
return cleaned_conversation
else:
print(f" Failed item {item_num}: conversation generation failed")
return None
async def generate_batch_concurrent(pipeline: SyntheticDataPipeline, batch_size: int, batch_num: int) -> List[Dict[str, Any]]:
"""Generate a full batch of items concurrently, retrying until batch is full."""
print(f"\n=== Processing Batch {batch_num} (Concurrent) ===")
categories = list(CATEGORY_DISTRIBUTION.keys())
weights = list(CATEGORY_DISTRIBUTION.values())
batch_data = []
items_needed = batch_size
while items_needed > 0:
# Select categories for this chunk of work
batch_categories = random.choices(categories, weights=weights, k=items_needed)
# Create tasks for needed items
tasks = [
generate_single_item(pipeline, category, len(batch_data) + i + 1)
for i, category in enumerate(batch_categories)
]
print(f" Launch {items_needed} concurrent tasks...")
results = await asyncio.gather(*tasks, return_exceptions=True)
# Collect successes
success_count = 0
for result in results:
if isinstance(result, Exception):
print(f" Task exception: {result}")
elif result is not None:
batch_data.append(result)
success_count += 1
items_needed = batch_size - len(batch_data)
if items_needed > 0:
print(f" Batch incomplete ({len(batch_data)}/{batch_size}). Retrying {items_needed} items in 5s...")
await asyncio.sleep(5)
print(f"Batch {batch_num} complete: {len(batch_data)}/{batch_size} items generated")
return batch_data
async def run_pipeline_batches_async(total_items: int = 100, batch_size: int = 10):
"""Run the full pipeline with concurrent batch processing."""
pipeline = SyntheticDataPipeline(max_retries=5)
all_data = []
num_batches = max(1, total_items // batch_size)
print(f"Starting CONCURRENT generation of {total_items} items in {num_batches} batches...")
print(f"Batch size: {batch_size} items (generated in parallel)")
for batch_num in range(1, num_batches + 1):
# Check if batch already exists
batch_filename = f"synthetic_data/batch_{batch_num:02d}.jsonl"
if os.path.exists(batch_filename):
print(f"Batch {batch_num} already exists ({batch_filename}). Skipping generation...")
# Load existing data to include in final output
try:
with open(batch_filename, 'r') as f:
for line in f:
if line.strip():
all_data.append(json.loads(line))
print(f"Loaded {len(all_data)} items so far.")
continue
except Exception as e:
print(f"Error reading existing batch {batch_num}: {e}. Regenerating...")
# Generate entire batch concurrently
batch_data = await generate_batch_concurrent(pipeline, batch_size, batch_num)
# Save batch as JSONL
with open(batch_filename, "w") as f:
for item in batch_data:
f.write(json.dumps(item) + "\n")
print(f"Saved batch to {batch_filename}")
# Validate batch
print("Validating batch...")
metrics = validate_synthetic_data(batch_filename)
print(json.dumps(metrics, indent=2))
all_data.extend(batch_data)
# Wait 5 seconds before next batch
if batch_num < num_batches:
print("Waiting 5 seconds before next batch...")
await asyncio.sleep(5)
# Save all data
output_file = f"synthetic_data/all_generated_data_{total_items}.jsonl"
with open(output_file, "w") as f:
for item in all_data:
f.write(json.dumps(item) + "\n")
print(f"\n{'='*60}")
print(f"COMPLETED: {len(all_data)} items generated")
print(f"Full dataset saved to {output_file}")
print(f"{'='*60}")
def _get_persistence_for_category(category: str) -> str:
"""Map category to its expected persistence level."""
if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category:
return "long"
elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category:
return "medium"
elif "business_priorities" in category or "session_history" in category:
return "short"
elif "performance_context" in category:
return "rolling"
elif "interaction_preferences" in category:
return "evolving"
elif "none" in category:
return "short"
return "medium"
if __name__ == "__main__":
total = int(sys.argv[1]) if len(sys.argv) > 1 else 100
batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10
asyncio.run(run_pipeline_batches_async(total, batch))
|