|
|
""" |
|
|
Prepare training data by merging datasets and preprocessing for Tinker. |
|
|
|
|
|
This script: |
|
|
1. Merges the original dataset with the new diverse dataset |
|
|
2. Validates and cleans the data |
|
|
3. Converts to the format expected by train_v2.py |
|
|
4. Splits into train/test sets |
|
|
5. Analyzes category distribution |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from collections import Counter |
|
|
from typing import List, Dict, Any |
|
|
import random |
|
|
|
|
|
|
|
|
ORIGINAL_DATASET = "synthetic_data/training_dataset_1000.jsonl" |
|
|
DIVERSE_DATASET = "synthetic_data/diverse_dataset_20251124_192207.jsonl" |
|
|
OUTPUT_DIR = "training/processed_data" |
|
|
TRAIN_OUTPUT = os.path.join(OUTPUT_DIR, "train_data.json") |
|
|
TEST_OUTPUT = os.path.join(OUTPUT_DIR, "test_data.json") |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """You route marketing conversations into structured memory categories. |
|
|
|
|
|
Available categories: |
|
|
- company.brand_core: Voice, values, positioning, identity anchors |
|
|
- company.strategic_signatures: Decision frameworks, strategic heuristics |
|
|
- company.knowledge_artifacts: Docs, style guides, playbooks |
|
|
- company.business_priorities: Quarterly/seasonal goals, active campaigns |
|
|
- company.tools_config: Integrations, API keys, workflow settings |
|
|
- company.performance_context: Campaign metrics, retrospectives, learnings |
|
|
- user.communication_style: Tone, verbosity, format expectations |
|
|
- user.strategic_approach: Personal priorities, success definitions |
|
|
- user.role_context: Title, scope, decision authority |
|
|
- user.workflow_patterns: Review cadence, collaboration norms |
|
|
- user.session_history: Immediate context, recent asks |
|
|
- user.interaction_preferences: Coaching style, feedback expectations |
|
|
- none: Irrelevant, vague, or transactional content |
|
|
|
|
|
Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
|
|
|
def load_jsonl(path: str) -> List[Dict]: |
|
|
"""Load JSONL file.""" |
|
|
data = [] |
|
|
with open(path, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
try: |
|
|
data.append(json.loads(line)) |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Warning: Skipping invalid JSON line: {e}") |
|
|
return data |
|
|
|
|
|
|
|
|
def clean_categories(categories: List[str]) -> List[str]: |
|
|
"""Clean and validate categories.""" |
|
|
cleaned = [] |
|
|
for cat in categories: |
|
|
cat_lower = cat.strip().lower() |
|
|
if cat_lower in VALID_CATEGORIES: |
|
|
cleaned.append(cat_lower) |
|
|
|
|
|
|
|
|
if len(cleaned) > 1 and "none" in cleaned: |
|
|
cleaned = [c for c in cleaned if c != "none"] |
|
|
|
|
|
|
|
|
seen = set() |
|
|
result = [] |
|
|
for c in cleaned: |
|
|
if c not in seen: |
|
|
seen.add(c) |
|
|
result.append(c) |
|
|
|
|
|
return result if result else ["none"] |
|
|
|
|
|
|
|
|
def convert_to_training_format(item: Dict) -> Dict: |
|
|
""" |
|
|
Convert a synthetic data item to the training format. |
|
|
|
|
|
Output format: |
|
|
{ |
|
|
"messages": [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": "...conversation..."}, |
|
|
{"role": "assistant", "content": "category1, category2"} |
|
|
], |
|
|
"categories": ["category1", "category2"], |
|
|
"scenario_id": "...", |
|
|
"metadata": {...} |
|
|
} |
|
|
""" |
|
|
|
|
|
conversation = item.get("conversation", []) |
|
|
if not conversation: |
|
|
return None |
|
|
|
|
|
|
|
|
conv_text = "" |
|
|
for turn in conversation: |
|
|
if isinstance(turn, dict): |
|
|
role = turn.get("role", "unknown") |
|
|
content = turn.get("content", "") |
|
|
conv_text += f"{role.upper()}: {content}\n" |
|
|
elif isinstance(turn, str): |
|
|
conv_text += f"{turn}\n" |
|
|
|
|
|
if not conv_text.strip(): |
|
|
return None |
|
|
|
|
|
|
|
|
categories = item.get("labels", {}).get("categories", []) |
|
|
if not categories: |
|
|
categories = [item.get("metadata", {}).get("primary_category", "none")] |
|
|
|
|
|
categories = clean_categories(categories) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": f"Analyze this conversation and determine which memory categories apply:\n\n{conv_text.strip()}"}, |
|
|
{"role": "assistant", "content": ", ".join(categories)} |
|
|
] |
|
|
|
|
|
return { |
|
|
"messages": messages, |
|
|
"categories": categories, |
|
|
"scenario_id": item.get("scenario_id", ""), |
|
|
"metadata": item.get("metadata", {}) |
|
|
} |
|
|
|
|
|
|
|
|
def analyze_distribution(data: List[Dict]) -> Dict[str, int]: |
|
|
"""Analyze category distribution.""" |
|
|
counter = Counter() |
|
|
for item in data: |
|
|
for cat in item.get("categories", []): |
|
|
counter[cat] += 1 |
|
|
return dict(counter) |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("=" * 70) |
|
|
print("PREPARING TRAINING DATA") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"\nLoading original dataset: {ORIGINAL_DATASET}") |
|
|
original_data = load_jsonl(ORIGINAL_DATASET) |
|
|
print(f" Loaded {len(original_data)} items") |
|
|
|
|
|
print(f"\nLoading diverse dataset: {DIVERSE_DATASET}") |
|
|
diverse_data = load_jsonl(DIVERSE_DATASET) |
|
|
print(f" Loaded {len(diverse_data)} items") |
|
|
|
|
|
|
|
|
print("\nConverting to training format...") |
|
|
|
|
|
all_data = [] |
|
|
skipped = 0 |
|
|
|
|
|
for item in original_data: |
|
|
converted = convert_to_training_format(item) |
|
|
if converted: |
|
|
converted["source"] = "original" |
|
|
all_data.append(converted) |
|
|
else: |
|
|
skipped += 1 |
|
|
|
|
|
for item in diverse_data: |
|
|
converted = convert_to_training_format(item) |
|
|
if converted: |
|
|
converted["source"] = "diverse" |
|
|
all_data.append(converted) |
|
|
else: |
|
|
skipped += 1 |
|
|
|
|
|
print(f" Converted: {len(all_data)}") |
|
|
print(f" Skipped: {skipped}") |
|
|
|
|
|
|
|
|
random.seed(42) |
|
|
random.shuffle(all_data) |
|
|
|
|
|
|
|
|
split_idx = int(len(all_data) * 0.9) |
|
|
train_data = all_data[:split_idx] |
|
|
test_data = all_data[split_idx:] |
|
|
|
|
|
print(f"\nSplit:") |
|
|
print(f" Train: {len(train_data)}") |
|
|
print(f" Test: {len(test_data)}") |
|
|
|
|
|
|
|
|
print("\n" + "-" * 50) |
|
|
print("CATEGORY DISTRIBUTION (Train)") |
|
|
print("-" * 50) |
|
|
|
|
|
train_dist = analyze_distribution(train_data) |
|
|
total = sum(train_dist.values()) |
|
|
|
|
|
for cat in sorted(train_dist.keys()): |
|
|
count = train_dist[cat] |
|
|
pct = count / total * 100 |
|
|
bar = "█" * int(pct / 2) + "░" * (50 - int(pct / 2)) |
|
|
print(f"{cat:<35} {count:>4} ({pct:>5.1f}%) {bar[:30]}") |
|
|
|
|
|
print(f"\nTotal labels: {total}") |
|
|
print(f"Unique categories: {len(train_dist)}") |
|
|
|
|
|
|
|
|
min_count = min(train_dist.values()) |
|
|
max_count = max(train_dist.values()) |
|
|
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf') |
|
|
print(f"\nImbalance ratio: {imbalance_ratio:.1f}x (max/min)") |
|
|
|
|
|
if imbalance_ratio < 3: |
|
|
print(" Status: GOOD - Dataset is reasonably balanced") |
|
|
elif imbalance_ratio < 5: |
|
|
print(" Status: OK - Some imbalance but acceptable") |
|
|
else: |
|
|
print(" Status: WARNING - Dataset is imbalanced") |
|
|
|
|
|
|
|
|
print(f"\nSaving to {OUTPUT_DIR}/...") |
|
|
|
|
|
with open(TRAIN_OUTPUT, 'w') as f: |
|
|
json.dump(train_data, f, indent=2) |
|
|
print(f" Saved train_data.json ({len(train_data)} items)") |
|
|
|
|
|
with open(TEST_OUTPUT, 'w') as f: |
|
|
json.dump(test_data, f, indent=2) |
|
|
print(f" Saved test_data.json ({len(test_data)} items)") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("DATA PREPARATION COMPLETE") |
|
|
print("=" * 70) |
|
|
print(f"Train: {TRAIN_OUTPUT}") |
|
|
print(f"Test: {TEST_OUTPUT}") |
|
|
print(f"\nReady for training with train_v2.py") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|