|
|
import json |
|
|
import random |
|
|
import time |
|
|
import sys |
|
|
from typing import List, Dict, Any |
|
|
from synthetic_data.pipeline import SyntheticDataPipeline |
|
|
from synthetic_data.validate import validate_synthetic_data |
|
|
|
|
|
CATEGORY_DISTRIBUTION = { |
|
|
"company.brand_core": 0.10, |
|
|
"company.strategic_signatures": 0.08, |
|
|
"company.knowledge_artifacts": 0.08, |
|
|
"company.business_priorities": 0.10, |
|
|
"company.tools_config": 0.07, |
|
|
"company.performance_context": 0.09, |
|
|
"user.communication_style": 0.10, |
|
|
"user.strategic_approach": 0.09, |
|
|
"user.role_context": 0.07, |
|
|
"user.workflow_patterns": 0.08, |
|
|
"user.session_history": 0.06, |
|
|
"user.interaction_preferences": 0.08, |
|
|
"none": 0.10 |
|
|
} |
|
|
|
|
|
def run_pipeline_batches(total_items: int = 100, batch_size: int = 10): |
|
|
pipeline = SyntheticDataPipeline() |
|
|
categories = list(CATEGORY_DISTRIBUTION.keys()) |
|
|
weights = list(CATEGORY_DISTRIBUTION.values()) |
|
|
|
|
|
all_data = [] |
|
|
num_batches = max(1, total_items // batch_size) |
|
|
|
|
|
print(f"Starting generation of {total_items} items in {num_batches} batches (Size: {batch_size})...") |
|
|
|
|
|
for batch_num in range(1, num_batches + 1): |
|
|
print(f"\n=== Processing Batch {batch_num}/{num_batches} ===") |
|
|
batch_data = [] |
|
|
|
|
|
while len(batch_data) < batch_size: |
|
|
category = random.choices(categories, weights=weights, k=1)[0] |
|
|
current_count = len(batch_data) + 1 |
|
|
print(f" Generating item {current_count}/{batch_size} (Category: {category})...") |
|
|
|
|
|
|
|
|
distractor = None |
|
|
if random.random() < 0.30 and category != "none": |
|
|
possible_distractors = [c for c in categories if c != category and c != "none"] |
|
|
if possible_distractors: |
|
|
distractor = random.choice(possible_distractors) |
|
|
|
|
|
persistence = _get_persistence_for_category(category) |
|
|
turns = random.randint(4, 10) |
|
|
|
|
|
scenario = pipeline.generate_scenario_spec( |
|
|
category=category, |
|
|
distractor=distractor, |
|
|
persistence=persistence, |
|
|
turns=turns |
|
|
) |
|
|
|
|
|
if not scenario: |
|
|
print(f" Failed to generate scenario for {category}. Retrying...") |
|
|
time.sleep(20) |
|
|
continue |
|
|
|
|
|
conversation = pipeline.generate_conversation(scenario, turn_count=turns) |
|
|
|
|
|
if conversation: |
|
|
batch_data.append(conversation) |
|
|
print(f" Generated: {conversation.get('scenario_id', 'Unknown ID')}") |
|
|
else: |
|
|
print(f" Failed to generate conversation for {category}. Retrying...") |
|
|
time.sleep(20) |
|
|
continue |
|
|
|
|
|
print(" Sleeping for 15s to avoid rate limits...") |
|
|
time.sleep(15) |
|
|
|
|
|
|
|
|
batch_filename = f"synthetic_data/batch_{batch_num:02d}.json" |
|
|
with open(batch_filename, "w") as f: |
|
|
json.dump(batch_data, f, indent=2) |
|
|
print(f" Saved batch to {batch_filename}") |
|
|
|
|
|
|
|
|
print(" Validating batch...") |
|
|
metrics = validate_synthetic_data(batch_filename) |
|
|
print(json.dumps(metrics, indent=2)) |
|
|
|
|
|
all_data.extend(batch_data) |
|
|
|
|
|
|
|
|
with open("synthetic_data/all_generated_data_100.json", "w") as f: |
|
|
json.dump(all_data, f, indent=2) |
|
|
print(f"\nCompleted. Total items generated: {len(all_data)}") |
|
|
print("Full dataset saved to synthetic_data/all_generated_data_100.json") |
|
|
|
|
|
def _get_persistence_for_category(category: str) -> str: |
|
|
if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category: |
|
|
return "long" |
|
|
elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category: |
|
|
return "medium" |
|
|
elif "business_priorities" in category or "session_history" in category: |
|
|
return "short" |
|
|
elif "performance_context" in category: |
|
|
return "rolling" |
|
|
elif "interaction_preferences" in category: |
|
|
return "evolving" |
|
|
elif "none" in category: |
|
|
return "short" |
|
|
return "medium" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
total = int(sys.argv[1]) if len(sys.argv) > 1 else 100 |
|
|
batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10 |
|
|
run_pipeline_batches(total, batch) |
|
|
|