MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
import json
import random
import time
import sys
from typing import List, Dict, Any
from synthetic_data.pipeline import SyntheticDataPipeline
from synthetic_data.validate import validate_synthetic_data
CATEGORY_DISTRIBUTION = {
"company.brand_core": 0.10,
"company.strategic_signatures": 0.08,
"company.knowledge_artifacts": 0.08,
"company.business_priorities": 0.10,
"company.tools_config": 0.07,
"company.performance_context": 0.09,
"user.communication_style": 0.10,
"user.strategic_approach": 0.09,
"user.role_context": 0.07,
"user.workflow_patterns": 0.08,
"user.session_history": 0.06,
"user.interaction_preferences": 0.08,
"none": 0.10
}
def run_pipeline_batches(total_items: int = 100, batch_size: int = 10):
pipeline = SyntheticDataPipeline()
categories = list(CATEGORY_DISTRIBUTION.keys())
weights = list(CATEGORY_DISTRIBUTION.values())
all_data = []
num_batches = max(1, total_items // batch_size)
print(f"Starting generation of {total_items} items in {num_batches} batches (Size: {batch_size})...")
for batch_num in range(1, num_batches + 1):
print(f"\n=== Processing Batch {batch_num}/{num_batches} ===")
batch_data = []
while len(batch_data) < batch_size:
category = random.choices(categories, weights=weights, k=1)[0]
current_count = len(batch_data) + 1
print(f" Generating item {current_count}/{batch_size} (Category: {category})...")
# Determine if we should add a distractor (30% chance)
distractor = None
if random.random() < 0.30 and category != "none":
possible_distractors = [c for c in categories if c != category and c != "none"]
if possible_distractors:
distractor = random.choice(possible_distractors)
persistence = _get_persistence_for_category(category)
turns = random.randint(4, 10)
scenario = pipeline.generate_scenario_spec(
category=category,
distractor=distractor,
persistence=persistence,
turns=turns
)
if not scenario:
print(f" Failed to generate scenario for {category}. Retrying...")
time.sleep(20)
continue
conversation = pipeline.generate_conversation(scenario, turn_count=turns)
if conversation:
batch_data.append(conversation)
print(f" Generated: {conversation.get('scenario_id', 'Unknown ID')}")
else:
print(f" Failed to generate conversation for {category}. Retrying...")
time.sleep(20)
continue
print(" Sleeping for 15s to avoid rate limits...")
time.sleep(15)
# Save batch
batch_filename = f"synthetic_data/batch_{batch_num:02d}.json"
with open(batch_filename, "w") as f:
json.dump(batch_data, f, indent=2)
print(f" Saved batch to {batch_filename}")
# Validate batch
print(" Validating batch...")
metrics = validate_synthetic_data(batch_filename)
print(json.dumps(metrics, indent=2))
all_data.extend(batch_data)
# Save all data
with open("synthetic_data/all_generated_data_100.json", "w") as f:
json.dump(all_data, f, indent=2)
print(f"\nCompleted. Total items generated: {len(all_data)}")
print("Full dataset saved to synthetic_data/all_generated_data_100.json")
def _get_persistence_for_category(category: str) -> str:
if "brand_core" in category or "strategic_signatures" in category or "knowledge_artifacts" in category or "communication_style" in category or "strategic_approach" in category:
return "long"
elif "tools_config" in category or "role_context" in category or "workflow_patterns" in category:
return "medium"
elif "business_priorities" in category or "session_history" in category:
return "short"
elif "performance_context" in category:
return "rolling"
elif "interaction_preferences" in category:
return "evolving"
elif "none" in category:
return "short"
return "medium"
if __name__ == "__main__":
total = int(sys.argv[1]) if len(sys.argv) > 1 else 100
batch = int(sys.argv[2]) if len(sys.argv) > 2 else 10
run_pipeline_batches(total, batch)