|
|
| """
|
| Codette Full Training Pipeline
|
| =================================
|
|
|
| End-to-end pipeline orchestration for the Codette training lab.
|
| Runs dataset generation, validation, reasoning forge enhancement,
|
| adapter training, evaluation benchmarks, and observatory logging.
|
|
|
| Each stage can be run independently or as part of the full pipeline.
|
|
|
| Usage:
|
| # Run everything
|
| python scripts/run_full_pipeline.py --all
|
|
|
| # Run specific stages
|
| python scripts/run_full_pipeline.py --generate --validate
|
| python scripts/run_full_pipeline.py --forge --train
|
| python scripts/run_full_pipeline.py --evaluate
|
|
|
| # Select specific adapters
|
| python scripts/run_full_pipeline.py --all --adapters newton davinci quantum
|
| """
|
|
|
| import argparse
|
| import json
|
| import logging
|
| import os
|
| import sys
|
| import time
|
| from datetime import datetime
|
| from pathlib import Path
|
|
|
|
|
|
|
|
|
| _project_root = str(Path(__file__).resolve().parent.parent)
|
| if _project_root not in sys.path:
|
| sys.path.insert(0, _project_root)
|
|
|
| import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
| def setup_pipeline_logging() -> logging.Logger:
|
| """Configure the pipeline logger with file and console handlers.
|
|
|
| Returns:
|
| Configured logger instance.
|
| """
|
| log_dir = Path("logs")
|
| log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| log_file = log_dir / f"pipeline_{timestamp}.log"
|
|
|
| logger = logging.getLogger("codette.pipeline")
|
| logger.setLevel(logging.DEBUG)
|
| logger.handlers.clear()
|
|
|
| fh = logging.FileHandler(str(log_file), encoding="utf-8")
|
| fh.setLevel(logging.DEBUG)
|
| fh.setFormatter(logging.Formatter(
|
| "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
| datefmt="%Y-%m-%d %H:%M:%S",
|
| ))
|
| logger.addHandler(fh)
|
|
|
| ch = logging.StreamHandler(sys.stdout)
|
| ch.setLevel(logging.INFO)
|
| ch.setFormatter(logging.Formatter(
|
| "%(asctime)s | %(levelname)-8s | %(message)s",
|
| datefmt="%H:%M:%S",
|
| ))
|
| logger.addHandler(ch)
|
|
|
| return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
| def load_pipeline_config(config_path: str = "configs/pipeline_config.yaml") -> dict:
|
| """Load the pipeline configuration from YAML.
|
|
|
| Args:
|
| config_path: Path to the pipeline config file.
|
|
|
| Returns:
|
| Parsed configuration dictionary.
|
| """
|
| path = Path(config_path)
|
| if not path.exists():
|
| raise FileNotFoundError(f"Pipeline config not found: {config_path}")
|
|
|
| with open(path, "r", encoding="utf-8") as f:
|
| return yaml.safe_load(f)
|
|
|
|
|
| def load_adapter_registry(config_path: str = "configs/adapter_registry.yaml") -> dict:
|
| """Load the adapter registry from YAML.
|
|
|
| Args:
|
| config_path: Path to the adapter registry file.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to configurations.
|
| """
|
| path = Path(config_path)
|
| if not path.exists():
|
| raise FileNotFoundError(f"Adapter registry not found: {config_path}")
|
|
|
| with open(path, "r", encoding="utf-8") as f:
|
| config = yaml.safe_load(f)
|
|
|
| return config.get("adapters", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ObservatoryLogger:
|
| """Centralized metrics logger for the Codette observatory.
|
|
|
| Accumulates metrics from all pipeline stages and writes them
|
| to a JSON file for dashboard consumption.
|
| """
|
|
|
| def __init__(self, output_path: str = "observatory_metrics.json"):
|
| self.output_path = Path(output_path)
|
| self.metrics: list[dict] = []
|
| self.pipeline_start = datetime.now()
|
|
|
|
|
| if self.output_path.exists():
|
| try:
|
| with open(self.output_path, "r", encoding="utf-8") as f:
|
| existing = json.load(f)
|
| if isinstance(existing, list):
|
| self.metrics = existing
|
| except (json.JSONDecodeError, IOError):
|
| self.metrics = []
|
|
|
| def log(self, stage: str, adapter: str | None, data: dict) -> None:
|
| """Log a metrics entry.
|
|
|
| Args:
|
| stage: Pipeline stage name.
|
| adapter: Adapter name (or None for global metrics).
|
| data: Dictionary of metric values.
|
| """
|
| entry = {
|
| "stage": stage,
|
| "adapter": adapter,
|
| "timestamp": datetime.now().isoformat(),
|
| "pipeline_run": self.pipeline_start.isoformat(),
|
| **data,
|
| }
|
| self.metrics.append(entry)
|
|
|
| def save(self) -> None:
|
| """Write all metrics to disk."""
|
| with open(self.output_path, "w", encoding="utf-8") as f:
|
| json.dump(self.metrics, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def stage_generate(
|
| registry: dict,
|
| pipeline_config: dict,
|
| adapter_names: list[str],
|
| observatory: ObservatoryLogger,
|
| logger: logging.Logger,
|
| ) -> dict[str, dict]:
|
| """Generate training datasets for selected adapters.
|
|
|
| Uses the dataset_engine module to produce JSONL files
|
| with chat-format training examples.
|
|
|
| Args:
|
| registry: Adapter registry configuration.
|
| pipeline_config: Pipeline configuration.
|
| adapter_names: List of adapter names to generate for.
|
| observatory: Metrics logger.
|
| logger: Logger instance.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to generation results.
|
| """
|
| logger.info("=" * 60)
|
| logger.info("STAGE 1: Dataset Generation")
|
| logger.info("=" * 60)
|
|
|
| gen_config = pipeline_config.get("generation", {})
|
| output_dir = pipeline_config.get("pipeline", {}).get(
|
| "dataset_output_dir", "./datasets"
|
| )
|
| Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
| results = {}
|
|
|
| try:
|
| from dataset_engine import DatasetGenerator
|
| except ImportError:
|
| logger.warning(
|
| "dataset_engine module not available. "
|
| "Checking for existing dataset files instead."
|
| )
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
| exists = Path(dataset_path).exists()
|
| count = 0
|
| if exists:
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| count = sum(1 for line in f if line.strip())
|
| results[name] = {
|
| "status": "exists" if exists else "missing",
|
| "examples": count,
|
| "path": dataset_path,
|
| }
|
| observatory.log("generate", name, results[name])
|
| if exists:
|
| logger.info(f" {name}: found {count} existing examples")
|
| else:
|
| logger.warning(f" {name}: dataset missing at {dataset_path}")
|
| return results
|
|
|
| seed = pipeline_config.get("pipeline", {}).get("seed", 42)
|
| generator = DatasetGenerator(output_dir=output_dir, seed=seed)
|
|
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
| target_examples = adapter_cfg.get("target_examples", 2000)
|
|
|
| logger.info(f"Generating dataset for: {name}")
|
| logger.info(f" Target: {target_examples} examples")
|
| logger.info(f" Output: {dataset_path}")
|
|
|
| start_time = time.time()
|
| try:
|
| generated_path = generator.generate_adapter(
|
| adapter=name,
|
| count=target_examples,
|
| )
|
|
|
| count = 0
|
| with open(generated_path, "r", encoding="utf-8") as f:
|
| count = sum(1 for line in f if line.strip())
|
| elapsed = time.time() - start_time
|
|
|
| results[name] = {
|
| "status": "generated",
|
| "examples": count,
|
| "path": generated_path,
|
| "time_seconds": elapsed,
|
| }
|
| logger.info(
|
| f" Generated {count} examples in {elapsed:.1f}s"
|
| )
|
|
|
| except Exception as e:
|
| elapsed = time.time() - start_time
|
| results[name] = {
|
| "status": "error",
|
| "error": str(e),
|
| "time_seconds": elapsed,
|
| }
|
| logger.error(f" Generation failed for {name}: {e}")
|
|
|
| observatory.log("generate", name, results[name])
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def stage_validate(
|
| registry: dict,
|
| pipeline_config: dict,
|
| adapter_names: list[str],
|
| observatory: ObservatoryLogger,
|
| logger: logging.Logger,
|
| ) -> dict[str, dict]:
|
| """Validate generated datasets for quality and correctness.
|
|
|
| Checks for proper JSON structure, required message roles,
|
| minimum token counts, and duplicate detection.
|
|
|
| Args:
|
| registry: Adapter registry configuration.
|
| pipeline_config: Pipeline configuration.
|
| adapter_names: List of adapter names to validate.
|
| observatory: Metrics logger.
|
| logger: Logger instance.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to validation results.
|
| """
|
| logger.info("=" * 60)
|
| logger.info("STAGE 2: Dataset Validation")
|
| logger.info("=" * 60)
|
|
|
| val_config = pipeline_config.get("validation", {})
|
| min_tokens = val_config.get("min_tokens", 40)
|
| max_dup_sim = val_config.get("max_duplicate_similarity", 0.85)
|
| required_roles = set(val_config.get("required_roles", ["system", "user", "assistant"]))
|
|
|
| results = {}
|
|
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
|
|
| logger.info(f"Validating: {name} ({dataset_path})")
|
|
|
| if not Path(dataset_path).exists():
|
| results[name] = {
|
| "status": "missing",
|
| "error": f"Dataset file not found: {dataset_path}",
|
| }
|
| observatory.log("validate", name, results[name])
|
| logger.warning(f" SKIP: dataset file not found")
|
| continue
|
|
|
| total = 0
|
| valid = 0
|
| errors = {
|
| "json_parse": 0,
|
| "missing_messages": 0,
|
| "missing_roles": 0,
|
| "too_short": 0,
|
| }
|
|
|
| try:
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| for line_num, line in enumerate(f, 1):
|
| line = line.strip()
|
| if not line:
|
| continue
|
| total += 1
|
|
|
|
|
| try:
|
| record = json.loads(line)
|
| except json.JSONDecodeError:
|
| errors["json_parse"] += 1
|
| continue
|
|
|
|
|
| messages = record.get("messages")
|
| if not isinstance(messages, list) or len(messages) < 2:
|
| errors["missing_messages"] += 1
|
| continue
|
|
|
|
|
| found_roles = {m.get("role") for m in messages if isinstance(m, dict)}
|
| if not required_roles.issubset(found_roles):
|
| errors["missing_roles"] += 1
|
| continue
|
|
|
|
|
| total_words = sum(
|
| len(m.get("content", "").split())
|
| for m in messages
|
| if isinstance(m, dict)
|
| )
|
| if total_words < min_tokens:
|
| errors["too_short"] += 1
|
| continue
|
|
|
| valid += 1
|
|
|
| error_count = sum(errors.values())
|
| pass_rate = (valid / total * 100) if total > 0 else 0
|
|
|
| results[name] = {
|
| "status": "valid" if pass_rate > 90 else "warning",
|
| "total_records": total,
|
| "valid_records": valid,
|
| "error_records": error_count,
|
| "pass_rate": round(pass_rate, 2),
|
| "errors": errors,
|
| }
|
|
|
| level = logging.INFO if pass_rate > 90 else logging.WARNING
|
| logger.log(
|
| level,
|
| f" {name}: {valid}/{total} valid "
|
| f"({pass_rate:.1f}% pass rate)",
|
| )
|
| if error_count > 0:
|
| for error_type, count in errors.items():
|
| if count > 0:
|
| logger.log(level, f" {error_type}: {count}")
|
|
|
| except Exception as e:
|
| results[name] = {
|
| "status": "error",
|
| "error": str(e),
|
| }
|
| logger.error(f" Validation failed for {name}: {e}")
|
|
|
| observatory.log("validate", name, results[name])
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def stage_forge(
|
| registry: dict,
|
| pipeline_config: dict,
|
| adapter_names: list[str],
|
| observatory: ObservatoryLogger,
|
| logger: logging.Logger,
|
| ) -> dict[str, dict]:
|
| """Run the reasoning forge to enhance datasets with multi-agent reasoning.
|
|
|
| Each dataset is processed through the forge's multi-agent pipeline,
|
| which adds analytical depth from multiple perspectives.
|
|
|
| Args:
|
| registry: Adapter registry configuration.
|
| pipeline_config: Pipeline configuration.
|
| adapter_names: List of adapter names to process.
|
| observatory: Metrics logger.
|
| logger: Logger instance.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to forge results.
|
| """
|
| logger.info("=" * 60)
|
| logger.info("STAGE 3: Reasoning Forge")
|
| logger.info("=" * 60)
|
|
|
| results = {}
|
|
|
| try:
|
| from reasoning_forge import ForgeEngine
|
| except ImportError:
|
| logger.warning(
|
| "reasoning_forge module not available. Skipping forge stage."
|
| )
|
| for name in adapter_names:
|
| results[name] = {"status": "skipped", "reason": "module_not_available"}
|
| observatory.log("forge", name, results[name])
|
| return results
|
|
|
| try:
|
| forge = ForgeEngine()
|
| except Exception as e:
|
| logger.error(f"Failed to initialize forge engine: {e}")
|
| for name in adapter_names:
|
| results[name] = {"status": "error", "error": str(e)}
|
| observatory.log("forge", name, results[name])
|
| return results
|
|
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
|
|
| if not Path(dataset_path).exists():
|
| results[name] = {"status": "skipped", "reason": "dataset_missing"}
|
| observatory.log("forge", name, results[name])
|
| logger.warning(f" SKIP {name}: dataset not found")
|
| continue
|
|
|
| logger.info(f"Forging: {name}")
|
| start_time = time.time()
|
|
|
| try:
|
|
|
| examples = []
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| examples.append(json.loads(line))
|
|
|
| enhanced_count = 0
|
| enhanced_examples = []
|
|
|
| for i, example in enumerate(examples):
|
| messages = example.get("messages", [])
|
|
|
| user_msg = next(
|
| (m["content"] for m in messages if m.get("role") == "user"),
|
| None,
|
| )
|
| if not user_msg:
|
| enhanced_examples.append(example)
|
| continue
|
|
|
| try:
|
| forge_result = forge.forge_single(user_msg)
|
| synthesis = None
|
| if forge_result:
|
|
|
|
|
| for m in forge_result.get("messages", []):
|
| if m.get("role") == "assistant":
|
| synthesis = m.get("content")
|
| break
|
| if synthesis:
|
|
|
| for msg in messages:
|
| if msg.get("role") == "assistant":
|
| original = msg["content"]
|
| msg["content"] = (
|
| f"{original}\n\n"
|
| f"[Multi-perspective synthesis]: {synthesis}"
|
| )
|
| enhanced_count += 1
|
| break
|
| except Exception:
|
| pass
|
|
|
| enhanced_examples.append(example)
|
|
|
|
|
| with open(dataset_path, "w", encoding="utf-8") as f:
|
| for ex in enhanced_examples:
|
| f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
|
|
| elapsed = time.time() - start_time
|
| results[name] = {
|
| "status": "success",
|
| "total_examples": len(examples),
|
| "enhanced_examples": enhanced_count,
|
| "time_seconds": elapsed,
|
| }
|
| logger.info(
|
| f" {name}: enhanced {enhanced_count}/{len(examples)} "
|
| f"examples in {elapsed:.1f}s"
|
| )
|
|
|
| except Exception as e:
|
| elapsed = time.time() - start_time
|
| results[name] = {
|
| "status": "error",
|
| "error": str(e),
|
| "time_seconds": elapsed,
|
| }
|
| logger.error(f" Forge failed for {name}: {e}")
|
|
|
| observatory.log("forge", name, results[name])
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def stage_train(
|
| registry: dict,
|
| pipeline_config: dict,
|
| adapter_names: list[str],
|
| observatory: ObservatoryLogger,
|
| logger: logging.Logger,
|
| ) -> dict[str, dict]:
|
| """Train LoRA adapters for selected perspectives.
|
|
|
| Delegates to training.train_all_adapters for the actual
|
| training loop.
|
|
|
| Args:
|
| registry: Adapter registry configuration.
|
| pipeline_config: Pipeline configuration.
|
| adapter_names: List of adapter names to train.
|
| observatory: Metrics logger.
|
| logger: Logger instance.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to training results.
|
| """
|
| logger.info("=" * 60)
|
| logger.info("STAGE 4: Adapter Training")
|
| logger.info("=" * 60)
|
|
|
| results = {}
|
|
|
| try:
|
| from training.train_all_adapters import (
|
| load_training_config,
|
| train_single_adapter,
|
| )
|
| except ImportError:
|
| logger.error("training module not available")
|
| for name in adapter_names:
|
| results[name] = {"status": "error", "error": "module_not_available"}
|
| observatory.log("train", name, results[name])
|
| return results
|
|
|
| training_defaults = load_training_config()
|
| output_dir = pipeline_config.get("pipeline", {}).get(
|
| "adapter_output_dir", "./adapters"
|
| )
|
|
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
|
|
| if not Path(dataset_path).exists():
|
| results[name] = {"status": "skipped", "reason": "dataset_missing"}
|
| observatory.log("train", name, results[name])
|
| logger.warning(f" SKIP {name}: dataset not found at {dataset_path}")
|
| continue
|
|
|
| logger.info(f"Training adapter: {name}")
|
| metrics = train_single_adapter(
|
| adapter_name=name,
|
| adapter_config=adapter_cfg,
|
| training_defaults=training_defaults,
|
| output_base_dir=output_dir,
|
| logger=logger,
|
| )
|
| results[name] = metrics
|
| observatory.log("train", name, metrics)
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def stage_evaluate(
|
| registry: dict,
|
| pipeline_config: dict,
|
| adapter_names: list[str],
|
| observatory: ObservatoryLogger,
|
| logger: logging.Logger,
|
| ) -> dict[str, dict]:
|
| """Run evaluation benchmarks on trained adapters.
|
|
|
| Uses the evaluation module to run reasoning tests and
|
| compute quality metrics.
|
|
|
| Args:
|
| registry: Adapter registry configuration.
|
| pipeline_config: Pipeline configuration.
|
| adapter_names: List of adapter names to evaluate.
|
| observatory: Metrics logger.
|
| logger: Logger instance.
|
|
|
| Returns:
|
| Dictionary mapping adapter names to evaluation results.
|
| """
|
| logger.info("=" * 60)
|
| logger.info("STAGE 5: Evaluation")
|
| logger.info("=" * 60)
|
|
|
| eval_config = pipeline_config.get("evaluation", {})
|
| results = {}
|
|
|
| try:
|
| from evaluation import ReasoningMetrics
|
| except ImportError:
|
| logger.warning(
|
| "evaluation module not fully available. "
|
| "Running basic dataset statistics instead."
|
| )
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
|
|
| if not Path(dataset_path).exists():
|
| results[name] = {"status": "skipped", "reason": "dataset_missing"}
|
| observatory.log("evaluate", name, results[name])
|
| continue
|
|
|
|
|
| total = 0
|
| total_words = 0
|
| total_turns = 0
|
|
|
| try:
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| record = json.loads(line)
|
| messages = record.get("messages", [])
|
| total += 1
|
| total_turns += len(messages)
|
| for msg in messages:
|
| if isinstance(msg, dict):
|
| total_words += len(
|
| msg.get("content", "").split()
|
| )
|
|
|
| avg_words = total_words / total if total > 0 else 0
|
| avg_turns = total_turns / total if total > 0 else 0
|
|
|
| results[name] = {
|
| "status": "basic_stats",
|
| "total_examples": total,
|
| "avg_words_per_example": round(avg_words, 1),
|
| "avg_turns_per_example": round(avg_turns, 1),
|
| "total_words": total_words,
|
| }
|
| logger.info(
|
| f" {name}: {total} examples, "
|
| f"avg {avg_words:.0f} words, "
|
| f"avg {avg_turns:.1f} turns"
|
| )
|
|
|
| except Exception as e:
|
| results[name] = {"status": "error", "error": str(e)}
|
| logger.error(f" Evaluation failed for {name}: {e}")
|
|
|
| observatory.log("evaluate", name, results[name])
|
|
|
| return results
|
|
|
|
|
|
|
| metrics = ReasoningMetrics()
|
|
|
| for name in adapter_names:
|
| adapter_cfg = registry.get(name, {})
|
| dataset_path = adapter_cfg.get("dataset", "")
|
|
|
| if not Path(dataset_path).exists():
|
| results[name] = {"status": "skipped", "reason": "dataset_missing"}
|
| observatory.log("evaluate", name, results[name])
|
| logger.warning(f" SKIP {name}: dataset not found")
|
| continue
|
|
|
| logger.info(f"Evaluating adapter: {name}")
|
| start_time = time.time()
|
|
|
| try:
|
|
|
| responses: list[str] = []
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| record = json.loads(line)
|
| for msg in record.get("messages", []):
|
| if msg.get("role") == "assistant":
|
| responses.append(msg["content"])
|
|
|
|
|
| batch_scores = metrics.score_batch(responses)
|
|
|
|
|
| if batch_scores:
|
| dim_keys = [k for k in batch_scores[0] if isinstance(batch_scores[0][k], (int, float))]
|
| avg_scores = {
|
| k: round(sum(s[k] for s in batch_scores) / len(batch_scores), 4)
|
| for k in dim_keys
|
| }
|
| else:
|
| avg_scores = {}
|
|
|
| elapsed = time.time() - start_time
|
| results[name] = {
|
| "status": "evaluated",
|
| "total_responses": len(responses),
|
| "scores": avg_scores,
|
| "time_seconds": elapsed,
|
| }
|
| logger.info(
|
| f" {name}: scored {len(responses)} responses, "
|
| f"overall={avg_scores.get('overall', 0):.3f} "
|
| f"in {elapsed:.1f}s"
|
| )
|
|
|
| except Exception as e:
|
| elapsed = time.time() - start_time
|
| results[name] = {
|
| "status": "error",
|
| "error": str(e),
|
| "time_seconds": elapsed,
|
| }
|
| logger.error(f" Evaluation failed for {name}: {e}")
|
|
|
| observatory.log("evaluate", name, results[name])
|
|
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def print_dashboard(
|
| all_results: dict[str, dict[str, dict]],
|
| total_time: float,
|
| logger: logging.Logger,
|
| ) -> None:
|
| """Print a comprehensive pipeline dashboard.
|
|
|
| Args:
|
| all_results: Nested dictionary of {stage: {adapter: results}}.
|
| total_time: Total pipeline execution time in seconds.
|
| logger: Logger instance.
|
| """
|
| logger.info("")
|
| logger.info("=" * 72)
|
| logger.info(" CODETTE TRAINING PIPELINE DASHBOARD")
|
| logger.info("=" * 72)
|
| logger.info(f" Total time: {total_time:.1f}s ({total_time / 60:.1f} min)")
|
| logger.info(f" Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| logger.info("")
|
|
|
|
|
| all_adapters = set()
|
| for stage_results in all_results.values():
|
| all_adapters.update(stage_results.keys())
|
| all_adapters = sorted(all_adapters)
|
|
|
| stages = ["generate", "validate", "forge", "train", "evaluate"]
|
|
|
|
|
| header = f"{'Adapter':<20}"
|
| for stage in stages:
|
| if stage in all_results:
|
| header += f" {stage[:8]:^10}"
|
| logger.info(header)
|
| logger.info("-" * 72)
|
|
|
|
|
| for adapter in all_adapters:
|
| row = f"{adapter:<20}"
|
| for stage in stages:
|
| if stage not in all_results:
|
| continue
|
| result = all_results.get(stage, {}).get(adapter, {})
|
| status = result.get("status", "---")
|
|
|
|
|
| if status in ("success", "generated", "valid", "evaluated", "exists"):
|
| symbol = "OK"
|
| elif status in ("warning", "basic_stats"):
|
| symbol = "WARN"
|
| elif status in ("skipped",):
|
| symbol = "SKIP"
|
| elif status in ("error", "missing"):
|
| symbol = "FAIL"
|
| else:
|
| symbol = status[:8]
|
|
|
| row += f" {symbol:^10}"
|
|
|
| logger.info(row)
|
|
|
| logger.info("-" * 72)
|
|
|
|
|
| logger.info("")
|
| for stage_name, stage_results in all_results.items():
|
| if not stage_results:
|
| continue
|
| ok = sum(
|
| 1 for r in stage_results.values()
|
| if r.get("status") in ("success", "generated", "valid", "evaluated", "exists", "basic_stats")
|
| )
|
| fail = sum(
|
| 1 for r in stage_results.values()
|
| if r.get("status") in ("error", "missing")
|
| )
|
| skip = sum(
|
| 1 for r in stage_results.values()
|
| if r.get("status") == "skipped"
|
| )
|
| logger.info(
|
| f" {stage_name:<12}: {ok} ok, {fail} failed, {skip} skipped"
|
| )
|
|
|
|
|
| train_results = all_results.get("train", {})
|
| if train_results:
|
| logger.info("")
|
| logger.info(" Training Details:")
|
| for name, metrics in train_results.items():
|
| if metrics.get("status") == "success":
|
| loss = metrics.get("final_loss", 0)
|
| steps = metrics.get("total_steps", 0)
|
| t = metrics.get("training_time_seconds", 0)
|
| logger.info(
|
| f" {name:<16}: loss={loss:.4f}, "
|
| f"steps={steps}, time={t:.1f}s"
|
| )
|
|
|
|
|
| val_results = all_results.get("validate", {})
|
| if val_results:
|
| logger.info("")
|
| logger.info(" Validation Details:")
|
| for name, metrics in val_results.items():
|
| if "pass_rate" in metrics:
|
| total = metrics.get("total_records", 0)
|
| valid = metrics.get("valid_records", 0)
|
| rate = metrics.get("pass_rate", 0)
|
| logger.info(
|
| f" {name:<16}: {valid}/{total} valid ({rate:.1f}%)"
|
| )
|
|
|
| logger.info("")
|
| logger.info("=" * 72)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| """Parse command-line arguments."""
|
| parser = argparse.ArgumentParser(
|
| description="Codette Full Training Pipeline",
|
| formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| )
|
|
|
|
|
| parser.add_argument("--all", action="store_true", help="Run all stages")
|
| parser.add_argument(
|
| "--generate", action="store_true", help="Stage 1: Generate datasets"
|
| )
|
| parser.add_argument(
|
| "--validate", action="store_true", help="Stage 2: Validate datasets"
|
| )
|
| parser.add_argument(
|
| "--forge", action="store_true", help="Stage 3: Run reasoning forge"
|
| )
|
| parser.add_argument(
|
| "--train", action="store_true", help="Stage 4: Train adapters"
|
| )
|
| parser.add_argument(
|
| "--evaluate", action="store_true", help="Stage 5: Run evaluations"
|
| )
|
|
|
|
|
| parser.add_argument(
|
| "--adapters",
|
| nargs="+",
|
| default=None,
|
| help="Specific adapters to process (default: all in registry)",
|
| )
|
| parser.add_argument(
|
| "--pipeline-config",
|
| type=str,
|
| default="configs/pipeline_config.yaml",
|
| help="Path to pipeline configuration",
|
| )
|
| parser.add_argument(
|
| "--adapter-registry",
|
| type=str,
|
| default="configs/adapter_registry.yaml",
|
| help="Path to adapter registry",
|
| )
|
| parser.add_argument(
|
| "--seed",
|
| type=int,
|
| default=None,
|
| help="Random seed (overrides config)",
|
| )
|
|
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| """Main entry point for the Codette training pipeline."""
|
| args = parse_args()
|
|
|
|
|
| run_all = args.all
|
| stages = {
|
| "generate": args.generate or run_all,
|
| "validate": args.validate or run_all,
|
| "forge": args.forge or run_all,
|
| "train": args.train or run_all,
|
| "evaluate": args.evaluate or run_all,
|
| }
|
|
|
| if not any(stages.values()):
|
| print(
|
| "No stages selected. Use --all or specify stages "
|
| "(--generate, --validate, --forge, --train, --evaluate)"
|
| )
|
| sys.exit(1)
|
|
|
|
|
| logger = setup_pipeline_logging()
|
| logger.info("=== Codette Training Pipeline ===")
|
| logger.info(f"Stages: {[s for s, enabled in stages.items() if enabled]}")
|
|
|
|
|
| try:
|
| pipeline_config = load_pipeline_config(args.pipeline_config)
|
| registry = load_adapter_registry(args.adapter_registry)
|
| except FileNotFoundError as e:
|
| logger.error(f"Configuration error: {e}")
|
| sys.exit(1)
|
|
|
|
|
| seed = args.seed or pipeline_config.get("pipeline", {}).get("seed", 42)
|
| import random
|
| import numpy as np
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| try:
|
| import torch
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed)
|
| except ImportError:
|
| pass
|
| logger.info(f"Random seed: {seed}")
|
|
|
|
|
| if args.adapters:
|
| adapter_names = args.adapters
|
| unknown = [n for n in adapter_names if n not in registry]
|
| if unknown:
|
| logger.error(
|
| f"Unknown adapters: {unknown}. "
|
| f"Available: {list(registry.keys())}"
|
| )
|
| sys.exit(1)
|
| else:
|
| adapter_names = list(registry.keys())
|
|
|
| logger.info(f"Adapters ({len(adapter_names)}): {adapter_names}")
|
|
|
|
|
| observatory = ObservatoryLogger()
|
|
|
|
|
| all_results: dict[str, dict[str, dict]] = {}
|
| pipeline_start = time.time()
|
|
|
| if stages["generate"]:
|
| all_results["generate"] = stage_generate(
|
| registry, pipeline_config, adapter_names, observatory, logger
|
| )
|
|
|
| if stages["validate"]:
|
| all_results["validate"] = stage_validate(
|
| registry, pipeline_config, adapter_names, observatory, logger
|
| )
|
|
|
| if stages["forge"]:
|
| all_results["forge"] = stage_forge(
|
| registry, pipeline_config, adapter_names, observatory, logger
|
| )
|
|
|
| if stages["train"]:
|
| all_results["train"] = stage_train(
|
| registry, pipeline_config, adapter_names, observatory, logger
|
| )
|
|
|
| if stages["evaluate"]:
|
| all_results["evaluate"] = stage_evaluate(
|
| registry, pipeline_config, adapter_names, observatory, logger
|
| )
|
|
|
| total_time = time.time() - pipeline_start
|
|
|
|
|
| observatory.log("pipeline", None, {
|
| "total_time_seconds": total_time,
|
| "stages_run": [s for s, enabled in stages.items() if enabled],
|
| "adapters_processed": adapter_names,
|
| })
|
| observatory.save()
|
| logger.info(f"Observatory metrics saved to: {observatory.output_path}")
|
|
|
|
|
| print_dashboard(all_results, total_time, logger)
|
|
|
|
|
| results_path = Path("logs") / "pipeline_results.json"
|
| with open(results_path, "w", encoding="utf-8") as f:
|
| json.dump(
|
| {
|
| "timestamp": datetime.now().isoformat(),
|
| "total_time_seconds": total_time,
|
| "seed": seed,
|
| "stages": {s: e for s, e in stages.items()},
|
| "adapters": adapter_names,
|
| "results": all_results,
|
| },
|
| f,
|
| indent=2,
|
| )
|
| logger.info(f"Pipeline results saved to: {results_path}")
|
|
|
|
|
| has_failures = False
|
| for stage_results in all_results.values():
|
| for result in stage_results.values():
|
| if result.get("status") == "error":
|
| has_failures = True
|
| break
|
|
|
| if has_failures:
|
| logger.warning("Pipeline completed with errors. Check logs for details.")
|
| sys.exit(1)
|
| else:
|
| logger.info("Pipeline completed successfully.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|