File size: 10,866 Bytes
ed1b365 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | """
Dataset Generator for Codette LoRA Training
=============================================
Main orchestrator that combines TemplateRegistry and AnswerGenerator
to produce chat-format JSONL files for fine-tuning Llama 3.1 8B
with LoRA adapters.
Features:
- Deduplication: tracks all generated prompts to prevent duplicates
- Reproducible: seed-based RNG for deterministic output
- CLI interface: generate for one adapter or all adapters
- Progress reporting: logs generation progress
- Validation: checks output format before writing
Usage:
python -m dataset_engine.dataset_generator --adapter newton --count 3000
python -m dataset_engine.dataset_generator --all
python -m dataset_engine.dataset_generator --adapter philosophy --count 2000 --seed 42
"""
import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Optional, Set
from dataset_engine.template_registry import TemplateRegistry
from dataset_engine.answer_generator import AnswerGenerator
logger = logging.getLogger("dataset_generator")
class DatasetGenerator:
"""Generates JSONL training datasets for Codette LoRA adapters."""
def __init__(self, output_dir: str = "datasets", seed: Optional[int] = None):
"""Initialize the generator.
Args:
output_dir: Directory for output JSONL files.
seed: Random seed for reproducibility. None for non-deterministic.
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.seed = seed
self.registry = TemplateRegistry(seed=seed)
self.answer_gen = AnswerGenerator(seed=seed)
self._seen_questions: Set[str] = set()
self._stats = {
"total_generated": 0,
"duplicates_skipped": 0,
"counterexamples": 0,
}
def reset_dedup(self):
"""Clear the deduplication set (use between adapters)."""
self._seen_questions.clear()
def reset_stats(self):
"""Reset generation statistics."""
self._stats = {
"total_generated": 0,
"duplicates_skipped": 0,
"counterexamples": 0,
}
def generate_adapter(self, adapter: str,
count: Optional[int] = None) -> str:
"""Generate a JSONL dataset for a single adapter.
Args:
adapter: Adapter name (e.g. 'newton', 'philosophy').
count: Number of examples to generate. Defaults to the
adapter's target size from the registry.
Returns:
Path to the generated JSONL file.
"""
if adapter not in self.registry.get_adapter_names():
raise ValueError(
f"Unknown adapter '{adapter}'. "
f"Available: {self.registry.get_adapter_names()}"
)
target = count or self.registry.get_target(adapter)
output_path = self.output_dir / f"{adapter}_reasoning.jsonl"
self.reset_dedup()
self.reset_stats()
logger.info(
"Generating %d examples for adapter '%s' -> %s",
target, adapter, output_path,
)
start_time = time.time()
examples = []
max_attempts = target * 5 # Safety valve against infinite loops
attempts = 0
while len(examples) < target and attempts < max_attempts:
attempts += 1
question, topic, subtopic, qtype = self.registry.sample_question(adapter)
# Deduplicate
q_normalized = question.strip().lower()
if q_normalized in self._seen_questions:
self._stats["duplicates_skipped"] += 1
continue
self._seen_questions.add(q_normalized)
# Generate answer
answer = self.answer_gen.generate(
adapter=adapter,
topic=topic,
subtopic=subtopic,
question=question,
question_type=qtype,
)
# Validate answer quality
if not self._validate_answer(answer):
continue
# Build chat-format message
message = {
"messages": [
{
"role": "system",
"content": self.registry.SYSTEM_PROMPT,
},
{
"role": "user",
"content": question,
},
{
"role": "assistant",
"content": answer,
},
]
}
examples.append(message)
if qtype == "counterexample":
self._stats["counterexamples"] += 1
self._stats["total_generated"] = len(examples)
# Progress reporting
if len(examples) > 0 and len(examples) % 500 == 0:
elapsed = time.time() - start_time
rate = len(examples) / elapsed if elapsed > 0 else 0
logger.info(
" [%s] %d / %d examples (%.1f/sec, %d duplicates skipped)",
adapter, len(examples), target, rate,
self._stats["duplicates_skipped"],
)
# Write output
with open(output_path, "w", encoding="utf-8") as f:
for example in examples:
f.write(json.dumps(example, ensure_ascii=False) + "\n")
elapsed = time.time() - start_time
counter_pct = (
(self._stats["counterexamples"] / len(examples) * 100)
if examples else 0
)
logger.info(
"Completed '%s': %d examples in %.1fs "
"(%.1f%% counterexamples, %d duplicates skipped)",
adapter, len(examples), elapsed, counter_pct,
self._stats["duplicates_skipped"],
)
if len(examples) < target:
logger.warning(
"Only generated %d / %d examples for '%s'. "
"Consider expanding template pools.",
len(examples), target, adapter,
)
return str(output_path)
def generate_all(self) -> dict:
"""Generate datasets for all adapters.
Returns:
Dict mapping adapter names to output file paths.
"""
results = {}
total_start = time.time()
for adapter in self.registry.get_adapter_names():
try:
path = self.generate_adapter(adapter)
results[adapter] = path
except Exception as e:
logger.error("Failed to generate '%s': %s", adapter, e)
results[adapter] = f"ERROR: {e}"
total_elapsed = time.time() - total_start
total_examples = sum(
self._count_lines(p) for p in results.values()
if not p.startswith("ERROR")
)
logger.info(
"All adapters complete: %d total examples in %.1fs",
total_examples, total_elapsed,
)
return results
@staticmethod
def _validate_answer(answer: str) -> bool:
"""Check that an answer meets minimum quality standards."""
if not answer or not answer.strip():
return False
words = answer.split()
if len(words) < 40:
return False
# Reject answers that are just the topic name repeated
unique_words = set(w.lower() for w in words)
if len(unique_words) < 20:
return False
return True
@staticmethod
def _count_lines(filepath: str) -> int:
"""Count lines in a file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return sum(1 for _ in f)
except (OSError, IOError):
return 0
def main():
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="Generate JSONL training datasets for Codette LoRA adapters.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=(
"Examples:\n"
" python -m dataset_engine.dataset_generator --adapter newton --count 3000\n"
" python -m dataset_engine.dataset_generator --all\n"
" python -m dataset_engine.dataset_generator --all --seed 42\n"
" python -m dataset_engine.dataset_generator --adapter philosophy --output-dir ./my_datasets\n"
),
)
parser.add_argument(
"--adapter",
type=str,
help="Adapter name to generate for (e.g. newton, philosophy).",
)
parser.add_argument(
"--all",
action="store_true",
help="Generate datasets for ALL adapters with their target sizes.",
)
parser.add_argument(
"--count",
type=int,
default=None,
help="Number of examples to generate (overrides default target).",
)
parser.add_argument(
"--output-dir",
type=str,
default="datasets",
help="Output directory for JSONL files (default: datasets).",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducible generation.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging.",
)
args = parser.parse_args()
# Configure logging
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(
level=log_level,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
if not args.adapter and not args.all:
parser.error("Specify --adapter NAME or --all")
generator = DatasetGenerator(
output_dir=args.output_dir,
seed=args.seed,
)
if args.all:
results = generator.generate_all()
print("\n--- Generation Summary ---")
for adapter, path in results.items():
if path.startswith("ERROR"):
print(f" {adapter}: {path}")
else:
count = generator._count_lines(path)
print(f" {adapter}: {count} examples -> {path}")
else:
path = generator.generate_adapter(args.adapter, args.count)
count = generator._count_lines(path)
print(f"\nGenerated {count} examples -> {path}")
if __name__ == "__main__":
main()
|