bee / scripts /distill_domain_seeds.py
Bee Deploy
HF Space backend deploy [de0cba5]
5e21013
"""Generate teacher-distilled Q&A seed data for the 10 Tier-1 Bee domains.
Why this exists
---------------
The Kaggle training cron is producing flat loss (~3.84 across 5 runs) because
the only training data we have is 20 hand-written bootstrap rows about Bee's
identity β€” all in the "general" domain. With the cron now rotating through
all 10 Tier-1 domains, every domain except "general" will return zero rows
and report `partial`.
Distillation closes the gap: a strong teacher LLM generates realistic
domain-specific Q&A pairs. The trained adapter actually learns domain
patterns, loss drops, benchmark scores improve.
Methodology (auditable, no fake numbers)
----------------------------------------
For each Tier-1 domain in `bee/domains.py:TIER_1_DOMAINS` we ask the teacher
to produce N realistic user-question + high-quality-answer pairs. Each row
records the exact teacher provider + model that produced it (in the `source`
field) so we can reproduce, audit, or revoke specific rows later.
Default teacher: Google Gemini 2.0 Flash ($0.40/M output tokens β€” cheapest
of the four configured teachers per `bee/teacher_providers.py`). Override
via --provider {anthropic,deepseek,openai,google}.
Cost estimate at default settings:
10 domains Γ— 200 pairs/domain Γ— ~250 tokens/pair = ~500k output tokens
Gemini Flash: 500k Γ— $0.40/M = ~$0.20 total
DeepSeek: 500k Γ— $2.19/M = ~$1.10 total
Claude: 500k Γ— $15/M = ~$7.50 total
Output rows
-----------
Same schema as scripts/seed_bee_interactions.py with `kind=distilled` and
`source` pointing at the exact teacher response. Uploaded to
cuilabs/bee-interactions/data/<domain>.jsonl. The Kaggle training kernel's
filter (assistant + not-downvoted + matching domain) picks them up
automatically.
Usage
-----
HF_TOKEN=hf_xxx \\
BEE_GOOGLE_API_KEY=AIza... \\
python scripts/distill_domain_seeds.py [options]
Options
-------
--domains general,programming,... Only generate for these (default: all 10 Tier-1)
--pairs N Pairs per domain (default 200)
--batch N Pairs per teacher call (default 25)
--provider <name> Teacher provider (default: deepseek; chain order: deepseek > google > openai > anthropic)
--dry-run Print plan, don't call teachers
--skip-existing Skip domains already in the HF dataset
"""
from __future__ import annotations
import argparse
import datetime
import json
import os
import sys
import tempfile
import threading
import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Iterator
REPO_ROOT = Path(__file__).resolve().parent.parent
# Mirror of bee/domains.py:TIER_1_DOMAINS. We don't import bee.domains here
# because that pulls heavy backend deps; the canonical list lives in
# bee/domains.py and we re-state it here with a comment pointing back.
TIER_1_DOMAINS = [
"general",
"programming",
"ai",
"cybersecurity",
"quantum",
"fintech",
"blockchain",
"infrastructure",
"research",
"business",
]
# Per-domain teacher tier. "pro" routes to a strong reasoning model
# (DeepSeek V4 Pro, currently 75%-off through May 31, 2026); "flash"
# routes to the cheap workhorse (V4 Flash). Picked by cognitive demand:
# domains that benefit from chain-of-thought reasoning get pro;
# pattern-based / operational domains get flash. Override per-domain
# at runtime by passing --pro-domains or --flash-domains.
DOMAIN_TIER: dict[str, str] = {
# Strong-reasoning domains β€” get the pro tier.
"cybersecurity": "pro", # threat models, attack chains, deep tradeoffs
"quantum": "pro", # math + algorithm analysis
"research": "pro", # methodology, paper critique, statistical depth
# Pattern-based / operational β€” flash is plenty.
"general": "flash",
"programming": "flash",
"ai": "flash",
"fintech": "flash",
"blockchain": "flash",
"infrastructure": "flash",
"business": "flash",
}
# Per-domain prompt context. Honest, real, drawn from how a working
# professional in each domain would actually talk. No invented stats.
DOMAIN_CONTEXT = {
"general": (
"general technical assistance β€” clear, well-grounded answers across "
"common professional and personal computing topics"
),
"programming": (
"software engineering β€” code review, architecture, debugging, "
"language-specific patterns (Python, TypeScript, Go, Rust, etc.), "
"build tooling, testing, and CI/CD"
),
"ai": (
"AI/ML β€” model architecture, training, inference, evaluation, RAG, "
"fine-tuning, prompt engineering, LLM tooling (HuggingFace, PyTorch, "
"vLLM, transformers), and the practical tradeoffs between approaches"
),
"cybersecurity": (
"cybersecurity β€” threat modeling, vulnerability analysis, secure code "
"review, OWASP, network security, cryptography (including post-quantum), "
"incident response, and security tooling. Focus on defensive use; "
"refuse weaponizable specifics"
),
"quantum": (
"quantum computing β€” Qiskit, circuit design, quantum algorithms (Shor, "
"Grover, VQE, QAOA), error correction, hardware (IBM Heron, IonQ, "
"Quantinuum), post-quantum cryptography (FIPS 203/204/205), and "
"the realistic limits of NISQ-era devices"
),
"fintech": (
"financial technology β€” payments, trading systems, market data, "
"regulatory compliance (PCI-DSS, KYC/AML), accounting concepts, "
"DeFi mechanics, risk management. Generic explanations only β€” "
"explicitly NOT personalized investment advice"
),
"blockchain": (
"blockchain and distributed ledgers β€” Bitcoin/Ethereum mechanics, "
"smart contract design (Solidity, Anchor), L2 scaling, consensus "
"(PoS, PoW, BFT), cryptographic primitives, MEV, and honest framing "
"of tradeoffs vs traditional databases"
),
"infrastructure": (
"cloud + infrastructure β€” AWS/GCP/Azure, Kubernetes, Terraform, "
"observability (Prometheus, OpenTelemetry), service mesh, "
"reliability engineering, capacity planning, and cost optimization"
),
"research": (
"research methodology β€” literature review, experimental design, "
"statistics, reproducibility, paper structure, peer review, and "
"specifically how to read and critique ML/CS papers from arXiv"
),
"business": (
"business operations and strategy for technical founders β€” pricing, "
"GTM, hiring, fundraising mechanics, term-sheet basics, "
"incorporation, and how to evaluate technical tradeoffs against "
"business constraints"
),
}
DATASET_REPO = "cuilabs/bee-interactions"
PROMPT_TEMPLATE = """You are generating training data for Bee, a domain-specialized AI assistant built by CUI Labs.
Domain: {domain_label}
Domain context: {domain_context}
Generate {n} distinct user-question + high-quality-answer pairs that a working professional in this domain might genuinely ask an AI assistant.
Requirements:
- Questions must be REALISTIC and SPECIFIC (no generic "what is X?" puffballs).
- Mix difficulty: ~30% beginner, ~50% intermediate, ~20% expert.
- Answers must be ACCURATE, CONCISE (2-6 paragraphs typical), and admit uncertainty when appropriate.
- Include code, equations, or commands where natural β€” but only if correct.
- Cover a wide range of subtopics within the domain.
- DO NOT invent statistics, dates, or proprietary product claims you cannot verify.
- DO NOT pretend to have personal experiences. Speak as a knowledgeable assistant.
- DO NOT include disclaimers like "I am an AI" β€” just answer.
Output STRICT JSON, a single object with this exact shape:
{{
"pairs": [
{{"prompt": "...", "content": "..."}},
{{"prompt": "...", "content": "..."}}
]
}}
The `pairs` array must contain exactly {n} elements. No markdown fences, no preamble, no trailing text β€” just the JSON object.
Generate now."""
def call_teacher(provider: str, prompt: str, model_override: str | None = None) -> tuple[str, dict]:
"""Call the configured teacher provider, return (text, telemetry).
We hit the OpenAI-compatible /chat/completions endpoint for all providers
except Anthropic. Anthropic uses /v1/messages with x-api-key. This is the
same logic baked into bee/teacher_providers.py β€” kept inline here so the
script doesn't pull the full backend module tree.
"""
if provider == "anthropic":
api_key = os.environ["BEE_TEACHER_API_KEY"]
url = "https://api.anthropic.com/v1/messages"
# Haiku 4.5 supports up to 16384 output tokens; Sonnet 4 the same.
# We override model via BEE_ANTHROPIC_MODEL so the script can pick
# the cheap one (Haiku 4.5) regardless of what the runtime config uses.
model = model_override or os.environ.get("BEE_ANTHROPIC_MODEL", "claude-haiku-4-5")
body = {
"model": model,
"max_tokens": 16384,
"messages": [{"role": "user", "content": prompt}],
}
headers = {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json",
}
elif provider == "deepseek":
api_key = os.environ["BEE_DEEPSEEK_API_KEY"]
url = "https://api.deepseek.com/v1/chat/completions"
# DeepSeek V4 model names: deepseek-v4-flash | deepseek-v4-pro.
# Legacy aliases (deepseek-chat, deepseek-reasoner) both route to
# v4-flash now; use explicit names so distillation provenance is
# honest. Default to flash for cost; override to pro for hardest
# domains via BEE_DEEPSEEK_MODEL=deepseek-v4-pro.
# max_tokens: DeepSeek V4 advertises a 1M-token context window
# and a 384K-token max-output ceiling per call (verified on
# api-docs.deepseek.com 2026-04-29). We use 128K β€” comfortable
# headroom for batch=200 (~140K out tokens) including V4 Pro's
# reasoning_tokens overhead, without a single runaway response
# blowing the day's budget. Bump toward 384K only if you need
# mega-batches (1000+ pairs) per call.
# response_format=json_object: forces clean JSON, eliminates the
# parse-recovery code path for the happy case.
model = model_override or os.environ.get("BEE_DEEPSEEK_MODEL", "deepseek-v4-flash")
body = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 131072,
"response_format": {"type": "json_object"},
}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
elif provider == "openai":
api_key = os.environ["BEE_OPENAI_API_KEY"]
url = "https://api.openai.com/v1/chat/completions"
# GPT-5 (Aug 2025 GA) supports 128K output tokens. Default here
# matches teacher_providers.py PROVIDERS["openai"].default_model.
# Bump to gpt-5.5 via BEE_OPENAI_MODEL if a job needs the (3x more
# expensive) latest. Override max_tokens at call site if needed.
model = model_override or os.environ.get("BEE_OPENAI_MODEL", "gpt-5")
body = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 131072,
"response_format": {"type": "json_object"},
}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
elif provider == "google":
api_key = os.environ["BEE_GOOGLE_API_KEY"]
url = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions"
# Gemini 2.5 Flash supports 64K output tokens. gemini-2.0-flash
# sunsets 2026-06-01 per ai.google.dev/gemini-api/docs/pricing
# so default is bumped to 2.5. Override to gemini-2.5-pro via
# BEE_GOOGLE_MODEL for higher quality.
model = model_override or os.environ.get("BEE_GOOGLE_MODEL", "gemini-2.5-flash")
body = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 65536,
"response_format": {"type": "json_object"},
}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
else:
raise ValueError(f"unknown provider: {provider}")
req = urllib.request.Request(
url,
data=json.dumps(body).encode("utf-8"),
headers=headers,
method="POST",
)
# Retry on:
# - transient network errors (ConnectionResetError seen back-to-back
# from Anthropic on long runs)
# - HTTP 429 (rate-limited; DeepSeek dynamically throttles per
# api-docs.deepseek.com/quick_start/rate_limit, no fixed RPM cap)
# - HTTP 502 / 503 / 504 (gateway / overload β€” Anthropic + DeepSeek
# both surface these under sustained load)
# Auth / quota / not-found (401, 403, 404, 400) are fatal β€” no retry.
# Honors Retry-After header on 429s when the server provides one.
TRANSIENT_HTTP = {429, 502, 503, 504}
last_err: Exception | None = None
raw = ""
elapsed = 0.0
for attempt in range(5): # 5 attempts: 0, 5, 15, 35, 75s default backoff
try:
t0 = time.time()
with urllib.request.urlopen(req, timeout=300) as resp:
raw = resp.read().decode("utf-8")
elapsed = time.time() - t0
last_err = None
break
except urllib.error.HTTPError as e:
if e.code not in TRANSIENT_HTTP:
raise # fatal: auth, quota, schema
last_err = e
# Honor Retry-After header if the server provided one (seconds).
retry_after = e.headers.get("Retry-After") if hasattr(e, "headers") else None
try:
ra = int(retry_after) if retry_after else None
except ValueError:
ra = None
backoff = ra if ra and ra > 0 else (5 * (2**attempt) if attempt > 0 else 5)
print(
f" http {e.code} ({e.reason}); retry {attempt + 1}/4 in {backoff}s"
+ (" (Retry-After)" if ra else "")
)
time.sleep(backoff)
except (ConnectionResetError, urllib.error.URLError, TimeoutError, OSError) as e:
last_err = e
backoff = 5 * (2**attempt) if attempt > 0 else 5
print(f" transient error ({type(e).__name__}: {e}); retry {attempt + 1}/4 in {backoff}s")
time.sleep(backoff)
if last_err is not None:
raise last_err
parsed = json.loads(raw)
if provider == "anthropic":
text = "".join(b.get("text", "") for b in parsed.get("content", []) if b.get("type") == "text")
usage = parsed.get("usage", {})
telemetry = {
"model": model,
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"elapsed_s": round(elapsed, 2),
}
else:
text = parsed["choices"][0]["message"]["content"]
usage = parsed.get("usage", {})
telemetry = {
"model": model,
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"elapsed_s": round(elapsed, 2),
}
return text, telemetry
def parse_pairs(raw: str) -> list[dict]:
"""Extract Q&A pairs from teacher output.
Happy path (with response_format=json_object set on the request):
teacher returns `{"pairs": [{"prompt": ..., "content": ...}, ...]}`
cleanly. We parse and return.
Recovery path: tolerates code fences AND truncated output. Scans
for individual `{"prompt": ..., "content": ...}` sub-objects via
balanced-brace walk, parses each. Survives when max_tokens is hit
mid-response or the model stuffs JSON into a markdown fence.
"""
s = raw.strip()
if s.startswith("```"):
s = s.split("\n", 1)[1] if "\n" in s else s
if s.endswith("```"):
s = s.rsplit("```", 1)[0]
s = s.strip()
if s.startswith("json\n"):
s = s[5:]
# Fast path 1: top-level object with "pairs" key (json_object format).
try:
obj = json.loads(s)
if isinstance(obj, dict) and isinstance(obj.get("pairs"), list):
pairs = []
for x in obj["pairs"]:
if isinstance(x, dict):
prompt = (x.get("prompt") or "").strip()
content = (x.get("content") or "").strip()
if prompt and content:
pairs.append({"prompt": prompt, "content": content})
if pairs:
return pairs
except json.JSONDecodeError:
pass
# Fast path 2 (legacy): top-level array, no wrapper.
a = s.find("[")
b = s.rfind("]")
if a != -1 and b != -1:
try:
arr = json.loads(s[a : b + 1])
if isinstance(arr, list):
pairs = []
for x in arr:
if isinstance(x, dict):
prompt = (x.get("prompt") or "").strip()
content = (x.get("content") or "").strip()
if prompt and content:
pairs.append({"prompt": prompt, "content": content})
if pairs:
return pairs
except json.JSONDecodeError:
pass # fall through to recovery
# Recovery: walk character-by-character collecting balanced { ... }
# sub-objects, parse each. Tolerates truncation at the end.
pairs: list[dict] = []
i = 0 if a == -1 else a + 1 # start inside the array if we found one
n = len(s)
while i < n:
if s[i] != "{":
i += 1
continue
depth = 0
in_string = False
escape = False
start = i
end = -1
while i < n:
c = s[i]
if escape:
escape = False
elif c == "\\":
escape = True
elif c == '"':
in_string = not in_string
elif not in_string:
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
end = i + 1
break
i += 1
if end == -1:
break # truncated mid-object
try:
obj = json.loads(s[start:end])
prompt = (obj.get("prompt") or "").strip() if isinstance(obj, dict) else ""
content = (obj.get("content") or "").strip() if isinstance(obj, dict) else ""
if prompt and content:
pairs.append({"prompt": prompt, "content": content})
except json.JSONDecodeError:
pass
i = end
if not pairs:
raise ValueError(f"no parsable Q&A objects in teacher output (first 200 chars: {raw[:200]!r})")
return pairs
_print_lock = threading.Lock()
def _emit(s: str) -> None:
"""Thread-safe print so parallel domain workers don't interleave lines."""
with _print_lock:
print(s, flush=True)
def resolve_model(provider: str, tier: str) -> str | None:
"""Pick an explicit model name for (provider, tier). None = use the
provider's default. Currently only DeepSeek has a tier distinction
that's automatable; other providers fall through to their defaults."""
if provider == "deepseek":
return "deepseek-v4-pro" if tier == "pro" else "deepseek-v4-flash"
# For openai/google/anthropic, tier mapping is not yet wired β€”
# use whatever BEE_<PROVIDER>_MODEL or the script default specifies.
return None
def distill_domain(
domain: str,
total: int,
batch: int,
provider: str,
dry_run: bool,
tier: str = "flash",
) -> tuple[list[dict], dict]:
"""Generate `total` Q&A pairs for `domain` in batches of `batch`.
`tier` selects model strength when the provider supports it (currently
DeepSeek: "pro" | "flash"). Each row's `source` field records the
actual model that produced it, so per-row provenance survives even
when different domains use different teachers.
"""
rows: list[dict] = []
telemetry: dict = {
"calls": 0, "input_tokens": 0, "output_tokens": 0, "elapsed_s": 0.0,
"provider": provider, "tier": tier,
}
seen_prompts: set[str] = set()
model_override = resolve_model(provider, tier)
while len(rows) < total:
n = min(batch, total - len(rows))
prompt = PROMPT_TEMPLATE.format(
domain_label=domain, domain_context=DOMAIN_CONTEXT[domain], n=n
)
if dry_run:
_emit(f" [dry-run] would call {provider}/{tier} for {n} pairs ({domain})")
return [], telemetry
try:
text, tele = call_teacher(provider, prompt, model_override=model_override)
except (urllib.error.URLError, urllib.error.HTTPError) as e:
_emit(f" [{domain}] teacher call failed: {e}; aborting domain")
break
try:
pairs = parse_pairs(text)
except (ValueError, json.JSONDecodeError) as e:
_emit(f" [{domain}] parse failed: {e}; aborting domain")
break
for p in pairs:
if p["prompt"] in seen_prompts:
continue
seen_prompts.add(p["prompt"])
rows.append({
"messages": [
{"role": "user", "content": p["prompt"]},
{"role": "assistant", "content": p["content"]},
],
"role": "assistant",
"prompt": p["prompt"],
"content": p["content"],
"feedback": None,
"source": f"teacher_distillation:{provider}:{tele['model']}",
"domain": domain,
"kind": "distilled",
})
if len(rows) >= total:
break
telemetry["calls"] += 1
telemetry["input_tokens"] += tele["input_tokens"]
telemetry["output_tokens"] += tele["output_tokens"]
telemetry["elapsed_s"] += tele["elapsed_s"]
_emit(
f" [{domain}] +{len(pairs):3d} pairs "
f"({len(rows):3d}/{total}, +{tele['output_tokens']} out tok, "
f"{tele['elapsed_s']:.1f}s, model {tele['model']})"
)
return rows, telemetry
def write_jsonl(rows: list[dict]) -> str:
return "\n".join(json.dumps(r, ensure_ascii=False) for r in rows) + "\n"
def upload_domain_jsonl(domain: str, jsonl: str, hf_token: str) -> str:
from huggingface_hub import HfApi # type: ignore[import-not-found]
api = HfApi(token=hf_token)
with tempfile.TemporaryDirectory() as tmp:
out = Path(tmp) / f"{domain}.jsonl"
out.write_text(jsonl, encoding="utf-8")
api.upload_file(
path_or_fileobj=str(out),
path_in_repo=f"data/{domain}.jsonl",
repo_id=DATASET_REPO,
repo_type="dataset",
commit_message=f"distill: teacher-generated {domain} seeds",
)
return f"https://huggingface.co/datasets/{DATASET_REPO}/blob/main/data/{domain}.jsonl"
def existing_data_files(hf_token: str) -> set[str]:
from huggingface_hub import HfApi # type: ignore[import-not-found]
api = HfApi(token=hf_token)
files = api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset")
return {Path(f).stem for f in files if f.startswith("data/") and f.endswith(".jsonl")}
def _process_one_domain(
domain: str, args: argparse.Namespace, hf_token: str | None,
pro_set: set[str], flash_set: set[str],
) -> tuple[str, list[dict], dict]:
"""Worker: distill one domain end-to-end (generate β†’ upload). Designed
to be called from ThreadPoolExecutor β€” only depends on its arguments
and (thread-safe) module-level state."""
# Tier resolution: explicit CLI flags > DOMAIN_TIER default > "flash".
if domain in pro_set:
tier = "pro"
elif domain in flash_set:
tier = "flash"
else:
tier = DOMAIN_TIER.get(domain, "flash")
_emit(f"=== {domain} ({args.provider}/{tier}) ===")
rows, tele = distill_domain(
domain, args.pairs, args.batch, args.provider, args.dry_run, tier=tier
)
if not args.dry_run and rows and hf_token:
jsonl = write_jsonl(rows)
url = upload_domain_jsonl(domain, jsonl, hf_token)
_emit(f" [{domain}] uploaded {len(rows)} rows β†’ {url}")
return domain, rows, tele
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--domains", default="", help="comma-separated subset (default: all 10)")
p.add_argument("--pairs", type=int, default=200, help="pairs per domain (default 200)")
p.add_argument("--batch", type=int, default=50, help="pairs per teacher call (default 50 β€” fits comfortably in 32K max_tokens budget for V4 Pro/Flash)")
p.add_argument("--provider", default="deepseek", choices=["deepseek", "google", "openai", "anthropic"])
p.add_argument(
"--workers", type=int, default=2,
help="parallel domain workers (default 2). Each worker handles one "
"domain end-to-end. Increase cautiously to avoid teacher RPM caps.",
)
p.add_argument(
"--pro-domains", default="",
help="comma-separated list of domains to FORCE onto the pro tier. "
"Otherwise the per-domain default in DOMAIN_TIER applies.",
)
p.add_argument(
"--flash-domains", default="",
help="comma-separated list of domains to FORCE onto the flash tier.",
)
p.add_argument("--dry-run", action="store_true")
p.add_argument("--skip-existing", action="store_true",
help="skip domains already in cuilabs/bee-interactions/data/")
args = p.parse_args()
if args.domains:
domains = [d.strip() for d in args.domains.split(",") if d.strip()]
bad = [d for d in domains if d not in TIER_1_DOMAINS]
if bad:
sys.exit(f"unknown domains: {bad}. Valid: {TIER_1_DOMAINS}")
else:
domains = list(TIER_1_DOMAINS)
pro_set = {d.strip() for d in args.pro_domains.split(",") if d.strip()}
flash_set = {d.strip() for d in args.flash_domains.split(",") if d.strip()}
hf_token = os.environ.get("HF_TOKEN")
if not hf_token and not args.dry_run:
sys.exit("HF_TOKEN required (set in env or .env)")
skip = set()
if args.skip_existing and not args.dry_run:
try:
skip = existing_data_files(hf_token) # type: ignore[arg-type]
print(f"skip-existing: dataset already has {sorted(skip)}")
except Exception as e:
print(f"could not list existing files: {e}; not skipping any")
todo = [d for d in domains if d not in skip]
print(
f"\nplan: provider={args.provider}, pairs/domain={args.pairs}, "
f"batch={args.batch}, workers={args.workers}\n"
f" todo: {todo}\n"
f" tier per domain:"
)
for d in todo:
if d in pro_set:
tier = "pro (forced)"
elif d in flash_set:
tier = "flash (forced)"
else:
tier = DOMAIN_TIER.get(d, "flash")
print(f" {d:<18} β†’ {tier}")
print()
started = datetime.datetime.now(datetime.timezone.utc).isoformat()
overall = {"calls": 0, "input_tokens": 0, "output_tokens": 0, "elapsed_s": 0.0, "rows": 0}
# Parallel worker pool. ThreadPoolExecutor is correct here β€” these
# workers are I/O-bound (HTTP roundtrips to teacher APIs); the GIL
# is released during socket reads so we get real concurrency.
with ThreadPoolExecutor(max_workers=max(1, args.workers)) as ex:
futures = {ex.submit(_process_one_domain, d, args, hf_token, pro_set, flash_set): d for d in todo}
for fut in as_completed(futures):
domain = futures[fut]
try:
_, rows, tele = fut.result()
for k in ("calls", "input_tokens", "output_tokens", "elapsed_s"):
overall[k] += tele[k]
overall["rows"] += len(rows)
except Exception as e:
_emit(f" [{domain}] worker failed: {type(e).__name__}: {e}")
print(
f"\nDONE. started={started}\n"
f" total rows: {overall['rows']}\n"
f" teacher calls: {overall['calls']}\n"
f" input tokens: {overall['input_tokens']}, output tokens: {overall['output_tokens']}\n"
f" elapsed: {overall['elapsed_s']:.1f}s"
)
if __name__ == "__main__":
main()