SynLayers commited on
Commit
a1286e3
·
verified ·
1 Parent(s): cc605e0

Upload demo/infer/run_caption_bbox_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/infer/run_caption_bbox_infer.py +286 -0
demo/infer/run_caption_bbox_infer.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run whole-caption + bbox inference and save portable JSONL results."""
3
+
4
+ import os
5
+ import json
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+
12
+ try:
13
+ from demo.infer.vlm_bbox_inference import (
14
+ get_model_and_processor,
15
+ parse_bbox_output,
16
+ )
17
+ except ImportError:
18
+ from vlm_bbox_inference import (
19
+ get_model_and_processor,
20
+ parse_bbox_output,
21
+ )
22
+
23
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
24
+
25
+
26
+ def resolve_default_bbox_model() -> str:
27
+ env_model = os.environ.get("SYNLAYERS_BBOX_MODEL") or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO")
28
+ if env_model:
29
+ return env_model
30
+
31
+ candidates = [
32
+ PROJECT_ROOT if (PROJECT_ROOT / "config.json").exists() and (PROJECT_ROOT / "tokenizer_config.json").exists() else None,
33
+ PROJECT_ROOT / "Bbox-caption-8b",
34
+ ]
35
+ for candidate in candidates:
36
+ if candidate and candidate.exists():
37
+ return str(candidate)
38
+ return "SynLayers/Bbox-caption-8b"
39
+
40
+
41
+ CAPTION_BBOX_PROMPT_TOP_LEFT = (
42
+ "<image>This image is 1024 pixels in width and 1024 pixels in height. "
43
+ "The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. "
44
+ "First describe the whole image in one detailed caption (whole_caption). "
45
+ "Then list the bounding box for each visible layer or object. "
46
+ "Each box is [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). "
47
+ "Output a single JSON object with exactly two keys: \"whole_caption\" (string) and \"boxes\" (list of [x_left,y_top,x_right,y_bottom] arrays). "
48
+ "Output only this JSON, no other text or markdown."
49
+ )
50
+
51
+ DEFAULT_BBOX_MODEL = resolve_default_bbox_model()
52
+
53
+ IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
54
+
55
+
56
+ def parse_json_caption_bbox(text: str):
57
+ """Parse model output into `(whole_caption, boxes)`."""
58
+ text = (text or "").strip()
59
+
60
+ if "```" in text:
61
+ parts = text.split("```")
62
+ for p in parts:
63
+ p = p.strip()
64
+ if p.startswith("json"):
65
+ p = p[4:].strip()
66
+ if p.startswith("{"):
67
+ try:
68
+ obj = json.loads(p)
69
+ if isinstance(obj, dict):
70
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
71
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
72
+ if isinstance(boxes, list):
73
+ return caption, boxes
74
+ except json.JSONDecodeError:
75
+ pass
76
+
77
+ match = re.search(r"\{[\s\S]*\}", text)
78
+ if match:
79
+ try:
80
+ obj = json.loads(match.group(0))
81
+ if isinstance(obj, dict):
82
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
83
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
84
+ if isinstance(boxes, list):
85
+ return caption, boxes
86
+ except json.JSONDecodeError:
87
+ pass
88
+
89
+ boxes = parse_bbox_output(text)
90
+ return "", boxes
91
+
92
+
93
+ def format_image_record_path(image_path: Path, data_dir: Path) -> str:
94
+ try:
95
+ return image_path.relative_to(data_dir).as_posix()
96
+ except ValueError:
97
+ return image_path.name
98
+
99
+
100
+ def collect_images(data_dir: Path, max_samples: int | None, target_samples: set | None = None):
101
+ """Collect images and keep a relative path for JSONL output."""
102
+ data_dir = Path(data_dir)
103
+ out = []
104
+
105
+ for d in sorted(data_dir.glob("sample_*")):
106
+ if not d.is_dir():
107
+ continue
108
+ if target_samples is not None and d.name not in target_samples:
109
+ continue
110
+ whole = d / "whole_image.png"
111
+ if whole.exists():
112
+ out.append((d.name, whole, format_image_record_path(whole, data_dir)))
113
+ if max_samples and len(out) >= max_samples:
114
+ return out
115
+
116
+ if not out:
117
+ def _sort_key(p: Path):
118
+ parts = p.stem.rsplit("_", 1)
119
+ try:
120
+ return (parts[0], int(parts[-1]))
121
+ except ValueError:
122
+ return (p.stem, 0)
123
+
124
+ all_imgs = [
125
+ p for ext in IMAGE_EXTS
126
+ for p in data_dir.glob(f"*{ext}")
127
+ if p.is_file()
128
+ ]
129
+
130
+ for p in sorted(all_imgs, key=_sort_key):
131
+ if target_samples is not None and p.stem not in target_samples:
132
+ continue
133
+ out.append((p.stem, p, format_image_record_path(p, data_dir)))
134
+ if max_samples and len(out) >= max_samples:
135
+ return out
136
+
137
+ return out
138
+
139
+
140
+ def draw_boxes(image_path: Path, bboxes: list, out_path: Path, color: str = "lime", width: int = 3):
141
+ """Draw bounding boxes on an image."""
142
+ img = Image.open(image_path).convert("RGB")
143
+ draw = ImageDraw.Draw(img)
144
+
145
+ for b in bboxes:
146
+ if len(b) != 4:
147
+ continue
148
+ x0, y0, x1, y1 = float(b[0]), float(b[1]), float(b[2]), float(b[3])
149
+ x0, x1 = min(x0, x1), max(x0, x1)
150
+ y0, y1 = min(y0, y1), max(y0, y1)
151
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
152
+
153
+ out_path.parent.mkdir(parents=True, exist_ok=True)
154
+ img.save(out_path)
155
+
156
+
157
+ def infer_caption_bbox(image_path: str | Path, model, processor, *, prompt: str, max_new_tokens: int = 1024):
158
+ """Run caption + bbox inference for one image."""
159
+ path = Path(image_path)
160
+ if not path.exists():
161
+ return "", []
162
+
163
+ content = [
164
+ {"type": "image", "image": str(path.absolute())},
165
+ {"type": "text", "text": prompt},
166
+ ]
167
+
168
+ messages = [{"role": "user", "content": content}]
169
+
170
+ inputs = processor.apply_chat_template(
171
+ messages,
172
+ tokenize=True,
173
+ add_generation_prompt=True,
174
+ return_dict=True,
175
+ return_tensors="pt",
176
+ )
177
+
178
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
179
+ inputs.pop("token_type_ids", None)
180
+
181
+ with torch.no_grad():
182
+ generated = model.generate(
183
+ **inputs,
184
+ max_new_tokens=max_new_tokens,
185
+ do_sample=True,
186
+ temperature=0.1,
187
+ repetition_penalty=1.1,
188
+ pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
189
+ )
190
+
191
+ input_len = inputs["input_ids"].shape[1]
192
+ output_ids = generated[:, input_len:]
193
+
194
+ output_text = processor.batch_decode(
195
+ output_ids,
196
+ skip_special_tokens=True,
197
+ clean_up_tokenization_spaces=True
198
+ )
199
+
200
+ raw = (output_text[0] or "").strip()
201
+ whole_caption, bboxes = parse_json_caption_bbox(raw)
202
+
203
+ result_boxes = []
204
+ for b in bboxes:
205
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
206
+ result_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
207
+
208
+ return whole_caption, result_boxes
209
+
210
+
211
+ def main():
212
+ import argparse
213
+
214
+ parser = argparse.ArgumentParser(
215
+ description="Caption + bbox inference (top-left origin)"
216
+ )
217
+
218
+ parser.add_argument("--data-dir", type=str, default="testset",
219
+ help="Directory containing sample_* or image files")
220
+ parser.add_argument("--output", type=str, default="outputs/infer/caption_bbox_infer.jsonl",
221
+ help="Output JSONL file")
222
+ parser.add_argument("--model", type=str, default=DEFAULT_BBOX_MODEL,
223
+ help="Model path (merged or LoRA) (default: %(default)s)")
224
+ parser.add_argument("--max-samples", type=int, default=None)
225
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
226
+ parser.add_argument("--samples", type=str, nargs="+",
227
+ help="Specify sample names (e.g. sample_001)")
228
+ parser.add_argument("--vis-dir", type=str, default=None,
229
+ help="Optional directory for visualization")
230
+
231
+ args = parser.parse_args()
232
+
233
+ data_dir = Path(args.data_dir)
234
+ target_samples = set(args.samples) if args.samples else None
235
+
236
+ rows = collect_images(data_dir, args.max_samples, target_samples)
237
+ if not rows:
238
+ print(f"No images found under {data_dir}")
239
+ return
240
+
241
+ print(f"Loading model: {args.model}")
242
+ model, processor = get_model_and_processor(args.model)
243
+
244
+ print(f"Running inference on {len(rows)} samples...")
245
+
246
+ out_path = Path(args.output)
247
+ out_path.parent.mkdir(parents=True, exist_ok=True)
248
+
249
+ vis_dir = Path(args.vis_dir) if args.vis_dir else None
250
+
251
+ with open(out_path, "w", encoding="utf-8") as f:
252
+ for name, image_path, image_record_path in rows:
253
+ print(f" {name}")
254
+
255
+ whole_caption, bboxes = infer_caption_bbox(
256
+ image_path,
257
+ model,
258
+ processor,
259
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
260
+ max_new_tokens=args.max_new_tokens,
261
+ )
262
+
263
+ num_layers = len(bboxes)
264
+
265
+ record = {
266
+ "sample_or_stem": name,
267
+ "image": image_record_path,
268
+ "whole_caption": whole_caption,
269
+ "bboxes": bboxes,
270
+ "num_layers": num_layers,
271
+ "coord": "top_left",
272
+ }
273
+
274
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
275
+ f.flush()
276
+
277
+ if vis_dir:
278
+ draw_boxes(Path(image_path), bboxes, vis_dir / f"{name}_vis.png")
279
+
280
+ print(f"Wrote {out_path}")
281
+ if vis_dir:
282
+ print(f"Visualizations saved to {vis_dir}")
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()