arcspan / scripts /generate_synthetic_batch.py
chairulridjal's picture
Add files using upload-large-folder tool
038e086 verified
#!/usr/bin/env python3
"""
Generate synthetic cybersecurity NER training data from prompt templates.
Usage:
python generate_synthetic_batch.py --prompt-id url_heavy_malware_infra --n 20 --output out.jsonl
python generate_synthetic_batch.py --all --output out.jsonl
python generate_synthetic_batch.py --validate data/processed/llm_generated_synthetic.jsonl
Requires ANTHROPIC_API_KEY or OPENAI_API_KEY env var for LLM generation.
Use --dry-run to print prompts without calling the API.
"""
import argparse
import json
import os
import re
import sys
from pathlib import Path
PROMPTS_FILE = Path(__file__).parent / "synthetic_prompts.json"
ENTITY_TYPES = [
"MALWARE", "THREAT_ACTOR", "TOOL", "VULNERABILITY", "SYSTEM",
"ORGANIZATION", "IP_ADDRESS", "DOMAIN", "URL", "HASH",
"EMAIL", "CVE_ID", "FILEPATH",
]
def load_prompts(path: Path = PROMPTS_FILE) -> list[dict]:
with open(path) as f:
return json.load(f)["prompts"]
def verify_offsets(record: dict) -> list[str]:
"""Verify all span offsets match the text. Returns list of error strings."""
errors = []
text = record.get("text", "")
spans = record.get("spans", {})
for key, offset_list in spans.items():
# Parse "TYPE: value" from key
if ": " not in key:
errors.append(f"Bad span key format: {key!r}")
continue
etype, expected_value = key.split(": ", 1)
if etype not in ENTITY_TYPES:
errors.append(f"Unknown entity type: {etype!r}")
for start, end in offset_list:
if start < 0 or end > len(text) or start >= end:
errors.append(f"Invalid offset [{start},{end}) for text len {len(text)}: {key}")
continue
actual = text[start:end]
if actual != expected_value:
errors.append(
f"Offset mismatch for {key}: "
f"text[{start}:{end}]={actual!r} != {expected_value!r}"
)
return errors
def try_fix_offsets(record: dict) -> dict:
"""Attempt to fix span offsets by searching for the entity value in text."""
text = record["text"]
fixed_spans = {}
for key, offset_list in record.get("spans", {}).items():
if ": " not in key:
continue
etype, expected_value = key.split(": ", 1)
new_offsets = []
for start, end in offset_list:
actual = text[start:end] if 0 <= start < end <= len(text) else ""
if actual == expected_value:
new_offsets.append([start, end])
else:
# Try to find the value in text
idx = text.find(expected_value)
if idx >= 0:
new_offsets.append([idx, idx + len(expected_value)])
# Look for additional occurrences if there were multiple
if len(offset_list) > 1:
search_from = idx + len(expected_value)
while True:
idx2 = text.find(expected_value, search_from)
if idx2 < 0:
break
new_offsets.append([idx2, idx2 + len(expected_value)])
search_from = idx2 + len(expected_value)
break # We handled all occurrences
else:
new_offsets.append([start, end]) # Keep broken, will be caught by validate
if new_offsets:
fixed_spans[key] = new_offsets
record["spans"] = fixed_spans
return record
def parse_llm_response(response_text: str) -> list[dict]:
"""Parse LLM response into list of records. Handles JSONL and JSON arrays."""
records = []
# Try line-by-line JSONL first
for line in response_text.strip().split("\n"):
line = line.strip()
if not line or line.startswith("```"):
continue
try:
obj = json.loads(line)
if isinstance(obj, dict) and "text" in obj:
records.append(obj)
elif isinstance(obj, list):
records.extend(r for r in obj if isinstance(r, dict) and "text" in r)
except json.JSONDecodeError:
continue
# If nothing parsed line-by-line, try the whole thing as JSON array
if not records:
try:
# Strip markdown code fences
cleaned = re.sub(r"```(?:json)?\n?", "", response_text).strip()
obj = json.loads(cleaned)
if isinstance(obj, list):
records = [r for r in obj if isinstance(r, dict) and "text" in r]
except json.JSONDecodeError:
pass
return records
def generate_with_anthropic(prompt: str, n: int) -> str:
"""Call Anthropic API."""
try:
import anthropic
except ImportError:
sys.exit("pip install anthropic")
client = anthropic.Anthropic()
msg = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=8192,
messages=[{"role": "user", "content": prompt.replace("{n}", str(n))}],
system="You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary.",
)
return msg.content[0].text
def generate_with_openai(prompt: str, n: int) -> str:
"""Call OpenAI API."""
try:
import openai
except ImportError:
sys.exit("pip install openai")
client = openai.OpenAI()
resp = client.chat.completions.create(
model="gpt-4o",
max_tokens=8192,
messages=[
{"role": "system", "content": "You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary."},
{"role": "user", "content": prompt.replace("{n}", str(n))},
],
)
return resp.choices[0].message.content
def generate_batch(
prompt_id: str | None,
n: int,
output_path: Path,
backend: str = "anthropic",
dry_run: bool = False,
fix: bool = True,
):
prompts = load_prompts()
if prompt_id:
prompts = [p for p in prompts if p["id"] == prompt_id]
if not prompts:
sys.exit(f"Unknown prompt_id: {prompt_id}")
generate_fn = generate_with_anthropic if backend == "anthropic" else generate_with_openai
all_records = []
total_errors = 0
for pdef in prompts:
count = n if n else pdef.get("total_target", 20)
prompt_text = pdef["prompt"]
print(f"\n{'='*60}")
print(f"Prompt: {pdef['id']} | Target entities: {pdef['target_entities']} | N={count}")
print(f"{'='*60}")
if dry_run:
print(prompt_text.replace("{n}", str(count)))
continue
# Generate in batches of 20
batch_size = min(20, count)
generated = 0
while generated < count:
this_batch = min(batch_size, count - generated)
print(f" Generating batch of {this_batch}...")
try:
raw = generate_fn(prompt_text, this_batch)
records = parse_llm_response(raw)
except Exception as e:
print(f" ERROR: {e}")
continue
for rec in records:
# Fix offsets if requested
if fix:
rec = try_fix_offsets(rec)
# Validate
errs = verify_offsets(rec)
if errs:
total_errors += len(errs)
for err in errs:
print(f" WARN: {err}")
if fix:
rec = try_fix_offsets(rec)
errs2 = verify_offsets(rec)
if errs2:
print(f" SKIP (unfixable): {rec.get('info', {}).get('id', '?')}")
continue
all_records.append(rec)
generated += this_batch
print(f" Got {len(records)} records (total so far: {len(all_records)})")
if dry_run:
return
# Assign sequential IDs
for i, rec in enumerate(all_records, 1):
if "info" not in rec:
rec["info"] = {}
rec["info"]["id"] = f"synth_batch_{i:05d}"
# Write output
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "a") as f:
for rec in all_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
# Summary stats
entity_counts: dict[str, int] = {}
for rec in all_records:
for key in rec.get("spans", {}):
etype = key.split(": ", 1)[0] if ": " in key else key
entity_counts[etype] = entity_counts.get(etype, 0) + len(rec["spans"][key])
print(f"\n{'='*60}")
print(f"SUMMARY: {len(all_records)} records written to {output_path}")
print(f"Offset errors encountered: {total_errors}")
print(f"Entity distribution:")
for etype in sorted(entity_counts, key=entity_counts.get, reverse=True):
print(f" {etype}: {entity_counts[etype]}")
def validate_file(path: Path):
"""Validate all records in an existing JSONL file."""
total = 0
bad = 0
entity_counts: dict[str, int] = {}
with open(path) as f:
for i, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except json.JSONDecodeError:
print(f"Line {i}: invalid JSON")
bad += 1
continue
total += 1
errs = verify_offsets(rec)
if errs:
bad += 1
for err in errs:
print(f"Line {i}: {err}")
for key, offsets in rec.get("spans", {}).items():
etype = key.split(": ", 1)[0] if ": " in key else key
entity_counts[etype] = entity_counts.get(etype, 0) + len(offsets)
print(f"\nValidated {total} records, {bad} with errors")
print("Entity distribution:")
for etype in sorted(entity_counts, key=entity_counts.get, reverse=True):
print(f" {etype}: {entity_counts[etype]}")
def main():
parser = argparse.ArgumentParser(description="Generate synthetic cybersecurity NER data")
parser.add_argument("--prompt-id", help="Run a specific prompt template")
parser.add_argument("--all", action="store_true", help="Run all prompt templates")
parser.add_argument("--n", type=int, default=0, help="Examples per prompt (0=use template default)")
parser.add_argument("--output", type=Path, default=Path("data/processed/llm_generated_synthetic_v2.jsonl"))
parser.add_argument("--backend", choices=["anthropic", "openai"], default="anthropic")
parser.add_argument("--dry-run", action="store_true", help="Print prompts without calling API")
parser.add_argument("--no-fix", action="store_true", help="Skip automatic offset fixing")
parser.add_argument("--validate", type=Path, help="Validate an existing JSONL file")
args = parser.parse_args()
if args.validate:
validate_file(args.validate)
return
if not args.prompt_id and not args.all:
parser.error("Specify --prompt-id or --all")
if not args.dry_run and not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("OPENAI_API_KEY"):
sys.exit("Set ANTHROPIC_API_KEY or OPENAI_API_KEY")
if args.all:
args.prompt_id = None
generate_batch(
prompt_id=args.prompt_id,
n=args.n,
output_path=args.output,
backend=args.backend,
dry_run=args.dry_run,
fix=not args.no_fix,
)
if __name__ == "__main__":
main()