Nikitasoni22 commited on
Commit
2b6814d
·
1 Parent(s): dbb0576

updated code

Browse files
Files changed (4) hide show
  1. cicd_debug_env/env.py +22 -2
  2. cicd_debug_env/rewards.py +28 -6
  3. train.py +12 -3
  4. train_colab.ipynb +321 -7
cicd_debug_env/env.py CHANGED
@@ -3,7 +3,15 @@ import random
3
 
4
  from .models import Action, Observation
5
  from .tasks import ALL_TASKS
6
- from .rewards import compute_total_reward
 
 
 
 
 
 
 
 
7
  from .memory.failure_bank import FailureMemoryBank
8
 
9
  try:
@@ -76,7 +84,19 @@ class CICDDebugEnv(_BaseEnv):
76
  self.current_observation.available_actions = self.available_actions()
77
 
78
  self._update_state()
79
- return self.current_observation, reward, self.done, {"task_id": self.current_task["id"], "reward_breakdown": reward}
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def state(self) -> dict:
82
  return self._state_dict
 
3
 
4
  from .models import Action, Observation
5
  from .tasks import ALL_TASKS
6
+ from .rewards import (
7
+ compute_total_reward,
8
+ reward_execution_success,
9
+ reward_fix_correctness,
10
+ reward_step_efficiency,
11
+ reward_format_compliance,
12
+ reward_robustness,
13
+ check_anti_hacking_guards,
14
+ )
15
  from .memory.failure_bank import FailureMemoryBank
16
 
17
  try:
 
84
  self.current_observation.available_actions = self.available_actions()
85
 
86
  self._update_state()
87
+ reward_components = {
88
+ "execution_success": reward_execution_success(self.current_observation, self.current_task),
89
+ "fix_correctness": reward_fix_correctness(self.current_observation, action, self.current_task),
90
+ "step_efficiency": reward_step_efficiency(self.current_observation, self.max_steps),
91
+ "format_compliance": reward_format_compliance(action),
92
+ "robustness": reward_robustness(self.current_observation, self.current_task),
93
+ "anti_hacking": check_anti_hacking_guards(self.current_observation, action),
94
+ "total": reward,
95
+ }
96
+ return self.current_observation, reward, self.done, {
97
+ "task_id": self.current_task["id"],
98
+ "reward_breakdown": reward_components,
99
+ }
100
 
101
  def state(self) -> dict:
102
  return self._state_dict
cicd_debug_env/rewards.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any, List
2
  from .models import Action, Observation
3
 
4
  def reward_execution_success(state: Observation, task: Dict[str, Any] = None) -> float:
@@ -31,11 +31,33 @@ def reward_format_compliance(action: Action) -> float:
31
  return 1.0
32
  return 0.0
33
 
34
- def reward_robustness(state: Observation, task: Dict[str, Any] = None) -> float:
35
- # After fix, does the pipeline pass 3 adversarial variants? (0, 0.33, 0.66, 1)
36
- if task and state.pipeline_yaml.strip() == task.get("correct_yaml", "").strip():
37
- return 1.0
38
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def check_anti_hacking_guards(state: Observation, action: Action) -> float:
41
  penalty = 0.0
 
1
+ from typing import Dict, Any, List, Optional
2
  from .models import Action, Observation
3
 
4
  def reward_execution_success(state: Observation, task: Dict[str, Any] = None) -> float:
 
31
  return 1.0
32
  return 0.0
33
 
34
+ def reward_robustness(state: Observation, task: Optional[Dict[str, Any]] = None) -> float:
35
+ """
36
+ Checks if the agent's fix is robust to 3 minor perturbations of the correct YAML.
37
+ Perturbations: trailing whitespace, extra blank lines, lowercase keys.
38
+ Score: 0.33 per perturbation passed (max 1.0).
39
+ Only runs if the agent has attempted an edit_config action.
40
+ """
41
+ if task is None:
42
+ return 0.0
43
+
44
+ correct = task.get("correct_yaml", "").strip()
45
+ agent_fix = state.pipeline_yaml.strip()
46
+
47
+ if not correct or not agent_fix:
48
+ return 0.0
49
+
50
+ def normalize(yaml_str: str) -> str:
51
+ lines = [l.rstrip() for l in yaml_str.splitlines()]
52
+ return "\n".join(l for l in lines if l)
53
+
54
+ perturbations = [
55
+ normalize(agent_fix) == normalize(correct), # trailing whitespace
56
+ agent_fix.replace("\n\n", "\n") == correct.replace("\n\n", "\n"), # blank lines
57
+ agent_fix.lower() == correct.lower(), # case insensitivity
58
+ ]
59
+ score = sum(perturbations) / 3.0
60
+ return round(score, 4)
61
 
62
  def check_anti_hacking_guards(state: Observation, action: Action) -> float:
63
  penalty = 0.0
train.py CHANGED
@@ -109,7 +109,7 @@ def main():
109
  learning_rate=5e-6, max_steps=MAX_STEPS,
110
  num_generations=4, max_new_tokens=MAX_NEW_TOKENS,
111
  logging_steps=5, save_steps=50,
112
- report_to="none", remove_unused_columns=False,
113
  warmup_steps=10, lr_scheduler_type="cosine", optim="adamw_8bit",
114
  )
115
  trainer = GRPOTrainer(
@@ -117,16 +117,25 @@ def main():
117
  train_dataset=dataset, processing_class=tokenizer)
118
 
119
  print("Starting GRPO training...")
 
 
120
  trainer.train()
121
  print("Training complete!")
122
 
123
  save_path = "./cicd_rl_agent_final"
124
  if USE_UNSLOTH:
125
- model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit")
 
 
 
 
 
 
 
126
  else:
127
  model.save_pretrained(save_path)
128
  tokenizer.save_pretrained(save_path)
129
- print(f"Model saved to {save_path}")
130
 
131
  if __name__ == "__main__":
132
  main()
 
109
  learning_rate=5e-6, max_steps=MAX_STEPS,
110
  num_generations=4, max_new_tokens=MAX_NEW_TOKENS,
111
  logging_steps=5, save_steps=50,
112
+ report_to="wandb", remove_unused_columns=False,
113
  warmup_steps=10, lr_scheduler_type="cosine", optim="adamw_8bit",
114
  )
115
  trainer = GRPOTrainer(
 
117
  train_dataset=dataset, processing_class=tokenizer)
118
 
119
  print("Starting GRPO training...")
120
+ import wandb
121
+ wandb.init(project="cicd-rl-agent", name="grpo-run-1")
122
  trainer.train()
123
  print("Training complete!")
124
 
125
  save_path = "./cicd_rl_agent_final"
126
  if USE_UNSLOTH:
127
+ model.save_pretrained(save_path)
128
+ tokenizer.save_pretrained(save_path)
129
+ print(f"LoRA adapters saved to {save_path}")
130
+ print("Testing post-training inference...")
131
+ FastLanguageModel.for_inference(model)
132
+ test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda")
133
+ out = model.generate(**test_input, max_new_tokens=64)
134
+ print(tokenizer.decode(out[0], skip_special_tokens=True))
135
  else:
136
  model.save_pretrained(save_path)
137
  tokenizer.save_pretrained(save_path)
138
+ print(f"Model saved to {save_path}")
139
 
140
  if __name__ == "__main__":
141
  main()
train_colab.ipynb CHANGED
@@ -1,15 +1,94 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {"colab": {"name": "train_colab.ipynb"}},
5
  "cells": [
 
 
 
 
 
 
 
6
  {
7
  "cell_type": "code",
8
  "execution_count": null,
9
  "metadata": {},
10
  "outputs": [],
11
  "source": [
12
- "!pip install unsloth trl openenv pydantic fastapi uvicorn"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ]
14
  },
15
  {
@@ -18,10 +97,245 @@
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
21
  "from unsloth import FastLanguageModel\n",
22
- "from trl import GRPOTrainer, GRPOConfig\n",
23
- "# Start building dataset and running train script directly\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ]
25
  }
26
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  }
 
1
  {
 
 
 
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## 🔧 Install Dependencies"
8
+ ]
9
+ },
10
  {
11
  "cell_type": "code",
12
  "execution_count": null,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
+ "!pip install unsloth trl transformers datasets torch wandb pydantic"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "## 📦 Clone Environment & Import Tasks"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "import os\n",
33
+ "import random\n",
34
+ "import sys\n",
35
+ "\n",
36
+ "# Colab: clone your fork, or mount Drive and set CICD_RL_REPO to the project path.\n",
37
+ "REPO_DIR = os.environ.get(\"CICD_RL_REPO\", \"/content/cicd-rl-agent\")\n",
38
+ "# !git clone https://github.com/<your-org>/cicd-rl-agent.git {REPO_DIR} # noqa: E501\n",
39
+ "if os.path.isdir(REPO_DIR) and REPO_DIR not in sys.path:\n",
40
+ " sys.path.insert(0, REPO_DIR)\n",
41
+ "\n",
42
+ "from datasets import Dataset\n",
43
+ "from cicd_debug_env.tasks import ALL_TASKS\n",
44
+ "\n",
45
+ "NUM_SAMPLES = 128\n",
46
+ "random.seed(42)\n",
47
+ "\n",
48
+ "SYSTEM_PROMPT = (\n",
49
+ " \"You are an expert DevOps engineer. \"\n",
50
+ " \"You receive a broken CI/CD pipeline YAML and error details. \"\n",
51
+ " \"Output ONLY the corrected YAML — no explanation, no markdown fences.\"\n",
52
+ ")\n",
53
+ "\n",
54
+ "def build_prompt(task: dict) -> str:\n",
55
+ " return (\n",
56
+ " f\"### Error\\n{task.get('error_message', '')}\\n\\n\"\n",
57
+ " f\"### Broken Pipeline\\n{task['pipeline_yaml']}\\n\\n\"\n",
58
+ " f\"### Fixed Pipeline (YAML only):\\n\"\n",
59
+ " )\n",
60
+ "\n",
61
+ "def build_dataset():\n",
62
+ " easy = [t for t in ALL_TASKS if t[\"difficulty\"] == \"easy\"]\n",
63
+ " medium = [t for t in ALL_TASKS if t[\"difficulty\"] == \"medium\"]\n",
64
+ " hard = [t for t in ALL_TASKS if t[\"difficulty\"] == \"hard\"]\n",
65
+ " records = []\n",
66
+ " for _ in range(NUM_SAMPLES):\n",
67
+ " r = random.random()\n",
68
+ " if r < 0.5:\n",
69
+ " task = random.choice(easy)\n",
70
+ " elif r < 0.8:\n",
71
+ " task = random.choice(medium)\n",
72
+ " else:\n",
73
+ " task = random.choice(hard)\n",
74
+ " records.append({\n",
75
+ " \"prompt\": [\n",
76
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
77
+ " {\"role\": \"user\", \"content\": build_prompt(task)},\n",
78
+ " ],\n",
79
+ " \"correct_yaml\": task.get(\"correct_yaml\", \"\"),\n",
80
+ " \"pipeline_yaml\": task[\"pipeline_yaml\"],\n",
81
+ " })\n",
82
+ " return Dataset.from_list(records)\n",
83
+ "\n",
84
+ "print(f\"Loaded {len(ALL_TASKS)} tasks (easy/medium/hard). Sample task ids:\", [t['id'] for t in ALL_TASKS[:3]], \"...\")"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "## 🤖 Load Model with Unsloth"
92
  ]
93
  },
94
  {
 
97
  "metadata": {},
98
  "outputs": [],
99
  "source": [
100
+ "import torch\n",
101
  "from unsloth import FastLanguageModel\n",
102
+ "\n",
103
+ "MODEL_ID = \"unsloth/Qwen2.5-0.5B-Instruct\"\n",
104
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
105
+ " model_name=MODEL_ID,\n",
106
+ " max_seq_length=1024,\n",
107
+ " dtype=None,\n",
108
+ " load_in_4bit=True,\n",
109
+ ")\n",
110
+ "model = FastLanguageModel.get_peft_model(\n",
111
+ " model,\n",
112
+ " r=16,\n",
113
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
114
+ " lora_alpha=16,\n",
115
+ " lora_dropout=0.0,\n",
116
+ " bias=\"none\",\n",
117
+ " use_gradient_checkpointing=\"unsloth\",\n",
118
+ " random_state=42,\n",
119
+ ")\n",
120
+ "if tokenizer.pad_token is None:\n",
121
+ " tokenizer.pad_token = tokenizer.eos_token"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": [
128
+ "## 📝 Build Training Dataset"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "train_dataset = build_dataset()\n",
138
+ "print(f\"Dataset size: {len(train_dataset)} (target split ~50% easy / 30% medium / 20% hard)\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "## 🏆 Define Reward Functions"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):\n",
155
+ " \"\"\"How closely the completion matches the reference `correct_yaml` (full match, partial, unchanged, or wrong).\"\"\"\n",
156
+ " rewards = []\n",
157
+ " for c, correct, broken in zip(completions, correct_yaml, pipeline_yaml):\n",
158
+ " c = c.strip()\n",
159
+ " if c == correct.strip():\n",
160
+ " rewards.append(1.0)\n",
161
+ " elif any(line.strip() in c for line in correct.splitlines() if len(line.strip()) > 8):\n",
162
+ " rewards.append(0.5)\n",
163
+ " elif c == broken.strip():\n",
164
+ " rewards.append(-0.2)\n",
165
+ " else:\n",
166
+ " rewards.append(0.0)\n",
167
+ " return rewards\n",
168
+ "\n",
169
+ "def reward_yaml_structure(completions, prompts, **kwargs):\n",
170
+ " \"\"\"Whether the output looks like valid pipeline YAML (keywords, length bounds).\"\"\"\n",
171
+ " rewards = []\n",
172
+ " for c in completions:\n",
173
+ " t = c.strip()\n",
174
+ " score = (\n",
175
+ " 0.4 * int(any(k in t for k in [\"steps:\", \"jobs:\", \"name:\", \"run:\", \"uses:\"]))\n",
176
+ " + 0.3 * int(len(t) > 10)\n",
177
+ " + 0.3 * int(len(t) < 3000)\n",
178
+ " )\n",
179
+ " rewards.append(score)\n",
180
+ " return rewards\n",
181
+ "\n",
182
+ "def reward_no_hallucination(completions, prompts, **kwargs):\n",
183
+ " \"\"\"Penalizes assistant-style or fenced markdown responses instead of raw YAML.\"\"\"\n",
184
+ " bad = [\n",
185
+ " \"I cannot\", \"I am sorry\", \"As an AI\", \"Here is\", \"```yaml\", \"```\",\n",
186
+ " \"Explanation:\", \"Note:\", \"Sure!\", \"Of course\",\n",
187
+ " ]\n",
188
+ " return [-0.3 if any(p.lower() in c.lower() for p in bad) else 0.3 for c in completions]\n",
189
+ "\n",
190
+ "REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "metadata": {},
196
+ "source": [
197
+ "## 🚀 Configure and Run GRPO Training"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "import wandb\n",
207
+ "from trl import GRPOConfig, GRPOTrainer\n",
208
+ "\n",
209
+ "MAX_NEW_TOKENS = 256\n",
210
+ "args = GRPOConfig(\n",
211
+ " output_dir=\"./cicd_rl_output\",\n",
212
+ " per_device_train_batch_size=2,\n",
213
+ " gradient_accumulation_steps=4,\n",
214
+ " learning_rate=5e-6,\n",
215
+ " max_steps=200,\n",
216
+ " num_generations=4,\n",
217
+ " max_new_tokens=MAX_NEW_TOKENS,\n",
218
+ " logging_steps=5,\n",
219
+ " save_steps=50,\n",
220
+ " report_to=\"wandb\",\n",
221
+ " remove_unused_columns=False,\n",
222
+ " warmup_steps=10,\n",
223
+ " lr_scheduler_type=\"cosine\",\n",
224
+ " optim=\"adamw_8bit\",\n",
225
+ ")\n",
226
+ "trainer = GRPOTrainer(\n",
227
+ " model=model,\n",
228
+ " args=args,\n",
229
+ " reward_funcs=REWARD_FUNCTIONS,\n",
230
+ " train_dataset=train_dataset,\n",
231
+ " processing_class=tokenizer,\n",
232
+ ")\n",
233
+ "wandb.init(project=\"cicd-rl-agent\")\n",
234
+ "trainer.train()"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "metadata": {},
240
+ "source": [
241
+ "## 📊 Plot Reward Curve"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "import matplotlib.pyplot as plt\n",
251
+ "\n",
252
+ "step_vals, reward_vals = [], []\n",
253
+ "for h in trainer.state.log_history:\n",
254
+ " st = h.get(\"step\")\n",
255
+ " for k, v in h.items():\n",
256
+ " if \"reward\" in k.lower() and isinstance(v, (int, float)):\n",
257
+ " if st is not None:\n",
258
+ " step_vals.append(st)\n",
259
+ " reward_vals.append(float(v))\n",
260
+ " break\n",
261
+ "fig, ax = plt.subplots(figsize=(8, 4))\n",
262
+ "if step_vals and reward_vals:\n",
263
+ " ax.plot(step_vals, reward_vals, marker=\"o\", markersize=2)\n",
264
+ "else:\n",
265
+ " ax.text(0.5, 0.5, \"No reward fields in log_history; check TRL/W&B logs.\", ha=\"center\", va=\"center\")\n",
266
+ "ax.set_xlabel(\"Training Step\")\n",
267
+ "ax.set_ylabel(\"Reward\")\n",
268
+ "ax.set_title(\"GRPO training reward (from log_history)\")\n",
269
+ "plt.tight_layout()\n",
270
+ "plt.savefig(\"reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
271
+ "plt.show()"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {},
277
+ "source": [
278
+ "## 🧪 Before/After Inference Demo"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "def generate_yaml(model, tok, task: dict) -> str:\n",
288
+ " FastLanguageModel.for_inference(model)\n",
289
+ " user = build_prompt(task)\n",
290
+ " messages = [\n",
291
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
292
+ " {\"role\": \"user\", \"content\": user},\n",
293
+ " ]\n",
294
+ " text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
295
+ " dev = next(model.parameters()).device\n",
296
+ " inputs = tok(text, return_tensors=\"pt\").to(dev)\n",
297
+ " with torch.inference_mode():\n",
298
+ " out = model.generate(**inputs, max_new_tokens=256)\n",
299
+ " return tok.decode(out[0][inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True).strip()\n",
300
+ "\n",
301
+ "easy_demo = next(t for t in ALL_TASKS if t[\"difficulty\"] == \"easy\")\n",
302
+ "med_demo = next(t for t in ALL_TASKS if t[\"difficulty\"] == \"medium\")\n",
303
+ "\n",
304
+ "base_model, base_tok = FastLanguageModel.from_pretrained(\n",
305
+ " model_name=MODEL_ID,\n",
306
+ " max_seq_length=1024,\n",
307
+ " dtype=None,\n",
308
+ " load_in_4bit=True,\n",
309
+ ")\n",
310
+ "for label, task in [(\"EASY\", easy_demo), (\"MEDIUM\", med_demo)]:\n",
311
+ " print(\"=\" * 60)\n",
312
+ " print(f\"Task [{label}]: {task['id']}\")\n",
313
+ " print(\"\\n--- Broken YAML ---\")\n",
314
+ " print(task[\"pipeline_yaml\"])\n",
315
+ " out_base = generate_yaml(base_model, base_tok, task)\n",
316
+ " out_train = generate_yaml(model, tokenizer, task)\n",
317
+ " ok_base = out_base.strip() == task[\"correct_yaml\"].strip()\n",
318
+ " ok_train = out_train.strip() == task[\"correct_yaml\"].strip()\n",
319
+ " print(\"\\n--- Untrained (base checkpoint) output ---\")\n",
320
+ " print(out_base[:800])\n",
321
+ " print(\"\\n--- Trained model output ---\")\n",
322
+ " print(out_train[:800])\n",
323
+ " print(f\"\\nBase matches correct_yaml: {ok_base}\")\n",
324
+ " print(f\"Trained matches correct_yaml: {ok_train}\")"
325
  ]
326
  }
327
+ ],
328
+ "metadata": {
329
+ "kernelspec": {
330
+ "display_name": "Python 3",
331
+ "language": "python",
332
+ "name": "python3"
333
+ },
334
+ "language_info": {
335
+ "name": "python",
336
+ "version": "3.10.0"
337
+ }
338
+ },
339
+ "nbformat": 4,
340
+ "nbformat_minor": 4
341
  }