zzsyppt commited on
Commit
287956c
·
verified ·
1 Parent(s): 9eb3c8c

Add Adacrop Space demo

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -35
  2. .gitignore +4 -0
  3. README.md +58 -15
  4. app.py +196 -0
  5. distillation/common.py +480 -0
  6. ppo_best_val_final_score.pth +3 -0
  7. requirements.txt +4 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ ppo_best_val_final_score.pth
3
+
4
+ ppo_best_val_final_score.pth
README.md CHANGED
@@ -1,15 +1,58 @@
1
- ---
2
- title: Adacrop Demo
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
- app_file: app.py
10
- pinned: false
11
- license: mit
12
- short_description: Demostrating AdaCrop full image cropping model.
13
- ---
14
-
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Adacrop Demo
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 6.14.0
8
+ python_version: '3.13'
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ short_description: Demostrating AdaCrop full image cropping model.
13
+ ---
14
+
15
+ # Adacrop Hugging Face Space Demo
16
+
17
+ This Gradio demo loads `ppo_best_val_final_score.pth`, predicts an initial crop with the BBox head, and optionally refines it with the PPO actor policy.
18
+
19
+ ## Required files
20
+
21
+ Deploy the Space with:
22
+
23
+ - `app.py`
24
+ - `requirements.txt`
25
+ - `ppo_best_val_final_score.pth`
26
+ - the existing `distillation/common.py` module from this repository
27
+
28
+ The easiest layout is:
29
+
30
+ ```text
31
+ app.py
32
+ requirements.txt
33
+ ppo_best_val_final_score.pth
34
+ distillation/
35
+ common.py
36
+ ```
37
+
38
+ If the checkpoint has a different path, set the Space environment variable:
39
+
40
+ ```text
41
+ MODEL_PATH=path/to/ppo_best_val_final_score.pth
42
+ ```
43
+
44
+ ## Behavior
45
+
46
+ - `max_steps = 0`: BBox head only.
47
+ - `max_steps > 0`: BBox head initializes the crop, then the actor policy refines it for up to `max_steps`.
48
+ - The UI shows the original image with a red crop box and the cropped result.
49
+
50
+ Optional environment variables:
51
+
52
+ ```text
53
+ FORCE_CPU=1
54
+ DISABLE_CUDNN=1
55
+ IMG_SIZE=224
56
+ ACTION_DELTA=0.05
57
+ DEFAULT_MAX_STEPS=60
58
+ ```
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+ from typing import List, Tuple
7
+
8
+ import gradio as gr
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+
12
+
13
+ SPACE_DIR = Path(__file__).resolve().parent
14
+ PROJECT_DIR = SPACE_DIR.parent
15
+ for path in (SPACE_DIR, PROJECT_DIR):
16
+ if str(path) not in sys.path:
17
+ sys.path.insert(0, str(path))
18
+
19
+ try:
20
+ from distillation.common import (
21
+ ACTIONS,
22
+ bbox_cxcywh_to_xyxy,
23
+ box_state,
24
+ clamp_xywh,
25
+ load_teacher,
26
+ render_crop,
27
+ render_full_image,
28
+ step_box,
29
+ )
30
+ except ModuleNotFoundError as exc:
31
+ raise ModuleNotFoundError(
32
+ "Cannot import distillation.common. Deploy this demo together with the "
33
+ "Adacrop/distillation directory, or copy distillation/common.py into the Space repo."
34
+ ) from exc
35
+
36
+
37
+ IMG_SIZE = int(os.getenv("IMG_SIZE", "224"))
38
+ ACTION_DELTA = float(os.getenv("ACTION_DELTA", "0.05"))
39
+ DEFAULT_MAX_STEPS = int(os.getenv("DEFAULT_MAX_STEPS", "60"))
40
+ MODEL_ENV = os.getenv("MODEL_PATH", "ppo_best_val_final_score.pth")
41
+
42
+
43
+ def resolve_model_path() -> Path:
44
+ raw = Path(MODEL_ENV)
45
+ candidates = []
46
+ if raw.is_absolute():
47
+ candidates.append(raw)
48
+ candidates.extend(
49
+ [
50
+ SPACE_DIR / raw,
51
+ PROJECT_DIR / raw,
52
+ SPACE_DIR / "models" / raw.name,
53
+ PROJECT_DIR / "models" / raw.name,
54
+ ]
55
+ )
56
+ for candidate in candidates:
57
+ if candidate.exists():
58
+ return candidate
59
+ checked = "\n".join(str(p) for p in candidates)
60
+ raise FileNotFoundError(
61
+ f"Could not find model checkpoint {MODEL_ENV!r}. Checked:\n{checked}\n"
62
+ "Put ppo_best_val_final_score.pth in the Space root, or set MODEL_PATH."
63
+ )
64
+
65
+
66
+ def get_device() -> torch.device:
67
+ if os.getenv("FORCE_CPU", "0") == "1":
68
+ return torch.device("cpu")
69
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+
72
+ @lru_cache(maxsize=1)
73
+ def get_model():
74
+ if os.getenv("DISABLE_CUDNN", "0") == "1":
75
+ torch.backends.cudnn.enabled = False
76
+ device = get_device()
77
+ model_path = resolve_model_path()
78
+ model = load_teacher(model_path, device)
79
+ return model, device, model_path
80
+
81
+
82
+ def predict_bbox(model, image: Image.Image, device: torch.device) -> Tuple[List[float], List[float]]:
83
+ width, height = image.size
84
+ img_t = render_full_image(image, IMG_SIZE).unsqueeze(0).to(device)
85
+ with torch.no_grad():
86
+ pred = model.backbone_forward(img_t).squeeze(0).detach().cpu().clamp(0.0, 1.0).tolist()
87
+ raw_xyxy = bbox_cxcywh_to_xyxy(pred, width, height)
88
+ x1, y1, x2, y2 = raw_xyxy
89
+ init_box = clamp_xywh(
90
+ [x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)],
91
+ width,
92
+ height,
93
+ delta=ACTION_DELTA,
94
+ )
95
+ return init_box, raw_xyxy
96
+
97
+
98
+ def predict_action(model, image: Image.Image, box_xywh: List[float], device: torch.device) -> int:
99
+ width, height = image.size
100
+ obs = render_crop(image, box_xywh, IMG_SIZE).unsqueeze(0).to(device)
101
+ state = box_state(box_xywh, width, height).unsqueeze(0).to(device)
102
+ with torch.no_grad():
103
+ probs, _ = model(obs, state)
104
+ return int(probs.argmax(dim=1).item())
105
+
106
+
107
+ def run_policy(model, image: Image.Image, init_box: List[float], max_steps: int, device: torch.device):
108
+ width, height = image.size
109
+ box = list(init_box)
110
+ actions = []
111
+ for _ in range(max_steps):
112
+ action_idx = predict_action(model, image, box, device)
113
+ action_name = ACTIONS[action_idx]
114
+ actions.append(action_name)
115
+ if action_name == "stop":
116
+ break
117
+ box = step_box(box, action_idx, width, height, delta=ACTION_DELTA)
118
+ return box, actions
119
+
120
+
121
+ def draw_box(image: Image.Image, box_xywh: List[float]) -> Image.Image:
122
+ out = image.copy().convert("RGB")
123
+ draw = ImageDraw.Draw(out)
124
+ x, y, w, h = [float(v) for v in box_xywh]
125
+ x2, y2 = x + w, y + h
126
+ line_width = max(3, int(min(out.size) * 0.006))
127
+ for offset in range(line_width):
128
+ draw.rectangle([x - offset, y - offset, x2 + offset, y2 + offset], outline=(255, 0, 0))
129
+ return out
130
+
131
+
132
+ def crop_image(image: Image.Image, box_xywh: List[float]) -> Image.Image:
133
+ x, y, w, h = [float(v) for v in box_xywh]
134
+ return image.crop((x, y, x + w, y + h)).convert("RGB")
135
+
136
+
137
+ def infer(image, max_steps):
138
+ if image is None:
139
+ raise gr.Error("Please upload an image first.")
140
+
141
+ image = image.convert("RGB")
142
+ max_steps = int(max(0, min(200, max_steps)))
143
+ model, device, model_path = get_model()
144
+
145
+ init_box, raw_bbox_xyxy = predict_bbox(model, image, device)
146
+ if max_steps == 0:
147
+ final_box = init_box
148
+ actions = []
149
+ mode = "BBox head only"
150
+ else:
151
+ final_box, actions = run_policy(model, image, init_box, max_steps, device)
152
+ mode = "BBox head + RL policy"
153
+
154
+ overlay = draw_box(image, final_box)
155
+ cropped = crop_image(image, final_box)
156
+ info = {
157
+ "mode": mode,
158
+ "device": str(device),
159
+ "model_path": str(model_path),
160
+ "image_size": {"width": image.width, "height": image.height},
161
+ "requested_max_steps": max_steps,
162
+ "actual_steps": len(actions),
163
+ "stopped": bool(actions and actions[-1] == "stop"),
164
+ "actions": actions,
165
+ "initial_box_xywh": [round(float(v), 3) for v in init_box],
166
+ "raw_bbox_head_xyxy": [round(float(v), 3) for v in raw_bbox_xyxy],
167
+ "final_box_xywh": [round(float(v), 3) for v in final_box],
168
+ }
169
+ return overlay, cropped, json.dumps(info, indent=2, ensure_ascii=False)
170
+
171
+
172
+ with gr.Blocks(title="Adacrop Core Policy Demo") as demo:
173
+ gr.Markdown("# Adacrop Crop Demo")
174
+ gr.Markdown("Upload an image. Set `max_steps = 0` to use only the BBox head; higher values run the RL policy refinement.")
175
+
176
+ with gr.Row():
177
+ with gr.Column():
178
+ input_image = gr.Image(type="pil", label="Input image")
179
+ max_steps = gr.Slider(
180
+ minimum=0,
181
+ maximum=200,
182
+ step=1,
183
+ value=min(max(DEFAULT_MAX_STEPS, 0), 200),
184
+ label="Max RL steps",
185
+ )
186
+ run_button = gr.Button("Crop", variant="primary")
187
+ with gr.Column():
188
+ overlay_image = gr.Image(type="pil", label="Original image with crop box")
189
+ cropped_image = gr.Image(type="pil", label="Cropped result")
190
+
191
+ info = gr.Code(label="Run details", language="json")
192
+ run_button.click(fn=infer, inputs=[input_image, max_steps], outputs=[overlay_image, cropped_image, info])
193
+
194
+
195
+ if __name__ == "__main__":
196
+ demo.launch()
distillation/common.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ from torchvision import models
14
+
15
+
16
+ ACTIONS = ["left", "right", "up", "down", "zoom_in", "zoom_out", "stop"]
17
+
18
+
19
+ def find_adacrop_root() -> Path:
20
+ return Path(__file__).resolve().parents[1]
21
+
22
+
23
+ def _strip_adacrop_prefix(path_text: str) -> str:
24
+ path_text = path_text.replace("\\", "/")
25
+ if path_text.startswith("./"):
26
+ path_text = path_text[2:]
27
+ if path_text.startswith("Adacrop/"):
28
+ path_text = path_text[len("Adacrop/") :]
29
+ return path_text
30
+
31
+
32
+ def resolve_image_path(raw_path: str, adacrop_root: Path, source_file: Optional[Path] = None) -> Path:
33
+ """Resolve mixed project paths, including JSONL paths like ./outpainted/a.png."""
34
+ raw = str(raw_path).replace("\\", "/")
35
+ candidates: List[Path] = []
36
+
37
+ p = Path(raw)
38
+ if p.is_absolute():
39
+ candidates.append(p)
40
+
41
+ if source_file is not None:
42
+ candidates.append(source_file.parent / raw)
43
+ if raw.startswith("./"):
44
+ candidates.append(source_file.parent / raw[2:])
45
+
46
+ stripped = _strip_adacrop_prefix(raw)
47
+ candidates.append(adacrop_root / stripped)
48
+ candidates.append(adacrop_root.parent / raw)
49
+
50
+ # Old merged JSONs may contain Adacrop/data/outpainted/foo.png, while this
51
+ # workspace stores those files under data/outpainted_dataset/outpainted.
52
+ if stripped.startswith("data/outpainted/"):
53
+ suffix = stripped[len("data/outpainted/") :]
54
+ candidates.append(adacrop_root / "data" / "outpainted_dataset" / "outpainted" / suffix)
55
+
56
+ # The outpainted JSONL stores paths as ./outpainted/foo.png relative to the
57
+ # JSONL file: data/outpainted_dataset/training_pairs.jsonl.
58
+ if stripped.startswith("outpainted/"):
59
+ candidates.append(adacrop_root / "data" / "outpainted_dataset" / stripped)
60
+
61
+ for cand in candidates:
62
+ if cand.exists():
63
+ return cand.resolve()
64
+ return candidates[0].resolve()
65
+
66
+
67
+ def normalize_boxes(value) -> List[List[float]]:
68
+ if value is None:
69
+ return []
70
+ if isinstance(value, dict):
71
+ if all(k in value for k in ("x1", "y1", "x2", "y2")):
72
+ return [[float(value["x1"]), float(value["y1"]), float(value["x2"]), float(value["y2"])]]
73
+ if all(k in value for k in ("x", "y", "w", "h")):
74
+ x, y, w, h = float(value["x"]), float(value["y"]), float(value["w"]), float(value["h"])
75
+ return [[x, y, x + w, y + h]]
76
+ return []
77
+ if isinstance(value, (list, tuple)):
78
+ if len(value) == 4 and all(isinstance(v, (int, float)) for v in value):
79
+ return [[float(v) for v in value]]
80
+ boxes: List[List[float]] = []
81
+ for item in value:
82
+ boxes.extend(normalize_boxes(item))
83
+ return boxes
84
+ return []
85
+
86
+
87
+ def canonical_box_xyxy(box: Sequence[float], width: int, height: int, img_path: Optional[str] = None) -> List[float]:
88
+ """Return a pixel-space [x1,y1,x2,y2] box.
89
+
90
+ The outpainted JSONL is xyxy, while the CUHK split files in this workspace
91
+ use yxyx-like coordinates. Use the image path when it is unambiguous, then
92
+ fall back to bounds checks.
93
+ """
94
+ a, b, c, d = [float(v) for v in box]
95
+ path_text = (img_path or "").replace("\\", "/").lower()
96
+
97
+ if "cuhk_images" in path_text:
98
+ x1, y1, x2, y2 = b, a, d, c
99
+ elif "outpainted" in path_text or "gaic_dataset" in path_text:
100
+ x1, y1, x2, y2 = a, b, c, d
101
+ else:
102
+ xyxy_valid = 0 <= a < c <= width and 0 <= b < d <= height
103
+ yxyx_valid = 0 <= b < d <= width and 0 <= a < c <= height
104
+ if yxyx_valid and not xyxy_valid:
105
+ x1, y1, x2, y2 = b, a, d, c
106
+ else:
107
+ x1, y1, x2, y2 = a, b, c, d
108
+
109
+ x1, x2 = sorted([x1, x2])
110
+ y1, y2 = sorted([y1, y2])
111
+ x1 = min(max(0.0, x1), float(width))
112
+ x2 = min(max(0.0, x2), float(width))
113
+ y1 = min(max(0.0, y1), float(height))
114
+ y2 = min(max(0.0, y2), float(height))
115
+ if x2 <= x1:
116
+ x2 = min(float(width), x1 + 1.0)
117
+ if y2 <= y1:
118
+ y2 = min(float(height), y1 + 1.0)
119
+ return [x1, y1, x2, y2]
120
+
121
+
122
+ def load_records(path: Path, adacrop_root: Path, require_images: bool = True) -> List[Dict]:
123
+ path = Path(path)
124
+ rows: List[Dict] = []
125
+ if path.suffix.lower() == ".jsonl":
126
+ with path.open("r", encoding="utf-8") as f:
127
+ for line in f:
128
+ line = line.strip()
129
+ if line:
130
+ rows.append(json.loads(line))
131
+ else:
132
+ with path.open("r", encoding="utf-8") as f:
133
+ rows = json.load(f)
134
+
135
+ records: List[Dict] = []
136
+ for row in rows:
137
+ raw_img = row.get("img") or row.get("file")
138
+ if not raw_img:
139
+ continue
140
+ img_path = resolve_image_path(raw_img, adacrop_root, source_file=path)
141
+ if require_images and not img_path.exists():
142
+ continue
143
+ boxes = normalize_boxes(row.get("box") or row.get("boxes") or row.get("orig_bbox"))
144
+ records.append({"img": str(img_path), "boxes": boxes, "raw": row})
145
+ return records
146
+
147
+
148
+ def resnet50_no_weights():
149
+ try:
150
+ return models.resnet50(weights=None)
151
+ except TypeError:
152
+ return models.resnet50(pretrained=False)
153
+
154
+
155
+ def mobilenet_v3_no_weights(arch: str):
156
+ if arch == "mobilenet_v3_large":
157
+ try:
158
+ return models.mobilenet_v3_large(weights=None)
159
+ except TypeError:
160
+ return models.mobilenet_v3_large(pretrained=False)
161
+ if arch == "mobilenet_v3_small":
162
+ try:
163
+ return models.mobilenet_v3_small(weights=None)
164
+ except TypeError:
165
+ return models.mobilenet_v3_small(pretrained=False)
166
+ raise ValueError(f"Unsupported student arch: {arch}")
167
+
168
+
169
+ class TeacherActorCritic(nn.Module):
170
+ def __init__(self, n_actions: int = len(ACTIONS)):
171
+ super().__init__()
172
+ self.backbone = resnet50_no_weights()
173
+ self.backbone.fc = nn.Identity()
174
+ feat_dim = 2048
175
+ self.actor = nn.Sequential(
176
+ nn.Linear(feat_dim + 4, 1024),
177
+ nn.ReLU(),
178
+ nn.Dropout(0.3),
179
+ nn.Linear(1024, 512),
180
+ nn.ReLU(),
181
+ nn.Dropout(0.2),
182
+ nn.Linear(512, n_actions),
183
+ )
184
+ self.critic = nn.Sequential(
185
+ nn.Linear(feat_dim + 4, 1024),
186
+ nn.ReLU(),
187
+ nn.Dropout(0.3),
188
+ nn.Linear(1024, 512),
189
+ nn.ReLU(),
190
+ nn.Dropout(0.2),
191
+ nn.Linear(512, 1),
192
+ )
193
+ self.bbox_head = nn.Sequential(nn.Linear(feat_dim, 512), nn.ReLU(), nn.Linear(512, 4))
194
+
195
+ def forward(self, img_tensor: torch.Tensor, state: torch.Tensor):
196
+ feats = self.backbone(img_tensor)
197
+ x = torch.cat([feats, state], dim=1)
198
+ logits = self.actor(x)
199
+ return F.softmax(logits, dim=1), self.critic(x)
200
+
201
+ def backbone_forward(self, img_tensor: torch.Tensor):
202
+ feats = self.backbone(img_tensor)
203
+ return self.bbox_head(feats)
204
+
205
+
206
+ class MobileNetPolicy(nn.Module):
207
+ def __init__(self, arch: str = "mobilenet_v3_small", n_actions: int = len(ACTIONS)):
208
+ super().__init__()
209
+ base = mobilenet_v3_no_weights(arch)
210
+ self.arch = arch
211
+ self.features = base.features
212
+ self.avgpool = base.avgpool
213
+ feat_dim = base.classifier[0].in_features
214
+ self.actor = nn.Sequential(
215
+ nn.Linear(feat_dim + 4, 512),
216
+ nn.ReLU(),
217
+ nn.Dropout(0.2),
218
+ nn.Linear(512, 256),
219
+ nn.ReLU(),
220
+ nn.Dropout(0.1),
221
+ nn.Linear(256, n_actions),
222
+ )
223
+ self.bbox_head = nn.Sequential(
224
+ nn.Linear(feat_dim, 256),
225
+ nn.ReLU(),
226
+ nn.Dropout(0.1),
227
+ nn.Linear(256, 4),
228
+ )
229
+
230
+ def extract_feats(self, img_tensor: torch.Tensor):
231
+ feats = self.features(img_tensor)
232
+ feats = self.avgpool(feats)
233
+ return torch.flatten(feats, 1)
234
+
235
+ def forward(self, img_tensor: torch.Tensor, state: torch.Tensor):
236
+ feats = self.extract_feats(img_tensor)
237
+ logits = self.actor(torch.cat([feats, state], dim=1))
238
+ return F.softmax(logits, dim=1), logits
239
+
240
+ def backbone_forward(self, img_tensor: torch.Tensor):
241
+ feats = self.extract_feats(img_tensor)
242
+ return torch.sigmoid(self.bbox_head(feats))
243
+
244
+
245
+ def load_teacher(ckpt_path: Path, device: torch.device) -> TeacherActorCritic:
246
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
247
+ state_dict = ckpt.get("model_state_dict", ckpt) if isinstance(ckpt, dict) else ckpt
248
+ model = TeacherActorCritic(n_actions=len(ACTIONS))
249
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
250
+ if unexpected:
251
+ print(f"[teacher] unexpected keys: {unexpected[:8]}")
252
+ missing_required = [k for k in missing if not k.startswith("critic.") and not k.startswith("bbox_head.")]
253
+ if missing_required:
254
+ raise RuntimeError(f"Teacher checkpoint missing required keys: {missing_required[:8]}")
255
+ return model.to(device).eval()
256
+
257
+
258
+ def load_student(ckpt_path: Path, device: torch.device, arch: Optional[str] = None) -> MobileNetPolicy:
259
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
260
+ ckpt_arch = ckpt.get("arch", arch or "mobilenet_v3_small")
261
+ model = MobileNetPolicy(arch=ckpt_arch, n_actions=len(ACTIONS))
262
+ state_dict = ckpt.get("model_state_dict", ckpt)
263
+ model.load_state_dict(state_dict)
264
+ return model.to(device).eval()
265
+
266
+
267
+ def xyxy_to_xywh(box: Sequence[float]) -> List[float]:
268
+ x1, y1, x2, y2 = [float(v) for v in box]
269
+ x1, x2 = sorted([x1, x2])
270
+ y1, y2 = sorted([y1, y2])
271
+ return [x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)]
272
+
273
+
274
+ def xywh_to_xyxy(box: Sequence[float]) -> List[float]:
275
+ x, y, w, h = [float(v) for v in box]
276
+ return [x, y, x + w, y + h]
277
+
278
+
279
+ def box_iou_xyxy(a: Sequence[float], b: Sequence[float]) -> float:
280
+ ax1, ay1, ax2, ay2 = [float(v) for v in a]
281
+ bx1, by1, bx2, by2 = [float(v) for v in b]
282
+ ix1, iy1 = max(ax1, bx1), max(ay1, by1)
283
+ ix2, iy2 = min(ax2, bx2), min(ay2, by2)
284
+ iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
285
+ inter = iw * ih
286
+ area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
287
+ area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
288
+ union = area_a + area_b - inter
289
+ return 0.0 if union <= 1e-8 else inter / union
290
+
291
+
292
+ def clamp_xywh(box: Sequence[float], width: int, height: int, delta: float = 0.05) -> List[float]:
293
+ x, y, w, h = [float(v) for v in box]
294
+ min_size = max(10.0, min(width, height) * 0.05)
295
+ w = max(min_size, min(w, float(width)))
296
+ h = max(min_size, min(h, float(height)))
297
+ x = min(max(0.0, x), float(width) - w)
298
+ y = min(max(0.0, y), float(height) - h)
299
+ w = max(min_size, min(float(width) - x, max(w, delta * width)))
300
+ h = max(min_size, min(float(height) - y, max(h, delta * height)))
301
+ return [x, y, w, h]
302
+
303
+
304
+ def random_box(width: int, height: int) -> List[float]:
305
+ ratio = width / max(1, height)
306
+ scale = random.uniform(0.3, 0.8)
307
+ if ratio >= 1:
308
+ w = max(10.0, width * scale)
309
+ h = max(10.0, w / ratio)
310
+ else:
311
+ h = max(10.0, height * scale)
312
+ w = max(10.0, h * ratio)
313
+ x = random.uniform(0.0, max(1.0, width - w))
314
+ y = random.uniform(0.0, max(1.0, height - h))
315
+ return clamp_xywh([x, y, w, h], width, height)
316
+
317
+
318
+ def jitter_box(box_xywh: Sequence[float], width: int, height: int, jitter: float = 0.12) -> List[float]:
319
+ x, y, w, h = [float(v) for v in box_xywh]
320
+ x += random.uniform(-jitter, jitter) * width
321
+ y += random.uniform(-jitter, jitter) * height
322
+ w *= random.uniform(1.0 - jitter, 1.0 + jitter)
323
+ h *= random.uniform(1.0 - jitter, 1.0 + jitter)
324
+ return clamp_xywh([x, y, w, h], width, height)
325
+
326
+
327
+ def box_state(box_xywh: Sequence[float], width: int, height: int) -> torch.Tensor:
328
+ x, y, w, h = [float(v) for v in box_xywh]
329
+ state = [
330
+ (x + 0.5 * w) / max(1.0, width),
331
+ (y + 0.5 * h) / max(1.0, height),
332
+ w / max(1.0, width),
333
+ h / max(1.0, height),
334
+ ]
335
+ if not all(math.isfinite(v) for v in state):
336
+ state = [0.5, 0.5, 0.6, 0.6]
337
+ return torch.tensor(state, dtype=torch.float32)
338
+
339
+
340
+ def render_crop(img: Image.Image, box_xywh: Sequence[float], img_size: int) -> torch.Tensor:
341
+ x, y, w, h = [float(v) for v in box_xywh]
342
+ crop = img.crop((x, y, x + w, y + h)).resize((img_size, img_size))
343
+ return T.ToTensor()(crop)
344
+
345
+
346
+ def render_full_image(img: Image.Image, img_size: int) -> torch.Tensor:
347
+ return T.ToTensor()(img.resize((img_size, img_size)))
348
+
349
+
350
+ def bbox_target_from_xyxy(box_xyxy: Sequence[float], width: int, height: int) -> torch.Tensor:
351
+ x1, y1, x2, y2 = [float(v) for v in box_xyxy]
352
+ x1, x2 = sorted([x1, x2])
353
+ y1, y2 = sorted([y1, y2])
354
+ target = [
355
+ ((x1 + x2) * 0.5) / max(1.0, width),
356
+ ((y1 + y2) * 0.5) / max(1.0, height),
357
+ max(1.0, x2 - x1) / max(1.0, width),
358
+ max(1.0, y2 - y1) / max(1.0, height),
359
+ ]
360
+ return torch.tensor([min(1.0, max(0.0, v)) for v in target], dtype=torch.float32)
361
+
362
+
363
+ def bbox_cxcywh_to_xyxy(box_cxcywh: Sequence[float], width: int, height: int) -> List[float]:
364
+ cx, cy, w, h = [float(v) for v in box_cxcywh]
365
+ bw = w * width
366
+ bh = h * height
367
+ x1 = cx * width - 0.5 * bw
368
+ y1 = cy * height - 0.5 * bh
369
+ x2 = x1 + bw
370
+ y2 = y1 + bh
371
+ return [
372
+ min(max(0.0, x1), float(width)),
373
+ min(max(0.0, y1), float(height)),
374
+ min(max(0.0, x2), float(width)),
375
+ min(max(0.0, y2), float(height)),
376
+ ]
377
+
378
+
379
+ def step_box(box_xywh: Sequence[float], action_idx: int, width: int, height: int, delta: float = 0.05) -> List[float]:
380
+ act = ACTIONS[int(action_idx)]
381
+ x, y, w, h = [float(v) for v in box_xywh]
382
+ dx, dy = delta * w, delta * h
383
+ cx, cy = x + 0.5 * w, y + 0.5 * h
384
+ if act == "left":
385
+ x = max(0.0, x - dx)
386
+ elif act == "right":
387
+ x = min(width - w, x + dx)
388
+ elif act == "up":
389
+ y = max(0.0, y - dy)
390
+ elif act == "down":
391
+ y = min(height - h, y + dy)
392
+ elif act == "zoom_in":
393
+ w *= 1.0 - delta
394
+ h *= 1.0 - delta
395
+ x = cx - 0.5 * w
396
+ y = cy - 0.5 * h
397
+ elif act == "zoom_out":
398
+ w *= 1.0 + delta
399
+ h *= 1.0 + delta
400
+ x = cx - 0.5 * w
401
+ y = cy - 0.5 * h
402
+ return clamp_xywh([x, y, w, h], width, height, delta=delta)
403
+
404
+
405
+ class PolicyStateDataset(Dataset):
406
+ def __init__(
407
+ self,
408
+ records: Sequence[Dict],
409
+ img_size: int = 224,
410
+ samples_per_image: int = 1,
411
+ random_box_prob: float = 0.65,
412
+ jitter: float = 0.12,
413
+ ):
414
+ self.records = list(records)
415
+ self.img_size = int(img_size)
416
+ self.samples_per_image = max(1, int(samples_per_image))
417
+ self.random_box_prob = float(random_box_prob)
418
+ self.jitter = float(jitter)
419
+
420
+ def __len__(self) -> int:
421
+ return len(self.records) * self.samples_per_image
422
+
423
+ def __getitem__(self, idx: int):
424
+ rec = self.records[idx % len(self.records)]
425
+ img = Image.open(rec["img"]).convert("RGB")
426
+ width, height = img.size
427
+ boxes = rec.get("boxes") or []
428
+
429
+ if boxes and random.random() > self.random_box_prob:
430
+ gt_box = canonical_box_xyxy(random.choice(boxes), width, height, img_path=rec["img"])
431
+ box = jitter_box(xyxy_to_xywh(gt_box), width, height, jitter=self.jitter)
432
+ else:
433
+ box = random_box(width, height)
434
+
435
+ return render_crop(img, box, self.img_size), box_state(box, width, height)
436
+
437
+
438
+ class BBoxDataset(Dataset):
439
+ def __init__(self, records: Sequence[Dict], img_size: int = 224, samples_per_image: int = 1):
440
+ self.records = [r for r in records if r.get("boxes")]
441
+ self.img_size = int(img_size)
442
+ self.samples_per_image = max(1, int(samples_per_image))
443
+
444
+ def __len__(self) -> int:
445
+ return len(self.records) * self.samples_per_image
446
+
447
+ def __getitem__(self, idx: int):
448
+ rec = self.records[idx % len(self.records)]
449
+ img = Image.open(rec["img"]).convert("RGB")
450
+ width, height = img.size
451
+ box = canonical_box_xyxy(random.choice(rec["boxes"]), width, height, img_path=rec["img"])
452
+ return render_full_image(img, self.img_size), bbox_target_from_xyxy(box, width, height)
453
+
454
+
455
+ class BBoxEvalDataset(Dataset):
456
+ def __init__(self, records: Sequence[Dict], img_size: int = 224):
457
+ self.records = [r for r in records if r.get("boxes")]
458
+ self.img_size = int(img_size)
459
+
460
+ def __len__(self) -> int:
461
+ return len(self.records)
462
+
463
+ def __getitem__(self, idx: int):
464
+ rec = self.records[idx]
465
+ img = Image.open(rec["img"]).convert("RGB")
466
+ width, height = img.size
467
+ targets = torch.stack(
468
+ [
469
+ bbox_target_from_xyxy(canonical_box_xyxy(box, width, height, img_path=rec["img"]), width, height)
470
+ for box in rec["boxes"]
471
+ ]
472
+ )
473
+ return render_full_image(img, self.img_size), targets
474
+
475
+
476
+ def soften_probs(probs: torch.Tensor, temperature: float) -> torch.Tensor:
477
+ if temperature <= 1.0:
478
+ return probs
479
+ softened = probs.clamp_min(1e-8).pow(1.0 / temperature)
480
+ return softened / softened.sum(dim=1, keepdim=True)
ppo_best_val_final_score.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1336a47608769e622c6539be106c037f43be479313af9cb3dcef33719c68d490
3
+ size 161679377
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ pillow