S-Dreamer commited on
Commit
b9ed97d
·
verified ·
1 Parent(s): 178abc4

Upload 4 files

Browse files
Files changed (4) hide show
  1. evaluators.py +229 -0
  2. hf-sync.yml +151 -0
  3. logging.py +55 -0
  4. pipeline.py +464 -0
evaluators.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluator implementations for code generation metrics.
3
+
4
+ Each evaluator exposes a single method:
5
+ evaluate(model, tokenizer, dataset) -> float
6
+
7
+ Scores are always in [0, 1].
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import ast
13
+ import multiprocessing
14
+ import textwrap
15
+ from abc import ABC, abstractmethod
16
+ from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeoutError
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ from datasets import Dataset
22
+ from sacrebleu.metrics import BLEU
23
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Base
28
+ # ---------------------------------------------------------------------------
29
+ class BaseEvaluator(ABC):
30
+ @abstractmethod
31
+ def evaluate(
32
+ self,
33
+ model: PreTrainedModel,
34
+ tokenizer: PreTrainedTokenizerBase,
35
+ dataset: Dataset,
36
+ ) -> float:
37
+ ...
38
+
39
+ def _generate_batch(
40
+ self,
41
+ model: PreTrainedModel,
42
+ tokenizer: PreTrainedTokenizerBase,
43
+ prompts: list[str],
44
+ max_new_tokens: int = 256,
45
+ num_return_sequences: int = 1,
46
+ temperature: float = 0.2,
47
+ ) -> list[list[str]]:
48
+ """Generate completions for a list of prompts. Returns list-of-lists."""
49
+ results: list[list[str]] = []
50
+ device = next(model.parameters()).device
51
+
52
+ for prompt in prompts:
53
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
54
+ inputs = {k: v.to(device) for k, v in inputs.items()}
55
+
56
+ with torch.no_grad():
57
+ outputs = model.generate(
58
+ **inputs,
59
+ max_new_tokens=max_new_tokens,
60
+ num_return_sequences=num_return_sequences,
61
+ do_sample=temperature > 0,
62
+ temperature=temperature if temperature > 0 else 1.0,
63
+ top_p=0.95,
64
+ pad_token_id=tokenizer.eos_token_id,
65
+ )
66
+
67
+ prompt_len = inputs["input_ids"].shape[1]
68
+ completions = [
69
+ tokenizer.decode(out[prompt_len:], skip_special_tokens=True)
70
+ for out in outputs
71
+ ]
72
+ results.append(completions)
73
+
74
+ return results
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Pass@k
79
+ # ---------------------------------------------------------------------------
80
+ class PassAtKEvaluator(BaseEvaluator):
81
+ """
82
+ Unbiased pass@k estimator from Chen et al. (2021):
83
+ pass@k = 1 - C(n-c, k) / C(n, k)
84
+ where n = total samples, c = correct samples.
85
+ """
86
+
87
+ def __init__(self, k: int = 1, n: int = 10) -> None:
88
+ self.k = k
89
+ self.n = n
90
+
91
+ def evaluate(
92
+ self,
93
+ model: PreTrainedModel,
94
+ tokenizer: PreTrainedTokenizerBase,
95
+ dataset: Dataset,
96
+ num_problems: int = 50,
97
+ ) -> float:
98
+ problems = dataset.select(range(min(num_problems, len(dataset))))
99
+ prompts = [str(ex.get("prompt", ex.get("content", ""))) for ex in problems]
100
+ references = [str(ex.get("canonical_solution", ex.get("content", ""))) for ex in problems]
101
+
102
+ all_completions = self._generate_batch(
103
+ model, tokenizer, prompts,
104
+ num_return_sequences=self.n,
105
+ temperature=0.8, # diversity for pass@k
106
+ )
107
+
108
+ scores: list[float] = []
109
+ for completions, reference in zip(all_completions, references):
110
+ correct = sum(
111
+ 1 for c in completions
112
+ if self._is_correct(c, reference)
113
+ )
114
+ scores.append(self._pass_at_k(n=self.n, c=correct, k=self.k))
115
+
116
+ return float(np.mean(scores))
117
+
118
+ @staticmethod
119
+ def _pass_at_k(n: int, c: int, k: int) -> float:
120
+ if n - c < k:
121
+ return 1.0
122
+ return 1.0 - float(np.prod([(n - c - i) / (n - i) for i in range(k)]))
123
+
124
+ @staticmethod
125
+ def _is_correct(completion: str, reference: str) -> bool:
126
+ # Basic syntactic check — override with execution check for HumanEval-style
127
+ try:
128
+ ast.parse(completion)
129
+ return completion.strip() == reference.strip()
130
+ except SyntaxError:
131
+ return False
132
+
133
+
134
+ # ---------------------------------------------------------------------------
135
+ # BLEU
136
+ # ---------------------------------------------------------------------------
137
+ class BleuEvaluator(BaseEvaluator):
138
+ def __init__(self, max_new_tokens: int = 256) -> None:
139
+ self._max_new_tokens = max_new_tokens
140
+ self._bleu = BLEU(effective_order=True)
141
+
142
+ def evaluate(
143
+ self,
144
+ model: PreTrainedModel,
145
+ tokenizer: PreTrainedTokenizerBase,
146
+ dataset: Dataset,
147
+ num_samples: int = 100,
148
+ ) -> float:
149
+ subset = dataset.select(range(min(num_samples, len(dataset))))
150
+ prompts = [str(ex.get("prompt", ex.get("content", ""))) for ex in subset]
151
+ references = [str(ex.get("canonical_solution", ex.get("content", ""))) for ex in subset]
152
+
153
+ completions_batch = self._generate_batch(
154
+ model, tokenizer, prompts, max_new_tokens=self._max_new_tokens
155
+ )
156
+ hypotheses = [batch[0] for batch in completions_batch]
157
+
158
+ result = self._bleu.corpus_score(hypotheses, [references])
159
+ # sacrebleu returns score in [0, 100]; normalise to [0, 1]
160
+ return result.score / 100.0
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Execution accuracy
165
+ # ---------------------------------------------------------------------------
166
+ def _run_code_safe(code: str, timeout: int) -> bool:
167
+ """Run in a subprocess to enforce timeout and isolate crashes."""
168
+ try:
169
+ exec(compile(code, "<string>", "exec"), {}) # noqa: S102
170
+ return True
171
+ except Exception:
172
+ return False
173
+
174
+
175
+ class ExecutionAccuracyEvaluator(BaseEvaluator):
176
+ """Fraction of generated code snippets that execute without error."""
177
+
178
+ def __init__(self, timeout: int = 10, max_new_tokens: int = 256) -> None:
179
+ self._timeout = timeout
180
+ self._max_new_tokens = max_new_tokens
181
+
182
+ def evaluate(
183
+ self,
184
+ model: PreTrainedModel,
185
+ tokenizer: PreTrainedTokenizerBase,
186
+ dataset: Dataset,
187
+ num_samples: int = 50,
188
+ ) -> float:
189
+ subset = dataset.select(range(min(num_samples, len(dataset))))
190
+ prompts = [str(ex.get("prompt", ex.get("content", ""))) for ex in subset]
191
+
192
+ completions_batch = self._generate_batch(
193
+ model, tokenizer, prompts, max_new_tokens=self._max_new_tokens
194
+ )
195
+ codes = [batch[0] for batch in completions_batch]
196
+
197
+ passed = 0
198
+ with ProcessPoolExecutor(max_workers=4) as executor:
199
+ futures = {executor.submit(_run_code_safe, code, self._timeout): code for code in codes}
200
+ for future in futures:
201
+ try:
202
+ if future.result(timeout=self._timeout + 1):
203
+ passed += 1
204
+ except (FuturesTimeoutError, Exception):
205
+ pass
206
+
207
+ return passed / len(codes) if codes else 0.0
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # Exact match
212
+ # ---------------------------------------------------------------------------
213
+ class ExactMatchEvaluator(BaseEvaluator):
214
+ def evaluate(
215
+ self,
216
+ model: PreTrainedModel,
217
+ tokenizer: PreTrainedTokenizerBase,
218
+ dataset: Dataset,
219
+ num_samples: int = 100,
220
+ ) -> float:
221
+ subset = dataset.select(range(min(num_samples, len(dataset))))
222
+ prompts = [str(ex.get("prompt", ex.get("content", ""))) for ex in subset]
223
+ references = [str(ex.get("canonical_solution", ex.get("content", ""))) for ex in subset]
224
+
225
+ completions_batch = self._generate_batch(model, tokenizer, prompts)
226
+ hypotheses = [batch[0].strip() for batch in completions_batch]
227
+
228
+ matches = sum(h == r.strip() for h, r in zip(hypotheses, references))
229
+ return matches / len(references) if references else 0.0
hf-sync.yml ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: HF ↔ GitHub Sync
2
+
3
+ # Trigger on every push to main (GitHub → HF direction)
4
+ # and hourly to pull any HF-side changes back (HF → GitHub direction)
5
+ on:
6
+ push:
7
+ branches: [main]
8
+ schedule:
9
+ - cron: '0 * * * *' # hourly HF pull-check
10
+ workflow_dispatch:
11
+ inputs:
12
+ force_direction:
13
+ description: 'Force sync direction (hf-wins | gh-wins | auto)'
14
+ required: false
15
+ default: 'auto'
16
+
17
+ env:
18
+ HF_REPO_TYPE: space # model | dataset | space
19
+ HF_REPO: ${{ vars.HF_REPO }} # e.g. your-org/codecraftlab
20
+
21
+ jobs:
22
+ sync:
23
+ name: Sync HuggingFace ↔ GitHub
24
+ runs-on: ubuntu-latest
25
+ permissions:
26
+ contents: write
27
+
28
+ steps:
29
+ - name: Checkout (full history)
30
+ uses: actions/checkout@v4
31
+ with:
32
+ fetch-depth: 0
33
+ token: ${{ secrets.GITHUB_TOKEN }}
34
+
35
+ - name: Configure git identity
36
+ run: |
37
+ git config user.email "sync-bot@codecraftlab.noreply"
38
+ git config user.name "CodeCraftLab Sync Bot"
39
+
40
+ - name: Install git-lfs
41
+ run: |
42
+ sudo apt-get install -y git-lfs
43
+ git lfs install
44
+
45
+ - name: Add HuggingFace remote
46
+ run: |
47
+ git remote add hf \
48
+ "https://user:${{ secrets.HF_TOKEN }}@huggingface.co/${HF_REPO_TYPE}s/${HF_REPO}"
49
+ git fetch hf --prune
50
+
51
+ - name: Detect divergence and resolve
52
+ id: sync
53
+ env:
54
+ FORCE_DIRECTION: ${{ github.event.inputs.force_direction || 'auto' }}
55
+ run: |
56
+ set -euo pipefail
57
+
58
+ HF_HEAD=$(git rev-parse hf/main 2>/dev/null || echo "NONE")
59
+ GH_HEAD=$(git rev-parse HEAD)
60
+
61
+ if [ "$HF_HEAD" = "NONE" ]; then
62
+ echo "action=push-to-hf" >> "$GITHUB_OUTPUT"
63
+ echo "reason=HF remote has no main branch — initial push"
64
+ exit 0
65
+ fi
66
+
67
+ BASE=$(git merge-base HEAD hf/main 2>/dev/null || echo "NONE")
68
+
69
+ if [ "$FORCE_DIRECTION" = "hf-wins" ]; then
70
+ echo "action=hf-wins" >> "$GITHUB_OUTPUT"
71
+ echo "reason=Forced HF-wins override"
72
+ elif [ "$FORCE_DIRECTION" = "gh-wins" ]; then
73
+ echo "action=push-to-hf" >> "$GITHUB_OUTPUT"
74
+ echo "reason=Forced GH-wins override"
75
+ elif [ "$HF_HEAD" = "$GH_HEAD" ]; then
76
+ echo "action=in-sync" >> "$GITHUB_OUTPUT"
77
+ echo "reason=Already in sync"
78
+ elif [ "$BASE" = "$GH_HEAD" ]; then
79
+ # HF is ahead — pull HF → GitHub
80
+ echo "action=hf-wins" >> "$GITHUB_OUTPUT"
81
+ echo "reason=GitHub is behind HF — fast-forward"
82
+ elif [ "$BASE" = "$HF_HEAD" ]; then
83
+ # GitHub is ahead — push GitHub → HF
84
+ echo "action=push-to-hf" >> "$GITHUB_OUTPUT"
85
+ echo "reason=HF is behind GitHub — pushing"
86
+ else
87
+ # Both diverged — HF is source of truth
88
+ echo "action=hf-wins" >> "$GITHUB_OUTPUT"
89
+ echo "reason=CONFLICT: both diverged — HF wins (source of truth)"
90
+ fi
91
+
92
+ - name: "[In-sync] Nothing to do"
93
+ if: steps.sync.outputs.action == 'in-sync'
94
+ run: echo "✅ HF and GitHub are in sync — no action required."
95
+
96
+ - name: "[Push] GitHub → HuggingFace"
97
+ if: steps.sync.outputs.action == 'push-to-hf'
98
+ run: |
99
+ echo "📤 Pushing GitHub → HuggingFace"
100
+ git push hf main
101
+
102
+ - name: "[HF Wins] HuggingFace → GitHub"
103
+ if: steps.sync.outputs.action == 'hf-wins'
104
+ run: |
105
+ echo "📥 HuggingFace wins — overwriting GitHub main"
106
+ git reset --hard hf/main
107
+ git push origin main --force-with-lease || git push origin main --force
108
+
109
+ - name: Summary
110
+ if: always()
111
+ run: |
112
+ echo "### Sync Result" >> "$GITHUB_STEP_SUMMARY"
113
+ echo "- **Action:** ${{ steps.sync.outputs.action }}" >> "$GITHUB_STEP_SUMMARY"
114
+ echo "- **Trigger:** ${{ github.event_name }}" >> "$GITHUB_STEP_SUMMARY"
115
+ echo "- **Branch:** main" >> "$GITHUB_STEP_SUMMARY"
116
+
117
+ # ------------------------------------------------------------------
118
+ # Validate HF Space config on every push
119
+ # ------------------------------------------------------------------
120
+ validate-space-config:
121
+ name: Validate HF Space README config
122
+ runs-on: ubuntu-latest
123
+ steps:
124
+ - uses: actions/checkout@v4
125
+
126
+ - name: Check README frontmatter
127
+ run: |
128
+ python3 - <<'EOF'
129
+ import re, sys
130
+
131
+ with open("README.md") as f:
132
+ content = f.read()
133
+
134
+ match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
135
+ if not match:
136
+ print("❌ README is missing HF Space YAML frontmatter")
137
+ sys.exit(1)
138
+
139
+ frontmatter = match.group(1)
140
+ required_keys = ["title", "sdk", "app_port", "license"]
141
+ missing = [k for k in required_keys if k + ":" not in frontmatter]
142
+ if missing:
143
+ print(f"❌ Missing frontmatter keys: {missing}")
144
+ sys.exit(1)
145
+
146
+ if "sdk: streamlit" in frontmatter:
147
+ print("❌ sdk is still 'streamlit' — should be 'docker' for FastAPI")
148
+ sys.exit(1)
149
+
150
+ print("✅ HF Space frontmatter is valid")
151
+ EOF
logging.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Structured logging setup via structlog.
3
+
4
+ Production: JSON output, machine-parseable.
5
+ Development: colourised console output.
6
+
7
+ Call configure_logging() once at application startup before any loggers are created.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import sys
14
+
15
+ import structlog
16
+
17
+
18
+ def configure_logging(log_level: str = "INFO", env: str = "development") -> None:
19
+ """Configure structlog with environment-appropriate rendering."""
20
+
21
+ shared_processors: list[structlog.types.Processor] = [
22
+ structlog.contextvars.merge_contextvars,
23
+ structlog.stdlib.add_log_level,
24
+ structlog.stdlib.add_logger_name,
25
+ structlog.processors.TimeStamper(fmt="iso"),
26
+ structlog.processors.StackInfoRenderer(),
27
+ ]
28
+
29
+ if env == "production":
30
+ processors: list[structlog.types.Processor] = [
31
+ *shared_processors,
32
+ structlog.processors.dict_tracebacks,
33
+ structlog.processors.JSONRenderer(),
34
+ ]
35
+ renderer = structlog.processors.JSONRenderer()
36
+ else:
37
+ processors = [
38
+ *shared_processors,
39
+ structlog.dev.ConsoleRenderer(colors=True),
40
+ ]
41
+ renderer = structlog.dev.ConsoleRenderer(colors=True)
42
+
43
+ structlog.configure(
44
+ processors=processors,
45
+ wrapper_class=structlog.make_filtering_bound_logger(
46
+ getattr(logging, log_level.upper(), logging.INFO)
47
+ ),
48
+ context_class=dict,
49
+ logger_factory=structlog.PrintLoggerFactory(sys.stdout),
50
+ cache_logger_on_first_use=True,
51
+ )
52
+
53
+ # Silence noisy third-party loggers
54
+ for noisy in ("uvicorn.access", "httpx", "transformers", "datasets"):
55
+ logging.getLogger(noisy).setLevel(logging.WARNING)
pipeline.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning pipeline with structured logging and eval hooks.
3
+
4
+ Pipeline stages:
5
+ 1. Preflight validation — config, GPU, disk, token
6
+ 2. Dataset preparation — load, tokenize, split
7
+ 3. Model initialisation — base model + LoRA adapters
8
+ 4. Training — Trainer with custom callbacks
9
+ 5. Evaluation — post-training metric suite
10
+ 6. Checkpoint export — save + optional HF Hub push
11
+
12
+ Each stage emits structured log events. Eval hooks are composable and
13
+ run both during training (via TrainerCallback) and post-training.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import shutil
21
+ import time
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ import structlog
27
+ import torch
28
+ from datasets import Dataset, DatasetDict, load_dataset
29
+ from peft import LoraConfig, TaskType, get_peft_model
30
+ from transformers import (
31
+ AutoModelForCausalLM,
32
+ AutoTokenizer,
33
+ DataCollatorForLanguageModeling,
34
+ PreTrainedModel,
35
+ PreTrainedTokenizerBase,
36
+ Trainer,
37
+ TrainerCallback,
38
+ TrainerControl,
39
+ TrainerState,
40
+ TrainingArguments,
41
+ )
42
+
43
+ from training.config import EvalMetric, EvalStrategy, TrainingJobConfig
44
+ from training.evaluators import (
45
+ BleuEvaluator,
46
+ ExecutionAccuracyEvaluator,
47
+ ExactMatchEvaluator,
48
+ PassAtKEvaluator,
49
+ )
50
+
51
+ log = structlog.get_logger(__name__)
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Eval result container
56
+ # ---------------------------------------------------------------------------
57
+ @dataclass
58
+ class EvalResults:
59
+ job_name: str
60
+ epoch: float
61
+ step: int
62
+ metrics: dict[str, float] = field(default_factory=dict)
63
+ errors: list[str] = field(default_factory=list)
64
+ duration_seconds: float = 0.0
65
+
66
+ def log(self, bound_log: structlog.BoundLogger) -> None:
67
+ bound_log.info(
68
+ "eval.completed",
69
+ epoch=self.epoch,
70
+ step=self.step,
71
+ duration_seconds=round(self.duration_seconds, 2),
72
+ **self.metrics,
73
+ )
74
+ for error in self.errors:
75
+ bound_log.warning("eval.error", message=error)
76
+
77
+ def to_dict(self) -> dict[str, Any]:
78
+ return {
79
+ "job_name": self.job_name,
80
+ "epoch": self.epoch,
81
+ "step": self.step,
82
+ "metrics": self.metrics,
83
+ "errors": self.errors,
84
+ "duration_seconds": self.duration_seconds,
85
+ }
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Eval hook registry
90
+ # ---------------------------------------------------------------------------
91
+ class EvalHookRunner:
92
+ """
93
+ Runs the configured evaluation metrics against a model + dataset.
94
+
95
+ Evaluators are resolved from the job config at construction time.
96
+ Each evaluator is independent; failures in one do not abort others.
97
+ """
98
+
99
+ def __init__(self, config: TrainingJobConfig, tokenizer: PreTrainedTokenizerBase) -> None:
100
+ self._config = config
101
+ self._tokenizer = tokenizer
102
+ self._evaluators = self._build_evaluators()
103
+ self._log = log.bind(job=config.job_name)
104
+
105
+ def _build_evaluators(self) -> dict[EvalMetric, Any]:
106
+ evals: dict[EvalMetric, Any] = {}
107
+ eval_cfg = self._config.evaluation
108
+ for metric in eval_cfg.metrics:
109
+ match metric:
110
+ case EvalMetric.PASS_AT_1:
111
+ evals[metric] = PassAtKEvaluator(k=1, n=eval_cfg.num_samples_per_problem)
112
+ case EvalMetric.PASS_AT_10:
113
+ evals[metric] = PassAtKEvaluator(k=10, n=eval_cfg.num_samples_per_problem)
114
+ case EvalMetric.BLEU:
115
+ evals[metric] = BleuEvaluator()
116
+ case EvalMetric.EXECUTION_ACCURACY:
117
+ evals[metric] = ExecutionAccuracyEvaluator(
118
+ timeout=self._config.evaluation.timeout_seconds
119
+ )
120
+ case EvalMetric.EXACT_MATCH:
121
+ evals[metric] = ExactMatchEvaluator()
122
+ return evals
123
+
124
+ def run(
125
+ self,
126
+ model: PreTrainedModel,
127
+ eval_dataset: Dataset,
128
+ epoch: float,
129
+ step: int,
130
+ ) -> EvalResults:
131
+ start = time.perf_counter()
132
+ results = EvalResults(job_name=self._config.job_name, epoch=epoch, step=step)
133
+
134
+ model.eval()
135
+ with torch.no_grad():
136
+ for metric, evaluator in self._evaluators.items():
137
+ try:
138
+ score = evaluator.evaluate(
139
+ model=model,
140
+ tokenizer=self._tokenizer,
141
+ dataset=eval_dataset,
142
+ )
143
+ results.metrics[metric.value] = round(score, 4)
144
+ self._log.info("eval.metric", metric=metric.value, score=score)
145
+ except Exception as exc: # noqa: BLE001
146
+ msg = f"{metric.value}: {exc}"
147
+ results.errors.append(msg)
148
+ self._log.warning("eval.metric_failed", metric=metric.value, error=str(exc))
149
+
150
+ results.duration_seconds = time.perf_counter() - start
151
+ results.log(self._log)
152
+ return results
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Custom training callback
157
+ # ---------------------------------------------------------------------------
158
+ class CodeCraftLabCallback(TrainerCallback):
159
+ """
160
+ Injects structured logging and eval hooks into the HF Trainer loop.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ hook_runner: EvalHookRunner,
166
+ eval_dataset: Dataset,
167
+ results_path: Path,
168
+ ) -> None:
169
+ self._runner = hook_runner
170
+ self._eval_dataset = eval_dataset
171
+ self._results_path = results_path
172
+ self._all_results: list[dict[str, Any]] = []
173
+ self._log = log
174
+
175
+ def on_epoch_end(
176
+ self,
177
+ args: TrainingArguments,
178
+ state: TrainerState,
179
+ control: TrainerControl,
180
+ model: PreTrainedModel,
181
+ **kwargs: Any,
182
+ ) -> TrainerControl:
183
+ self._log.info(
184
+ "training.epoch_end",
185
+ epoch=state.epoch,
186
+ step=state.global_step,
187
+ loss=state.log_history[-1].get("loss") if state.log_history else None,
188
+ )
189
+ results = self._runner.run(
190
+ model=model,
191
+ eval_dataset=self._eval_dataset,
192
+ epoch=state.epoch or 0.0,
193
+ step=state.global_step,
194
+ )
195
+ self._all_results.append(results.to_dict())
196
+ self._persist_results()
197
+ return control
198
+
199
+ def on_log(
200
+ self,
201
+ args: TrainingArguments,
202
+ state: TrainerState,
203
+ control: TrainerControl,
204
+ logs: dict[str, float],
205
+ **kwargs: Any,
206
+ ) -> TrainerControl:
207
+ self._log.info("training.log", step=state.global_step, **logs)
208
+ return control
209
+
210
+ def on_train_end(
211
+ self,
212
+ args: TrainingArguments,
213
+ state: TrainerState,
214
+ control: TrainerControl,
215
+ **kwargs: Any,
216
+ ) -> TrainerControl:
217
+ self._log.info(
218
+ "training.completed",
219
+ total_steps=state.global_step,
220
+ total_flos=state.total_flos,
221
+ )
222
+ return control
223
+
224
+ def _persist_results(self) -> None:
225
+ self._results_path.write_text(
226
+ json.dumps(self._all_results, indent=2), encoding="utf-8"
227
+ )
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # Pipeline
232
+ # ---------------------------------------------------------------------------
233
+ class FineTuningPipeline:
234
+ """
235
+ Orchestrates the full fine-tuning lifecycle.
236
+
237
+ Usage:
238
+ config = TrainingJobConfig.model_validate(raw_dict)
239
+ pipeline = FineTuningPipeline(config)
240
+ pipeline.run()
241
+ """
242
+
243
+ def __init__(self, config: TrainingJobConfig) -> None:
244
+ self._config = config
245
+ self._log = log.bind(job=config.job_name, model=config.base_model)
246
+ self._output_dir = Path(config.checkpoint.output_dir) / config.job_name
247
+
248
+ # ------------------------------------------------------------------
249
+ # Public entry point
250
+ # ------------------------------------------------------------------
251
+ def run(self) -> Path:
252
+ """Execute all pipeline stages. Returns the final checkpoint path."""
253
+ self._log.info("pipeline.started")
254
+ self._preflight()
255
+ datasets = self._prepare_datasets()
256
+ model, tokenizer = self._load_model()
257
+ self._train(model, tokenizer, datasets)
258
+ final_path = self._export(model, tokenizer)
259
+ self._log.info("pipeline.finished", output=str(final_path))
260
+ return final_path
261
+
262
+ # ------------------------------------------------------------------
263
+ # Stage 1: Preflight
264
+ # ------------------------------------------------------------------
265
+ def _preflight(self) -> None:
266
+ self._log.info("pipeline.preflight")
267
+
268
+ # Validate config (already done at submission, but be defensive)
269
+ self._config.model_validate(self._config.model_dump())
270
+
271
+ # GPU check
272
+ if torch.cuda.is_available():
273
+ device_name = torch.cuda.get_device_name(0)
274
+ vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
275
+ self._log.info("preflight.gpu", device=device_name, vram_gb=round(vram_gb, 1))
276
+ else:
277
+ self._log.warning("preflight.no_gpu", message="Training on CPU — will be slow")
278
+
279
+ # Disk space (rough check — 20 GB minimum)
280
+ free_gb = shutil.disk_usage(self._output_dir.parent).free / 1e9
281
+ if free_gb < 20:
282
+ self._log.warning("preflight.disk_low", free_gb=round(free_gb, 1))
283
+
284
+ # HF token if pushing
285
+ if self._config.hub.push_to_hub and not os.environ.get("HF_TOKEN"):
286
+ raise EnvironmentError("HF_TOKEN is required when hub.push_to_hub=true")
287
+
288
+ self._output_dir.mkdir(parents=True, exist_ok=True)
289
+ self._log.info("preflight.passed")
290
+
291
+ # ------------------------------------------------------------------
292
+ # Stage 2: Dataset preparation
293
+ # ------------------------------------------------------------------
294
+ def _prepare_datasets(self) -> DatasetDict:
295
+ self._log.info("pipeline.dataset_prep")
296
+ ds_cfg = self._config.dataset
297
+
298
+ # Load — support both HF Hub paths and internal dataset IDs
299
+ raw: Dataset
300
+ if ds_cfg.dataset_id.startswith("ds_"):
301
+ # Internal dataset — load from local store
302
+ raw = Dataset.load_from_disk(f"./data/{ds_cfg.dataset_id}")
303
+ else:
304
+ raw = load_dataset(ds_cfg.dataset_id, split="train") # type: ignore[assignment]
305
+
306
+ if ds_cfg.max_samples:
307
+ raw = raw.select(range(min(ds_cfg.max_samples, len(raw))))
308
+
309
+ if ds_cfg.shuffle:
310
+ raw = raw.shuffle(seed=ds_cfg.shuffle_seed)
311
+
312
+ n_train = int(len(raw) * ds_cfg.split_ratio)
313
+ splits = DatasetDict(
314
+ {
315
+ "train": raw.select(range(n_train)),
316
+ "eval": raw.select(range(n_train, len(raw))),
317
+ }
318
+ )
319
+ self._log.info(
320
+ "dataset.prepared",
321
+ train_size=len(splits["train"]),
322
+ eval_size=len(splits["eval"]),
323
+ column=ds_cfg.text_column,
324
+ )
325
+ return splits
326
+
327
+ # ------------------------------------------------------------------
328
+ # Stage 3: Model initialisation
329
+ # ------------------------------------------------------------------
330
+ def _load_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]:
331
+ self._log.info("pipeline.model_load")
332
+ hp = self._config.training
333
+
334
+ dtype_map = {
335
+ "fp32": torch.float32,
336
+ "fp16": torch.float16,
337
+ "bf16": torch.bfloat16,
338
+ }
339
+ torch_dtype = dtype_map.get(hp.precision.value, torch.bfloat16)
340
+
341
+ tokenizer = AutoTokenizer.from_pretrained(self._config.base_model)
342
+ if tokenizer.pad_token is None:
343
+ tokenizer.pad_token = tokenizer.eos_token
344
+
345
+ model = AutoModelForCausalLM.from_pretrained(
346
+ self._config.base_model,
347
+ torch_dtype=torch_dtype,
348
+ device_map="auto" if torch.cuda.is_available() else "cpu",
349
+ )
350
+
351
+ if self._config.lora and self._config.lora.enabled:
352
+ lora_cfg = self._config.lora
353
+ peft_config = LoraConfig(
354
+ task_type=TaskType.CAUSAL_LM,
355
+ r=lora_cfg.r,
356
+ lora_alpha=lora_cfg.alpha,
357
+ lora_dropout=lora_cfg.dropout,
358
+ target_modules=lora_cfg.target_modules,
359
+ bias=lora_cfg.bias, # type: ignore[arg-type]
360
+ )
361
+ model = get_peft_model(model, peft_config)
362
+ trainable, total = model.get_nb_trainable_parameters()
363
+ self._log.info(
364
+ "model.lora_applied",
365
+ trainable_params=trainable,
366
+ total_params=total,
367
+ trainable_pct=round(100 * trainable / total, 2),
368
+ )
369
+ else:
370
+ self._log.info("model.full_finetune")
371
+
372
+ return model, tokenizer # type: ignore[return-value]
373
+
374
+ # ------------------------------------------------------------------
375
+ # Stage 4: Training
376
+ # ------------------------------------------------------------------
377
+ def _train(
378
+ self,
379
+ model: PreTrainedModel,
380
+ tokenizer: PreTrainedTokenizerBase,
381
+ datasets: DatasetDict,
382
+ ) -> None:
383
+ self._log.info("pipeline.training_start")
384
+ hp = self._config.training
385
+ ckpt = self._config.checkpoint
386
+ eval_cfg = self._config.evaluation
387
+
388
+ def tokenize(examples: dict[str, list[str]]) -> dict[str, Any]:
389
+ return tokenizer(
390
+ examples[self._config.dataset.text_column],
391
+ truncation=True,
392
+ max_length=hp.max_seq_length,
393
+ padding=False,
394
+ )
395
+
396
+ tokenized = datasets.map(tokenize, batched=True, remove_columns=datasets["train"].column_names)
397
+
398
+ training_args = TrainingArguments(
399
+ output_dir=str(self._output_dir),
400
+ num_train_epochs=hp.num_epochs,
401
+ per_device_train_batch_size=hp.batch_size,
402
+ per_device_eval_batch_size=hp.batch_size,
403
+ gradient_accumulation_steps=hp.gradient_accumulation_steps,
404
+ learning_rate=hp.learning_rate,
405
+ weight_decay=hp.weight_decay,
406
+ warmup_ratio=hp.warmup_ratio,
407
+ max_grad_norm=hp.max_grad_norm,
408
+ optim=hp.optimizer.value,
409
+ lr_scheduler_type=hp.lr_scheduler,
410
+ fp16=hp.precision.value == "fp16",
411
+ bf16=hp.precision.value == "bf16",
412
+ evaluation_strategy=eval_cfg.strategy.value,
413
+ eval_steps=eval_cfg.eval_steps,
414
+ save_strategy=ckpt.save_strategy.value,
415
+ save_steps=ckpt.save_steps,
416
+ save_total_limit=ckpt.save_total_limit,
417
+ load_best_model_at_end=eval_cfg.load_best_model_at_end,
418
+ metric_for_best_model=eval_cfg.metric_for_best_model.value,
419
+ greater_is_better=eval_cfg.greater_is_better,
420
+ seed=hp.seed,
421
+ dataloader_num_workers=hp.dataloader_num_workers,
422
+ report_to="none", # structlog handles all logging
423
+ logging_steps=10,
424
+ resume_from_checkpoint=ckpt.resume_from_checkpoint,
425
+ push_to_hub=False, # push handled separately in export stage
426
+ )
427
+
428
+ hook_runner = EvalHookRunner(self._config, tokenizer)
429
+ results_path = self._output_dir / "eval_results.json"
430
+ callback = CodeCraftLabCallback(
431
+ hook_runner=hook_runner,
432
+ eval_dataset=datasets["eval"],
433
+ results_path=results_path,
434
+ )
435
+
436
+ trainer = Trainer(
437
+ model=model,
438
+ args=training_args,
439
+ train_dataset=tokenized["train"],
440
+ eval_dataset=tokenized["eval"],
441
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
442
+ callbacks=[callback],
443
+ )
444
+
445
+ trainer.train(resume_from_checkpoint=ckpt.resume_from_checkpoint)
446
+
447
+ # ------------------------------------------------------------------
448
+ # Stage 5: Export + Hub push
449
+ # ------------------------------------------------------------------
450
+ def _export(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) -> Path:
451
+ self._log.info("pipeline.export")
452
+ final_path = self._output_dir / "final"
453
+ model.save_pretrained(str(final_path))
454
+ tokenizer.save_pretrained(str(final_path))
455
+ self._log.info("model.saved", path=str(final_path))
456
+
457
+ hub_cfg = self._config.hub
458
+ if hub_cfg.push_to_hub and hub_cfg.repo_id:
459
+ self._log.info("hub.pushing", repo_id=hub_cfg.repo_id)
460
+ model.push_to_hub(hub_cfg.repo_id, private=hub_cfg.private)
461
+ tokenizer.push_to_hub(hub_cfg.repo_id, private=hub_cfg.private)
462
+ self._log.info("hub.pushed", repo_id=hub_cfg.repo_id)
463
+
464
+ return final_path