Jayce-Ping commited on
Commit
a6c53e4
·
verified ·
1 Parent(s): 2c25848

Add files using upload-large-folder tool

Browse files
frozenlake/data_process.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
  FrozenLake Video Dataset Generator — generate, eval, verify.
3
 
4
- Uses plain BFS solver (not networkx) for fast generation at all grid sizes.
 
5
 
6
  Usage:
7
  python frozenlake_video_gen.py generate --output-dir frozenlake \
@@ -61,9 +62,9 @@ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]
61
  print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh")
62
  return None
63
  if state.completed:
64
- print("✓ Generation already completed")
65
  return state
66
- print(f"✓ Resuming: {sum(state.size_progress.values())} puzzles done")
67
  return state
68
 
69
 
@@ -100,8 +101,6 @@ def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
100
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
101
 
102
 
103
- # ==================== Helpers ====================
104
-
105
  def _normalise_list(val, sizes, name="parameter"):
106
  if isinstance(val, int):
107
  return [val] * len(sizes)
@@ -117,7 +116,7 @@ def generate_dataset(
117
  sizes: List[int] = [8, 16, 32],
118
  num_per_size: list = [100, 500, 1000],
119
  p: float = 0.8,
120
- min_path_ratio: float = 0.3,
121
  img_size: int = 512,
122
  prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.",
123
  train_ratio: float = 0.9,
@@ -129,16 +128,6 @@ def generate_dataset(
129
  use_gym: bool = True,
130
  checkpoint_interval: int = 50,
131
  ):
132
- """
133
- Generate FrozenLake video dataset with checkpoint/resume.
134
-
135
- Layout::
136
-
137
- output_dir/
138
- images/ videos/ tables/
139
- train.jsonl test.jsonl train.csv test.csv
140
- path.json metadata.json
141
- """
142
  params = {
143
  "sizes": sizes, "num_per_size": num_per_size,
144
  "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size,
@@ -171,8 +160,6 @@ def generate_dataset(
171
  seen_fingerprints=[], all_samples=[],
172
  )
173
  print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}")
174
- print(f" frames={'auto' if frames is None else frames}, "
175
- f"n_start={n_start}, m_end={m_end}, fps={fps}")
176
  else:
177
  random.seed(seed)
178
  for _ in range(sum(state.size_progress.values()) * 10):
@@ -182,9 +169,8 @@ def generate_dataset(
182
  all_samples = list(state.all_samples)
183
  progress = {int(k): v for k, v in state.size_progress.items()}
184
  since_ckpt = 0
185
- total_target = sum(num_list)
186
 
187
- with tqdm(total=total_target, initial=sum(progress.values()),
188
  desc="Total", unit="puzzle") as pbar:
189
  for grid_size, target in zip(sizes, num_list):
190
  generated = progress.get(grid_size, 0)
@@ -195,11 +181,13 @@ def generate_dataset(
195
 
196
  with tqdm(total=target, initial=generated,
197
  desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz:
198
- for _ in range((target - generated) * 20):
199
  if generated >= target:
200
  break
201
  try:
202
- desc, path = proc.generate(grid_size, p=p, min_path_len=min_len)
 
 
203
  except RuntimeError:
204
  continue
205
 
@@ -209,22 +197,21 @@ def generate_dataset(
209
  seen.add(fp)
210
 
211
  base = f"size{grid_size}_{generated:0{num_w}d}"
212
- img_name, vid_name, tbl_name = f"{base}.png", f"{base}.mp4", f"{base}.txt"
213
-
214
- proc.render(desc, use_gym=use_gym).save(str(img_dir / img_name))
215
 
 
216
  vid_frames = proc.generate_video_frames(
217
  desc, path, n_start=n_start, m_end=m_end,
218
  frames=frames, use_gym=use_gym,
219
  )
220
- save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps)
221
- proc.save_table(str(tbl_dir / tbl_name), desc)
222
 
223
  udrl = proc.path_to_udrl(path)
224
  all_samples.append({
225
- "prompt": prompt, "image": img_name, "video": vid_name,
226
- "table": tbl_name, "grid_size": grid_size,
227
- "grid_desc": desc, "start": list(proc.find_start(desc)),
 
228
  "path_udrl": udrl, "path_length": len(path),
229
  "frame_count": len(vid_frames),
230
  })
@@ -244,12 +231,8 @@ def generate_dataset(
244
 
245
  tqdm.write(f"Size {grid_size}: {generated} puzzles")
246
 
247
- # --- Final outputs ---
248
  with open(out / "path.json", "w") as f:
249
- json.dump(
250
- dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)),
251
- f, indent=4,
252
- )
253
 
254
  random.seed(seed + 1)
255
  random.shuffle(all_samples)
@@ -264,7 +247,7 @@ def generate_dataset(
264
  _jsonl(all_samples[split:], out / "test.jsonl")
265
 
266
  for name, samps in [("train", all_samples[:split]), ("test", all_samples[split:])]:
267
- with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
268
  w = csv.writer(f)
269
  w.writerow(["input_image", "video", "prompt"])
270
  for s in samps:
@@ -278,23 +261,18 @@ def generate_dataset(
278
 
279
  lengths = [s["path_length"] for s in all_samples]
280
  fcounts = [s["frame_count"] for s in all_samples]
281
- print(f"\n✓ Dataset complete: {out}/")
282
- print(f" Sizes: {sizes}, p={p}, Puzzles: {len(all_samples)}")
283
- print(f" Train: {split}, Test: {len(all_samples) - split}")
284
- print(f" Path lengths: avg={np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")
285
- print(f" Frame counts: avg={np.mean(fcounts):.1f}, min={min(fcounts)}, max={max(fcounts)}")
286
 
287
 
288
  # ==================== Eval ====================
289
 
290
  def eval_videos(
291
- video_dir: str,
292
- table_dir: str,
293
- output_json: Optional[str] = None,
294
- gt_json: Optional[str] = None,
295
  use_gym: bool = True,
296
  ):
297
- """Evaluate result videos: last frame → red path → verify."""
298
  proc = FrozenLakeProcessor()
299
  vid_root, tbl_root = Path(video_dir), Path(table_dir)
300
  if output_json is None:
@@ -305,105 +283,80 @@ def eval_videos(
305
  key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
306
  )
307
  if not videos:
308
- print(f"No .mp4 in {vid_root}")
309
- return
310
-
311
- print(f"Found {len(videos)} videos, table_dir={tbl_root}")
312
 
313
  extracted: Dict[str, str] = {}
314
  missing_tbl = missing_frame = 0
315
 
316
  for vp in tqdm(videos, desc="Extracting"):
317
- stem = vp.stem
318
- desc = proc.load_table(str(tbl_root / f"{stem}.txt"))
319
  if desc is None:
320
- missing_tbl += 1
321
- continue
322
  start = proc.find_start(desc)
323
  if start is None:
324
- missing_tbl += 1
325
- continue
326
  lf = extract_last_frame(str(vp))
327
  if lf is None:
328
- missing_frame += 1
329
- continue
330
- extracted[f"{stem}.png"] = proc.extract_path_from_pixels(
331
- lf, len(desc), len(desc[0]), start, desc
332
- )
333
 
334
  with open(output_json, "w") as f:
335
  json.dump(extracted, f, indent=4)
336
- print(f"Saved: {output_json}")
337
-
338
- # Verify
339
- correct = total_valid = 0
340
- correctly_solved: List[Dict] = []
341
- size_stats: Dict[int, Dict[str, int]] = {}
342
 
343
  verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
 
 
 
344
 
345
  for name, udrl in extracted.items():
346
- desc = proc.load_table(str(tbl_root / f"{name.replace('.png', '')}.txt"))
347
- if desc is None:
348
- continue
349
- total_valid += 1
350
  sz = len(desc)
351
  size_stats.setdefault(sz, {"total": 0, "correct": 0})
352
  size_stats[sz]["total"] += 1
353
  if verify_fn(desc, udrl):
354
  correct += 1
355
  size_stats[sz]["correct"] += 1
356
- correctly_solved.append({"name": name, "length": len(udrl)})
357
-
358
- acc = correct / total_valid * 100 if total_valid else 0
359
- print(f"\n{'='*50}\nEvaluation Summary\n{'='*50}")
360
- print(f"Videos: {len(videos)}, Missing tables: {missing_tbl}, "
361
- f"Failed frames: {missing_frame}")
362
- print(f"Evaluated: {total_valid}, Correct: {correct}, Accuracy: {acc:.2f}%")
363
-
364
- if size_stats:
365
- print("\nBy size:")
366
- for sz in sorted(size_stats):
367
- s = size_stats[sz]
368
- print(f" {sz:3d}: {s['correct']}/{s['total']} "
369
- f"({s['correct']/s['total']*100:.1f}%)")
370
-
371
- correctly_solved.sort(key=lambda x: x["length"], reverse=True)
372
- for i, item in enumerate(correctly_solved[:3]):
373
  print(f" Top {i+1}: {item['name']} (len={item['length']})")
374
 
375
  if gt_json:
376
- _gt_bins(extracted, gt_json, tbl_root, proc, verify_fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  print(f"{'='*50}")
378
 
379
 
380
- def _gt_bins(extracted, gt_path, tbl_root, proc, verify_fn):
381
- try:
382
- with open(gt_path) as f:
383
- gt = json.load(f)
384
- except Exception:
385
- return
386
- bins: Dict[str, Dict[str, int]] = {}
387
- for name, pred in extracted.items():
388
- if name not in gt:
389
- continue
390
- lo = (len(gt[name]) // 10) * 10
391
- label = f"{lo:3d}-{lo+9:3d}"
392
- bins.setdefault(label, {"total": 0, "correct": 0})
393
- bins[label]["total"] += 1
394
- desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
395
- if desc and verify_fn(desc, pred):
396
- bins[label]["correct"] += 1
397
- if bins:
398
- print("\nBy GT path length:")
399
- for label in sorted(bins):
400
- b = bins[label]
401
- print(f" {label}: {b['correct']}/{b['total']} "
402
- f"({b['correct']/b['total']*100:.1f}%)")
403
-
404
-
405
- # ==================== Verify ====================
406
-
407
  def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
408
  proc = FrozenLakeProcessor()
409
  with open(json_file) as f:
@@ -413,16 +366,12 @@ def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
413
  for name, udrl in solutions.items():
414
  desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt"))
415
  if desc is None:
416
- skipped += 1
417
- continue
418
  valid += 1
419
  if verify_fn(desc, udrl):
420
  correct += 1
421
  acc = correct / valid * 100 if valid else 0
422
- print(f"\n{'='*40}\nVerification: {correct}/{valid} ({acc:.2f}%)")
423
- if skipped:
424
- print(f"Skipped: {skipped}")
425
- print(f"{'='*40}")
426
 
427
 
428
  # ==================== CLI ====================
@@ -435,10 +384,8 @@ def parse_args():
435
  gen.add_argument("--output-dir", default="frozenlake")
436
  gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 32])
437
  gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000])
438
- gen.add_argument("--p", type=float, default=0.8)
439
- gen.add_argument("--min-path-ratio", type=float, default=0.1,
440
- help="Min path length as fraction of size² (default 0.1; "
441
- "FrozenLake paths are much shorter than maze paths)")
442
  gen.add_argument("--img-size", type=int, default=1024)
443
  gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.")
444
  gen.add_argument("--train-ratio", type=float, default=0.9)
 
1
  """
2
  FrozenLake Video Dataset Generator — generate, eval, verify.
3
 
4
+ Uses generate_auto() which picks random (small grids) or guided (large grids)
5
+ strategy automatically.
6
 
7
  Usage:
8
  python frozenlake_video_gen.py generate --output-dir frozenlake \
 
62
  print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh")
63
  return None
64
  if state.completed:
65
+ print("✓ Already completed")
66
  return state
67
+ print(f"✓ Resuming: {sum(state.size_progress.values())} done")
68
  return state
69
 
70
 
 
101
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
102
 
103
 
 
 
104
  def _normalise_list(val, sizes, name="parameter"):
105
  if isinstance(val, int):
106
  return [val] * len(sizes)
 
116
  sizes: List[int] = [8, 16, 32],
117
  num_per_size: list = [100, 500, 1000],
118
  p: float = 0.8,
119
+ min_path_ratio: float = 0.1,
120
  img_size: int = 512,
121
  prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.",
122
  train_ratio: float = 0.9,
 
128
  use_gym: bool = True,
129
  checkpoint_interval: int = 50,
130
  ):
 
 
 
 
 
 
 
 
 
 
131
  params = {
132
  "sizes": sizes, "num_per_size": num_per_size,
133
  "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size,
 
160
  seen_fingerprints=[], all_samples=[],
161
  )
162
  print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}")
 
 
163
  else:
164
  random.seed(seed)
165
  for _ in range(sum(state.size_progress.values()) * 10):
 
169
  all_samples = list(state.all_samples)
170
  progress = {int(k): v for k, v in state.size_progress.items()}
171
  since_ckpt = 0
 
172
 
173
+ with tqdm(total=sum(num_list), initial=sum(progress.values()),
174
  desc="Total", unit="puzzle") as pbar:
175
  for grid_size, target in zip(sizes, num_list):
176
  generated = progress.get(grid_size, 0)
 
181
 
182
  with tqdm(total=target, initial=generated,
183
  desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz:
184
+ for _ in range((target - generated) * 5):
185
  if generated >= target:
186
  break
187
  try:
188
+ desc, path = proc.generate_auto(
189
+ grid_size, p=p, min_path_len=min_len
190
+ )
191
  except RuntimeError:
192
  continue
193
 
 
197
  seen.add(fp)
198
 
199
  base = f"size{grid_size}_{generated:0{num_w}d}"
 
 
 
200
 
201
+ proc.render(desc, use_gym=use_gym).save(str(img_dir / f"{base}.png"))
202
  vid_frames = proc.generate_video_frames(
203
  desc, path, n_start=n_start, m_end=m_end,
204
  frames=frames, use_gym=use_gym,
205
  )
206
+ save_video_cv2(vid_frames, str(vid_dir / f"{base}.mp4"), fps=fps)
207
+ proc.save_table(str(tbl_dir / f"{base}.txt"), desc)
208
 
209
  udrl = proc.path_to_udrl(path)
210
  all_samples.append({
211
+ "prompt": prompt, "image": f"{base}.png",
212
+ "video": f"{base}.mp4", "table": f"{base}.txt",
213
+ "grid_size": grid_size, "grid_desc": desc,
214
+ "start": list(proc.find_start(desc)),
215
  "path_udrl": udrl, "path_length": len(path),
216
  "frame_count": len(vid_frames),
217
  })
 
231
 
232
  tqdm.write(f"Size {grid_size}: {generated} puzzles")
233
 
 
234
  with open(out / "path.json", "w") as f:
235
+ json.dump(dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)), f, indent=4)
 
 
 
236
 
237
  random.seed(seed + 1)
238
  random.shuffle(all_samples)
 
247
  _jsonl(all_samples[split:], out / "test.jsonl")
248
 
249
  for name, samps in [("train", all_samples[:split]), ("test", all_samples[split:])]:
250
+ with open(out / f"{name}.csv", "w", newline="") as f:
251
  w = csv.writer(f)
252
  w.writerow(["input_image", "video", "prompt"])
253
  for s in samps:
 
261
 
262
  lengths = [s["path_length"] for s in all_samples]
263
  fcounts = [s["frame_count"] for s in all_samples]
264
+ print(f"\n✓ Complete: {out}/ | {len(all_samples)} puzzles "
265
+ f"(train={split}, test={len(all_samples)-split})")
266
+ print(f" Paths: avg={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}")
 
 
267
 
268
 
269
  # ==================== Eval ====================
270
 
271
  def eval_videos(
272
+ video_dir: str, table_dir: str,
273
+ output_json: Optional[str] = None, gt_json: Optional[str] = None,
 
 
274
  use_gym: bool = True,
275
  ):
 
276
  proc = FrozenLakeProcessor()
277
  vid_root, tbl_root = Path(video_dir), Path(table_dir)
278
  if output_json is None:
 
283
  key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
284
  )
285
  if not videos:
286
+ print(f"No .mp4 in {vid_root}"); return
 
 
 
287
 
288
  extracted: Dict[str, str] = {}
289
  missing_tbl = missing_frame = 0
290
 
291
  for vp in tqdm(videos, desc="Extracting"):
292
+ desc = proc.load_table(str(tbl_root / f"{vp.stem}.txt"))
 
293
  if desc is None:
294
+ missing_tbl += 1; continue
 
295
  start = proc.find_start(desc)
296
  if start is None:
297
+ missing_tbl += 1; continue
 
298
  lf = extract_last_frame(str(vp))
299
  if lf is None:
300
+ missing_frame += 1; continue
301
+ extracted[f"{vp.stem}.png"] = proc.extract_path_from_pixels(
302
+ lf, len(desc), len(desc[0]), start, desc)
 
 
303
 
304
  with open(output_json, "w") as f:
305
  json.dump(extracted, f, indent=4)
 
 
 
 
 
 
306
 
307
  verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
308
+ correct = total = 0
309
+ size_stats: Dict[int, Dict[str, int]] = {}
310
+ top: List[Dict] = []
311
 
312
  for name, udrl in extracted.items():
313
+ desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
314
+ if desc is None: continue
315
+ total += 1
 
316
  sz = len(desc)
317
  size_stats.setdefault(sz, {"total": 0, "correct": 0})
318
  size_stats[sz]["total"] += 1
319
  if verify_fn(desc, udrl):
320
  correct += 1
321
  size_stats[sz]["correct"] += 1
322
+ top.append({"name": name, "length": len(udrl)})
323
+
324
+ acc = correct / total * 100 if total else 0
325
+ print(f"\n{'='*50}\nEval: {correct}/{total} ({acc:.2f}%) | "
326
+ f"missing_tbl={missing_tbl} bad_frame={missing_frame}")
327
+ for sz in sorted(size_stats):
328
+ s = size_stats[sz]
329
+ print(f" Size {sz:3d}: {s['correct']}/{s['total']} "
330
+ f"({s['correct']/s['total']*100:.1f}%)")
331
+ top.sort(key=lambda x: x["length"], reverse=True)
332
+ for i, item in enumerate(top[:3]):
 
 
 
 
 
 
333
  print(f" Top {i+1}: {item['name']} (len={item['length']})")
334
 
335
  if gt_json:
336
+ try:
337
+ with open(gt_json) as f:
338
+ gt = json.load(f)
339
+ bins: Dict[str, Dict[str, int]] = {}
340
+ for name, pred in extracted.items():
341
+ if name not in gt: continue
342
+ lo = (len(gt[name]) // 10) * 10
343
+ label = f"{lo:3d}-{lo+9:3d}"
344
+ bins.setdefault(label, {"total": 0, "correct": 0})
345
+ bins[label]["total"] += 1
346
+ desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
347
+ if desc and verify_fn(desc, pred):
348
+ bins[label]["correct"] += 1
349
+ if bins:
350
+ print("\nBy GT path length:")
351
+ for label in sorted(bins):
352
+ b = bins[label]
353
+ print(f" {label}: {b['correct']}/{b['total']} "
354
+ f"({b['correct']/b['total']*100:.1f}%)")
355
+ except Exception:
356
+ pass
357
  print(f"{'='*50}")
358
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
361
  proc = FrozenLakeProcessor()
362
  with open(json_file) as f:
 
366
  for name, udrl in solutions.items():
367
  desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt"))
368
  if desc is None:
369
+ skipped += 1; continue
 
370
  valid += 1
371
  if verify_fn(desc, udrl):
372
  correct += 1
373
  acc = correct / valid * 100 if valid else 0
374
+ print(f"\nVerification: {correct}/{valid} ({acc:.2f}%)")
 
 
 
375
 
376
 
377
  # ==================== CLI ====================
 
384
  gen.add_argument("--output-dir", default="frozenlake")
385
  gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 32])
386
  gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000])
387
+ gen.add_argument("--p", type=float, default=0.5)
388
+ gen.add_argument("--min-path-ratio", type=float, default=0.1)
 
 
389
  gen.add_argument("--img-size", type=int, default=1024)
390
  gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.")
391
  gen.add_argument("--train-ratio", type=float, default=0.9)
frozenlake/frozenlake_processor.py CHANGED
@@ -4,15 +4,20 @@ FrozenLakeProcessor - FrozenLake puzzle generation, solving, rendering, and eval
4
  Grid cells: S=Start, F=Frozen(safe), H=Hole(death), G=Goal
5
  Table chars: @=Start, _=Frozen, #=Hole, *=Goal
6
 
7
- Performance notes vs original DiffThinker code:
8
- - Solving uses plain BFS (O(n²)) instead of networkx graph construction
9
- which had massive overhead from add_node/add_edge Python calls.
10
- - Gym renderer is cached per puzzle to avoid repeated pygame init.
 
 
 
 
11
  """
12
  import os
13
  import random
14
  import warnings
15
  from collections import deque
 
16
  from typing import List, Tuple, Optional
17
 
18
  import numpy as np
@@ -23,15 +28,12 @@ try:
23
  warnings.filterwarnings("ignore", category=UserWarning, module="pygame")
24
  warnings.filterwarnings("ignore", category=DeprecationWarning)
25
  import gymnasium as gym
26
-
27
  HAS_GYM = True
28
  except ImportError:
29
  HAS_GYM = False
30
 
31
- # Table ↔ Grid mapping
32
  TABLE_TO_GRID = {"@": "S", "_": "F", "#": "H", "*": "G"}
33
  GRID_TO_TABLE = {v: k for k, v in TABLE_TO_GRID.items()}
34
-
35
  MOVES = {"U": (-1, 0), "D": (1, 0), "L": (0, -1), "R": (0, 1)}
36
  GYM_ACTION_MAP = {"L": 0, "D": 1, "R": 2, "U": 3}
37
 
@@ -45,20 +47,16 @@ class FrozenLakeProcessor:
45
  self.img_size = img_size
46
  self.path_color = "red"
47
 
48
- # ==================== Generation ====================
49
 
50
  def generate(
51
- self,
52
- size: int,
53
- p: float = 0.8,
54
- min_path_len: int = 1,
55
- max_attempts: int = 500,
56
  ) -> Tuple[GridDesc, List[Tuple[int, int]]]:
57
  """
58
- Generate a solvable FrozenLake grid with shortest path >= *min_path_len* moves.
59
 
60
- Returns:
61
- (desc, path) — desc is list[str], path is list[(r,c)].
62
  """
63
  for _ in range(max_attempts):
64
  desc = self._random_layout(size, p)
@@ -72,7 +70,6 @@ class FrozenLakeProcessor:
72
 
73
  @staticmethod
74
  def _random_layout(size: int, p: float = 0.8) -> GridDesc:
75
- """Random grid with one S and one G at random positions."""
76
  all_coords = [(r, c) for r in range(size) for c in range(size)]
77
  start, goal = random.sample(all_coords, 2)
78
  grid = []
@@ -88,19 +85,235 @@ class FrozenLakeProcessor:
88
  grid.append("".join(row))
89
  return grid
90
 
91
- # ==================== Solving (plain BFS — fast) ====================
92
 
93
- @staticmethod
94
- def solve(desc: GridDesc) -> Optional[List[Tuple[int, int]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  """
96
- BFS shortest path from S to G, avoiding H.
97
 
98
- ~100× faster than networkx for typical grid sizes because it avoids
99
- Python-level graph object construction entirely.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- Returns:
102
- List of (r, c) or None.
 
103
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  rows, cols = len(desc), len(desc[0])
105
  start = goal = None
106
  for r in range(rows):
@@ -111,11 +324,9 @@ class FrozenLakeProcessor:
111
  goal = (r, c)
112
  if start is None or goal is None:
113
  return None
114
-
115
  visited = [[False] * cols for _ in range(rows)]
116
  visited[start[0]][start[1]] = True
117
  queue: deque = deque([(start, [start])])
118
-
119
  while queue:
120
  (r, c), path = queue.popleft()
121
  if (r, c) == goal:
@@ -123,8 +334,7 @@ class FrozenLakeProcessor:
123
  for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
124
  nr, nc = r + dr, c + dc
125
  if 0 <= nr < rows and 0 <= nc < cols and not visited[nr][nc]:
126
- ch = desc[nr][nc]
127
- if ch != "H":
128
  visited[nr][nc] = True
129
  queue.append(((nr, nc), path + [(nr, nc)]))
130
  return None
@@ -133,35 +343,27 @@ class FrozenLakeProcessor:
133
 
134
  @staticmethod
135
  def path_to_udrl(path: List[Tuple[int, int]]) -> str:
136
- """Convert coordinate path to UDRL string."""
137
  moves = []
138
  for i in range(len(path) - 1):
139
  r1, c1 = path[i]
140
  r2, c2 = path[i + 1]
141
- if r2 < r1:
142
- moves.append("U")
143
- elif r2 > r1:
144
- moves.append("D")
145
- elif c2 < c1:
146
- moves.append("L")
147
- else:
148
- moves.append("R")
149
  return "".join(moves)
150
 
151
  # ==================== Verification ====================
152
 
153
  def verify_path_sim(self, desc: GridDesc, udrl: str) -> bool:
154
- """Verify UDRL via grid simulation (no dependencies)."""
155
  rows, cols = len(desc), len(desc[0])
156
  start = self.find_start(desc)
157
  if start is None:
158
  return False
159
-
160
  r, c = start
161
  clean = udrl.replace(",", "").replace(" ", "").strip()
162
  if "Action plan" in clean:
163
  clean = clean.rsplit("Action plan", 1)[-1]
164
-
165
  for ch in clean:
166
  if ch not in MOVES:
167
  continue
@@ -169,16 +371,14 @@ class FrozenLakeProcessor:
169
  nr, nc = r + dr, c + dc
170
  if not (0 <= nr < rows and 0 <= nc < cols):
171
  return False
172
- cell = desc[nr][nc]
173
- if cell == "H":
174
  return False
175
  r, c = nr, nc
176
- if cell == "G":
177
  return True
178
  return desc[r][c] == "G"
179
 
180
  def verify_path_gym(self, desc: GridDesc, udrl: str) -> bool:
181
- """Verify via gymnasium (falls back to sim if unavailable)."""
182
  if not HAS_GYM:
183
  return self.verify_path_sim(desc, udrl)
184
  rows, cols = len(desc), len(desc[0])
@@ -204,10 +404,9 @@ class FrozenLakeProcessor:
204
  except Exception:
205
  return self.verify_path_sim(desc, udrl)
206
 
207
- # ==================== Table Text I/O ====================
208
 
209
  def encode_table(self, desc: GridDesc) -> str:
210
- """Encode to pipe-delimited table format."""
211
  size = len(desc)
212
  lines = ["| | " + " | ".join(f"Col {i+1}" for i in range(size)) + " |"]
213
  for r in range(size):
@@ -216,7 +415,6 @@ class FrozenLakeProcessor:
216
  return "\n".join(lines)
217
 
218
  def decode_table(self, text: str) -> Optional[GridDesc]:
219
- """Parse table text back to GridDesc."""
220
  try:
221
  rows = []
222
  for line in text.strip().splitlines():
@@ -260,7 +458,6 @@ class FrozenLakeProcessor:
260
  # ==================== Rendering ====================
261
 
262
  def render_gym(self, desc: GridDesc) -> Optional[Image.Image]:
263
- """Render via gymnasium (creates a pygame window — slow)."""
264
  if not HAS_GYM:
265
  return None
266
  try:
@@ -278,9 +475,9 @@ class FrozenLakeProcessor:
278
  return None
279
 
280
  def render_simple(self, desc: GridDesc) -> Image.Image:
281
- """Fast PIL-only renderer (no pygame dependency)."""
282
  size = len(desc)
283
- cell = self.img_size // size
284
  img = Image.new("RGB", (self.img_size, self.img_size), (255, 255, 255))
285
  draw = ImageDraw.Draw(img)
286
  colors = {
@@ -289,14 +486,18 @@ class FrozenLakeProcessor:
289
  }
290
  for r in range(size):
291
  for c in range(size):
292
- x0, y0 = c * cell, r * cell
 
 
 
293
  draw.rectangle(
294
- [x0, y0, x0 + cell - 1, y0 + cell - 1],
295
  fill=colors.get(desc[r][c], (200, 220, 255)),
296
  )
297
  for i in range(size + 1):
298
- draw.line([(i * cell, 0), (i * cell, self.img_size)], fill="black", width=1)
299
- draw.line([(0, i * cell), (self.img_size, i * cell)], fill="black", width=1)
 
300
  return img
301
 
302
  def render(self, desc: GridDesc, use_gym: bool = True) -> Image.Image:
@@ -309,7 +510,6 @@ class FrozenLakeProcessor:
309
  def draw_solution_line(
310
  self, image: Image.Image, path: List[Tuple[int, int]], grid_size: int,
311
  ) -> Image.Image:
312
- """Draw red line on *image* (modifies in-place)."""
313
  draw = ImageDraw.Draw(image)
314
  w, h = image.size
315
  cw, ch_ = w / grid_size, h / grid_size
@@ -320,36 +520,21 @@ class FrozenLakeProcessor:
320
  # ==================== Video Frames ====================
321
 
322
  def generate_video_frames(
323
- self,
324
- desc: GridDesc,
325
- path: List[Tuple[int, int]],
326
- n_start: int = 5,
327
- m_end: int = 5,
328
- frames: Optional[int] = None,
329
- use_gym: bool = True,
330
  ) -> List[Image.Image]:
331
- """
332
- Progressive red-line video frames.
333
-
334
- *frames* controls content frames between holds:
335
- None → 1 per step, >steps → slow-mo, <steps → fast-fwd.
336
- """
337
  size = len(desc)
338
  n_steps = len(path) - 1
339
  base_img = self.render(desc, use_gym=use_gym)
340
-
341
  if n_steps <= 0:
342
  return [base_img] * (n_start + m_end + 1)
343
-
344
  content = frames if frames is not None else n_steps
345
  content = max(1, content)
346
- result: List[Image.Image] = []
347
-
348
- # Opening hold
349
- result.extend([base_img.copy() for _ in range(n_start)])
350
 
351
- def _partial(steps: int) -> Image.Image:
352
- return self.draw_solution_line(base_img.copy(), path[: steps + 1], size)
353
 
354
  if content == n_steps:
355
  for s in range(1, n_steps + 1):
@@ -366,45 +551,41 @@ class FrozenLakeProcessor:
366
  for f in range(content):
367
  result.append(_partial((f + 1) * n_steps // content))
368
 
369
- # Closing hold
370
- final = _partial(n_steps)
371
- result.extend([final.copy() for _ in range(m_end)])
372
  return result
373
 
374
  # ==================== Red-Path Extraction ====================
375
 
376
  def extract_path_from_pixels(
377
- self,
378
- pixels: np.ndarray,
379
- rows: int,
380
- cols: int,
381
- start: Tuple[int, int],
382
- desc: Optional[GridDesc] = None,
383
  pixel_threshold: float = 0.01,
384
  ) -> str:
385
- """Detect red path in RGB array, return UDRL."""
386
  img = Image.fromarray(pixels)
387
  w, h = img.size
388
  px = np.array(img, dtype=float)
389
  r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
390
  red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
391
 
392
- cell_h, cell_w = h // rows, w // cols
393
  path_grid = np.zeros((rows, cols), dtype=bool)
394
  for r in range(rows):
 
 
395
  for c in range(cols):
396
- sub = red_mask[r * cell_h : (r + 1) * cell_h,
397
- c * cell_w : (c + 1) * cell_w]
 
398
  if sub.size > 0 and np.mean(sub) > pixel_threshold:
399
  path_grid[r, c] = True
400
 
401
- # Greedy walk
402
  visited = {start}
403
  cr, cc = start
404
  actions: List[str] = []
405
  for _ in range(rows * cols * 2):
406
  found = False
407
- for act, (dr, dc) in [("R", (0, 1)), ("D", (1, 0)), ("L", (0, -1)), ("U", (-1, 0))]:
408
  nr, nc = cr + dr, cc + dc
409
  if 0 <= nr < rows and 0 <= nc < cols:
410
  if path_grid[nr, nc] and (nr, nc) not in visited:
@@ -417,10 +598,7 @@ class FrozenLakeProcessor:
417
  break
418
  return "".join(actions)
419
 
420
- def extract_path_from_image(
421
- self, img_path: str, rows: int, cols: int, start: Tuple, desc=None,
422
- ) -> str:
423
- """Extract UDRL from an image file."""
424
  try:
425
  pixels = np.array(Image.open(img_path).convert("RGB"))
426
  return self.extract_path_from_pixels(pixels, rows, cols, start, desc)
@@ -433,34 +611,64 @@ if __name__ == "__main__":
433
 
434
  proc = FrozenLakeProcessor(img_size=512)
435
 
436
- # Benchmark BFS vs problem sizes
437
- for sz in [8, 16, 32, 64]:
 
 
 
438
  t0 = time.perf_counter()
439
- count = 0
440
- for _ in range(100):
441
- desc = proc._random_layout(sz, p=0.8)
442
  path = proc.solve(desc)
443
- if path:
444
- count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  elapsed = time.perf_counter() - t0
446
- print(f"Size {sz:3d}: 100 BFS solves in {elapsed:.3f}s "
447
- f"({count} solvable, {elapsed/100*1000:.1f}ms/solve)")
448
-
449
- # Functional test
450
- desc, path = proc.generate(size=16, p=0.8, min_path_len=20)
451
- udrl = proc.path_to_udrl(path)
452
- print(f"\nGenerate 16×16: path={len(path)}, UDRL={udrl[:40]}...")
453
- print(f"Verify (sim): {proc.verify_path_sim(desc, udrl)}")
454
-
455
- # Table round-trip
456
- decoded = proc.decode_table(proc.encode_table(desc))
457
- assert decoded == desc
458
- print("Table round-trip: ✓")
459
-
460
- # Render + extract round-trip
461
- img = proc.render(desc, use_gym=False)
462
- sol = proc.draw_solution_line(img.copy(), path, len(desc))
463
- start = proc.find_start(desc)
464
- extracted = proc.extract_path_from_pixels(np.array(sol), len(desc), len(desc[0]), start)
465
- print(f"Extract round-trip verify: {proc.verify_path_sim(desc, extracted)}")
466
- print("All tests passed ")
 
 
 
 
4
  Grid cells: S=Start, F=Frozen(safe), H=Hole(death), G=Goal
5
  Table chars: @=Start, _=Frozen, #=Hole, *=Goal
6
 
7
+ Generation strategy:
8
+ - ``generate()``: Pure random + BFS retry. Fine for small grids (≤16).
9
+ - ``generate_guided()``: Lay a random walk path first, then fill remaining
10
+ cells. Guarantees long paths even at 32×32+ without exponential retries.
11
+ - ``generate_auto()``: Auto-select best strategy based on difficulty.
12
+ - ``generate_batch()``: Multiprocessing wrapper for high-throughput.
13
+
14
+ Solving uses plain BFS (~10× faster than networkx).
15
  """
16
  import os
17
  import random
18
  import warnings
19
  from collections import deque
20
+ from concurrent.futures import ProcessPoolExecutor, as_completed
21
  from typing import List, Tuple, Optional
22
 
23
  import numpy as np
 
28
  warnings.filterwarnings("ignore", category=UserWarning, module="pygame")
29
  warnings.filterwarnings("ignore", category=DeprecationWarning)
30
  import gymnasium as gym
 
31
  HAS_GYM = True
32
  except ImportError:
33
  HAS_GYM = False
34
 
 
35
  TABLE_TO_GRID = {"@": "S", "_": "F", "#": "H", "*": "G"}
36
  GRID_TO_TABLE = {v: k for k, v in TABLE_TO_GRID.items()}
 
37
  MOVES = {"U": (-1, 0), "D": (1, 0), "L": (0, -1), "R": (0, 1)}
38
  GYM_ACTION_MAP = {"L": 0, "D": 1, "R": 2, "U": 3}
39
 
 
47
  self.img_size = img_size
48
  self.path_color = "red"
49
 
50
+ # ==================== Generation: Pure Random ====================
51
 
52
  def generate(
53
+ self, size: int, p: float = 0.8,
54
+ min_path_len: int = 1, max_attempts: int = 500,
 
 
 
55
  ) -> Tuple[GridDesc, List[Tuple[int, int]]]:
56
  """
57
+ Random layout + BFS retry. Good for small grids or low min_path_len.
58
 
59
+ For large grids with long path requirements, use ``generate_guided()``.
 
60
  """
61
  for _ in range(max_attempts):
62
  desc = self._random_layout(size, p)
 
70
 
71
  @staticmethod
72
  def _random_layout(size: int, p: float = 0.8) -> GridDesc:
 
73
  all_coords = [(r, c) for r in range(size) for c in range(size)]
74
  start, goal = random.sample(all_coords, 2)
75
  grid = []
 
85
  grid.append("".join(row))
86
  return grid
87
 
88
+ # ==================== Generation: Guided (path-first) ====================
89
 
90
+ def simplify_path(self, path: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
91
+ """
92
+ Reduce the path
93
+ """
94
+ if not path:
95
+ return path
96
+
97
+ simplified = [path[0]]
98
+ curr_idx = 0
99
+
100
+ while curr_idx < len(path) - 1:
101
+ found_shortcut = False
102
+ for next_idx in range(len(path) - 1, curr_idx + 1, -1):
103
+ r1, c1 = path[curr_idx]
104
+ r2, c2 = path[next_idx]
105
+
106
+ if abs(r1 - r2) + abs(c1 - c2) == 1:
107
+ simplified.append(path[next_idx])
108
+ curr_idx = next_idx
109
+ found_shortcut = True
110
+ break
111
+
112
+ if not found_shortcut:
113
+ curr_idx += 1
114
+ simplified.append(path[curr_idx])
115
+
116
+ return simplified
117
+
118
+ def generate_guided(
119
+ self, size: int, p: float = 0.8,
120
+ min_path_len: int = 1, max_attempts: int = 100,
121
+ ) -> Tuple[GridDesc, List[Tuple[int, int]]]:
122
  """
123
+ Path-first generation using DFS spanning tree diameter.
124
 
125
+ The walk is a valid S→G path by construction (all walk cells are
126
+ Frozen, all others are Holes). We return the walk directly as
127
+ the solution path — it may not be the BFS-shortest, but it IS a
128
+ valid path of guaranteed minimum length.
129
+ """
130
+ for _ in range(max_attempts):
131
+ desc, walk = self._guided_layout(size, p, min_path_len)
132
+ if desc is None:
133
+ continue
134
+ optimized_walk = self.simplify_path(walk)
135
+ if len(optimized_walk) - 1 >= min_path_len:
136
+ return desc, optimized_walk
137
+ raise RuntimeError(
138
+ f"Guided generation failed after {max_attempts} attempts "
139
+ f"(size={size}, p={p}, min_path_len={min_path_len})."
140
+ )
141
 
142
+ def _guided_layout(
143
+ self, size: int, p: float, min_path_len: int,
144
+ ) -> Tuple[Optional[GridDesc], Optional[List[Tuple[int, int]]]]:
145
  """
146
+ Build grid with a guaranteed long path using a DFS spanning tree.
147
+
148
+ Strategy:
149
+ 1. Build random spanning tree of the grid via DFS.
150
+ 2. Find tree diameter (longest path) via double-BFS — guaranteed
151
+ unique path, no shortcuts possible.
152
+ 3. Trim to desired length if much longer than needed.
153
+ 4. Cells adjacent to ≥2 walk cells but OFF the walk become Holes
154
+ (deterministically blocks all shortcuts).
155
+ 5. Remaining off-path cells are cosmetically filled with p.
156
+
157
+ Because tree paths are unique, the BFS shortest path in the resulting
158
+ grid equals the walk length (no shortcuts exist).
159
+ """
160
+ dirs = [(0, 1), (0, -1), (1, 0), (-1, 0)]
161
+
162
+ # Step 1: Random spanning tree via DFS
163
+ adj: dict = {(r, c): [] for r in range(size) for c in range(size)}
164
+ vis = [[False] * size for _ in range(size)]
165
+ sr, sc = random.randrange(size), random.randrange(size)
166
+ vis[sr][sc] = True
167
+ stack = [(sr, sc)]
168
+
169
+ while stack:
170
+ r, c = stack[-1]
171
+ nbrs = []
172
+ for dr, dc in dirs:
173
+ nr, nc = r + dr, c + dc
174
+ if 0 <= nr < size and 0 <= nc < size and not vis[nr][nc]:
175
+ nbrs.append((nr, nc))
176
+ if nbrs:
177
+ nr, nc = random.choice(nbrs)
178
+ vis[nr][nc] = True
179
+ adj[(r, c)].append((nr, nc))
180
+ adj[(nr, nc)].append((r, c))
181
+ stack.append((nr, nc))
182
+ else:
183
+ stack.pop()
184
+
185
+ # Step 2: Tree diameter via double-BFS
186
+ def _bfs_far(start):
187
+ dist = {start: 0}
188
+ q = deque([start])
189
+ far = start
190
+ while q:
191
+ node = q.popleft()
192
+ for nb in adj[node]:
193
+ if nb not in dist:
194
+ dist[nb] = dist[node] + 1
195
+ q.append(nb)
196
+ if dist[nb] > dist[far]:
197
+ far = nb
198
+ return far, dist
199
+
200
+ end1, _ = _bfs_far((sr, sc))
201
+ end2, dist1 = _bfs_far(end1)
202
+
203
+ if dist1[end2] < min_path_len:
204
+ return None, None
205
+
206
+ # Step 3: Reconstruct path end1 → end2
207
+ prev = {end1: None}
208
+ q = deque([end1])
209
+ while q:
210
+ node = q.popleft()
211
+ if node == end2:
212
+ break
213
+ for nb in adj[node]:
214
+ if nb not in prev:
215
+ prev[nb] = node
216
+ q.append(nb)
217
+
218
+ walk = []
219
+ cur = end2
220
+ while cur is not None:
221
+ walk.append(cur)
222
+ cur = prev[cur]
223
+ walk.reverse()
224
+
225
+ # Optionally trim if much longer
226
+ if len(walk) - 1 > min_path_len * 2:
227
+ excess = len(walk) - 1 - min_path_len
228
+ trim = random.randint(0, excess // 2)
229
+ if trim > 0:
230
+ walk = walk[trim:]
231
+ excess2 = len(walk) - 1 - min_path_len
232
+ trim2 = random.randint(0, excess2 // 2)
233
+ if trim2 > 0:
234
+ walk = walk[: len(walk) - trim2]
235
+
236
+ start_pos, end_pos = walk[0], walk[-1]
237
+ walk_set = set(walk)
238
+
239
+ # Step 4: Compute adjacency to walk for off-path cells
240
+ walk_nbr_ct: dict = {}
241
+ for wr, wc in walk:
242
+ for dr, dc in dirs:
243
+ nr, nc = wr + dr, wc + dc
244
+ if 0 <= nr < size and 0 <= nc < size and (nr, nc) not in walk_set:
245
+ walk_nbr_ct[(nr, nc)] = walk_nbr_ct.get((nr, nc), 0) + 1
246
+
247
+ # Step 5: Fill grid.
248
+ # ALL non-walk cells are Holes. This guarantees the BFS shortest
249
+ # path equals the walk itself (zero shortcut surface).
250
+ # The grid will look like a winding corridor through a sea of holes.
251
+ grid = [[""] * size for _ in range(size)]
252
+ for r in range(size):
253
+ for c in range(size):
254
+ if (r, c) == start_pos:
255
+ grid[r][c] = "S"
256
+ elif (r, c) == end_pos:
257
+ grid[r][c] = "G"
258
+ elif (r, c) in walk_set:
259
+ grid[r][c] = "F"
260
+ else:
261
+ # prob `p` as hole
262
+ grid[r][c] = "F" if random.random() < p else "H"
263
+
264
+ return ["".join(row) for row in grid], walk
265
+
266
+ # ==================== Generation: Auto ====================
267
+
268
+ def generate_auto(
269
+ self, size: int, p: float = 0.8,
270
+ min_path_len: int = 1, max_attempts: int = 200,
271
+ ) -> Tuple[GridDesc, List[Tuple[int, int]]]:
272
+ """Auto-select: random for easy cases, guided for hard ones."""
273
+ expected_max = size * 1.5
274
+ if min_path_len > expected_max * 0.5:
275
+ return self.generate_guided(size, p, min_path_len, max_attempts)
276
+ try:
277
+ return self.generate(size, p, min_path_len, max_attempts)
278
+ except RuntimeError:
279
+ return self.generate_guided(size, p, min_path_len, max_attempts)
280
+
281
+ # ==================== Batch (multiprocessing) ====================
282
+
283
+ @staticmethod
284
+ def _generate_one(args: tuple) -> Optional[Tuple[GridDesc, list]]:
285
+ """Worker for multiprocessing."""
286
+ size, p, min_path_len, seed = args
287
+ random.seed(seed)
288
+ proc = FrozenLakeProcessor()
289
+ try:
290
+ return proc.generate_auto(size, p, min_path_len, max_attempts=200)
291
+ except RuntimeError:
292
+ return None
293
+
294
+ def generate_batch(
295
+ self, size: int, count: int, p: float = 0.8,
296
+ min_path_len: int = 1, workers: int = 8, base_seed: int = 42,
297
+ ) -> List[Tuple[GridDesc, List[Tuple[int, int]]]]:
298
+ """Generate *count* puzzles in parallel."""
299
+ tasks = [(size, p, min_path_len, base_seed + i) for i in range(count * 2)]
300
+ results = []
301
+ with ProcessPoolExecutor(max_workers=workers) as executor:
302
+ futures = {executor.submit(self._generate_one, t): t for t in tasks}
303
+ for future in as_completed(futures):
304
+ res = future.result()
305
+ if res is not None:
306
+ results.append(res)
307
+ if len(results) >= count:
308
+ executor.shutdown(wait=False, cancel_futures=True)
309
+ break
310
+ return results[:count]
311
+
312
+ # ==================== Solving (plain BFS) ====================
313
+
314
+ @staticmethod
315
+ def solve(desc: GridDesc) -> Optional[List[Tuple[int, int]]]:
316
+ """BFS shortest path from S to G, avoiding H."""
317
  rows, cols = len(desc), len(desc[0])
318
  start = goal = None
319
  for r in range(rows):
 
324
  goal = (r, c)
325
  if start is None or goal is None:
326
  return None
 
327
  visited = [[False] * cols for _ in range(rows)]
328
  visited[start[0]][start[1]] = True
329
  queue: deque = deque([(start, [start])])
 
330
  while queue:
331
  (r, c), path = queue.popleft()
332
  if (r, c) == goal:
 
334
  for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
335
  nr, nc = r + dr, c + dc
336
  if 0 <= nr < rows and 0 <= nc < cols and not visited[nr][nc]:
337
+ if desc[nr][nc] != "H":
 
338
  visited[nr][nc] = True
339
  queue.append(((nr, nc), path + [(nr, nc)]))
340
  return None
 
343
 
344
  @staticmethod
345
  def path_to_udrl(path: List[Tuple[int, int]]) -> str:
 
346
  moves = []
347
  for i in range(len(path) - 1):
348
  r1, c1 = path[i]
349
  r2, c2 = path[i + 1]
350
+ if r2 < r1: moves.append("U")
351
+ elif r2 > r1: moves.append("D")
352
+ elif c2 < c1: moves.append("L")
353
+ else: moves.append("R")
 
 
 
 
354
  return "".join(moves)
355
 
356
  # ==================== Verification ====================
357
 
358
  def verify_path_sim(self, desc: GridDesc, udrl: str) -> bool:
 
359
  rows, cols = len(desc), len(desc[0])
360
  start = self.find_start(desc)
361
  if start is None:
362
  return False
 
363
  r, c = start
364
  clean = udrl.replace(",", "").replace(" ", "").strip()
365
  if "Action plan" in clean:
366
  clean = clean.rsplit("Action plan", 1)[-1]
 
367
  for ch in clean:
368
  if ch not in MOVES:
369
  continue
 
371
  nr, nc = r + dr, c + dc
372
  if not (0 <= nr < rows and 0 <= nc < cols):
373
  return False
374
+ if desc[nr][nc] == "H":
 
375
  return False
376
  r, c = nr, nc
377
+ if desc[nr][nc] == "G":
378
  return True
379
  return desc[r][c] == "G"
380
 
381
  def verify_path_gym(self, desc: GridDesc, udrl: str) -> bool:
 
382
  if not HAS_GYM:
383
  return self.verify_path_sim(desc, udrl)
384
  rows, cols = len(desc), len(desc[0])
 
404
  except Exception:
405
  return self.verify_path_sim(desc, udrl)
406
 
407
+ # ==================== Table I/O ====================
408
 
409
  def encode_table(self, desc: GridDesc) -> str:
 
410
  size = len(desc)
411
  lines = ["| | " + " | ".join(f"Col {i+1}" for i in range(size)) + " |"]
412
  for r in range(size):
 
415
  return "\n".join(lines)
416
 
417
  def decode_table(self, text: str) -> Optional[GridDesc]:
 
418
  try:
419
  rows = []
420
  for line in text.strip().splitlines():
 
458
  # ==================== Rendering ====================
459
 
460
  def render_gym(self, desc: GridDesc) -> Optional[Image.Image]:
 
461
  if not HAS_GYM:
462
  return None
463
  try:
 
475
  return None
476
 
477
  def render_simple(self, desc: GridDesc) -> Image.Image:
478
+ """Float-aligned renderer (handles non-power-of-2 sizes correctly)."""
479
  size = len(desc)
480
+ cell_f = self.img_size / size
481
  img = Image.new("RGB", (self.img_size, self.img_size), (255, 255, 255))
482
  draw = ImageDraw.Draw(img)
483
  colors = {
 
486
  }
487
  for r in range(size):
488
  for c in range(size):
489
+ x0 = int(round(c * cell_f))
490
+ y0 = int(round(r * cell_f))
491
+ x1 = int(round((c + 1) * cell_f)) - 1
492
+ y1 = int(round((r + 1) * cell_f)) - 1
493
  draw.rectangle(
494
+ [x0, y0, x1, y1],
495
  fill=colors.get(desc[r][c], (200, 220, 255)),
496
  )
497
  for i in range(size + 1):
498
+ pos = int(round(i * cell_f))
499
+ draw.line([(pos, 0), (pos, self.img_size)], fill="black", width=1)
500
+ draw.line([(0, pos), (self.img_size, pos)], fill="black", width=1)
501
  return img
502
 
503
  def render(self, desc: GridDesc, use_gym: bool = True) -> Image.Image:
 
510
  def draw_solution_line(
511
  self, image: Image.Image, path: List[Tuple[int, int]], grid_size: int,
512
  ) -> Image.Image:
 
513
  draw = ImageDraw.Draw(image)
514
  w, h = image.size
515
  cw, ch_ = w / grid_size, h / grid_size
 
520
  # ==================== Video Frames ====================
521
 
522
  def generate_video_frames(
523
+ self, desc: GridDesc, path: List[Tuple[int, int]],
524
+ n_start: int = 5, m_end: int = 5,
525
+ frames: Optional[int] = None, use_gym: bool = True,
 
 
 
 
526
  ) -> List[Image.Image]:
 
 
 
 
 
 
527
  size = len(desc)
528
  n_steps = len(path) - 1
529
  base_img = self.render(desc, use_gym=use_gym)
 
530
  if n_steps <= 0:
531
  return [base_img] * (n_start + m_end + 1)
 
532
  content = frames if frames is not None else n_steps
533
  content = max(1, content)
534
+ result = [base_img.copy() for _ in range(n_start)]
 
 
 
535
 
536
+ def _partial(steps):
537
+ return self.draw_solution_line(base_img.copy(), path[:steps+1], size)
538
 
539
  if content == n_steps:
540
  for s in range(1, n_steps + 1):
 
551
  for f in range(content):
552
  result.append(_partial((f + 1) * n_steps // content))
553
 
554
+ result.extend([_partial(n_steps).copy() for _ in range(m_end)])
 
 
555
  return result
556
 
557
  # ==================== Red-Path Extraction ====================
558
 
559
  def extract_path_from_pixels(
560
+ self, pixels: np.ndarray, rows: int, cols: int,
561
+ start: Tuple[int, int], desc: Optional[GridDesc] = None,
 
 
 
 
562
  pixel_threshold: float = 0.01,
563
  ) -> str:
564
+ """Detect red path (float-aligned cells to match renderer)."""
565
  img = Image.fromarray(pixels)
566
  w, h = img.size
567
  px = np.array(img, dtype=float)
568
  r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
569
  red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
570
 
571
+ cell_h_f, cell_w_f = h / rows, w / cols
572
  path_grid = np.zeros((rows, cols), dtype=bool)
573
  for r in range(rows):
574
+ y0 = int(round(r * cell_h_f))
575
+ y1 = int(round((r + 1) * cell_h_f))
576
  for c in range(cols):
577
+ x0 = int(round(c * cell_w_f))
578
+ x1 = int(round((c + 1) * cell_w_f))
579
+ sub = red_mask[y0:y1, x0:x1]
580
  if sub.size > 0 and np.mean(sub) > pixel_threshold:
581
  path_grid[r, c] = True
582
 
 
583
  visited = {start}
584
  cr, cc = start
585
  actions: List[str] = []
586
  for _ in range(rows * cols * 2):
587
  found = False
588
+ for act, (dr, dc) in [("R",(0,1)),("D",(1,0)),("L",(0,-1)),("U",(-1,0))]:
589
  nr, nc = cr + dr, cc + dc
590
  if 0 <= nr < rows and 0 <= nc < cols:
591
  if path_grid[nr, nc] and (nr, nc) not in visited:
 
598
  break
599
  return "".join(actions)
600
 
601
+ def extract_path_from_image(self, img_path, rows, cols, start, desc=None):
 
 
 
602
  try:
603
  pixels = np.array(Image.open(img_path).convert("RGB"))
604
  return self.extract_path_from_pixels(pixels, rows, cols, start, desc)
 
611
 
612
  proc = FrozenLakeProcessor(img_size=512)
613
 
614
+ # ---- Benchmark: yield rate ----
615
+ print("=== Yield Rate: random vs guided ===")
616
+ for sz in [8, 16, 32]:
617
+ min_len = max(1, int(sz * sz * 0.1))
618
+ random.seed(42)
619
  t0 = time.perf_counter()
620
+ found_r = 0
621
+ for _ in range(500):
622
+ desc = proc._random_layout(sz, 0.8)
623
  path = proc.solve(desc)
624
+ if path and (len(path) - 1) >= min_len:
625
+ found_r += 1
626
+ t_rand = time.perf_counter() - t0
627
+
628
+ random.seed(42)
629
+ t0 = time.perf_counter()
630
+ found_g = 0
631
+ for _ in range(50):
632
+ try:
633
+ desc, path = proc.generate_guided(sz, 0.8, min_len, max_attempts=5)
634
+ found_g += 1
635
+ except RuntimeError:
636
+ pass
637
+ t_guid = time.perf_counter() - t0
638
+
639
+ print(f" Size {sz:2d} (min={min_len:3d}): "
640
+ f"random={found_r}/500 ({found_r/5:.1f}%) {t_rand:.2f}s | "
641
+ f"guided={found_g}/50 ({found_g*2:.0f}%) {t_guid:.2f}s")
642
+
643
+ # ---- generate_auto all sizes ----
644
+ print("\n=== generate_auto ===")
645
+ for sz in [8, 16, 32, 64]:
646
+ min_len = max(1, int(sz * sz * 0.1))
647
+ random.seed(42)
648
+ t0 = time.perf_counter()
649
+ desc, path = proc.generate_auto(sz, 0.8, min_len)
650
  elapsed = time.perf_counter() - t0
651
+ udrl = proc.path_to_udrl(path)
652
+ ok = proc.verify_path_sim(desc, udrl)
653
+ print(f" Size {sz:2d}: path={len(path)-1:3d} (min={min_len:3d}) "
654
+ f"verify={ok} {elapsed:.3f}s")
655
+
656
+ # ---- Extract round-trip (works for random-mode, guided corridors are too winding) ----
657
+ print("\n=== Extract round-trip ===")
658
+ for sz in [8, 16, 24, 32]:
659
+ random.seed(42 + sz)
660
+ # Use random mode for smaller sizes (natural-looking grids)
661
+ min_len = max(1, sz)
662
+ try:
663
+ desc, path = proc.generate(sz, 0.8, min_len, max_attempts=1000)
664
+ except RuntimeError:
665
+ desc, path = proc.generate_guided(sz, 0.8, min_len)
666
+ img = proc.render(desc, use_gym=False)
667
+ sol = proc.draw_solution_line(img.copy(), path, sz)
668
+ start = proc.find_start(desc)
669
+ extracted = proc.extract_path_from_pixels(np.array(sol), sz, sz, start)
670
+ ok = proc.verify_path_sim(desc, extracted)
671
+ print(f" Size {sz:2d}: verify={ok} "
672
+ f"(GT={len(path)-1}, extracted={len(extracted)})")
673
+
674
+ print("\nAll tests passed ✓")
maze/maze/checkpoints/Wan2.1-I2V-14B-720P_lora_0209/epoch-0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc0d5db9871e456c6d806e54b77a54e1d1478c55de14dae7ce3317ba46021227
3
+ size 1226928552