|
|
import json |
|
|
import random |
|
|
import time |
|
|
import sys |
|
|
import asyncio |
|
|
import os |
|
|
from typing import List, Dict, Any |
|
|
from synthetic_data.pipeline import SyntheticDataPipeline |
|
|
from synthetic_data.validate import validate_synthetic_data |
|
|
from synthetic_data.clean_data import clean_datum |
|
|
|
|
|
CATEGORY_DISTRIBUTION = { |
|
|
"none": 0.15, |
|
|
"user.interaction_preferences": 0.12, |
|
|
"user.session_history": 0.10, |
|
|
"company.brand_core": 0.10, |
|
|
|
|
|
"company.strategic_signatures": 0.07, |
|
|
"company.knowledge_artifacts": 0.07, |
|
|
"user.communication_style": 0.07, |
|
|
"user.strategic_approach": 0.07, |
|
|
"user.workflow_patterns": 0.07, |
|
|
|
|
|
"company.tools_config": 0.05, |
|
|
"company.performance_context": 0.05, |
|
|
"company.business_priorities": 0.04, |
|
|
"user.role_context": 0.04 |
|
|
} |
|
|
|
|
|
async def generate_single_item(pipeline: SyntheticDataPipeline, category: str, item_num: int) -> Dict[str, Any]: |
|
|
"""Generate a single conversation item asynchronously.""" |
|
|
print(f" Starting item {item_num} (Target: {category})...") |
|
|
|
|
|
|
|
|
categories = list(CATEGORY_DISTRIBUTION.keys()) |
|
|
distractor = None |
|
|
if random.random() < 0.30 and category != "none": |
|
|
possible_distractors = [c for c in categories if c != category and c != "none"] |
|
|
if possible_distractors: |
|
|
distractor = random.choice(possible_distractors) |
|
|
|
|
|
persistence = _get_persistence_for_category(category) |
|
|
turns = random.randint(4, 10) |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
scenario = await loop.run_in_executor( |
|
|
None, |
|
|
pipeline.generate_scenario_spec, |
|
|
category, |
|
|
distractor, |
|
|
persistence, |
|
|
"neutral", |
|
|
turns, |
|
|
"" |
|
|
) |
|
|
|
|
|
if not scenario: |
|
|
print(f" Failed item {item_num}: scenario generation failed") |
|
|
return None |
|
|
|
|
|
|
|
|
conversation = await loop.run_in_executor( |
|
|
None, |
|
|
pipeline.generate_conversation, |
|
|
scenario, |
|
|
turns, |
|
|
category |
|
|
) |
|
|
|
|
|
if conversation: |
|
|
|
|
|
cleaned_conversation = clean_datum(conversation) |
|
|
print(f" Completed item {item_num}: {cleaned_conversation.get('scenario_id', 'Unknown')}") |
|
|
return cleaned_conversation |
|
|
else: |
|
|
print(f" Failed item {item_num}: conversation generation failed") |
|
|
return None |
|
|
|
|
|
async def generate_batch_concurrent(pipeline: SyntheticDataPipeline, batch_size: int, batch_num: int) -> List[Dict[str, Any]]: |
|
|
"""Generate a full batch of items concurrently, retrying until batch is full.""" |
|
|
print(f"\n=== Processing Batch {batch_num} (Concurrent) ===") |
|
|
|
|
|
categories = list(CATEGORY_DISTRIBUTION.keys()) |
|
|
weights = list(CATEGORY_DISTRIBUTION.values()) |
|
|
|
|
|
batch_data = [] |
|
|
items_needed = batch_size |
|
|
|
|
|
while items_needed > 0: |
|
|
|
|
|
batch_categories = random.choices(categories, weights=weights, k=items_needed) |
|
|
|
|
|
|
|
|
tasks = [ |
|
|
generate_single_item(pipeline, category, len(batch_data) + i + 1) |
|
|
for i, category in enumerate(batch_categories) |
|
|
] |
|
|
|
|
|
print(f" Launch {items_needed} concurrent tasks...") |
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
success_count = 0 |
|
|
for result in results: |
|
|
if isinstance(result, Exception): |
|
|
print(f" Task exception: {result}") |
|
|
elif result is not None: |
|
|
batch_data.append(result) |
|
|
success_count += 1 |
|
|
|
|
|
items_needed = batch_size - len(batch_data) |
|
|
if items_needed > 0: |
|
|
print(f" Batch incomplete ({len(batch_data)}/{batch_size}). Retrying {items_needed} items in 5s...") |
|
|
await asyncio.sleep(5) |
|
|
|
|
|
print(f"Batch {batch_num} complete: {len(batch_data)}/{batch_size} items generated") |
|
|
return batch_data |
|
|
|
|
|
async def run_pipeline_batches_async(total_items: int = 100, batch_size: int = 10): |
|
|
"""Run the full pipeline with concurrent batch processing.""" |
|
|
pipeline = SyntheticDataPipeline(max_retries=5) |
|
|
|
|
|
all_data = [] |
|
|
num_batches = max(1, total_items // batch_size) |
|
|
|
|
|
print(f"Starting CONCURRENT generation of {total_items} items in {num_batches} batches...") |
|
|
print(f"Batch size: {batch_size} items (generated in parallel)") |
|
|
|
|
|
for batch_num in range(1, num_batches + 1): |
|
|
|
|
|
batch_filename = f"synthetic_data/batch_{batch_num:02d}.jsonl" |
|
|
if os.path.exists(batch_filename): |
|
|
print(f"Batch {batch_num} already exists ({batch_filename}). Skipping generation...") |
|
|
|
|
|
try: |
|
|
with open(batch_filename, 'r') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
all_data.append(json.loads(line)) |
|
|
print(f"Loaded {len(all_data)} items so far.") |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"Error reading existing batch {batch_num}: {e}. Regenerating...") |
|
|
|
|
|
|
|
|
batch_data = await generate_batch_concurrent(pipeline, batch_size, batch_num) |
|
|
|
|
|
|
|
|
with open(batch_filename, "w") as f: |
|
|
for item in batch_data: |
|
|
f.write(json.dumps(item) + "\n") |
|
|
print(f"Saved batch to {batch_filename}") |
|
|
|
|
|
|
|
|
print("Validating batch...") |
|
|
metrics = validate_synthetic_data(batch_filename) |
|
|
print(json.dumps(metrics, indent=2)) |
|
|
|
|
|
all_data.extend(batch_data) |
|
|
|
|
|
|
|
|
if batch_num < num_batches: |
|
|
print("Waiting 5 seconds before next batch...") |
|
|
await asyncio.sleep(5) |
|
|
|
|
|
|
|
|
output_file = f"synthetic_data/all_generated_data_{total_items}.jsonl" |
|
|
with open(output_file, "w") as f: |
|
|
for item in all_data: |
|
|
f.write(json.dumps(item) + "\n") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"COMPLETED: {len(all_data)} items generated") |
|
|
print(f"Full dataset saved to {output_file}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
def _get_persistence_for_category(category: str) -> str: |
|
|
"""Map category to its expected persistence level.""" |
|
|
if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category: |
|
|
return "long" |
|
|
elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category: |
|
|
return "medium" |
|
|
elif "business_priorities" in category or "session_history" in category: |
|
|
return "short" |
|
|
elif "performance_context" in category: |
|
|
return "rolling" |
|
|
elif "interaction_preferences" in category: |
|
|
return "evolving" |
|
|
elif "none" in category: |
|
|
return "short" |
|
|
return "medium" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
total = int(sys.argv[1]) if len(sys.argv) > 1 else 100 |
|
|
batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10 |
|
|
|
|
|
asyncio.run(run_pipeline_batches_async(total, batch)) |
|
|
|