geetxnsh commited on
Commit
aceaec1
·
verified ·
1 Parent(s): 9aaf5f3

Upload video_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. video_inference.py +251 -0
video_inference.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ video_inference.py
3
+ ------------------
4
+ Process an MP4 (or webcam) through the Screen ON/OFF classifier.
5
+
6
+ Requirements:
7
+ pip install opencv-python-headless numpy onnxruntime
8
+
9
+ Usage:
10
+ # Annotate a video, write to file (no GUI needed)
11
+ python video_inference.py --video input.mp4 --roi 200 100 300 400 --out output.mp4
12
+
13
+ # Frame-by-frame (lowest latency, best for real-time preview)
14
+ python video_inference.py --video input.mp4 --roi 200 100 300 400 --display --batch 1
15
+
16
+ # Batch mode (higher throughput, slight latency trade-off)
17
+ python video_inference.py --video input.mp4 --roi 200 100 300 400 --out output.mp4 --batch 8
18
+
19
+ # Live webcam
20
+ python video_inference.py --camera 0 --roi 200 100 300 400 --display
21
+
22
+ The --roi values are: x y width height (pixel coords in the original frame).
23
+ If your video is already cropped to the phone screen, omit --roi.
24
+ """
25
+ import argparse
26
+ import time
27
+
28
+ import cv2
29
+ import numpy as np
30
+ import onnxruntime as ort
31
+
32
+
33
+ class ScreenClassifier:
34
+ """ONNX wrapper with the exact preprocessing used during training."""
35
+
36
+ def __init__(self, onnx_path: str = "screen_classifier.onnx"):
37
+ opts = ort.SessionOptions()
38
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
39
+ opts.inter_op_num_threads = 1
40
+ opts.intra_op_num_threads = 2
41
+
42
+ self.session = ort.InferenceSession(
43
+ onnx_path,
44
+ sess_options=opts,
45
+ providers=["CPUExecutionProvider"],
46
+ )
47
+ self.input_name = self.session.get_inputs()[0].name
48
+
49
+ def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
50
+ """BGR/HWC -> normalised greyscale NCHW (1,1,64,64)."""
51
+ gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
52
+ gray = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_LINEAR)
53
+ gray = (gray.astype(np.float32) / 255.0 - 0.5) / 0.5
54
+ return gray[np.newaxis, np.newaxis, :, :] # (1, 1, 64, 64)
55
+
56
+ def predict(self, frame: np.ndarray) -> tuple[str, float]:
57
+ x = self._preprocess(frame)
58
+ logit = self.session.run(None, {self.input_name: x})[0]
59
+ prob = 1.0 / (1.0 + np.exp(-logit.item()))
60
+ label = "ON" if prob > 0.5 else "OFF"
61
+ confidence = prob if label == "ON" else (1.0 - prob)
62
+ return label, float(confidence)
63
+
64
+ def predict_batch(self, frames: list[np.ndarray]) -> list[tuple[str, float]]:
65
+ if not frames:
66
+ return []
67
+ batch = np.concatenate([self._preprocess(f) for f in frames], axis=0)
68
+ logits = self.session.run(None, {self.input_name: batch})[0]
69
+ probs = 1.0 / (1.0 + np.exp(-logits)).flatten()
70
+ out = []
71
+ for p in probs:
72
+ label = "ON" if p > 0.5 else "OFF"
73
+ out.append((label, float(p if label == "ON" else 1.0 - p)))
74
+ return out
75
+
76
+
77
+ def draw_label(frame: np.ndarray, label: str, conf: float,
78
+ x: int = 10, y: int = 30) -> np.ndarray:
79
+ """Draw green "ON" or red "OFF" label on a BGR frame."""
80
+ colour = (0, 255, 0) if label == "ON" else (0, 0, 255)
81
+ text = f"{label} {conf:.2%}"
82
+ cv2.putText(frame, text, (x, y),
83
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, colour, 2)
84
+ return frame
85
+
86
+
87
+ def _safe_display(win_name: str, frame: np.ndarray, display_enabled: bool) -> bool:
88
+ """Show frame if display is enabled; silently skip in headless envs."""
89
+ if not display_enabled:
90
+ return True
91
+ try:
92
+ cv2.imshow(win_name, frame)
93
+ return (cv2.waitKey(1) & 0xFF) != ord("q")
94
+ except cv2.error:
95
+ return True # headless
96
+
97
+
98
+ def main():
99
+ parser = argparse.ArgumentParser(description="Screen ON/OFF classifier for video")
100
+ parser.add_argument("--video", type=str, default=None,
101
+ help="Path to input MP4/video file")
102
+ parser.add_argument("--camera", type=int, default=None,
103
+ help="Webcam index (e.g. 0). Mutually exclusive with --video")
104
+ parser.add_argument("--roi", type=int, nargs=4, metavar=("X", "Y", "W", "H"),
105
+ default=None,
106
+ help="Crop region: x y width height")
107
+ parser.add_argument("--out", type=str, default=None,
108
+ help="Path to write annotated output video (MP4)")
109
+ parser.add_argument("--display", action="store_true",
110
+ help="Show live preview window (needs GUI)")
111
+ parser.add_argument("--model", type=str, default="screen_classifier.onnx",
112
+ help="Path to ONNX model")
113
+ parser.add_argument("--batch", type=int, default=1,
114
+ help="Inference batch size (1 = lowest latency, >1 = higher throughput)")
115
+ args = parser.parse_args()
116
+
117
+ if args.video is None and args.camera is None:
118
+ parser.error("Provide either --video <path> or --camera <index>")
119
+ if args.video and args.camera is not None:
120
+ parser.error("Use --video OR --camera, not both")
121
+
122
+ # ------------------------------------------------------------------ #
123
+ # Open source
124
+ # ------------------------------------------------------------------ #
125
+ source = args.video if args.video else args.camera
126
+ cap = cv2.VideoCapture(source)
127
+ if not cap.isOpened():
128
+ raise RuntimeError(f"Cannot open video source: {source}")
129
+
130
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
131
+ frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
132
+ frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
133
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if args.video else -1
134
+
135
+ print(f"Source : {source}")
136
+ print(f"Resolution : {frame_w}x{frame_h} @ {fps:.1f} FPS")
137
+ print(f"Total frames : {total_frames if total_frames > 0 else 'N/A (live)'}")
138
+ print(f"Model : {args.model}")
139
+ print(f"Batch size : {args.batch}")
140
+
141
+ # ------------------------------------------------------------------ #
142
+ # Optional output writer
143
+ # ------------------------------------------------------------------ #
144
+ writer = None
145
+ if args.out:
146
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
147
+ writer = cv2.VideoWriter(args.out, fourcc, fps, (frame_w, frame_h))
148
+ if not writer.isOpened():
149
+ raise RuntimeError(f"Cannot open VideoWriter for {args.out}")
150
+
151
+ # ------------------------------------------------------------------ #
152
+ # Classifier
153
+ # ------------------------------------------------------------------ #
154
+ clf = ScreenClassifier(args.model)
155
+
156
+ # ROI defaults to full frame if not given
157
+ roi = args.roi
158
+ if roi is None:
159
+ roi = (0, 0, frame_w, frame_h)
160
+ print("No --roi specified; using full frame.")
161
+ else:
162
+ print(f"Crop ROI : x={roi[0]}, y={roi[1]}, w={roi[2]}, h={roi[3]}")
163
+
164
+ rx, ry, rw, rh = roi
165
+
166
+ # ------------------------------------------------------------------ #
167
+ # Main loop
168
+ # ------------------------------------------------------------------ #
169
+ frame_idx = 0
170
+ t0 = time.perf_counter()
171
+
172
+ # For batch mode we accumulate (original_frame, crop) tuples
173
+ batch_buffer: list[tuple[np.ndarray, np.ndarray, int, int]] = []
174
+
175
+ while True:
176
+ ok, original_frame = cap.read()
177
+ if not ok:
178
+ break
179
+
180
+ crop = original_frame[ry:ry + rh, rx:rx + rw]
181
+
182
+ if args.batch == 1:
183
+ label, conf = clf.predict(crop)
184
+ out_frame = draw_label(original_frame.copy(), label, conf,
185
+ x=rx + 10, y=ry + 30)
186
+
187
+ if not _safe_display("Screen ON/OFF", out_frame, args.display):
188
+ break
189
+ if writer:
190
+ writer.write(out_frame)
191
+ frame_idx += 1
192
+
193
+ else:
194
+ batch_buffer.append((original_frame, crop, rx, ry))
195
+
196
+ if len(batch_buffer) == args.batch:
197
+ crops = [c for _, c, _, _ in batch_buffer]
198
+ results = clf.predict_batch(crops)
199
+
200
+ for i, (label, conf) in enumerate(results):
201
+ orig, _, bx, by = batch_buffer[i]
202
+ annotated = draw_label(orig, label, conf, x=bx + 10, y=by + 30)
203
+
204
+ if not _safe_display("Screen ON/OFF", annotated, args.display):
205
+ cap.release()
206
+ if writer:
207
+ writer.release()
208
+ cv2.destroyAllWindows()
209
+ return
210
+
211
+ if writer:
212
+ writer.write(annotated)
213
+
214
+ frame_idx += len(batch_buffer)
215
+ batch_buffer.clear()
216
+
217
+ if frame_idx % 60 == 0 and frame_idx > 0:
218
+ elapsed = time.perf_counter() - t0
219
+ print(f"Processed {frame_idx} frames | "
220
+ f"{frame_idx / elapsed:.1f} FPS | "
221
+ f"{elapsed:.1f} s elapsed")
222
+
223
+ # ------------------------------------------------------------------ #
224
+ # Drain remaining frames in batch buffer
225
+ # ------------------------------------------------------------------ #
226
+ if args.batch > 1 and batch_buffer:
227
+ crops = [c for _, c, _, _ in batch_buffer]
228
+ results = clf.predict_batch(crops)
229
+ for i, (label, conf) in enumerate(results):
230
+ orig, _, bx, by = batch_buffer[i]
231
+ annotated = draw_label(orig, label, conf, x=bx + 10, y=by + 30)
232
+ if writer:
233
+ writer.write(annotated)
234
+ frame_idx += len(batch_buffer)
235
+ batch_buffer.clear()
236
+
237
+ cap.release()
238
+ if writer:
239
+ writer.release()
240
+ try:
241
+ cv2.destroyAllWindows()
242
+ except cv2.error:
243
+ pass
244
+
245
+ total_time = time.perf_counter() - t0
246
+ avg_fps = frame_idx / total_time if total_time > 0 else 0.0
247
+ print(f"\nDone. {frame_idx} frames in {total_time:.2f} s ({avg_fps:.1f} FPS average)")
248
+
249
+
250
+ if __name__ == "__main__":
251
+ main()