Syzygianinfern0 commited on
Commit
d41a400
·
1 Parent(s): 2d6187c

Bring over latest scripts from demo (7c8fc86)

Browse files
execute_demo_v2.py CHANGED
@@ -1,588 +1,588 @@
1
- import json
2
- import os
3
- import uuid
4
- import cv2
5
- import subprocess
6
- import numpy as np
7
- import gradio as gr
8
- import tempfile
9
- from typing import Dict, List, Iterable, Tuple
10
-
11
- from ns_vfs.video.read_mp4 import Mp4Reader
12
- from execute_with_mp4 import process_entry
13
- from matplotlib import pyplot as plt
14
-
15
- import base64
16
-
17
- from openai import OpenAI
18
-
19
- class VLLMClient:
20
- def __init__(
21
- self,
22
- api_key="EMPTY",
23
- api_base="http://localhost:8000/v1",
24
- model="OpenGVLab/InternVL2-8B",
25
- # model="Qwen/Qwen2.5-VL-7B-Instruct",
26
- ):
27
- self.client = OpenAI(api_key=api_key, base_url=api_base)
28
- self.model = model
29
-
30
- # def _encode_frame(self, frame):
31
- # return base64.b64encode(frame.tobytes()).decode("utf-8")
32
- def _encode_frame(self, frame):
33
- # Encode a uint8 numpy array (image) as a JPEG and then base64 encode it.
34
- ret, buffer = cv2.imencode(".jpg", frame)
35
- if not ret:
36
- raise ValueError("Could not encode frame")
37
- return base64.b64encode(buffer).decode("utf-8")
38
-
39
- def caption( self, frames: list[np.ndarray]):
40
-
41
- parsing_rule = " You must return a caption for the sequence of images. The caption must be a single sentence. The caption must be in the same language as the question."
42
- prompt = rf"Give me a detailed description of what you see in the images " f"\n[PARSING RULE]: {parsing_rule}"
43
-
44
- # Encode each frame.
45
- encoded_images = [self._encode_frame(frame) for frame in frames]
46
-
47
- # Build the user message: a text prompt plus one image for each frame.
48
- user_content = [
49
- {
50
- "type": "text",
51
- "text": f"The following is the sequence of images",
52
- }
53
- ]
54
- for encoded in encoded_images:
55
- user_content.append(
56
- {
57
- "type": "image_url",
58
- "image_url": {"url": f"data:image/jpeg;base64,{encoded}"},
59
- }
60
- )
61
-
62
- # Create a chat completion request.
63
- chat_response = self.client.chat.completions.create(
64
- model=self.model,
65
- messages=[
66
- {"role": "system", "content": prompt},
67
- {"role": "user", "content": user_content},
68
- ],
69
- max_tokens=1000,
70
- temperature=0.0,
71
- logprobs=True,
72
- )
73
- content = chat_response.choices[0].message.content
74
- return content
75
-
76
-
77
- def _load_entry_from_reader(video_path, query_text):
78
- reader = Mp4Reader(
79
- [{"path": video_path, "query": query_text}],
80
- openai_save_path="",
81
- sampling_rate_fps=0.5
82
- )
83
- data = reader.read_video()
84
- if not data:
85
- raise RuntimeError("No data returned by Mp4Reader (check video path)")
86
- return data[0]
87
-
88
-
89
- def _make_empty_video(path, width=320, height=240, fps=1.0):
90
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
91
- writer = cv2.VideoWriter(path, fourcc, fps, (width, height))
92
- frame = np.zeros((height, width, 3), dtype=np.uint8)
93
- writer.write(frame)
94
- writer.release()
95
- return path
96
-
97
-
98
- def _crop_video_ffmpeg(input_path, output_path, frame_indices, prop_matrix):
99
- if len(frame_indices) == 0:
100
- cap = cv2.VideoCapture(str(input_path))
101
- if not cap.isOpened():
102
- raise RuntimeError(f"Could not open video: {input_path}")
103
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
104
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
105
- cap.release()
106
- _make_empty_video(output_path, width, height, fps=1.0)
107
- return
108
-
109
- def group_into_ranges(frames):
110
- if not frames:
111
- return []
112
- frames = sorted(set(frames))
113
- ranges = []
114
- start = prev = frames[0]
115
- for f in frames[1:]:
116
- if f == prev + 1:
117
- prev = f
118
- else:
119
- ranges.append((start, prev + 1)) # end-exclusive
120
- start = prev = f
121
- ranges.append((start, prev + 1))
122
- return ranges
123
-
124
- ranges = group_into_ranges(frame_indices)
125
- filters = []
126
- labels = []
127
- for i, (start, end) in enumerate(ranges):
128
- filters.append(
129
- f"[0:v]trim=start_frame={start}:end_frame={end},setpts=PTS-STARTPTS[v{i}]"
130
- )
131
- labels.append(f"[v{i}]")
132
- filters.append(f"{''.join(labels)}concat=n={len(ranges)}:v=1:a=0[outv]")
133
-
134
- cmd = [
135
- "ffmpeg", "-y", "-i", input_path,
136
- "-filter_complex", "; ".join(filters),
137
- "-map", "[outv]",
138
- "-c:v", "libx264", "-preset", "fast", "-crf", "23",
139
- output_path,
140
- ]
141
- subprocess.run(cmd, check=True)
142
-
143
-
144
- def _crop_video(input_path: str, output_path: str, frame_indices: List[int], prop_matrix: Dict[str, List[int]]):
145
- input_path = str(input_path)
146
- output_path = str(output_path)
147
-
148
- # Probe width/height/fps
149
- cap = cv2.VideoCapture(input_path)
150
- if not cap.isOpened():
151
- raise RuntimeError(f"Could not open video: {input_path}")
152
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
153
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
154
- fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
155
- cap.release()
156
- if fps <= 0:
157
- fps = 30.0
158
-
159
- # If nothing to write, emit a 1-frame empty video
160
- if not frame_indices:
161
- from numpy import zeros, uint8
162
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
163
- out = cv2.VideoWriter(output_path, fourcc, 1.0, (width, height))
164
- out.write(zeros((height, width, 3), dtype=uint8))
165
- out.release()
166
- return
167
-
168
- # Helper: group consecutive integers into (start, end_exclusive)
169
- def _group_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
170
- f = sorted(set(int(x) for x in frames))
171
- if not f:
172
- return []
173
- out = []
174
- s = p = f[0]
175
- for x in f[1:]:
176
- if x == p + 1:
177
- p = x
178
- else:
179
- out.append((s, p + 1))
180
- s = p = x
181
- out.append((s, p + 1))
182
- return out
183
-
184
- # Invert prop_matrix to {frame_idx: sorted [props]}
185
- props_by_frame: Dict[int, List[str]] = {}
186
- for prop, frames in (prop_matrix or {}).items():
187
- for fi in frames:
188
- fi = int(fi)
189
- props_by_frame.setdefault(fi, []).append(prop)
190
- for fi in list(props_by_frame.keys()):
191
- props_by_frame[fi] = sorted(set(props_by_frame[fi]))
192
-
193
- # Only subtitle frames we will output
194
- fi_set = set(int(x) for x in frame_indices)
195
- frames_with_labels = sorted(fi for fi in fi_set if props_by_frame.get(fi))
196
-
197
- # Compress consecutive frames that share the same label set
198
- grouped_label_spans: List[Tuple[int, int, Tuple[str, ...]]] = []
199
- prev_f = None
200
- prev_labels: Tuple[str, ...] = ()
201
- span_start = None
202
- for f in frames_with_labels:
203
- labels = tuple(props_by_frame.get(f, []))
204
- if prev_f is None:
205
- span_start, prev_f, prev_labels = f, f, labels
206
- elif (f == prev_f + 1) and (labels == prev_labels):
207
- prev_f = f
208
- else:
209
- grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
210
- span_start, prev_f, prev_labels = f, f, labels
211
- if prev_f is not None and prev_labels:
212
- grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
213
-
214
- # Build ASS subtitle file (top-right)
215
- def ass_time(t_sec: float) -> str:
216
- cs = int(round(t_sec * 100))
217
- h = cs // (100 * 3600)
218
- m = (cs // (100 * 60)) % 60
219
- s = (cs // 100) % 60
220
- cs = cs % 100
221
- return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
222
-
223
- def make_ass(width: int, height: int) -> str:
224
- lines = []
225
- lines.append("[Script Info]")
226
- lines.append("ScriptType: v4.00+")
227
- lines.append("ScaledBorderAndShadow: yes")
228
- lines.append(f"PlayResX: {width}")
229
- lines.append(f"PlayResY: {height}")
230
- lines.append("")
231
- lines.append("[V4+ Styles]")
232
- lines.append("Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, "
233
- "Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, "
234
- "Shadow, Alignment, MarginL, MarginR, MarginV, Encoding")
235
- # Font size 18 per your request; Alignment=9 (top-right)
236
- lines.append("Style: Default,DejaVu Sans,18,&H00FFFFFF,&H000000FF,&H00000000,&H64000000,"
237
- "0,0,0,0,100,100,0,0,1,2,0.8,9,16,16,16,1")
238
- lines.append("")
239
- lines.append("[Events]")
240
- lines.append("Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text")
241
-
242
- for start_f, end_f, labels in grouped_label_spans:
243
- if not labels:
244
- continue
245
- start_t = ass_time(start_f / fps)
246
- end_t = ass_time(end_f / fps)
247
- text = r"\N".join(labels) # stacked lines
248
- lines.append(f"Dialogue: 0,{start_t},{end_t},Default,,0,0,0,,{text}")
249
-
250
- return "\n".join(lines)
251
-
252
- tmp_dir = tempfile.mkdtemp(prefix="props_ass_")
253
- ass_path = os.path.join(tmp_dir, "props.ass")
254
- with open(ass_path, "w", encoding="utf-8") as f:
255
- f.write(make_ass(width, height))
256
-
257
- # Build trim/concat ranges from requested frame_indices
258
- ranges = _group_ranges(frame_indices)
259
-
260
- # Filtergraph with burned subtitles then trim/concat
261
- split_labels = [f"[s{i}]" for i in range(len(ranges))] if ranges else []
262
- out_labels = [f"[v{i}]" for i in range(len(ranges))] if ranges else []
263
-
264
- filters = []
265
- ass_arg = ass_path.replace("\\", "\\\\")
266
- filters.append(f"[0:v]subtitles='{ass_arg}'[sub]")
267
-
268
- if len(ranges) == 1:
269
- s0, e0 = ranges[0]
270
- filters.append(f"[sub]trim=start_frame={s0}:end_frame={e0},setpts=PTS-STARTPTS[v0]")
271
- else:
272
- if ranges:
273
- filters.append(f"[sub]split={len(ranges)}{''.join(split_labels)}")
274
- for i, (s, e) in enumerate(ranges):
275
- filters.append(f"{split_labels[i]}trim=start_frame={s}:end_frame={e},setpts=PTS-STARTPTS{out_labels[i]}")
276
-
277
- if ranges:
278
- filters.append(f"{''.join(out_labels)}concat=n={len(ranges)}:v=1:a=0[outv]")
279
-
280
- filter_complex = "; ".join(filters)
281
-
282
- cmd = [
283
- "ffmpeg", "-y",
284
- "-i", input_path,
285
- "-filter_complex", filter_complex,
286
- "-map", "[outv]" if ranges else "[sub]",
287
- "-c:v", "libx264", "-preset", "fast", "-crf", "23",
288
- output_path,
289
- ]
290
- try:
291
- subprocess.run(cmd, check=True)
292
- finally:
293
- try:
294
- os.remove(ass_path)
295
- os.rmdir(tmp_dir)
296
- except OSError:
297
- pass
298
-
299
- def _format_prop_ranges_dict(prop_matrix: Dict[str, List[int]]) -> str:
300
- def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
301
- f = sorted(set(int(x) for x in frames))
302
- if not f:
303
- return []
304
- ranges: List[Tuple[int, int]] = []
305
- s = p = f[0]
306
- for x in f[1:]:
307
- if x == p + 1:
308
- p = x
309
- else:
310
- ranges.append((s, p)) # inclusive end for display
311
- s = p = x
312
- ranges.append((s, p))
313
- return ranges
314
-
315
- detections = {}
316
- for prop, frames in prop_matrix.items():
317
- ranges = group_into_ranges(frames)
318
- detections[prop] = ranges
319
- return detections
320
-
321
-
322
- def _format_prop_ranges(prop_matrix: Dict[str, List[int]]) -> str:
323
- def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
324
- f = sorted(set(int(x) for x in frames))
325
- if not f:
326
- return []
327
- ranges: List[Tuple[int, int]] = []
328
- s = p = f[0]
329
- for x in f[1:]:
330
- if x == p + 1:
331
- p = x
332
- else:
333
- ranges.append((s, p)) # inclusive end for display
334
- s = p = x
335
- ranges.append((s, p))
336
- return ranges
337
-
338
- if not prop_matrix:
339
- return "No propositions detected."
340
-
341
- lines = []
342
- for prop, frames in prop_matrix.items():
343
- ranges = group_into_ranges(frames)
344
- pretty = prop.replace("_", " ").title()
345
- if not ranges:
346
- lines.append(f"{pretty}: —")
347
- continue
348
- parts = [f"{a}" if a == b else f"{a}-{b}" for (a, b) in ranges]
349
- lines.append(f"{pretty}: {', '.join(parts)}")
350
- return "\n".join(lines)
351
-
352
- def generate_timeline_plot(detections, total_frames):
353
- """
354
- Generates a timeline plot from detection data using Matplotlib.
355
-
356
- Args:
357
- detections (dict): A dictionary where keys are string labels and values are lists
358
- of (start_frame, end_frame) tuples.
359
- e.g., {"dog": [(0, 45), (90, 100)], "grass": [(30, 80)]}
360
- total_frames (int): The total number of frames in the video for the x-axis scale.
361
-
362
- Returns:
363
- matplotlib.figure.Figure: The generated plot figure.
364
- """
365
- labels = list(detections.keys())
366
- num_labels = len(labels)
367
-
368
- # Handle case with no detections
369
- if num_labels == 0:
370
- fig, ax = plt.subplots(figsize=(10, 1))
371
- ax.text(0.5, 0.5, 'No propositions detected.', ha='center', va='center')
372
- ax.set_axis_off()
373
- return fig
374
-
375
- # Use a color map to assign distinct colors automatically
376
- colors = plt.cm.get_cmap('tab10', num_labels)
377
-
378
- fig, ax = plt.subplots(figsize=(10, num_labels * 0.6 + 0.5))
379
-
380
- ax.set_xlim(0, total_frames)
381
- ax.set_ylim(0, num_labels)
382
- ax.set_yticks(np.arange(num_labels) + 0.5)
383
- ax.set_yticklabels(labels, fontsize=12)
384
- ax.set_xlabel("Frame Number", fontsize=12)
385
- ax.grid(axis='x', linestyle='--', alpha=0.6)
386
-
387
- # Invert y-axis to have the first proposition on top
388
- ax.invert_yaxis()
389
-
390
- for i, label in enumerate(labels):
391
- # matplotlib's broken_barh needs a list of (start, width) tuples
392
- segments = [(start, end - start) for start, end in detections[label]]
393
- # The bar is drawn at y-position 'i' with a height of 0.8
394
- ax.broken_barh(segments, (i + 0.1, 0.8), facecolors=colors(i))
395
-
396
- plt.tight_layout()
397
- return fig
398
-
399
- # -----------------------------
400
- # Gradio handler
401
- # -----------------------------
402
- def run_pipeline(input_video, mode, query_text, propositions_json, specification_text):
403
- """
404
- Returns: (cropped_video_path, prop_ranges_text, tl_text)
405
- """
406
-
407
- def _err(msg, width=320, height=240): # keep outputs shape consistent
408
- tmp_out = os.path.join("/tmp", f"empty_{uuid.uuid4().hex}.mp4")
409
- _make_empty_video(tmp_out, width=width, height=height, fps=1.0)
410
- return (
411
- tmp_out,
412
- "No propositions detected.",
413
- f"Error: {msg}"
414
- )
415
-
416
- # Resolve video path
417
- if isinstance(input_video, dict) and "name" in input_video:
418
- video_path = input_video["name"]
419
- elif isinstance(input_video, str):
420
- video_path = input_video
421
- else:
422
- return _err("Please provide a video.")
423
-
424
- # Build entry
425
- if mode == "Natural language query":
426
- if not query_text or not query_text.strip():
427
- return _err("Please enter a query.")
428
- entry = _load_entry_from_reader(video_path, query_text)
429
- else:
430
- if not (propositions_json and propositions_json.strip()) or not (specification_text and specification_text.strip()):
431
- return _err("Please provide both Propositions (array) and Specification.")
432
- entry = _load_entry_from_reader(video_path, "dummy-query")
433
- try:
434
- props = json.loads(propositions_json)
435
- if not isinstance(props, list):
436
- return _err("Propositions must be a JSON array.")
437
- except Exception as e:
438
- return _err(f"Failed to parse propositions JSON: {e}")
439
- entry["tl"] = {
440
- "propositions": props,
441
- "specification": specification_text
442
- }
443
-
444
- # Compute FOI
445
- try:
446
- foi, prop_matrix, p2 = process_entry(entry) # list of frame indices & {prop: [frames]}
447
- print(foi)
448
- print(prop_matrix)
449
- print(p2)
450
- except Exception as e:
451
- return _err(f"Processing error: {e}")
452
-
453
- # Write cropped video
454
- try:
455
- out_path = os.path.join("/tmp", f"cropped_{uuid.uuid4().hex}.mp4")
456
- _crop_video(video_path, out_path, foi, prop_matrix)
457
- print(f"Wrote cropped video to: {out_path}")
458
- except Exception as e:
459
- return _err(f"Failed to write cropped video: {e}")
460
-
461
- # Build right-side text sections
462
- prop_ranges_text = _format_prop_ranges(prop_matrix)
463
- prop_ranges_dict = _format_prop_ranges_dict(prop_matrix)
464
- plot = generate_timeline_plot(prop_ranges_dict, entry["video_info"].frame_count)
465
- tl_text = (
466
- f"Propositions: {json.dumps(entry['tl']['propositions'], ensure_ascii=False)}\n"
467
- f"Specification: {entry['tl']['specification']}"
468
- )
469
- return out_path, prop_ranges_text, tl_text, plot
470
-
471
- def generate_caption(video_path):
472
- """
473
- Simulates generating a caption for the given video file.
474
- """
475
- # If the video is cleared, the input will be None
476
- if video_path is None:
477
- # Hide the caption box and clear its content
478
- return gr.update(value="", visible=False)
479
- print(f"Generating caption for: {video_path}")
480
- vllm_client = VLLMClient()
481
- entry = _load_entry_from_reader(video_path, "dummy-query")
482
- # sample 4 frames from the video evenly
483
- len_frames = len(entry['images'])
484
- images = [entry['images'][i] for i in range(0, len_frames, len_frames//3)]
485
- caption_text = vllm_client.caption(images)
486
- # Simulate model inference time
487
- # Use gr.update to change both the value and visibility of the textbox
488
- return gr.update(value=caption_text, visible=True)
489
- # -----------------------------
490
- # UI
491
- # -----------------------------
492
- with gr.Blocks(css="""
493
- #io-col {display: flex; gap: 1rem;}
494
- #left {flex: 1;}
495
- #right {flex: 1;}
496
- """, title="NSVS-TL") as demo:
497
-
498
- gr.Markdown("# Neuro-Symbolic Visual Search with Temporal Logic")
499
- gr.Markdown(
500
- "Upload a video and either provide a natural-language **Query** *or* directly supply **Propositions** (array) + **Specification**. "
501
- "On the right, you'll get a **cropped video** containing only the frames of interest, a **Propositions by Frames** summary, and the combined TL summary."
502
- )
503
-
504
- with gr.Row(elem_id="io-col"):
505
- with gr.Column(elem_id="left"):
506
- mode = gr.Radio(
507
- choices=["Natural language query", "Props/Spec"],
508
- value="Natural language query",
509
- label="Input mode"
510
- )
511
- video = gr.Video(label="Upload Video")
512
-
513
- query = gr.Textbox(
514
- label="Query (natural language)",
515
- placeholder="e.g., a man is jumping and panting until he falls down"
516
- )
517
-
518
- captions = gr.Textbox(
519
- label="Video Caption",
520
- placeholder="e.g., a man is jumping and panting until he falls down",
521
- lines=4,
522
- visible=False
523
- )
524
-
525
- propositions = gr.Textbox(
526
- label="Propositions (JSON array)",
527
- placeholder='e.g., ["man_jumps", "man_pants", "man_falls_down"]',
528
- lines=4,
529
- visible=False
530
- )
531
- specification = gr.Textbox(
532
- label="Specification",
533
- placeholder='e.g., ("woman_jumps" & "woman_claps") U "candle_is_blown"',
534
- visible=False
535
- )
536
-
537
- def _toggle_fields(m):
538
- if m == "Natural language query":
539
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
540
- else:
541
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
542
-
543
- mode.change(_toggle_fields, inputs=[mode], outputs=[query, propositions, specification])
544
- video.change(
545
- fn=generate_caption,
546
- inputs=[video],
547
- outputs=[captions]
548
- )
549
- run_btn = gr.Button("Run", variant="primary")
550
-
551
- gr.Examples(
552
- label="Examples (dummy paths + queries)",
553
- examples=[
554
- ["demo_videos/dog_jump.mp4", "a dog jumps until a red tube is in view"],
555
- ["demo_videos/blue_shirt.mp4", "a girl in a green shirt until a candle is blown"],
556
- ["demo_videos/car.mp4", "red car until a truck"],
557
- ["demo_videos/newyork_1.mp4", "bright lights until empire state building"],
558
- ["demo_videos/chicago_2.mp4", "ocean until ship"],
559
- ],
560
- inputs=[video, query],
561
- cache_examples=False,
562
- )
563
-
564
- with gr.Column(elem_id="right"):
565
- cropped_video = gr.Video(label="Cropped Video (Frames of Interest Only)")
566
-
567
- prop_ranges_out = gr.Textbox(
568
- label="Propositions by Frames",
569
- lines=6,
570
- interactive=False
571
- )
572
-
573
- timeline_plot_output = gr.Plot(label="Propositions Timeline")
574
-
575
- tl_out = gr.Textbox(
576
- label="TL (Propositions & Specification)",
577
- lines=8,
578
- interactive=False
579
- )
580
-
581
- run_btn.click(
582
- fn=run_pipeline,
583
- inputs=[video, mode, query, propositions, specification],
584
- outputs=[cropped_video, prop_ranges_out, tl_out, timeline_plot_output]
585
- )
586
-
587
- if __name__ == "__main__":
588
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import json
2
+ import os
3
+ import uuid
4
+ import cv2
5
+ import subprocess
6
+ import numpy as np
7
+ import gradio as gr
8
+ import tempfile
9
+ from typing import Dict, List, Iterable, Tuple
10
+
11
+ from ns_vfs.video.read_mp4 import Mp4Reader
12
+ from execute_with_mp4 import process_entry
13
+ from matplotlib import pyplot as plt
14
+
15
+ import base64
16
+
17
+ from openai import OpenAI
18
+
19
+ class VLLMClient:
20
+ def __init__(
21
+ self,
22
+ api_key="EMPTY",
23
+ api_base="http://localhost:8000/v1",
24
+ model="OpenGVLab/InternVL2-8B",
25
+ # model="Qwen/Qwen2.5-VL-7B-Instruct",
26
+ ):
27
+ self.client = OpenAI(api_key=api_key, base_url=api_base)
28
+ self.model = model
29
+
30
+ # def _encode_frame(self, frame):
31
+ # return base64.b64encode(frame.tobytes()).decode("utf-8")
32
+ def _encode_frame(self, frame):
33
+ # Encode a uint8 numpy array (image) as a JPEG and then base64 encode it.
34
+ ret, buffer = cv2.imencode(".jpg", frame)
35
+ if not ret:
36
+ raise ValueError("Could not encode frame")
37
+ return base64.b64encode(buffer).decode("utf-8")
38
+
39
+ def caption( self, frames: list[np.ndarray]):
40
+
41
+ parsing_rule = " You must return a caption for the sequence of images. The caption must be a single sentence. The caption must be in the same language as the question."
42
+ prompt = rf"Give me a detailed description of what you see in the images " f"\n[PARSING RULE]: {parsing_rule}"
43
+
44
+ # Encode each frame.
45
+ encoded_images = [self._encode_frame(frame) for frame in frames]
46
+
47
+ # Build the user message: a text prompt plus one image for each frame.
48
+ user_content = [
49
+ {
50
+ "type": "text",
51
+ "text": f"The following is the sequence of images",
52
+ }
53
+ ]
54
+ for encoded in encoded_images:
55
+ user_content.append(
56
+ {
57
+ "type": "image_url",
58
+ "image_url": {"url": f"data:image/jpeg;base64,{encoded}"},
59
+ }
60
+ )
61
+
62
+ # Create a chat completion request.
63
+ chat_response = self.client.chat.completions.create(
64
+ model=self.model,
65
+ messages=[
66
+ {"role": "system", "content": prompt},
67
+ {"role": "user", "content": user_content},
68
+ ],
69
+ max_tokens=1000,
70
+ temperature=0.0,
71
+ logprobs=True,
72
+ )
73
+ content = chat_response.choices[0].message.content
74
+ return content
75
+
76
+
77
+ def _load_entry_from_reader(video_path, query_text):
78
+ reader = Mp4Reader(
79
+ [{"path": video_path, "query": query_text}],
80
+ openai_save_path="",
81
+ sampling_rate_fps=0.5
82
+ )
83
+ data = reader.read_video()
84
+ if not data:
85
+ raise RuntimeError("No data returned by Mp4Reader (check video path)")
86
+ return data[0]
87
+
88
+
89
+ def _make_empty_video(path, width=320, height=240, fps=1.0):
90
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
91
+ writer = cv2.VideoWriter(path, fourcc, fps, (width, height))
92
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
93
+ writer.write(frame)
94
+ writer.release()
95
+ return path
96
+
97
+
98
+ def _crop_video_ffmpeg(input_path, output_path, frame_indices, prop_matrix):
99
+ if len(frame_indices) == 0:
100
+ cap = cv2.VideoCapture(str(input_path))
101
+ if not cap.isOpened():
102
+ raise RuntimeError(f"Could not open video: {input_path}")
103
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
104
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
105
+ cap.release()
106
+ _make_empty_video(output_path, width, height, fps=1.0)
107
+ return
108
+
109
+ def group_into_ranges(frames):
110
+ if not frames:
111
+ return []
112
+ frames = sorted(set(frames))
113
+ ranges = []
114
+ start = prev = frames[0]
115
+ for f in frames[1:]:
116
+ if f == prev + 1:
117
+ prev = f
118
+ else:
119
+ ranges.append((start, prev + 1)) # end-exclusive
120
+ start = prev = f
121
+ ranges.append((start, prev + 1))
122
+ return ranges
123
+
124
+ ranges = group_into_ranges(frame_indices)
125
+ filters = []
126
+ labels = []
127
+ for i, (start, end) in enumerate(ranges):
128
+ filters.append(
129
+ f"[0:v]trim=start_frame={start}:end_frame={end},setpts=PTS-STARTPTS[v{i}]"
130
+ )
131
+ labels.append(f"[v{i}]")
132
+ filters.append(f"{''.join(labels)}concat=n={len(ranges)}:v=1:a=0[outv]")
133
+
134
+ cmd = [
135
+ "ffmpeg", "-y", "-i", input_path,
136
+ "-filter_complex", "; ".join(filters),
137
+ "-map", "[outv]",
138
+ "-c:v", "libx264", "-preset", "fast", "-crf", "23",
139
+ output_path,
140
+ ]
141
+ subprocess.run(cmd, check=True)
142
+
143
+
144
+ def _crop_video(input_path: str, output_path: str, frame_indices: List[int], prop_matrix: Dict[str, List[int]]):
145
+ input_path = str(input_path)
146
+ output_path = str(output_path)
147
+
148
+ # Probe width/height/fps
149
+ cap = cv2.VideoCapture(input_path)
150
+ if not cap.isOpened():
151
+ raise RuntimeError(f"Could not open video: {input_path}")
152
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
153
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
154
+ fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
155
+ cap.release()
156
+ if fps <= 0:
157
+ fps = 30.0
158
+
159
+ # If nothing to write, emit a 1-frame empty video
160
+ if not frame_indices:
161
+ from numpy import zeros, uint8
162
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
163
+ out = cv2.VideoWriter(output_path, fourcc, 1.0, (width, height))
164
+ out.write(zeros((height, width, 3), dtype=uint8))
165
+ out.release()
166
+ return
167
+
168
+ # Helper: group consecutive integers into (start, end_exclusive)
169
+ def _group_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
170
+ f = sorted(set(int(x) for x in frames))
171
+ if not f:
172
+ return []
173
+ out = []
174
+ s = p = f[0]
175
+ for x in f[1:]:
176
+ if x == p + 1:
177
+ p = x
178
+ else:
179
+ out.append((s, p + 1))
180
+ s = p = x
181
+ out.append((s, p + 1))
182
+ return out
183
+
184
+ # Invert prop_matrix to {frame_idx: sorted [props]}
185
+ props_by_frame: Dict[int, List[str]] = {}
186
+ for prop, frames in (prop_matrix or {}).items():
187
+ for fi in frames:
188
+ fi = int(fi)
189
+ props_by_frame.setdefault(fi, []).append(prop)
190
+ for fi in list(props_by_frame.keys()):
191
+ props_by_frame[fi] = sorted(set(props_by_frame[fi]))
192
+
193
+ # Only subtitle frames we will output
194
+ fi_set = set(int(x) for x in frame_indices)
195
+ frames_with_labels = sorted(fi for fi in fi_set if props_by_frame.get(fi))
196
+
197
+ # Compress consecutive frames that share the same label set
198
+ grouped_label_spans: List[Tuple[int, int, Tuple[str, ...]]] = []
199
+ prev_f = None
200
+ prev_labels: Tuple[str, ...] = ()
201
+ span_start = None
202
+ for f in frames_with_labels:
203
+ labels = tuple(props_by_frame.get(f, []))
204
+ if prev_f is None:
205
+ span_start, prev_f, prev_labels = f, f, labels
206
+ elif (f == prev_f + 1) and (labels == prev_labels):
207
+ prev_f = f
208
+ else:
209
+ grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
210
+ span_start, prev_f, prev_labels = f, f, labels
211
+ if prev_f is not None and prev_labels:
212
+ grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
213
+
214
+ # Build ASS subtitle file (top-right)
215
+ def ass_time(t_sec: float) -> str:
216
+ cs = int(round(t_sec * 100))
217
+ h = cs // (100 * 3600)
218
+ m = (cs // (100 * 60)) % 60
219
+ s = (cs // 100) % 60
220
+ cs = cs % 100
221
+ return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
222
+
223
+ def make_ass(width: int, height: int) -> str:
224
+ lines = []
225
+ lines.append("[Script Info]")
226
+ lines.append("ScriptType: v4.00+")
227
+ lines.append("ScaledBorderAndShadow: yes")
228
+ lines.append(f"PlayResX: {width}")
229
+ lines.append(f"PlayResY: {height}")
230
+ lines.append("")
231
+ lines.append("[V4+ Styles]")
232
+ lines.append("Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, "
233
+ "Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, "
234
+ "Shadow, Alignment, MarginL, MarginR, MarginV, Encoding")
235
+ # Font size 18 per your request; Alignment=9 (top-right)
236
+ lines.append("Style: Default,DejaVu Sans,18,&H00FFFFFF,&H000000FF,&H00000000,&H64000000,"
237
+ "0,0,0,0,100,100,0,0,1,2,0.8,9,16,16,16,1")
238
+ lines.append("")
239
+ lines.append("[Events]")
240
+ lines.append("Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text")
241
+
242
+ for start_f, end_f, labels in grouped_label_spans:
243
+ if not labels:
244
+ continue
245
+ start_t = ass_time(start_f / fps)
246
+ end_t = ass_time(end_f / fps)
247
+ text = r"\N".join(labels) # stacked lines
248
+ lines.append(f"Dialogue: 0,{start_t},{end_t},Default,,0,0,0,,{text}")
249
+
250
+ return "\n".join(lines)
251
+
252
+ tmp_dir = tempfile.mkdtemp(prefix="props_ass_")
253
+ ass_path = os.path.join(tmp_dir, "props.ass")
254
+ with open(ass_path, "w", encoding="utf-8") as f:
255
+ f.write(make_ass(width, height))
256
+
257
+ # Build trim/concat ranges from requested frame_indices
258
+ ranges = _group_ranges(frame_indices)
259
+
260
+ # Filtergraph with burned subtitles then trim/concat
261
+ split_labels = [f"[s{i}]" for i in range(len(ranges))] if ranges else []
262
+ out_labels = [f"[v{i}]" for i in range(len(ranges))] if ranges else []
263
+
264
+ filters = []
265
+ ass_arg = ass_path.replace("\\", "\\\\")
266
+ filters.append(f"[0:v]subtitles='{ass_arg}'[sub]")
267
+
268
+ if len(ranges) == 1:
269
+ s0, e0 = ranges[0]
270
+ filters.append(f"[sub]trim=start_frame={s0}:end_frame={e0},setpts=PTS-STARTPTS[v0]")
271
+ else:
272
+ if ranges:
273
+ filters.append(f"[sub]split={len(ranges)}{''.join(split_labels)}")
274
+ for i, (s, e) in enumerate(ranges):
275
+ filters.append(f"{split_labels[i]}trim=start_frame={s}:end_frame={e},setpts=PTS-STARTPTS{out_labels[i]}")
276
+
277
+ if ranges:
278
+ filters.append(f"{''.join(out_labels)}concat=n={len(ranges)}:v=1:a=0[outv]")
279
+
280
+ filter_complex = "; ".join(filters)
281
+
282
+ cmd = [
283
+ "ffmpeg", "-y",
284
+ "-i", input_path,
285
+ "-filter_complex", filter_complex,
286
+ "-map", "[outv]" if ranges else "[sub]",
287
+ "-c:v", "libx264", "-preset", "fast", "-crf", "23",
288
+ output_path,
289
+ ]
290
+ try:
291
+ subprocess.run(cmd, check=True)
292
+ finally:
293
+ try:
294
+ os.remove(ass_path)
295
+ os.rmdir(tmp_dir)
296
+ except OSError:
297
+ pass
298
+
299
+ def _format_prop_ranges_dict(prop_matrix: Dict[str, List[int]]) -> str:
300
+ def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
301
+ f = sorted(set(int(x) for x in frames))
302
+ if not f:
303
+ return []
304
+ ranges: List[Tuple[int, int]] = []
305
+ s = p = f[0]
306
+ for x in f[1:]:
307
+ if x == p + 1:
308
+ p = x
309
+ else:
310
+ ranges.append((s, p)) # inclusive end for display
311
+ s = p = x
312
+ ranges.append((s, p))
313
+ return ranges
314
+
315
+ detections = {}
316
+ for prop, frames in prop_matrix.items():
317
+ ranges = group_into_ranges(frames)
318
+ detections[prop] = ranges
319
+ return detections
320
+
321
+
322
+ def _format_prop_ranges(prop_matrix: Dict[str, List[int]]) -> str:
323
+ def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
324
+ f = sorted(set(int(x) for x in frames))
325
+ if not f:
326
+ return []
327
+ ranges: List[Tuple[int, int]] = []
328
+ s = p = f[0]
329
+ for x in f[1:]:
330
+ if x == p + 1:
331
+ p = x
332
+ else:
333
+ ranges.append((s, p)) # inclusive end for display
334
+ s = p = x
335
+ ranges.append((s, p))
336
+ return ranges
337
+
338
+ if not prop_matrix:
339
+ return "No propositions detected."
340
+
341
+ lines = []
342
+ for prop, frames in prop_matrix.items():
343
+ ranges = group_into_ranges(frames)
344
+ pretty = prop.replace("_", " ").title()
345
+ if not ranges:
346
+ lines.append(f"{pretty}: —")
347
+ continue
348
+ parts = [f"{a}" if a == b else f"{a}-{b}" for (a, b) in ranges]
349
+ lines.append(f"{pretty}: {', '.join(parts)}")
350
+ return "\n".join(lines)
351
+
352
+ def generate_timeline_plot(detections, total_frames):
353
+ """
354
+ Generates a timeline plot from detection data using Matplotlib.
355
+
356
+ Args:
357
+ detections (dict): A dictionary where keys are string labels and values are lists
358
+ of (start_frame, end_frame) tuples.
359
+ e.g., {"dog": [(0, 45), (90, 100)], "grass": [(30, 80)]}
360
+ total_frames (int): The total number of frames in the video for the x-axis scale.
361
+
362
+ Returns:
363
+ matplotlib.figure.Figure: The generated plot figure.
364
+ """
365
+ labels = list(detections.keys())
366
+ num_labels = len(labels)
367
+
368
+ # Handle case with no detections
369
+ if num_labels == 0:
370
+ fig, ax = plt.subplots(figsize=(10, 1))
371
+ ax.text(0.5, 0.5, 'No propositions detected.', ha='center', va='center')
372
+ ax.set_axis_off()
373
+ return fig
374
+
375
+ # Use a color map to assign distinct colors automatically
376
+ colors = plt.cm.get_cmap('tab10', num_labels)
377
+
378
+ fig, ax = plt.subplots(figsize=(10, num_labels * 0.6 + 0.5))
379
+
380
+ ax.set_xlim(0, total_frames)
381
+ ax.set_ylim(0, num_labels)
382
+ ax.set_yticks(np.arange(num_labels) + 0.5)
383
+ ax.set_yticklabels(labels, fontsize=12)
384
+ ax.set_xlabel("Frame Number", fontsize=12)
385
+ ax.grid(axis='x', linestyle='--', alpha=0.6)
386
+
387
+ # Invert y-axis to have the first proposition on top
388
+ ax.invert_yaxis()
389
+
390
+ for i, label in enumerate(labels):
391
+ # matplotlib's broken_barh needs a list of (start, width) tuples
392
+ segments = [(start, end - start) for start, end in detections[label]]
393
+ # The bar is drawn at y-position 'i' with a height of 0.8
394
+ ax.broken_barh(segments, (i + 0.1, 0.8), facecolors=colors(i))
395
+
396
+ plt.tight_layout()
397
+ return fig
398
+
399
+ # -----------------------------
400
+ # Gradio handler
401
+ # -----------------------------
402
+ def run_pipeline(input_video, mode, query_text, propositions_json, specification_text):
403
+ """
404
+ Returns: (cropped_video_path, prop_ranges_text, tl_text)
405
+ """
406
+
407
+ def _err(msg, width=320, height=240): # keep outputs shape consistent
408
+ tmp_out = os.path.join("/tmp", f"empty_{uuid.uuid4().hex}.mp4")
409
+ _make_empty_video(tmp_out, width=width, height=height, fps=1.0)
410
+ return (
411
+ tmp_out,
412
+ "No propositions detected.",
413
+ f"Error: {msg}"
414
+ )
415
+
416
+ # Resolve video path
417
+ if isinstance(input_video, dict) and "name" in input_video:
418
+ video_path = input_video["name"]
419
+ elif isinstance(input_video, str):
420
+ video_path = input_video
421
+ else:
422
+ return _err("Please provide a video.")
423
+
424
+ # Build entry
425
+ if mode == "Natural language query":
426
+ if not query_text or not query_text.strip():
427
+ return _err("Please enter a query.")
428
+ entry = _load_entry_from_reader(video_path, query_text)
429
+ else:
430
+ if not (propositions_json and propositions_json.strip()) or not (specification_text and specification_text.strip()):
431
+ return _err("Please provide both Propositions (array) and Specification.")
432
+ entry = _load_entry_from_reader(video_path, "dummy-query")
433
+ try:
434
+ props = json.loads(propositions_json)
435
+ if not isinstance(props, list):
436
+ return _err("Propositions must be a JSON array.")
437
+ except Exception as e:
438
+ return _err(f"Failed to parse propositions JSON: {e}")
439
+ entry["tl"] = {
440
+ "propositions": props,
441
+ "specification": specification_text
442
+ }
443
+
444
+ # Compute FOI
445
+ try:
446
+ foi, prop_matrix, p2 = process_entry(entry) # list of frame indices & {prop: [frames]}
447
+ print(foi)
448
+ print(prop_matrix)
449
+ print(p2)
450
+ except Exception as e:
451
+ return _err(f"Processing error: {e}")
452
+
453
+ # Write cropped video
454
+ try:
455
+ out_path = os.path.join("/tmp", f"cropped_{uuid.uuid4().hex}.mp4")
456
+ _crop_video(video_path, out_path, foi, prop_matrix)
457
+ print(f"Wrote cropped video to: {out_path}")
458
+ except Exception as e:
459
+ return _err(f"Failed to write cropped video: {e}")
460
+
461
+ # Build right-side text sections
462
+ prop_ranges_text = _format_prop_ranges(prop_matrix)
463
+ prop_ranges_dict = _format_prop_ranges_dict(prop_matrix)
464
+ plot = generate_timeline_plot(prop_ranges_dict, entry["video_info"].frame_count)
465
+ tl_text = (
466
+ f"Propositions: {json.dumps(entry['tl']['propositions'], ensure_ascii=False)}\n"
467
+ f"Specification: {entry['tl']['specification']}"
468
+ )
469
+ return out_path, prop_ranges_text, tl_text, plot
470
+
471
+ def generate_caption(video_path):
472
+ """
473
+ Simulates generating a caption for the given video file.
474
+ """
475
+ # If the video is cleared, the input will be None
476
+ if video_path is None:
477
+ # Hide the caption box and clear its content
478
+ return gr.update(value="", visible=False)
479
+ print(f"Generating caption for: {video_path}")
480
+ vllm_client = VLLMClient()
481
+ entry = _load_entry_from_reader(video_path, "dummy-query")
482
+ # sample 4 frames from the video evenly
483
+ len_frames = len(entry['images'])
484
+ images = [entry['images'][i] for i in range(0, len_frames, len_frames//3)]
485
+ caption_text = vllm_client.caption(images)
486
+ # Simulate model inference time
487
+ # Use gr.update to change both the value and visibility of the textbox
488
+ return gr.update(value=caption_text, visible=True)
489
+ # -----------------------------
490
+ # UI
491
+ # -----------------------------
492
+ with gr.Blocks(css="""
493
+ #io-col {display: flex; gap: 1rem;}
494
+ #left {flex: 1;}
495
+ #right {flex: 1;}
496
+ """, title="NSVS-TL") as demo:
497
+
498
+ gr.Markdown("# Neuro-Symbolic Visual Search with Temporal Logic")
499
+ gr.Markdown(
500
+ "Upload a video and either provide a natural-language **Query** *or* directly supply **Propositions** (array) + **Specification**. "
501
+ "On the right, you'll get a **cropped video** containing only the frames of interest, a **Propositions by Frames** summary, and the combined TL summary."
502
+ )
503
+
504
+ with gr.Row(elem_id="io-col"):
505
+ with gr.Column(elem_id="left"):
506
+ mode = gr.Radio(
507
+ choices=["Natural language query", "Props/Spec"],
508
+ value="Natural language query",
509
+ label="Input mode"
510
+ )
511
+ video = gr.Video(label="Upload Video")
512
+
513
+ query = gr.Textbox(
514
+ label="Query (natural language)",
515
+ placeholder="e.g., a man is jumping and panting until he falls down"
516
+ )
517
+
518
+ captions = gr.Textbox(
519
+ label="Video Caption",
520
+ placeholder="e.g., a man is jumping and panting until he falls down",
521
+ lines=4,
522
+ visible=False
523
+ )
524
+
525
+ propositions = gr.Textbox(
526
+ label="Propositions (JSON array)",
527
+ placeholder='e.g., ["man_jumps", "man_pants", "man_falls_down"]',
528
+ lines=4,
529
+ visible=False
530
+ )
531
+ specification = gr.Textbox(
532
+ label="Specification",
533
+ placeholder='e.g., ("woman_jumps" & "woman_claps") U "candle_is_blown"',
534
+ visible=False
535
+ )
536
+
537
+ def _toggle_fields(m):
538
+ if m == "Natural language query":
539
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
540
+ else:
541
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
542
+
543
+ mode.change(_toggle_fields, inputs=[mode], outputs=[query, propositions, specification])
544
+ video.change(
545
+ fn=generate_caption,
546
+ inputs=[video],
547
+ outputs=[captions]
548
+ )
549
+ run_btn = gr.Button("Run", variant="primary")
550
+
551
+ gr.Examples(
552
+ label="Examples (dummy paths + queries)",
553
+ examples=[
554
+ ["demo_videos/dog_jump.mp4", "a dog jumps until a red tube is in view"],
555
+ ["demo_videos/blue_shirt.mp4", "a girl in a green shirt until a candle is blown"],
556
+ ["demo_videos/car.mp4", "red car until a truck"]
557
+ ],
558
+ inputs=[video, query],
559
+ cache_examples=False
560
+ )
561
+
562
+ with gr.Column(elem_id="right"):
563
+ cropped_video = gr.Video(label="Cropped Video (Frames of Interest Only)")
564
+
565
+ prop_ranges_out = gr.Textbox(
566
+ label="Propositions by Frames",
567
+ lines=6,
568
+ interactive=False
569
+ )
570
+
571
+ timeline_plot_output = gr.Plot(label="Propositions Timeline")
572
+
573
+ tl_out = gr.Textbox(
574
+ label="TL (Propositions & Specification)",
575
+ lines=8,
576
+ interactive=False
577
+ )
578
+
579
+ run_btn.click(
580
+ fn=run_pipeline,
581
+ inputs=[video, mode, query, propositions, specification],
582
+ outputs=[cropped_video, prop_ranges_out, tl_out, timeline_plot_output]
583
+ )
584
+
585
+ if __name__ == "__main__":
586
+ demo.launch(server_name="0.0.0.0", server_port=7860)
587
+
588
+
execute_demo_v3.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import uuid
5
+ import base64
6
+ import tempfile
7
+ import subprocess
8
+ import numpy as np
9
+ import gradio as gr
10
+
11
+ from openai import OpenAI
12
+ from matplotlib import pyplot as plt
13
+ from typing import Dict, List, Iterable, Tuple, Union
14
+
15
+ from ns_vfs.video.read_mp4 import Mp4Reader
16
+ from execute_with_mp4 import process_entry
17
+
18
+ # Optional import of preprocess_yolo if available alongside process_entry
19
+ try:
20
+ from execute_with_mp4 import preprocess_yolo
21
+ except Exception:
22
+ preprocess_yolo = None
23
+
24
+
25
+ class VLLMClient:
26
+ def __init__(
27
+ self,
28
+ api_key="EMPTY",
29
+ api_base="http://localhost:8000/v1",
30
+ model="OpenGVLab/InternVL2-8B",
31
+ ):
32
+ self.client = OpenAI(api_key=api_key, base_url=api_base)
33
+ self.model = model
34
+
35
+ def _encode_frame(self, frame):
36
+ ok, buffer = cv2.imencode(".jpg", frame)
37
+ if not ok:
38
+ raise ValueError("Could not encode frame")
39
+ return base64.b64encode(buffer).decode("utf-8")
40
+
41
+ def caption(self, frames: list[np.ndarray]):
42
+ parsing_rule = (
43
+ " You must return a caption for the sequence of images. "
44
+ "The caption must be a single sentence. "
45
+ "The caption must be in the same language as the question."
46
+ )
47
+ prompt = (
48
+ r"Give me a detailed description of what you see in the images "
49
+ f"\n[PARSING RULE]: {parsing_rule}"
50
+ )
51
+ encoded_images = [self._encode_frame(frame) for frame in frames]
52
+ user_content = [{"type": "text", "text": "The following is the sequence of images"}]
53
+ for encoded in encoded_images:
54
+ user_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded}"}})
55
+
56
+ chat_response = self.client.chat.completions.create(
57
+ model=self.model,
58
+ messages=[
59
+ {"role": "system", "content": prompt},
60
+ {"role": "user", "content": user_content},
61
+ ],
62
+ max_tokens=1000,
63
+ temperature=0.0,
64
+ logprobs=True,
65
+ )
66
+ return chat_response.choices[0].message.content
67
+
68
+
69
+ def _load_entry_from_reader(video_path, query_text):
70
+ reader = Mp4Reader(
71
+ [{"path": video_path, "query": query_text}],
72
+ openai_save_path="",
73
+ sampling_rate_fps=2
74
+ )
75
+ data = reader.read_video()
76
+ if not data:
77
+ raise RuntimeError("No data returned by Mp4Reader (check video path)")
78
+ return data[0]
79
+
80
+
81
+ def _make_empty_video(path, width=320, height=240, fps=1.0):
82
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
83
+ writer = cv2.VideoWriter(path, fourcc, fps, (width, height))
84
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
85
+ writer.write(frame)
86
+ writer.release()
87
+ return path
88
+
89
+
90
+ # -----------------------------
91
+ # Helpers to detect bbox-style outputs and to convert them
92
+ # -----------------------------
93
+ BBox = Tuple[float, float, float, float]
94
+ YOLODict = Dict[str, List[Tuple[int, BBox]]]
95
+ VLMDict = Dict[str, List[int]]
96
+
97
+ def _has_bboxes(prop_matrix: Union[YOLODict, VLMDict]) -> bool:
98
+ """Return True if the prop_matrix contains (frame_idx, bbox) tuples."""
99
+ if not prop_matrix:
100
+ return False
101
+ for v in prop_matrix.values():
102
+ if not v:
103
+ continue
104
+ first = v[0]
105
+ if isinstance(first, tuple) and len(first) == 2 and hasattr(first[1], "__len__") and len(first[1]) == 4:
106
+ return True
107
+ return False
108
+
109
+ def _bbox_dict_to_frames_only(prop_bboxes: YOLODict) -> VLMDict:
110
+ """Convert {'car': [(i, (x1,y1,x2,y2)), ...], ...} -> {'car': [i, ...], ...}"""
111
+ out: VLMDict = {}
112
+ for k, pairs in (prop_bboxes or {}).items():
113
+ out[k] = [int(i) for i, _ in pairs]
114
+ return out
115
+
116
+
117
+ # -----------------------------
118
+ # Video cropping and overlays
119
+ # -----------------------------
120
+ def _crop_video_subtitles(input_path: str, output_path: str, frame_indices: List[int], prop_matrix: VLMDict):
121
+ """
122
+ Existing behavior (VLM/no bboxes):
123
+ - Keep only frames in frame_indices (in order, contiguous groups)
124
+ - Overlay top-right proposition text via ASS subtitles
125
+ """
126
+ input_path = str(input_path)
127
+ output_path = str(output_path)
128
+
129
+ cap = cv2.VideoCapture(input_path)
130
+ if not cap.isOpened():
131
+ raise RuntimeError(f"Could not open video: {input_path}")
132
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
133
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
134
+ fps = float(cap.get(cv2.CAP_PROP_FPS)) or 0.0
135
+ cap.release()
136
+ if fps <= 0:
137
+ fps = 30.0
138
+
139
+ if not frame_indices:
140
+ from numpy import zeros, uint8
141
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
142
+ out = cv2.VideoWriter(output_path, fourcc, 1.0, (width, height))
143
+ out.write(zeros((height, width, 3), dtype=uint8))
144
+ out.release()
145
+ return
146
+
147
+ def _group_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
148
+ f = sorted(set(int(x) for x in frames))
149
+ if not f:
150
+ return []
151
+ out = []
152
+ s = p = f[0]
153
+ for x in f[1:]:
154
+ if x == p + 1:
155
+ p = x
156
+ else:
157
+ out.append((s, p + 1))
158
+ s = p = x
159
+ out.append((s, p + 1))
160
+ return out
161
+
162
+ props_by_frame: Dict[int, List[str]] = {}
163
+ for prop, frames in (prop_matrix or {}).items():
164
+ for fi in frames:
165
+ fi = int(fi)
166
+ props_by_frame.setdefault(fi, []).append(prop)
167
+ for fi in list(props_by_frame.keys()):
168
+ props_by_frame[fi] = sorted(set(props_by_frame[fi]))
169
+
170
+ fi_set = set(int(x) for x in frame_indices)
171
+ frames_with_labels = sorted(fi for fi in fi_set if props_by_frame.get(fi))
172
+
173
+ grouped_label_spans: List[Tuple[int, int, Tuple[str, ...]]] = []
174
+ prev_f = None
175
+ prev_labels: Tuple[str, ...] = ()
176
+ span_start = None
177
+ for f in frames_with_labels:
178
+ labels = tuple(props_by_frame.get(f, []))
179
+ if prev_f is None:
180
+ span_start, prev_f, prev_labels = f, f, labels
181
+ elif (f == prev_f + 1) and (labels == prev_labels):
182
+ prev_f = f
183
+ else:
184
+ grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
185
+ span_start, prev_f, prev_labels = f, f, labels
186
+ if prev_f is not None and prev_labels:
187
+ grouped_label_spans.append((span_start, prev_f + 1, prev_labels))
188
+
189
+ # Build ASS subtitle (top-right)
190
+ def ass_time(t_sec: float) -> str:
191
+ cs = int(round(t_sec * 100))
192
+ h = cs // (100 * 3600)
193
+ m = (cs // (100 * 60)) % 60
194
+ s = (cs // 100) % 60
195
+ cs = cs % 100
196
+ return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
197
+
198
+ def make_ass(width: int, height: int) -> str:
199
+ lines = []
200
+ lines.append("[Script Info]")
201
+ lines.append("ScriptType: v4.00+")
202
+ lines.append("ScaledBorderAndShadow: yes")
203
+ lines.append(f"PlayResX: {width}")
204
+ lines.append(f"PlayResY: {height}")
205
+ lines.append("")
206
+ lines.append("[V4+ Styles]")
207
+ lines.append("Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, "
208
+ "Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, "
209
+ "Shadow, Alignment, MarginL, MarginR, MarginV, Encoding")
210
+ lines.append("Style: Default,DejaVu Sans,18,&H00FFFFFF,&H000000FF,&H00000000,&H64000000,"
211
+ "0,0,0,0,100,100,0,0,1,2,0.8,9,16,16,16,1")
212
+ lines.append("")
213
+ lines.append("[Events]")
214
+ lines.append("Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text")
215
+
216
+ for start_f, end_f, labels in grouped_label_spans:
217
+ if not labels:
218
+ continue
219
+ start_t = ass_time(start_f / fps)
220
+ end_t = ass_time(end_f / fps)
221
+ text = r"\N".join(labels) # stacked lines
222
+ lines.append(f"Dialogue: 0,{start_t},{end_t},Default,,0,0,0,,{text}")
223
+
224
+ return "\n".join(lines)
225
+
226
+ tmp_dir = tempfile.mkdtemp(prefix="props_ass_")
227
+ ass_path = os.path.join(tmp_dir, "props.ass")
228
+ with open(ass_path, "w", encoding="utf-8") as f:
229
+ f.write(make_ass(width, height))
230
+
231
+ ranges = _group_ranges(frame_indices)
232
+
233
+ split_labels = [f"[s{i}]" for i in range(len(ranges))] if ranges else []
234
+ out_labels = [f"[v{i}]" for i in range(len(ranges))] if ranges else []
235
+
236
+ filters = []
237
+ ass_arg = ass_path.replace("\\", "\\\\")
238
+ filters.append(f"[0:v]subtitles='{ass_arg}'[sub]")
239
+
240
+ if len(ranges) == 1:
241
+ s0, e0 = ranges[0]
242
+ filters.append(f"[sub]trim=start_frame={s0}:end_frame={e0},setpts=PTS-STARTPTS[v0]")
243
+ else:
244
+ if ranges:
245
+ filters.append(f"[sub]split={len(ranges)}{''.join(split_labels)}")
246
+ for i, (s, e) in enumerate(ranges):
247
+ filters.append(f"{split_labels[i]}trim=start_frame={s}:end_frame={e},setpts=PTS-STARTPTS{out_labels[i]}")
248
+
249
+ if ranges:
250
+ filters.append(f"{''.join(out_labels)}concat=n={len(ranges)}:v=1:a=0[outv]")
251
+
252
+ filter_complex = "; ".join(filters)
253
+
254
+ cmd = [
255
+ "ffmpeg", "-y",
256
+ "-i", input_path,
257
+ "-filter_complex", filter_complex,
258
+ "-map", "[outv]" if ranges else "[sub]",
259
+ "-c:v", "libx264", "-preset", "fast", "-crf", "23",
260
+ output_path,
261
+ ]
262
+ try:
263
+ subprocess.run(cmd, check=True)
264
+ finally:
265
+ try:
266
+ os.remove(ass_path)
267
+ os.rmdir(tmp_dir)
268
+ except OSError:
269
+ pass
270
+
271
+
272
+ def _crop_video_bboxes(input_path: str, output_path: str, frame_indices: List[int], prop_bboxes: YOLODict):
273
+ """
274
+ YOLO path (with bounding boxes):
275
+ - Keep only frames in frame_indices.
276
+ - Draw rectangles for each detected prop on the kept frames.
277
+ - Label each rectangle with the prop name (top-left of box).
278
+ """
279
+ keep_set = set(int(x) for x in frame_indices)
280
+ if not keep_set:
281
+ # output a 1-frame empty video (consistent with _crop_video_subtitles)
282
+ cap0 = cv2.VideoCapture(input_path)
283
+ if not cap0.isOpened():
284
+ raise RuntimeError(f"Could not open video: {input_path}")
285
+ width = int(cap0.get(cv2.CAP_PROP_FRAME_WIDTH))
286
+ height = int(cap0.get(cv2.CAP_PROP_FRAME_HEIGHT))
287
+ cap0.release()
288
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
289
+ out = cv2.VideoWriter(output_path, fourcc, 1.0, (width, height))
290
+ out.write(np.zeros((height, width, 3), dtype=np.uint8))
291
+ out.release()
292
+ return
293
+
294
+ # Build frame -> list[(prop, bbox)]
295
+ per_frame: Dict[int, List[Tuple[str, BBox]]] = {}
296
+ for prop, pairs in (prop_bboxes or {}).items():
297
+ for fi, bbox in pairs:
298
+ fi = int(fi)
299
+ per_frame.setdefault(fi, []).append((prop, bbox))
300
+
301
+ cap = cv2.VideoCapture(input_path)
302
+ if not cap.isOpened():
303
+ raise RuntimeError(f"Could not open video: {input_path}")
304
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
305
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
306
+ fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0
307
+
308
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
309
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
310
+
311
+ idx = 0
312
+ ok, frame = cap.read()
313
+ while ok:
314
+ if idx in keep_set:
315
+ # draw all bboxes for this frame
316
+ for prop, (x1, y1, x2, y2) in per_frame.get(idx, []):
317
+ p1 = (int(round(x1)), int(round(y1)))
318
+ p2 = (int(round(x2)), int(round(y2)))
319
+ cv2.rectangle(frame, p1, p2, (0, 255, 0), 2) # green rectangle
320
+ # text background for readability
321
+ label = prop.replace("_", " ")
322
+ txt_origin = (p1[0], max(0, p1[1] - 5))
323
+ cv2.putText(frame, label, txt_origin, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
324
+ cv2.putText(frame, label, txt_origin, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
325
+ out.write(frame)
326
+ idx += 1
327
+ ok, frame = cap.read()
328
+
329
+ cap.release()
330
+ out.release()
331
+
332
+
333
+ def _crop_video(
334
+ input_path: str,
335
+ output_path: str,
336
+ frame_indices: List[int],
337
+ prop_matrix: Union[VLMDict, YOLODict]
338
+ ):
339
+ """
340
+ Dispatch to the appropriate cropper:
341
+ - VLM/no-bbox: ASS subtitle overlay.
342
+ - YOLO with bbox: draw rectangles overlay via OpenCV.
343
+ """
344
+ if _has_bboxes(prop_matrix):
345
+ _crop_video_bboxes(input_path, output_path, frame_indices, prop_matrix) # type: ignore[arg-type]
346
+ else:
347
+ _crop_video_subtitles(input_path, output_path, frame_indices, prop_matrix) # type: ignore[arg-type]
348
+
349
+
350
+ # -----------------------------
351
+ # Text helpers (unchanged API, but robust to bbox dicts)
352
+ # -----------------------------
353
+ def _format_prop_ranges_dict(prop_matrix: Union[VLMDict, YOLODict]) -> Dict[str, List[Tuple[int, int]]]:
354
+ def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
355
+ f = sorted(set(int(x) for x in frames))
356
+ if not f:
357
+ return []
358
+ ranges: List[Tuple[int, int]] = []
359
+ s = p = f[0]
360
+ for x in f[1:]:
361
+ if x == p + 1:
362
+ p = x
363
+ else:
364
+ ranges.append((s, p))
365
+ s = p = x
366
+ ranges.append((s, p))
367
+ return ranges
368
+
369
+ if _has_bboxes(prop_matrix):
370
+ frames_only = _bbox_dict_to_frames_only(prop_matrix) # type: ignore[arg-type]
371
+ else:
372
+ frames_only = prop_matrix # type: ignore[assignment]
373
+
374
+ detections: Dict[str, List[Tuple[int, int]]] = {}
375
+ for prop, frames in (frames_only or {}).items():
376
+ detections[prop] = group_into_ranges(frames)
377
+ return detections
378
+
379
+
380
+ def _format_prop_ranges(prop_matrix: Union[VLMDict, YOLODict]) -> str:
381
+ def group_into_ranges(frames: Iterable[int]) -> List[Tuple[int, int]]:
382
+ f = sorted(set(int(x) for x in frames))
383
+ if not f:
384
+ return []
385
+ ranges: List[Tuple[int, int]] = []
386
+ s = p = f[0]
387
+ for x in f[1:]:
388
+ if x == p + 1:
389
+ p = x
390
+ else:
391
+ ranges.append((s, p))
392
+ s = p = x
393
+ ranges.append((s, p))
394
+ return ranges
395
+
396
+ if not prop_matrix:
397
+ return "No propositions detected."
398
+
399
+ if _has_bboxes(prop_matrix):
400
+ frames_only = _bbox_dict_to_frames_only(prop_matrix) # type: ignore[arg-type]
401
+ else:
402
+ frames_only = prop_matrix # type: ignore[assignment]
403
+
404
+ lines = []
405
+ for prop, frames in (frames_only or {}).items():
406
+ ranges = group_into_ranges(frames)
407
+ pretty = prop.replace("_", " ").title()
408
+ if not ranges:
409
+ lines.append(f"{pretty}: —")
410
+ continue
411
+ parts = [f"{a}" if a == b else f"{a}-{b}" for (a, b) in ranges]
412
+ lines.append(f"{pretty}: {', '.join(parts)}")
413
+ return "\n".join(lines)
414
+
415
+
416
+ # -----------------------------
417
+ # Plotting
418
+ # -----------------------------
419
+ def generate_timeline_plot(detections, total_frames):
420
+ labels = list(detections.keys())
421
+ num_labels = len(labels)
422
+
423
+ if num_labels == 0:
424
+ fig, ax = plt.subplots(figsize=(10, 1))
425
+ ax.text(0.5, 0.5, 'No propositions detected.', ha='center', va='center')
426
+ ax.set_axis_off()
427
+ return fig
428
+
429
+ colors = plt.cm.get_cmap('tab10', num_labels)
430
+ fig, ax = plt.subplots(figsize=(10, num_labels * 0.6 + 0.5))
431
+
432
+ ax.set_xlim(0, total_frames)
433
+ ax.set_ylim(0, num_labels)
434
+ ax.set_yticks(np.arange(num_labels) + 0.5)
435
+ ax.set_yticklabels(labels, fontsize=12)
436
+ ax.set_xlabel("Frame Number", fontsize=12)
437
+ ax.grid(axis='x', linestyle='--', alpha=0.6)
438
+ ax.invert_yaxis()
439
+
440
+ for i, label in enumerate(labels):
441
+ segments = [(start, end - start) for start, end in detections[label]]
442
+ ax.broken_barh(segments, (i + 0.1, 0.8))
443
+
444
+ plt.tight_layout()
445
+ return fig
446
+
447
+
448
+ # -----------------------------
449
+ # Helpers for YOLO cache path
450
+ # -----------------------------
451
+ def _yolo_cache_path_for_video(video_path: str) -> str:
452
+ """
453
+ Always save the YOLO cache in the demo_videos folder.
454
+ demo_videos/car.mp4 -> demo_videos/car.npz
455
+ uploads/tmp123.mp4 -> demo_videos/tmp123.npz
456
+ """
457
+ base = os.path.basename(video_path)
458
+ root, _ = os.path.splitext(base)
459
+ os.makedirs("demo_videos", exist_ok=True)
460
+ return os.path.join("demo_videos", f"{root}.npz")
461
+
462
+
463
+ # -----------------------------
464
+ # Gradio handler
465
+ # -----------------------------
466
+ def run_pipeline(input_video, mode, detector, query_text, propositions_json, specification_text):
467
+ def _err(msg, width=320, height=240):
468
+ tmp_out = os.path.join("/tmp", f"empty_{uuid.uuid4().hex}.mp4")
469
+ _make_empty_video(tmp_out, width=width, height=height, fps=1.0)
470
+ return (tmp_out, "No propositions detected.", f"Error: {msg}", None)
471
+
472
+ # Normalize input path
473
+ if isinstance(input_video, dict) and "name" in input_video:
474
+ video_path = input_video["name"]
475
+ elif isinstance(input_video, str):
476
+ video_path = input_video
477
+ else:
478
+ return _err("Please provide a video.")
479
+
480
+ # Build entry
481
+ if mode == "Natural language query":
482
+ if not query_text or not query_text.strip():
483
+ return _err("Please enter a query.")
484
+ entry = _load_entry_from_reader(video_path, query_text)
485
+ else:
486
+ if not (propositions_json and propositions_json.strip()) or not (specification_text and specification_text.strip()):
487
+ return _err("Please provide both Propositions (array) and Specification.")
488
+ entry = _load_entry_from_reader(video_path, "dummy-query")
489
+ try:
490
+ props = json.loads(propositions_json)
491
+ if not isinstance(props, list):
492
+ return _err("Propositions must be a JSON array.")
493
+ except Exception as e:
494
+ return _err(f"Failed to parse propositions JSON: {e}")
495
+ entry["tl"] = {"propositions": props, "specification": specification_text}
496
+
497
+ # Process depending on detector
498
+ foi = None
499
+ prop_matrix: Union[VLMDict, YOLODict] = {}
500
+
501
+ if detector == "YOLO":
502
+ cache_path = _yolo_cache_path_for_video(video_path)
503
+
504
+ # 1) preprocess_yolo when YOLO is on
505
+ try:
506
+ if preprocess_yolo is None:
507
+ raise NameError("preprocess_yolo() not defined")
508
+ ret_path = preprocess_yolo(
509
+ entry["images"],
510
+ model_weights="yolov8n.pt",
511
+ device="cuda:0",
512
+ out_path=cache_path
513
+ )
514
+ if isinstance(ret_path, str) and ret_path.strip():
515
+ cache_path = ret_path
516
+ except NameError:
517
+ return _err("YOLO selected but preprocess_yolo is not available.")
518
+ except Exception as e:
519
+ return _err(f"YOLO preprocessing error: {e}")
520
+
521
+ # 2) then run with YOLO
522
+ try:
523
+ res = process_entry(entry, run_with_yolo=True, cache_path=cache_path)
524
+ if isinstance(res, tuple) and len(res) == 2:
525
+ foi, prop_matrix = res
526
+ else:
527
+ foi = res
528
+ prop_matrix = {}
529
+ except Exception as e:
530
+ return _err(f"Processing error (YOLO mode): {e}")
531
+
532
+ else:
533
+ # VLM path only
534
+ try:
535
+ foi, prop_matrix = process_entry(entry, run_with_yolo=False)
536
+ except Exception as e:
537
+ return _err(f"Processing error (VLM mode): {e}")
538
+
539
+ # Export cropped video (with either subtitles or bbox overlays)
540
+ try:
541
+ out_path = os.path.join("/tmp", f"cropped_{uuid.uuid4().hex}.mp4")
542
+ _crop_video(video_path, out_path, foi, prop_matrix)
543
+ except Exception as e:
544
+ return _err(f"Failed to write cropped video: {e}")
545
+
546
+ # Text + plot (work from frames; ignore bbox coords)
547
+ try:
548
+ prop_ranges_text = _format_prop_ranges(prop_matrix)
549
+ prop_ranges_dict = _format_prop_ranges_dict(prop_matrix)
550
+ plot = generate_timeline_plot(prop_ranges_dict, entry["video_info"].frame_count)
551
+ except Exception:
552
+ prop_ranges_text = "No propositions detected." if not prop_matrix else str(prop_matrix)
553
+ plot = generate_timeline_plot({}, entry["video_info"].frame_count)
554
+
555
+ tl_text = (
556
+ f"Propositions: {json.dumps(entry['tl']['propositions'], ensure_ascii=False)}\n"
557
+ f"Specification: {entry['tl']['specification']}"
558
+ )
559
+ return out_path, prop_ranges_text, tl_text, plot
560
+
561
+
562
+ def generate_caption(video_path):
563
+ if video_path is None:
564
+ return gr.update(value="", visible=False)
565
+ vllm_client = VLLMClient()
566
+ entry = _load_entry_from_reader(video_path, "dummy-query")
567
+ n = len(entry['images'])
568
+ step = max(1, n // 3)
569
+ images = [entry['images'][i] for i in range(0, n, step)][:3]
570
+ caption_text = vllm_client.caption(images)
571
+ return gr.update(value=caption_text, visible=True)
572
+
573
+
574
+ # -----------------------------
575
+ # UI
576
+ # -----------------------------
577
+ with gr.Blocks(css="""
578
+ #io-col {display: flex; gap: 1rem;}
579
+ #left {flex: 1;}
580
+ #right {flex: 1;}
581
+ """, title="NSVS-TL") as demo:
582
+
583
+ gr.Markdown("# Neuro-Symbolic Visual Search with Temporal Logic")
584
+ gr.Markdown("Upload a video and either provide a natural-language **Query** *or* directly supply **Propositions** + **Specification**.")
585
+
586
+ with gr.Row(elem_id="io-col"):
587
+ with gr.Column(elem_id="left"):
588
+ mode = gr.Radio(
589
+ choices=["Natural language query", "Props/Spec"],
590
+ value="Natural language query",
591
+ label="Input mode"
592
+ )
593
+
594
+ detector = gr.Radio(
595
+ choices=["VLM", "YOLO"],
596
+ value="VLM",
597
+ label="Yolo vs VLM"
598
+ )
599
+
600
+ video = gr.Video(label="Upload Video")
601
+
602
+ query = gr.Textbox(
603
+ label="Query (natural language)",
604
+ placeholder="e.g., a man is jumping and panting until he falls down"
605
+ )
606
+
607
+ captions = gr.Textbox(
608
+ label="Video Caption",
609
+ placeholder="Auto caption will appear here",
610
+ lines=4,
611
+ visible=False
612
+ )
613
+
614
+ propositions = gr.Textbox(
615
+ label="Propositions (JSON array)",
616
+ placeholder='e.g., ["man_jumps", "man_pants", "man_falls_down"]',
617
+ lines=4,
618
+ visible=False
619
+ )
620
+ specification = gr.Textbox(
621
+ label="Specification",
622
+ placeholder='e.g., ("woman_jumps" & "woman_claps") U "candle_is_blown"',
623
+ visible=False
624
+ )
625
+
626
+ def _toggle_fields(m):
627
+ if m == "Natural language query":
628
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
629
+ else:
630
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
631
+
632
+ # Only toggles visibility of fields; no processing
633
+ mode.change(_toggle_fields, inputs=[mode], outputs=[query, propositions, specification])
634
+
635
+ # Only auto-caption runs on video change
636
+ video.change(fn=generate_caption, inputs=[video], outputs=[captions], queue=False)
637
+
638
+ run_btn = gr.Button("Run", variant="primary")
639
+
640
+ gr.Examples(
641
+ label="Examples",
642
+ examples=[
643
+ ["demo_videos/dog_jump.mp4", "a dog jumps until a red tube is in view"],
644
+ ["demo_videos/blue_shirt.mp4", "a girl in a green shirt until a candle is blown"],
645
+ ["demo_videos/car.mp4", "red car until a truck"],
646
+ ["demo_videos/newyork_1.mp4", "taxi until empire state building"],
647
+ ["demo_videos/chicago_2.mp4", "boat until ferris wheel"]
648
+ ],
649
+ inputs=[video, query],
650
+ cache_examples=False
651
+ )
652
+
653
+ with gr.Column(elem_id="right"):
654
+ cropped_video = gr.Video(label="Cropped Video (Frames of Interest Only)")
655
+ prop_ranges_out = gr.Textbox(label="Propositions by Frames", lines=6, interactive=False)
656
+ timeline_plot_output = gr.Plot(label="Propositions Timeline")
657
+ tl_out = gr.Textbox(label="TL (Propositions & Specification)", lines=8, interactive=False)
658
+
659
+ # ONLY the Run button triggers processing/preprocessing
660
+ run_btn.click(
661
+ fn=run_pipeline,
662
+ inputs=[video, mode, detector, query, propositions, specification],
663
+ outputs=[cropped_video, prop_ranges_out, tl_out, timeline_plot_output]
664
+ )
665
+
666
+ if __name__ == "__main__":
667
+ demo.launch(server_name="0.0.0.0", server_port=7860)
668
+
execute_with_mp4.py CHANGED
@@ -1,4 +1,4 @@
1
- from tqdm import tqdm
2
  import itertools
3
  import operator
4
  import json
@@ -6,24 +6,30 @@ import time
6
  import os
7
 
8
  from ns_vfs.nsvs import run_nsvs
 
9
  from ns_vfs.video.read_mp4 import Mp4Reader
10
 
11
 
12
  VIDEOS = [
13
  {
14
- "path": "demo_videos/blue_shirt.mp4",
15
- "query": "a woman is jumping and clapping until a candle is blown"
16
  }
17
  ]
18
  DEVICE = 7 # GPU device index
19
  OPENAI_SAVE_PATH = ""
20
  OUTPUT_DIR = "output"
21
 
 
 
22
  def fill_in_frame_count(arr, entry):
23
  scale = (entry["video_info"].fps) / (entry["metadata"]["sampling_rate_fps"])
24
 
25
  runs = []
26
- for _, grp in itertools.groupby(sorted(arr), key=lambda x, c=[0]: (x - (c.__setitem__(0, c[0]+1) or c[0]))):
 
 
 
27
  g = list(grp)
28
  runs.append((g[0], g[-1]))
29
 
@@ -36,30 +42,106 @@ def fill_in_frame_count(arr, entry):
36
  real.extend(range(a, b + 1))
37
  return real
38
 
39
- def process_entry(entry):
40
- foi, object_frame_dict, px = run_nsvs(
41
- frames=entry['images'],
42
- proposition=entry['tl']['propositions'],
43
- specification=entry['tl']['specification'],
44
- model_name="InternVL2-8B",
45
- device=DEVICE
46
- )
47
 
48
- foi = fill_in_frame_count([i for sub in foi for i in sub], entry)
49
- object_frame_dict = {key: fill_in_frame_count(value, entry) for key, value in object_frame_dict.items()}
50
- px = {key: fill_in_frame_count(value, entry) for key, value in px.items()}
51
- return foi, object_frame_dict, px
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def main():
54
  reader = Mp4Reader(VIDEOS, OPENAI_SAVE_PATH, sampling_rate_fps=1)
55
  data = reader.read_video()
56
  if not data:
57
  return
 
 
 
58
 
59
- with tqdm(enumerate(data), total=len(data), desc="Processing entries") as pbar:
60
  for i, entry in pbar:
61
  start_time = time.time()
62
- foi = process_entry(entry)
63
  end_time = time.time()
64
  processing_time = round(end_time - start_time, 3)
65
 
 
1
+ import tqdm
2
  import itertools
3
  import operator
4
  import json
 
6
  import os
7
 
8
  from ns_vfs.nsvs import run_nsvs
9
+ from ns_vfs.nsvs_yolo import *
10
  from ns_vfs.video.read_mp4 import Mp4Reader
11
 
12
 
13
  VIDEOS = [
14
  {
15
+ "path": "demo_videos/car.mp4",
16
+ "query": "car until truck"
17
  }
18
  ]
19
  DEVICE = 7 # GPU device index
20
  OPENAI_SAVE_PATH = ""
21
  OUTPUT_DIR = "output"
22
 
23
+ import itertools
24
+
25
  def fill_in_frame_count(arr, entry):
26
  scale = (entry["video_info"].fps) / (entry["metadata"]["sampling_rate_fps"])
27
 
28
  runs = []
29
+ for _, grp in itertools.groupby(
30
+ sorted(arr),
31
+ key=lambda x, c=[0]: (x - (c.__setitem__(0, c[0] + 1) or c[0]))
32
+ ):
33
  g = list(grp)
34
  runs.append((g[0], g[-1]))
35
 
 
42
  real.extend(range(a, b + 1))
43
  return real
44
 
 
 
 
 
 
 
 
 
45
 
46
+ def _fill_in_frame_count_pairs(pairs, entry):
47
+ if not pairs:
48
+ return []
49
+ scale = (entry["video_info"].fps) / (entry["metadata"]["sampling_rate_fps"])
50
+
51
+ pairs = sorted(pairs, key=lambda t: int(t[0]))
52
+ sampled_indices = [int(i) for i, _ in pairs]
53
+
54
+ runs = []
55
+ for _, grp in itertools.groupby(
56
+ sampled_indices,
57
+ key=lambda x, c=[0]: (x - (c.__setitem__(0, c[0] + 1) or c[0]))
58
+ ):
59
+ g = list(grp)
60
+ runs.append((g[0], g[-1]))
61
+
62
+ idx2bbox = {}
63
+ for i, bbox in pairs:
64
+ i = int(i)
65
+ if i not in idx2bbox:
66
+ idx2bbox[i] = bbox
67
+
68
+ expanded: list[tuple[int, tuple[float, float, float, float]]] = []
69
+ last_real = -1
70
+
71
+ for start_i, end_i in runs:
72
+ rep_bbox = idx2bbox.get(start_i)
73
+ if rep_bbox is None:
74
+ for k in range(start_i, end_i + 1):
75
+ if k in idx2bbox:
76
+ rep_bbox = idx2bbox[k]
77
+ break
78
+ if rep_bbox is None:
79
+ continue
80
+
81
+ a = int(round(start_i * scale))
82
+ b = int(round(end_i * scale))
83
+ if expanded and a <= last_real:
84
+ a = last_real + 1
85
+ for real_i in range(a, b + 1):
86
+ expanded.append((real_i, rep_bbox))
87
+ last_real = b
88
+
89
+ return expanded
90
+
91
+
92
+ def process_entry(entry, run_with_yolo=False, cache_path=""):
93
+ """
94
+ VLM path (run_with_yolo=False):
95
+ - Returns (foi, object_frame_dict_expanded)
96
+ where object_frame_dict_expanded: Dict[str, List[int]] (real frame indices)
97
+
98
+ YOLO path (run_with_yolo=True):
99
+ - Expects run_nsvs_yolo to return (foi, object_frame_bounding_boxes)
100
+ where object_frame_bounding_boxes: Dict[str, List[(sample_idx, bbox)]]
101
+ - Returns (foi, object_frame_bounding_boxes_expanded)
102
+ where each bbox is duplicated across the scaled span to real frames:
103
+ Dict[str, List[(real_idx, bbox)]]
104
+ """
105
+ if run_with_yolo:
106
+ foi, object_frame_bounding_boxes = run_nsvs_yolo(
107
+ frames=entry["images"],
108
+ proposition=entry['tl']['propositions'],
109
+ specification=entry['tl']['specification'],
110
+ yolo_cache_path=cache_path,
111
+ vlm_detection_threshold=0.35,
112
+ )
113
+ foi = fill_in_frame_count([i for sub in foi for i in sub], entry)
114
+
115
+ expanded_boxes = {}
116
+ for key, pairs in (object_frame_bounding_boxes or {}).items():
117
+ expanded_boxes[key] = _fill_in_frame_count_pairs(pairs, entry)
118
+ return foi, expanded_boxes
119
+
120
+ else:
121
+ foi, object_frame_dict = run_nsvs(
122
+ frames=entry['images'],
123
+ proposition=entry['tl']['propositions'],
124
+ specification=entry['tl']['specification'],
125
+ model_name="InternVL2-8B",
126
+ device=DEVICE
127
+ )
128
+ foi = fill_in_frame_count([i for sub in foi for i in sub], entry)
129
+ object_frame_dict = {key: fill_in_frame_count(value, entry) for key, value in (object_frame_dict or {}).items()}
130
+ return foi, object_frame_dict
131
 
132
  def main():
133
  reader = Mp4Reader(VIDEOS, OPENAI_SAVE_PATH, sampling_rate_fps=1)
134
  data = reader.read_video()
135
  if not data:
136
  return
137
+
138
+ # cache_path = preprocess_yolo(entry["images"], model_weights="yolov8n.pt",
139
+ # device="cuda:0", out_path="yolo_cache.npz")
140
 
141
+ with tqdm.tqdm(enumerate(data), total=len(data), desc="Processing entries") as pbar:
142
  for i, entry in pbar:
143
  start_time = time.time()
144
+ foi = process_entry(entry, run_with_yolo=True)
145
  end_time = time.time()
146
  processing_time = round(end_time - start_time, 3)
147
 
launch_space.sh CHANGED
@@ -2,6 +2,7 @@
2
 
3
  apt update
4
  apt install -y ffmpeg
 
5
 
6
  # Start vLLM server in background
7
  ./vllm_serve.sh &
@@ -19,4 +20,4 @@ echo "
19
  "
20
 
21
  # Start Gradio app
22
- python3 execute_demo_v2.py
 
2
 
3
  apt update
4
  apt install -y ffmpeg
5
+ pip install ultralytics
6
 
7
  # Start vLLM server in background
8
  ./vllm_serve.sh &
 
20
  "
21
 
22
  # Start Gradio app
23
+ python3 execute_demo_v3.py
ns_vfs/nsvs.py CHANGED
@@ -24,15 +24,12 @@ def run_nsvs(
24
  tl_satisfaction_threshold: float = 0.6,
25
  detection_threshold: float = 0.5,
26
  vlm_detection_threshold: float = 0.35,
27
- image_output_dir: str = "output"
28
  ):
29
  """Find relevant frames from a video that satisfy a specification"""
30
 
31
  object_frame_dict = {}
32
- object_frame_dict_prob = {}
33
- vlm = VLLMClient()
34
- # vlm = InternVL(model_name=model_name, device=device)
35
 
 
36
  automaton = VideoAutomaton(include_initial_state=True)
37
  automaton.set_up(proposition_set=proposition)
38
 
@@ -62,12 +59,6 @@ def run_nsvs(
62
  object_of_interest[prop] = detected_object
63
  if detected_object.is_detected:
64
  multi_frame_arr = [frame_count * num_of_frame_in_sequence + j for j in range(num_of_frame_in_sequence)]
65
- p2 = f"{prop}: {detected_object.probability}"
66
- if p2 in object_frame_dict_prob:
67
- object_frame_dict_prob[p2].extend(multi_frame_arr)
68
- else:
69
- object_frame_dict_prob[p2] = multi_frame_arr
70
-
71
  if prop in object_frame_dict:
72
  object_frame_dict[prop].extend(multi_frame_arr)
73
  else:
@@ -93,9 +84,6 @@ def run_nsvs(
93
  print("\n" + "*"*50 + f" {i}/{len(frame_windows)-1} " + "*"*50)
94
  print("Detections:")
95
  frame = process_frame(sequence_of_frames, i)
96
- if PRINT_ALL:
97
- os.makedirs(image_output_dir, exist_ok=True)
98
- frame.save_frame_img(save_path=os.path.join(image_output_dir, f"{i}"))
99
 
100
  if checker.validate_frame(frame_of_interest=frame):
101
  automaton.add_frame(frame=frame)
@@ -112,5 +100,5 @@ def run_nsvs(
112
  print("Detected frames of interest:")
113
  print(foi)
114
 
115
- return foi, object_frame_dict, object_frame_dict_prob
116
 
 
24
  tl_satisfaction_threshold: float = 0.6,
25
  detection_threshold: float = 0.5,
26
  vlm_detection_threshold: float = 0.35,
 
27
  ):
28
  """Find relevant frames from a video that satisfy a specification"""
29
 
30
  object_frame_dict = {}
 
 
 
31
 
32
+ vlm = VLLMClient()
33
  automaton = VideoAutomaton(include_initial_state=True)
34
  automaton.set_up(proposition_set=proposition)
35
 
 
59
  object_of_interest[prop] = detected_object
60
  if detected_object.is_detected:
61
  multi_frame_arr = [frame_count * num_of_frame_in_sequence + j for j in range(num_of_frame_in_sequence)]
 
 
 
 
 
 
62
  if prop in object_frame_dict:
63
  object_frame_dict[prop].extend(multi_frame_arr)
64
  else:
 
84
  print("\n" + "*"*50 + f" {i}/{len(frame_windows)-1} " + "*"*50)
85
  print("Detections:")
86
  frame = process_frame(sequence_of_frames, i)
 
 
 
87
 
88
  if checker.validate_frame(frame_of_interest=frame):
89
  automaton.add_frame(frame=frame)
 
100
  print("Detected frames of interest:")
101
  print(foi)
102
 
103
+ return foi, object_frame_dict
104
 
ns_vfs/nsvs_yolo.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------
2
+ # Preprocess: per-frame dicts {class: List[(conf, (x1,y1,x2,y2))]}
3
+ # -------------------------------
4
+ from ultralytics import YOLO
5
+ import numpy as np
6
+ import warnings
7
+ import tqdm
8
+ import os
9
+ import pickle
10
+ import re
11
+ from typing import Dict, List, Literal, Tuple
12
+
13
+ from ns_vfs.model_checker.property_checker import PropertyChecker
14
+ from ns_vfs.model_checker.video_automaton import VideoAutomaton
15
+ from ns_vfs.vlm.obj import DetectedObject
16
+ from ns_vfs.vlm.vllm_client import VLLMClient
17
+ from ns_vfs.video.frame import FramesofInterest, VideoFrame
18
+
19
+ PRINT_ALL = True
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ def preprocess_yolo(
24
+ frames: List[np.ndarray],
25
+ model_weights: str = "yolov8n.pt",
26
+ device: str | int = "cuda:0",
27
+ batch_size: int = 16,
28
+ out_path: str = "yolo_det_cache.pkl",
29
+ conf_threshold: float = 0.001,
30
+ iou: float = 0.7,
31
+ ) -> str:
32
+ """
33
+ Run YOLOv8 detection on every frame and save a list of dicts.
34
+ cache format:
35
+ yolo_dets: List[ Dict[str, List[Tuple[float, Tuple[float,float,float,float]]]] ]
36
+ # one item per frame
37
+ # each frame dict maps: class_name (lowercase, spaces) ->
38
+ # list of (confidence, (x1, y1, x2, y2)) in pixel coordinates
39
+ """
40
+ model = YOLO(model_weights)
41
+ id_to_name: Dict[int, str] = {int(k): str(v).lower() for k, v in model.names.items()}
42
+
43
+ yolo_dets: List[Dict[str, List[Tuple[float, Tuple[float, float, float, float]]]]] = []
44
+
45
+ for start in range(0, len(frames), batch_size):
46
+ batch = frames[start:start + batch_size]
47
+ results = model.predict(
48
+ batch,
49
+ device=device,
50
+ conf=conf_threshold,
51
+ iou=iou,
52
+ verbose=False,
53
+ )
54
+
55
+ for r in results:
56
+ frame_dict: Dict[str, List[Tuple[float, Tuple[float, float, float, float]]]] = {}
57
+ if r.boxes is not None and len(r.boxes) > 0:
58
+ # xyxy in pixels, conf, and class ids
59
+ xyxy = r.boxes.xyxy.detach().cpu().numpy().astype(float)
60
+ confs = r.boxes.conf.detach().cpu().numpy().astype(float)
61
+ cls_ids = r.boxes.cls.detach().cpu().numpy().astype(int)
62
+
63
+ for (x1, y1, x2, y2), conf, cid in zip(xyxy, confs, cls_ids):
64
+ name = id_to_name.get(int(cid), str(cid)) # e.g., "traffic light"
65
+ frame_dict.setdefault(name, []).append(
66
+ (float(conf), (float(x1), float(y1), float(x2), float(y2)))
67
+ )
68
+
69
+ yolo_dets.append(frame_dict)
70
+
71
+ assert len(yolo_dets) == len(frames), f"expected {len(frames)} dicts, got {len(yolo_dets)}"
72
+
73
+ with open(out_path, "wb") as f:
74
+ pickle.dump(yolo_dets, f, protocol=pickle.HIGHEST_PROTOCOL)
75
+
76
+ return out_path
77
+
78
+
79
+ # -------------------------------
80
+ # NSVS using cached YOLO dicts; 1 frame per step
81
+ # -------------------------------
82
+
83
+ # normalize props to YOLO label style (spaces, lowercase, collapsed whitespace)
84
+ _WS = re.compile(r"\s+")
85
+ def normalize_label_for_yolo(s: str) -> str:
86
+ s = (s or "").strip().lower()
87
+ s = s.replace("_", " ")
88
+ s = s.replace("-", " ").replace("–", " ").replace("—", " ")
89
+ s = _WS.sub(" ", s)
90
+ return s
91
+
92
+
93
+ def run_nsvs_yolo(
94
+ frames: List[np.ndarray],
95
+ proposition: List[str],
96
+ specification: str,
97
+ *,
98
+ yolo_cache_path: str = "yolo_det_cache.pkl",
99
+ model_type: str = "dtmc",
100
+ tl_satisfaction_threshold: float = 0.6,
101
+ detection_threshold: float = 0.5,
102
+ vlm_detection_threshold: float = 0.35, # used as 'false_threshold' in calibrate()
103
+ image_output_dir: str = "output",
104
+ ) -> Tuple[List[VideoFrame], Dict[str, List[Tuple[int, Tuple[float, float, float, float]]]]]:
105
+ """
106
+ Replaces vlm.detect with cached YOLO per frame (1-frame sequences).
107
+ Returns:
108
+ foi: List[VideoFrame]
109
+ object_frame_bounding_boxes:
110
+ Dict[str, List[(frame_index, (x1, y1, x2, y2))]]
111
+ # one bbox per frame (the highest-confidence bbox for that class in that frame)
112
+ """
113
+ if not os.path.exists(yolo_cache_path):
114
+ raise FileNotFoundError(
115
+ f"YOLO cache not found at '{yolo_cache_path}'. "
116
+ f"Call preprocess_yolo(frames, out_path='yolo_det_cache.pkl') first."
117
+ )
118
+
119
+ with open(yolo_cache_path, "rb") as f:
120
+ # List[Dict[str, List[(conf, (x1,y1,x2,y2))]]]
121
+ yolo_dets: List[Dict[str, List[Tuple[float, Tuple[float, float, float, float]]]]] = pickle.load(f)
122
+
123
+ if len(yolo_dets) != len(frames):
124
+ raise ValueError(f"cache length {len(yolo_dets)} != frames length {len(frames)}")
125
+
126
+ # Build normalized lookup (e.g., "traffic_light" -> "traffic light")
127
+ prop_lookup: Dict[str, str] = {prop_raw: normalize_label_for_yolo(prop_raw) for prop_raw in proposition}
128
+
129
+ automaton = VideoAutomaton(include_initial_state=True)
130
+ automaton.set_up(proposition_set=proposition) # original props for TL
131
+
132
+ checker = PropertyChecker(
133
+ proposition=proposition,
134
+ specification=specification,
135
+ model_type=model_type,
136
+ tl_satisfaction_threshold=tl_satisfaction_threshold,
137
+ detection_threshold=detection_threshold,
138
+ )
139
+
140
+ frame_of_interest = FramesofInterest(1) # 1-frame sequences
141
+ object_frame_bounding_boxes: Dict[str, List[Tuple[int, Tuple[float, float, float, float]]]] = {}
142
+
143
+ calibrator = VLLMClient()
144
+
145
+ def _mk_detected_object(name: str, confidence: float) -> DetectedObject:
146
+ probability = calibrator.calibrate(confidence=confidence, false_threshold=vlm_detection_threshold)
147
+ return DetectedObject(
148
+ name=name,
149
+ is_detected=confidence >= vlm_detection_threshold,
150
+ confidence=confidence,
151
+ probability=probability,
152
+ )
153
+
154
+ looper = range(len(frames)) if PRINT_ALL else tqdm.tqdm(range(len(frames)))
155
+ for i in looper:
156
+ if PRINT_ALL:
157
+ print("\n" + "*" * 50 + f" {i}/{len(frames) - 1} " + "*" * 50)
158
+ print("Detections:")
159
+
160
+ # Per-frame dict: class -> List[(conf, (x1,y1,x2,y2))]
161
+ det_dict = yolo_dets[i]
162
+ object_of_interest = {}
163
+
164
+ for prop_raw in proposition:
165
+ yolo_label = prop_lookup[prop_raw]
166
+ dets_for_class = det_dict.get(yolo_label, [])
167
+
168
+ # confidence for decision = max conf for that class in this frame (0 if none)
169
+ if dets_for_class:
170
+ confs = [c for c, _ in dets_for_class]
171
+ max_idx = int(np.argmax(confs))
172
+ best_conf, best_bbox = dets_for_class[max_idx]
173
+ else:
174
+ best_conf, best_bbox = 0.0, None
175
+
176
+ det = _mk_detected_object(prop_raw, float(best_conf))
177
+ object_of_interest[prop_raw] = det
178
+
179
+ if det.is_detected and best_bbox is not None:
180
+ # one bbox per frame (highest-confidence one)
181
+ object_frame_bounding_boxes.setdefault(prop_raw, []).append((i, best_bbox))
182
+
183
+ if PRINT_ALL:
184
+ if best_bbox is not None:
185
+ x1, y1, x2, y2 = best_bbox
186
+ print(f"\t{prop_raw} (yolo='{yolo_label}'): conf={det.confidence:.3f} "
187
+ f"-> prob={det.probability:.3f} bbox=({x1:.1f},{y1:.1f},{x2:.1f},{y2:.1f})"
188
+ + (" [DETECTED]" if det.is_detected else ""))
189
+ else:
190
+ print(f"\t{prop_raw} (yolo='{yolo_label}'): conf=0.000 -> prob={det.probability:.3f}")
191
+
192
+ frame = VideoFrame(
193
+ frame_idx=i,
194
+ frame_images=[frames[i]], # single-frame
195
+ object_of_interest=object_of_interest,
196
+ )
197
+
198
+ if checker.validate_frame(frame_of_interest=frame):
199
+ automaton.add_frame(frame=frame)
200
+ frame_of_interest.frame_buffer.append(frame)
201
+ model_check = checker.check_automaton(automaton=automaton)
202
+ if model_check:
203
+ automaton.reset()
204
+ frame_of_interest.flush_frame_buffer()
205
+
206
+ foi = frame_of_interest.foi_list
207
+
208
+ if PRINT_ALL:
209
+ print("\n" + "-" * 107)
210
+ print("Detected frames of interest:")
211
+ print(foi)
212
+
213
+ # NOTE: replaced the old object_frame_dict return
214
+ return foi, object_frame_bounding_boxes
215
+
pyproject.toml CHANGED
@@ -18,5 +18,6 @@ dependencies = [
18
  "timm>=1.0.19",
19
  "tqdm>=4.67.1",
20
  "transformers>=4.41,<4.47",
 
21
  ]
22
 
 
18
  "timm>=1.0.19",
19
  "tqdm>=4.67.1",
20
  "transformers>=4.41,<4.47",
21
+ "ultralytics>=8.3.201",
22
  ]
23