| |
| """ |
| Generate synthetic cybersecurity NER training data from prompt templates. |
| |
| Usage: |
| python generate_synthetic_batch.py --prompt-id url_heavy_malware_infra --n 20 --output out.jsonl |
| python generate_synthetic_batch.py --all --output out.jsonl |
| python generate_synthetic_batch.py --validate data/processed/llm_generated_synthetic.jsonl |
| |
| Requires ANTHROPIC_API_KEY or OPENAI_API_KEY env var for LLM generation. |
| Use --dry-run to print prompts without calling the API. |
| """ |
| import argparse |
| import json |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| PROMPTS_FILE = Path(__file__).parent / "synthetic_prompts.json" |
|
|
| ENTITY_TYPES = [ |
| "MALWARE", "THREAT_ACTOR", "TOOL", "VULNERABILITY", "SYSTEM", |
| "ORGANIZATION", "IP_ADDRESS", "DOMAIN", "URL", "HASH", |
| "EMAIL", "CVE_ID", "FILEPATH", |
| ] |
|
|
|
|
| def load_prompts(path: Path = PROMPTS_FILE) -> list[dict]: |
| with open(path) as f: |
| return json.load(f)["prompts"] |
|
|
|
|
| def verify_offsets(record: dict) -> list[str]: |
| """Verify all span offsets match the text. Returns list of error strings.""" |
| errors = [] |
| text = record.get("text", "") |
| spans = record.get("spans", {}) |
| for key, offset_list in spans.items(): |
| |
| if ": " not in key: |
| errors.append(f"Bad span key format: {key!r}") |
| continue |
| etype, expected_value = key.split(": ", 1) |
| if etype not in ENTITY_TYPES: |
| errors.append(f"Unknown entity type: {etype!r}") |
| for start, end in offset_list: |
| if start < 0 or end > len(text) or start >= end: |
| errors.append(f"Invalid offset [{start},{end}) for text len {len(text)}: {key}") |
| continue |
| actual = text[start:end] |
| if actual != expected_value: |
| errors.append( |
| f"Offset mismatch for {key}: " |
| f"text[{start}:{end}]={actual!r} != {expected_value!r}" |
| ) |
| return errors |
|
|
|
|
| def try_fix_offsets(record: dict) -> dict: |
| """Attempt to fix span offsets by searching for the entity value in text.""" |
| text = record["text"] |
| fixed_spans = {} |
| for key, offset_list in record.get("spans", {}).items(): |
| if ": " not in key: |
| continue |
| etype, expected_value = key.split(": ", 1) |
| new_offsets = [] |
| for start, end in offset_list: |
| actual = text[start:end] if 0 <= start < end <= len(text) else "" |
| if actual == expected_value: |
| new_offsets.append([start, end]) |
| else: |
| |
| idx = text.find(expected_value) |
| if idx >= 0: |
| new_offsets.append([idx, idx + len(expected_value)]) |
| |
| if len(offset_list) > 1: |
| search_from = idx + len(expected_value) |
| while True: |
| idx2 = text.find(expected_value, search_from) |
| if idx2 < 0: |
| break |
| new_offsets.append([idx2, idx2 + len(expected_value)]) |
| search_from = idx2 + len(expected_value) |
| break |
| else: |
| new_offsets.append([start, end]) |
| if new_offsets: |
| fixed_spans[key] = new_offsets |
| record["spans"] = fixed_spans |
| return record |
|
|
|
|
| def parse_llm_response(response_text: str) -> list[dict]: |
| """Parse LLM response into list of records. Handles JSONL and JSON arrays.""" |
| records = [] |
| |
| for line in response_text.strip().split("\n"): |
| line = line.strip() |
| if not line or line.startswith("```"): |
| continue |
| try: |
| obj = json.loads(line) |
| if isinstance(obj, dict) and "text" in obj: |
| records.append(obj) |
| elif isinstance(obj, list): |
| records.extend(r for r in obj if isinstance(r, dict) and "text" in r) |
| except json.JSONDecodeError: |
| continue |
|
|
| |
| if not records: |
| try: |
| |
| cleaned = re.sub(r"```(?:json)?\n?", "", response_text).strip() |
| obj = json.loads(cleaned) |
| if isinstance(obj, list): |
| records = [r for r in obj if isinstance(r, dict) and "text" in r] |
| except json.JSONDecodeError: |
| pass |
|
|
| return records |
|
|
|
|
| def generate_with_anthropic(prompt: str, n: int) -> str: |
| """Call Anthropic API.""" |
| try: |
| import anthropic |
| except ImportError: |
| sys.exit("pip install anthropic") |
| client = anthropic.Anthropic() |
| msg = client.messages.create( |
| model="claude-sonnet-4-20250514", |
| max_tokens=8192, |
| messages=[{"role": "user", "content": prompt.replace("{n}", str(n))}], |
| system="You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary.", |
| ) |
| return msg.content[0].text |
|
|
|
|
| def generate_with_openai(prompt: str, n: int) -> str: |
| """Call OpenAI API.""" |
| try: |
| import openai |
| except ImportError: |
| sys.exit("pip install openai") |
| client = openai.OpenAI() |
| resp = client.chat.completions.create( |
| model="gpt-4o", |
| max_tokens=8192, |
| messages=[ |
| {"role": "system", "content": "You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary."}, |
| {"role": "user", "content": prompt.replace("{n}", str(n))}, |
| ], |
| ) |
| return resp.choices[0].message.content |
|
|
|
|
| def generate_batch( |
| prompt_id: str | None, |
| n: int, |
| output_path: Path, |
| backend: str = "anthropic", |
| dry_run: bool = False, |
| fix: bool = True, |
| ): |
| prompts = load_prompts() |
| if prompt_id: |
| prompts = [p for p in prompts if p["id"] == prompt_id] |
| if not prompts: |
| sys.exit(f"Unknown prompt_id: {prompt_id}") |
|
|
| generate_fn = generate_with_anthropic if backend == "anthropic" else generate_with_openai |
|
|
| all_records = [] |
| total_errors = 0 |
|
|
| for pdef in prompts: |
| count = n if n else pdef.get("total_target", 20) |
| prompt_text = pdef["prompt"] |
| print(f"\n{'='*60}") |
| print(f"Prompt: {pdef['id']} | Target entities: {pdef['target_entities']} | N={count}") |
| print(f"{'='*60}") |
|
|
| if dry_run: |
| print(prompt_text.replace("{n}", str(count))) |
| continue |
|
|
| |
| batch_size = min(20, count) |
| generated = 0 |
| while generated < count: |
| this_batch = min(batch_size, count - generated) |
| print(f" Generating batch of {this_batch}...") |
| try: |
| raw = generate_fn(prompt_text, this_batch) |
| records = parse_llm_response(raw) |
| except Exception as e: |
| print(f" ERROR: {e}") |
| continue |
|
|
| for rec in records: |
| |
| if fix: |
| rec = try_fix_offsets(rec) |
| |
| errs = verify_offsets(rec) |
| if errs: |
| total_errors += len(errs) |
| for err in errs: |
| print(f" WARN: {err}") |
| if fix: |
| rec = try_fix_offsets(rec) |
| errs2 = verify_offsets(rec) |
| if errs2: |
| print(f" SKIP (unfixable): {rec.get('info', {}).get('id', '?')}") |
| continue |
| all_records.append(rec) |
|
|
| generated += this_batch |
| print(f" Got {len(records)} records (total so far: {len(all_records)})") |
|
|
| if dry_run: |
| return |
|
|
| |
| for i, rec in enumerate(all_records, 1): |
| if "info" not in rec: |
| rec["info"] = {} |
| rec["info"]["id"] = f"synth_batch_{i:05d}" |
|
|
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "a") as f: |
| for rec in all_records: |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
|
|
| |
| entity_counts: dict[str, int] = {} |
| for rec in all_records: |
| for key in rec.get("spans", {}): |
| etype = key.split(": ", 1)[0] if ": " in key else key |
| entity_counts[etype] = entity_counts.get(etype, 0) + len(rec["spans"][key]) |
|
|
| print(f"\n{'='*60}") |
| print(f"SUMMARY: {len(all_records)} records written to {output_path}") |
| print(f"Offset errors encountered: {total_errors}") |
| print(f"Entity distribution:") |
| for etype in sorted(entity_counts, key=entity_counts.get, reverse=True): |
| print(f" {etype}: {entity_counts[etype]}") |
|
|
|
|
| def validate_file(path: Path): |
| """Validate all records in an existing JSONL file.""" |
| total = 0 |
| bad = 0 |
| entity_counts: dict[str, int] = {} |
|
|
| with open(path) as f: |
| for i, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| rec = json.loads(line) |
| except json.JSONDecodeError: |
| print(f"Line {i}: invalid JSON") |
| bad += 1 |
| continue |
| total += 1 |
| errs = verify_offsets(rec) |
| if errs: |
| bad += 1 |
| for err in errs: |
| print(f"Line {i}: {err}") |
| for key, offsets in rec.get("spans", {}).items(): |
| etype = key.split(": ", 1)[0] if ": " in key else key |
| entity_counts[etype] = entity_counts.get(etype, 0) + len(offsets) |
|
|
| print(f"\nValidated {total} records, {bad} with errors") |
| print("Entity distribution:") |
| for etype in sorted(entity_counts, key=entity_counts.get, reverse=True): |
| print(f" {etype}: {entity_counts[etype]}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate synthetic cybersecurity NER data") |
| parser.add_argument("--prompt-id", help="Run a specific prompt template") |
| parser.add_argument("--all", action="store_true", help="Run all prompt templates") |
| parser.add_argument("--n", type=int, default=0, help="Examples per prompt (0=use template default)") |
| parser.add_argument("--output", type=Path, default=Path("data/processed/llm_generated_synthetic_v2.jsonl")) |
| parser.add_argument("--backend", choices=["anthropic", "openai"], default="anthropic") |
| parser.add_argument("--dry-run", action="store_true", help="Print prompts without calling API") |
| parser.add_argument("--no-fix", action="store_true", help="Skip automatic offset fixing") |
| parser.add_argument("--validate", type=Path, help="Validate an existing JSONL file") |
|
|
| args = parser.parse_args() |
|
|
| if args.validate: |
| validate_file(args.validate) |
| return |
|
|
| if not args.prompt_id and not args.all: |
| parser.error("Specify --prompt-id or --all") |
|
|
| if not args.dry_run and not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("OPENAI_API_KEY"): |
| sys.exit("Set ANTHROPIC_API_KEY or OPENAI_API_KEY") |
|
|
| if args.all: |
| args.prompt_id = None |
|
|
| generate_batch( |
| prompt_id=args.prompt_id, |
| n=args.n, |
| output_path=args.output, |
| backend=args.backend, |
| dry_run=args.dry_run, |
| fix=not args.no_fix, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|