#!/usr/bin/env python3 """ Generate synthetic cybersecurity NER training data from prompt templates. Usage: python generate_synthetic_batch.py --prompt-id url_heavy_malware_infra --n 20 --output out.jsonl python generate_synthetic_batch.py --all --output out.jsonl python generate_synthetic_batch.py --validate data/processed/llm_generated_synthetic.jsonl Requires ANTHROPIC_API_KEY or OPENAI_API_KEY env var for LLM generation. Use --dry-run to print prompts without calling the API. """ import argparse import json import os import re import sys from pathlib import Path PROMPTS_FILE = Path(__file__).parent / "synthetic_prompts.json" ENTITY_TYPES = [ "MALWARE", "THREAT_ACTOR", "TOOL", "VULNERABILITY", "SYSTEM", "ORGANIZATION", "IP_ADDRESS", "DOMAIN", "URL", "HASH", "EMAIL", "CVE_ID", "FILEPATH", ] def load_prompts(path: Path = PROMPTS_FILE) -> list[dict]: with open(path) as f: return json.load(f)["prompts"] def verify_offsets(record: dict) -> list[str]: """Verify all span offsets match the text. Returns list of error strings.""" errors = [] text = record.get("text", "") spans = record.get("spans", {}) for key, offset_list in spans.items(): # Parse "TYPE: value" from key if ": " not in key: errors.append(f"Bad span key format: {key!r}") continue etype, expected_value = key.split(": ", 1) if etype not in ENTITY_TYPES: errors.append(f"Unknown entity type: {etype!r}") for start, end in offset_list: if start < 0 or end > len(text) or start >= end: errors.append(f"Invalid offset [{start},{end}) for text len {len(text)}: {key}") continue actual = text[start:end] if actual != expected_value: errors.append( f"Offset mismatch for {key}: " f"text[{start}:{end}]={actual!r} != {expected_value!r}" ) return errors def try_fix_offsets(record: dict) -> dict: """Attempt to fix span offsets by searching for the entity value in text.""" text = record["text"] fixed_spans = {} for key, offset_list in record.get("spans", {}).items(): if ": " not in key: continue etype, expected_value = key.split(": ", 1) new_offsets = [] for start, end in offset_list: actual = text[start:end] if 0 <= start < end <= len(text) else "" if actual == expected_value: new_offsets.append([start, end]) else: # Try to find the value in text idx = text.find(expected_value) if idx >= 0: new_offsets.append([idx, idx + len(expected_value)]) # Look for additional occurrences if there were multiple if len(offset_list) > 1: search_from = idx + len(expected_value) while True: idx2 = text.find(expected_value, search_from) if idx2 < 0: break new_offsets.append([idx2, idx2 + len(expected_value)]) search_from = idx2 + len(expected_value) break # We handled all occurrences else: new_offsets.append([start, end]) # Keep broken, will be caught by validate if new_offsets: fixed_spans[key] = new_offsets record["spans"] = fixed_spans return record def parse_llm_response(response_text: str) -> list[dict]: """Parse LLM response into list of records. Handles JSONL and JSON arrays.""" records = [] # Try line-by-line JSONL first for line in response_text.strip().split("\n"): line = line.strip() if not line or line.startswith("```"): continue try: obj = json.loads(line) if isinstance(obj, dict) and "text" in obj: records.append(obj) elif isinstance(obj, list): records.extend(r for r in obj if isinstance(r, dict) and "text" in r) except json.JSONDecodeError: continue # If nothing parsed line-by-line, try the whole thing as JSON array if not records: try: # Strip markdown code fences cleaned = re.sub(r"```(?:json)?\n?", "", response_text).strip() obj = json.loads(cleaned) if isinstance(obj, list): records = [r for r in obj if isinstance(r, dict) and "text" in r] except json.JSONDecodeError: pass return records def generate_with_anthropic(prompt: str, n: int) -> str: """Call Anthropic API.""" try: import anthropic except ImportError: sys.exit("pip install anthropic") client = anthropic.Anthropic() msg = client.messages.create( model="claude-sonnet-4-20250514", max_tokens=8192, messages=[{"role": "user", "content": prompt.replace("{n}", str(n))}], system="You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary.", ) return msg.content[0].text def generate_with_openai(prompt: str, n: int) -> str: """Call OpenAI API.""" try: import openai except ImportError: sys.exit("pip install openai") client = openai.OpenAI() resp = client.chat.completions.create( model="gpt-4o", max_tokens=8192, messages=[ {"role": "system", "content": "You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary."}, {"role": "user", "content": prompt.replace("{n}", str(n))}, ], ) return resp.choices[0].message.content def generate_batch( prompt_id: str | None, n: int, output_path: Path, backend: str = "anthropic", dry_run: bool = False, fix: bool = True, ): prompts = load_prompts() if prompt_id: prompts = [p for p in prompts if p["id"] == prompt_id] if not prompts: sys.exit(f"Unknown prompt_id: {prompt_id}") generate_fn = generate_with_anthropic if backend == "anthropic" else generate_with_openai all_records = [] total_errors = 0 for pdef in prompts: count = n if n else pdef.get("total_target", 20) prompt_text = pdef["prompt"] print(f"\n{'='*60}") print(f"Prompt: {pdef['id']} | Target entities: {pdef['target_entities']} | N={count}") print(f"{'='*60}") if dry_run: print(prompt_text.replace("{n}", str(count))) continue # Generate in batches of 20 batch_size = min(20, count) generated = 0 while generated < count: this_batch = min(batch_size, count - generated) print(f" Generating batch of {this_batch}...") try: raw = generate_fn(prompt_text, this_batch) records = parse_llm_response(raw) except Exception as e: print(f" ERROR: {e}") continue for rec in records: # Fix offsets if requested if fix: rec = try_fix_offsets(rec) # Validate errs = verify_offsets(rec) if errs: total_errors += len(errs) for err in errs: print(f" WARN: {err}") if fix: rec = try_fix_offsets(rec) errs2 = verify_offsets(rec) if errs2: print(f" SKIP (unfixable): {rec.get('info', {}).get('id', '?')}") continue all_records.append(rec) generated += this_batch print(f" Got {len(records)} records (total so far: {len(all_records)})") if dry_run: return # Assign sequential IDs for i, rec in enumerate(all_records, 1): if "info" not in rec: rec["info"] = {} rec["info"]["id"] = f"synth_batch_{i:05d}" # Write output output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "a") as f: for rec in all_records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") # Summary stats entity_counts: dict[str, int] = {} for rec in all_records: for key in rec.get("spans", {}): etype = key.split(": ", 1)[0] if ": " in key else key entity_counts[etype] = entity_counts.get(etype, 0) + len(rec["spans"][key]) print(f"\n{'='*60}") print(f"SUMMARY: {len(all_records)} records written to {output_path}") print(f"Offset errors encountered: {total_errors}") print(f"Entity distribution:") for etype in sorted(entity_counts, key=entity_counts.get, reverse=True): print(f" {etype}: {entity_counts[etype]}") def validate_file(path: Path): """Validate all records in an existing JSONL file.""" total = 0 bad = 0 entity_counts: dict[str, int] = {} with open(path) as f: for i, line in enumerate(f, 1): line = line.strip() if not line: continue try: rec = json.loads(line) except json.JSONDecodeError: print(f"Line {i}: invalid JSON") bad += 1 continue total += 1 errs = verify_offsets(rec) if errs: bad += 1 for err in errs: print(f"Line {i}: {err}") for key, offsets in rec.get("spans", {}).items(): etype = key.split(": ", 1)[0] if ": " in key else key entity_counts[etype] = entity_counts.get(etype, 0) + len(offsets) print(f"\nValidated {total} records, {bad} with errors") print("Entity distribution:") for etype in sorted(entity_counts, key=entity_counts.get, reverse=True): print(f" {etype}: {entity_counts[etype]}") def main(): parser = argparse.ArgumentParser(description="Generate synthetic cybersecurity NER data") parser.add_argument("--prompt-id", help="Run a specific prompt template") parser.add_argument("--all", action="store_true", help="Run all prompt templates") parser.add_argument("--n", type=int, default=0, help="Examples per prompt (0=use template default)") parser.add_argument("--output", type=Path, default=Path("data/processed/llm_generated_synthetic_v2.jsonl")) parser.add_argument("--backend", choices=["anthropic", "openai"], default="anthropic") parser.add_argument("--dry-run", action="store_true", help="Print prompts without calling API") parser.add_argument("--no-fix", action="store_true", help="Skip automatic offset fixing") parser.add_argument("--validate", type=Path, help="Validate an existing JSONL file") args = parser.parse_args() if args.validate: validate_file(args.validate) return if not args.prompt_id and not args.all: parser.error("Specify --prompt-id or --all") if not args.dry_run and not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("OPENAI_API_KEY"): sys.exit("Set ANTHROPIC_API_KEY or OPENAI_API_KEY") if args.all: args.prompt_id = None generate_batch( prompt_id=args.prompt_id, n=args.n, output_path=args.output, backend=args.backend, dry_run=args.dry_run, fix=not args.no_fix, ) if __name__ == "__main__": main()