File size: 11,830 Bytes
038e086 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | #!/usr/bin/env python3
"""
Generate synthetic cybersecurity NER training data from prompt templates.
Usage:
python generate_synthetic_batch.py --prompt-id url_heavy_malware_infra --n 20 --output out.jsonl
python generate_synthetic_batch.py --all --output out.jsonl
python generate_synthetic_batch.py --validate data/processed/llm_generated_synthetic.jsonl
Requires ANTHROPIC_API_KEY or OPENAI_API_KEY env var for LLM generation.
Use --dry-run to print prompts without calling the API.
"""
import argparse
import json
import os
import re
import sys
from pathlib import Path
PROMPTS_FILE = Path(__file__).parent / "synthetic_prompts.json"
ENTITY_TYPES = [
"MALWARE", "THREAT_ACTOR", "TOOL", "VULNERABILITY", "SYSTEM",
"ORGANIZATION", "IP_ADDRESS", "DOMAIN", "URL", "HASH",
"EMAIL", "CVE_ID", "FILEPATH",
]
def load_prompts(path: Path = PROMPTS_FILE) -> list[dict]:
with open(path) as f:
return json.load(f)["prompts"]
def verify_offsets(record: dict) -> list[str]:
"""Verify all span offsets match the text. Returns list of error strings."""
errors = []
text = record.get("text", "")
spans = record.get("spans", {})
for key, offset_list in spans.items():
# Parse "TYPE: value" from key
if ": " not in key:
errors.append(f"Bad span key format: {key!r}")
continue
etype, expected_value = key.split(": ", 1)
if etype not in ENTITY_TYPES:
errors.append(f"Unknown entity type: {etype!r}")
for start, end in offset_list:
if start < 0 or end > len(text) or start >= end:
errors.append(f"Invalid offset [{start},{end}) for text len {len(text)}: {key}")
continue
actual = text[start:end]
if actual != expected_value:
errors.append(
f"Offset mismatch for {key}: "
f"text[{start}:{end}]={actual!r} != {expected_value!r}"
)
return errors
def try_fix_offsets(record: dict) -> dict:
"""Attempt to fix span offsets by searching for the entity value in text."""
text = record["text"]
fixed_spans = {}
for key, offset_list in record.get("spans", {}).items():
if ": " not in key:
continue
etype, expected_value = key.split(": ", 1)
new_offsets = []
for start, end in offset_list:
actual = text[start:end] if 0 <= start < end <= len(text) else ""
if actual == expected_value:
new_offsets.append([start, end])
else:
# Try to find the value in text
idx = text.find(expected_value)
if idx >= 0:
new_offsets.append([idx, idx + len(expected_value)])
# Look for additional occurrences if there were multiple
if len(offset_list) > 1:
search_from = idx + len(expected_value)
while True:
idx2 = text.find(expected_value, search_from)
if idx2 < 0:
break
new_offsets.append([idx2, idx2 + len(expected_value)])
search_from = idx2 + len(expected_value)
break # We handled all occurrences
else:
new_offsets.append([start, end]) # Keep broken, will be caught by validate
if new_offsets:
fixed_spans[key] = new_offsets
record["spans"] = fixed_spans
return record
def parse_llm_response(response_text: str) -> list[dict]:
"""Parse LLM response into list of records. Handles JSONL and JSON arrays."""
records = []
# Try line-by-line JSONL first
for line in response_text.strip().split("\n"):
line = line.strip()
if not line or line.startswith("```"):
continue
try:
obj = json.loads(line)
if isinstance(obj, dict) and "text" in obj:
records.append(obj)
elif isinstance(obj, list):
records.extend(r for r in obj if isinstance(r, dict) and "text" in r)
except json.JSONDecodeError:
continue
# If nothing parsed line-by-line, try the whole thing as JSON array
if not records:
try:
# Strip markdown code fences
cleaned = re.sub(r"```(?:json)?\n?", "", response_text).strip()
obj = json.loads(cleaned)
if isinstance(obj, list):
records = [r for r in obj if isinstance(r, dict) and "text" in r]
except json.JSONDecodeError:
pass
return records
def generate_with_anthropic(prompt: str, n: int) -> str:
"""Call Anthropic API."""
try:
import anthropic
except ImportError:
sys.exit("pip install anthropic")
client = anthropic.Anthropic()
msg = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=8192,
messages=[{"role": "user", "content": prompt.replace("{n}", str(n))}],
system="You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary.",
)
return msg.content[0].text
def generate_with_openai(prompt: str, n: int) -> str:
"""Call OpenAI API."""
try:
import openai
except ImportError:
sys.exit("pip install openai")
client = openai.OpenAI()
resp = client.chat.completions.create(
model="gpt-4o",
max_tokens=8192,
messages=[
{"role": "system", "content": "You are a cybersecurity data generation assistant. Output ONLY valid JSONL — one JSON object per line, no markdown fences, no commentary."},
{"role": "user", "content": prompt.replace("{n}", str(n))},
],
)
return resp.choices[0].message.content
def generate_batch(
prompt_id: str | None,
n: int,
output_path: Path,
backend: str = "anthropic",
dry_run: bool = False,
fix: bool = True,
):
prompts = load_prompts()
if prompt_id:
prompts = [p for p in prompts if p["id"] == prompt_id]
if not prompts:
sys.exit(f"Unknown prompt_id: {prompt_id}")
generate_fn = generate_with_anthropic if backend == "anthropic" else generate_with_openai
all_records = []
total_errors = 0
for pdef in prompts:
count = n if n else pdef.get("total_target", 20)
prompt_text = pdef["prompt"]
print(f"\n{'='*60}")
print(f"Prompt: {pdef['id']} | Target entities: {pdef['target_entities']} | N={count}")
print(f"{'='*60}")
if dry_run:
print(prompt_text.replace("{n}", str(count)))
continue
# Generate in batches of 20
batch_size = min(20, count)
generated = 0
while generated < count:
this_batch = min(batch_size, count - generated)
print(f" Generating batch of {this_batch}...")
try:
raw = generate_fn(prompt_text, this_batch)
records = parse_llm_response(raw)
except Exception as e:
print(f" ERROR: {e}")
continue
for rec in records:
# Fix offsets if requested
if fix:
rec = try_fix_offsets(rec)
# Validate
errs = verify_offsets(rec)
if errs:
total_errors += len(errs)
for err in errs:
print(f" WARN: {err}")
if fix:
rec = try_fix_offsets(rec)
errs2 = verify_offsets(rec)
if errs2:
print(f" SKIP (unfixable): {rec.get('info', {}).get('id', '?')}")
continue
all_records.append(rec)
generated += this_batch
print(f" Got {len(records)} records (total so far: {len(all_records)})")
if dry_run:
return
# Assign sequential IDs
for i, rec in enumerate(all_records, 1):
if "info" not in rec:
rec["info"] = {}
rec["info"]["id"] = f"synth_batch_{i:05d}"
# Write output
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "a") as f:
for rec in all_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
# Summary stats
entity_counts: dict[str, int] = {}
for rec in all_records:
for key in rec.get("spans", {}):
etype = key.split(": ", 1)[0] if ": " in key else key
entity_counts[etype] = entity_counts.get(etype, 0) + len(rec["spans"][key])
print(f"\n{'='*60}")
print(f"SUMMARY: {len(all_records)} records written to {output_path}")
print(f"Offset errors encountered: {total_errors}")
print(f"Entity distribution:")
for etype in sorted(entity_counts, key=entity_counts.get, reverse=True):
print(f" {etype}: {entity_counts[etype]}")
def validate_file(path: Path):
"""Validate all records in an existing JSONL file."""
total = 0
bad = 0
entity_counts: dict[str, int] = {}
with open(path) as f:
for i, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except json.JSONDecodeError:
print(f"Line {i}: invalid JSON")
bad += 1
continue
total += 1
errs = verify_offsets(rec)
if errs:
bad += 1
for err in errs:
print(f"Line {i}: {err}")
for key, offsets in rec.get("spans", {}).items():
etype = key.split(": ", 1)[0] if ": " in key else key
entity_counts[etype] = entity_counts.get(etype, 0) + len(offsets)
print(f"\nValidated {total} records, {bad} with errors")
print("Entity distribution:")
for etype in sorted(entity_counts, key=entity_counts.get, reverse=True):
print(f" {etype}: {entity_counts[etype]}")
def main():
parser = argparse.ArgumentParser(description="Generate synthetic cybersecurity NER data")
parser.add_argument("--prompt-id", help="Run a specific prompt template")
parser.add_argument("--all", action="store_true", help="Run all prompt templates")
parser.add_argument("--n", type=int, default=0, help="Examples per prompt (0=use template default)")
parser.add_argument("--output", type=Path, default=Path("data/processed/llm_generated_synthetic_v2.jsonl"))
parser.add_argument("--backend", choices=["anthropic", "openai"], default="anthropic")
parser.add_argument("--dry-run", action="store_true", help="Print prompts without calling API")
parser.add_argument("--no-fix", action="store_true", help="Skip automatic offset fixing")
parser.add_argument("--validate", type=Path, help="Validate an existing JSONL file")
args = parser.parse_args()
if args.validate:
validate_file(args.validate)
return
if not args.prompt_id and not args.all:
parser.error("Specify --prompt-id or --all")
if not args.dry_run and not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("OPENAI_API_KEY"):
sys.exit("Set ANTHROPIC_API_KEY or OPENAI_API_KEY")
if args.all:
args.prompt_id = None
generate_batch(
prompt_id=args.prompt_id,
n=args.n,
output_path=args.output,
backend=args.backend,
dry_run=args.dry_run,
fix=not args.no_fix,
)
if __name__ == "__main__":
main()
|