File size: 6,877 Bytes
0b6ab33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Generate diverse tagged scripts for ElevenLabs v3 TTS synthesis.

Uses async concurrency for fast generation (~2 min for 1000 scripts).
"""

import asyncio
import json
import random
import os
import sys
from pathlib import Path
from anthropic import AsyncAnthropic
from dotenv import load_dotenv
from .tag_taxonomy import ALL_BRACKET_TAGS, SLICE_CONFIG, DOMAINS

load_dotenv()

SYSTEM_PROMPT = """You are a script writer generating realistic speech samples with inline ElevenLabs v3 audio tags.

Audio tags use square brackets: [laughs], [excited], [whispers], [pause], etc.
Emphasis uses CAPITALIZATION of stressed words.
Pauses use ellipses ...

Rules:
- Write natural, diverse dialogue/monologue snippets (15-80 words each)
- Tags must feel organic, not forced
- Vary domains: conversation, podcast, storytelling, argument, etc.
- Include a mix of male/female perspectives and speaking styles
- Each sample should be self-contained (no context needed)
- NEVER use tags that aren't in the available list
- For CAPS emphasis, only capitalize 1-3 key words per sentence, not whole sentences

Available tags (use ONLY these in brackets): {tags}

Output ONLY a valid JSON array of objects. Each object has these fields:
- "tagged_text": the text with inline tags (string)
- "plain_text": same text with all [tags] removed and CAPS converted to lowercase (string)
- "tags_used": list of tag names used without brackets (array of strings)
- "domain": one of {domains} (string)
- "tag_count": number of bracket tags used (integer)

Example output:
[
  {{
    "tagged_text": "[excited] I can't BELIEVE we actually won! [laughs] This is incredible...",
    "plain_text": "I can't believe we actually won! This is incredible...",
    "tags_used": ["excited", "laughs"],
    "domain": "conversation",
    "tag_count": 2
  }}
]"""


async def generate_batch(
    client: AsyncAnthropic,
    slice_name: str,
    batch_num: int,
    count: int,
    config: dict,
    model: str,
) -> list[dict]:
    """Generate a single batch of scripts."""
    tag_density = config["tag_density"]

    if slice_name == "plain":
        density_instruction = "Do NOT include any audio tags [brackets] or CAPS emphasis. Write completely plain, natural speech only. No tags at all."
    elif isinstance(tag_density, tuple):
        density_instruction = f"Use exactly {tag_density[0]} to {tag_density[1]} audio tags per sample."
    else:
        density_instruction = f"Use exactly {tag_density} audio tags per sample."

    if slice_name == "edge":
        density_instruction += (
            " Include edge cases: tags at the very start and end, "
            "consecutive tags like [angry] [laughs], "
            "ambiguous emotions, very short utterances with tags, "
            "and long utterances with scattered tags."
        )

    domain_subset = random.sample(DOMAINS, min(4, len(DOMAINS)))
    prompt = (
        f"Generate exactly {count} unique speech samples.\n"
        f"Slice type: {slice_name}{config['description']}\n"
        f"Tag density: {density_instruction}\n"
        f"Distribute across these domains: {', '.join(domain_subset)}\n"
        f"Make each sample UNIQUE — different topics, tones, speakers, situations.\n"
        f"Vary sentence length: mix short (15-25 words), medium (25-45 words), and long (45-80 words)."
    )

    try:
        response = await client.messages.create(
            model=model,
            max_tokens=4096,
            system=SYSTEM_PROMPT.format(
                tags=", ".join(ALL_BRACKET_TAGS),
                domains=", ".join(DOMAINS),
            ),
            messages=[{"role": "user", "content": prompt}],
        )

        content = response.content[0].text
        if "```json" in content:
            content = content.split("```json")[1].split("```")[0]
        elif "```" in content:
            content = content.split("```")[1].split("```")[0]

        batch = json.loads(content.strip())
        if not isinstance(batch, list):
            batch = batch.get("samples", batch.get("scripts", [batch]))

        for item in batch:
            item["slice_type"] = slice_name

        print(f"  [{slice_name}] Batch {batch_num + 1}: {len(batch)} scripts", flush=True)
        return batch

    except Exception as e:
        print(f"  [{slice_name}] Batch {batch_num + 1} FAILED: {e}", flush=True)
        return []


async def generate_slice(
    client: AsyncAnthropic,
    slice_name: str,
    count: int,
    model: str,
    max_concurrent: int = 10,
) -> list[dict]:
    """Generate all scripts for a slice with concurrent batches."""
    config = SLICE_CONFIG[slice_name]
    batch_size = 25
    num_batches = (count + batch_size - 1) // batch_size

    semaphore = asyncio.Semaphore(max_concurrent)

    async def limited_batch(batch_num: int, n: int):
        async with semaphore:
            return await generate_batch(client, slice_name, batch_num, n, config, model)

    tasks = []
    for i in range(num_batches):
        n = min(batch_size, count - i * batch_size)
        tasks.append(limited_batch(i, n))

    results = await asyncio.gather(*tasks)
    scripts = [s for batch in results for s in batch]
    return scripts[:count]


async def generate_full_dataset_async(
    total: int = 1000,
    output_path: str = "data/scripts/scripts.json",
    model: str = "claude-sonnet-4-5-20250929",
    max_concurrent: int = 10,
) -> list[dict]:
    """Generate the full balanced dataset with async concurrency."""
    client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    all_scripts = []

    for slice_name, config in SLICE_CONFIG.items():
        count = int(total * config["ratio"])
        print(f"\nGenerating {count} '{slice_name}' scripts ({(count + 24) // 25} batches, {max_concurrent} concurrent)...", flush=True)
        scripts = await generate_slice(client, slice_name, count, model, max_concurrent)
        all_scripts.extend(scripts)
        print(f"  Total for {slice_name}: {len(scripts)}", flush=True)

    random.shuffle(all_scripts)

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(all_scripts, f, indent=2)

    print(f"\n{'=' * 50}", flush=True)
    print(f"DATASET GENERATION COMPLETE", flush=True)
    print(f"{'=' * 50}", flush=True)
    print(f"Total scripts: {len(all_scripts)}", flush=True)
    for sn in SLICE_CONFIG:
        c = sum(1 for s in all_scripts if s.get("slice_type") == sn)
        print(f"  {sn}: {c}", flush=True)
    print(f"Saved to: {output_path}", flush=True)

    return all_scripts


def generate_full_dataset(total: int = 1000, output_path: str = "data/scripts/scripts.json") -> list[dict]:
    """Sync wrapper for async generation."""
    return asyncio.run(generate_full_dataset_async(total, output_path))


if __name__ == "__main__":
    generate_full_dataset()