File size: 13,820 Bytes
1b944d0
 
 
 
 
 
 
 
 
 
 
b43eb74
1b944d0
 
 
 
 
 
 
 
 
 
 
 
 
 
a33a0b8
1b944d0
 
b43eb74
1b944d0
 
 
 
 
 
 
 
8ea6230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import subprocess

import os
import sys
import subprocess

def run(cmd, cwd=None):
    print(f"▶ {cmd}")
    subprocess.check_call(cmd, shell=True, cwd=cwd)

def setup_deps():
    print(f"mouse_yolo.pt size: {os.path.getsize('checkpoints/mouse_yolo.pt')} bytes")
    # Use a flag to prevent infinite restarts
    if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1":
        return

    # Try importing something to check if it's already set up
    try:
        import torch
        import sam2
        print("🔧 Dependencies already installed.")
        return  # all good, don't reinstall
    except ImportError:
        pass

    print("🔧 Installing dependencies...")
    run("pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu")
    run("pip install -e .", cwd="segment-anything-2")
    run("pip install --no-deps -r requirements_manual.txt")
    run("git lfs install && git lfs pull", cwd=".")  # assuming checkpoints/mouse_yolo.pt is tracked

    # Relaunch the script with an env flag to avoid looping
    print("♻️ Restarting app to apply changes...")
    os.environ["HF_SPACE_BOOTSTRAPPED"] = "1"
    os.execv(sys.executable, [sys.executable] + sys.argv)

setup_deps()

import gradio as gr
import cv2
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from gradio_image_annotation import image_annotator
from train.model import ResNetNoseRegressor
import sam_utils
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
import torch
import os
import tqdm
import tempfile
from pathlib import Path

# Global states
step = 0  # 0: crop, 1: nose, 2: toys, 3: track, 4: export
selected_box = None
last_frame = None
cropped_image_original = None
nose_point = None
toy_1_area = None
toy_2_area = None
sam_model = None
yolo_model = None
nose_model = None
vis_frames = []

meta = {
    "start": None,
    "crop": None,
    "nose": None,
    "toys": None,
}

def get_nose_point(cropped_image):
    global meta, yolo_model
    if yolo_model is None:
        yolo_model = YOLO("checkpoints/mouse_yolo.pt")
    
    results = yolo_model.predict(source=cropped_image, conf=0.5, verbose=False)

    box_mouse = None
    for r in results:
        if len(r.boxes) == 0:
            continue

        # Get highest confidence box
        best_idx = r.boxes.conf.argmax().item()
        best_box = r.boxes[best_idx]

        x1, y1, x2, y2 = map(int, best_box.xyxy[0].tolist())
        box_mouse = np.array([x1, y1, x2, y2])
        conf = float(best_box.conf)

    if box_mouse is None or len(box_mouse) == 0:
        meta["nose"] = (0,0)
        return meta["nose"]
    box_mouse = box_mouse.astype(np.int32)
    mouse = cropped_image.copy()
    mouse = mouse[box_mouse[1]:box_mouse[3], box_mouse[0]:box_mouse[2]]
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    global nose_model
    if nose_model is None:
        nose_model = ResNetNoseRegressor(pretrained=False)
        nose_model.load_state_dict(torch.load("checkpoints/nose_detector.pth", map_location="cpu"))
        nose_model.eval().to(device)

    image = Image.fromarray(cv2.cvtColor(mouse, cv2.COLOR_BGR2RGB))
    orig_w, orig_h = image.size

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    image_tensor = image_tensor.to(device)

    with torch.no_grad():
        pred = nose_model(image_tensor)[0].cpu().numpy()  # shape: (2,) normalized

    x_pred = int(pred[0] * orig_w)
    y_pred = int(pred[1] * orig_h)

    x_pred += box_mouse[0]
    y_pred += box_mouse[1]
    meta["nose"] = (x_pred, y_pred)

    return meta["nose"]

def get_video_info(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): return 0, 0
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    cap.release()
    return int(frames), fps

def get_frame(video_path, frame_index):
    cap = cv2.VideoCapture(video_path)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    cap.release()
    return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None

def load_frame(video_file, frame_idx):
    global last_frame, meta
    meta["start"] = frame_idx
    last_frame = get_frame(video_file, frame_idx)
    return {"image": last_frame, "boxes": []} if last_frame is not None else None

def handle_box_select(data):
    global selected_box
    boxes = data.get("boxes", [])
    if boxes:
        selected_box = boxes[0]
        return gr.update(interactive=True)
    return gr.update(interactive=False)

def unified_button_handler(cropped_img, toy_ann, video_file, progress=gr.Progress(track_tqdm=True)):
    global step, selected_box, last_frame, cropped_image_original, nose_point
    global toy_1_area, toy_2_area, sam_model, vis_frames
    global meta, download_path

    if step == 0:
        if selected_box is None or last_frame is None:
            return gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(value="Crop"), gr.update(visible=False), gr.update(visible=False)
        x1, y1 = int(selected_box["xmin"]), int(selected_box["ymin"])
        x2, y2 = int(selected_box["xmax"]), int(selected_box["ymax"])
        meta["crop"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
        cropped = last_frame[y1:y2, x1:x2]
        cropped_image_original = cropped.copy()
        step = 1

        # Automatically determine nose point here using heuristic
        # Replace the following line with your heuristic
        nose_point = get_nose_point(cropped_image_original)
        meta["nose"] = nose_point

        # Optional visualization (can be removed)
        img_copy = cropped_image_original.copy()
        cv2.circle(img_copy, (nose_point[0], nose_point[1]), 4, (255, 0, 0), -1)

        step = 2
        return gr.update(value=cropped_image_original, visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value={"image": img_copy, "boxes": []}, visible=True), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False)


    elif step == 2:
        boxes = toy_ann.get("boxes", [])
        if len(boxes) < 2:
            return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False)
        
        step = 3
        toy_1_area = np.array([boxes[0]["xmin"], boxes[0]["ymin"], boxes[0]["xmax"], boxes[0]["ymax"]])
        toy_2_area = np.array([boxes[1]["xmin"], boxes[1]["ymin"], boxes[1]["xmax"], boxes[1]["ymax"]])
        if sam_model is None:
            model = sam_utils.load_SAM2("checkpoints/sam2_hiera_small.pt", "checkpoints/sam2_hiera_s.yaml")
            if torch.cuda.is_available():
                model = model.to("cuda")
            sam_model = SAM2ImagePredictor(model)
            
        sam_model.set_image(cropped_img)
        mask_1, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_1_area[None, :], multimask_output=False)
        mask_2, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_2_area[None, :], multimask_output=False)
        mask_1 = mask_1[0].astype(bool)
        mask_2 = mask_2[0].astype(bool)
        meta["toys"] = {"toy_1": mask_1.tolist(), "toy_2": mask_2.tolist()}
        result = cropped_image_original.copy()
        result[mask_1] = [0, 0, 255]
        result[mask_2] = [0, 255, 0]
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Process Whole Video"), gr.update(value=result, visible=True), gr.update(visible=False)
    
    elif step == 3:
        step = 4
        cap = cv2.VideoCapture(video_file)
        ms_per_frame = 1000 / cap.get(cv2.CAP_PROP_FPS)
        toy_1_time = 0.0
        toy_2_time = 0.0
        total_time = 0.0
        cap.set(cv2.CAP_PROP_POS_FRAMES, meta["start"])
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - meta["start"]
        vis_frames = []
        qbar = tqdm.tqdm(total=total_frames, desc="Processing Video", unit="frame", leave=False)

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)[meta["crop"]["y1"]:meta["crop"]["y2"], meta["crop"]["x1"]:meta["crop"]["x2"]]
            nose_point = get_nose_point(frame_rgb)
            vis_frame = frame_rgb.copy()
            overlay = vis_frame.copy()
            cv2.circle(overlay, tuple(nose_point), 10, (255, 0, 0), -1)  # draw on overlay
            alpha = 0.4  # transparency factor

            # Blend the overlay with the original frame
            vis_frame = cv2.addWeighted(overlay, alpha, vis_frame, 1 - alpha, 0)

            # Toy overlays
            layer = np.zeros_like(vis_frame)
            layer[np.array(meta["toys"]["toy_1"])] = [0, 255, 0]
            layer[np.array(meta["toys"]["toy_2"])] = [0, 0, 255]

            toy_1_mask = np.array(meta["toys"]["toy_1"]).copy().astype(np.uint8)
            toy_2_mask = np.array(meta["toys"]["toy_2"]).copy().astype(np.uint8)
            toy_1_mask = cv2.dilate(toy_1_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool)
            toy_2_mask = cv2.dilate(toy_2_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool)

            x, y = nose_point
            increment = ms_per_frame / 1000.0
            # text colors of all 5 lines, default is white
            colors = [(255, 255, 255)] * 5
            if toy_1_mask[y, x] or toy_2_mask[y, x]:
                colors[3] = (255, 0, 0)
                if toy_1_time + toy_2_time + increment >= 20:
                    increment = 20 - (toy_1_time + toy_2_time)
            if toy_1_mask[y, x]:
                toy_1_time += increment
                colors[1] = (255, 0, 0)  # Red for Toy 1
            if toy_2_mask[y, x]:
                toy_2_time += increment
                colors[2] = (255, 0, 0)  # Red for Toy 2
            total_time += increment

            # Time text
            cv2.putText(vis_frame, f"Time spent on toys:", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[0], 2)
            cv2.putText(vis_frame, f"Toy 1: {int(toy_1_time // 60):02}:{int(toy_1_time % 60):02}.{int((toy_1_time % 1) * 100):02}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[1], 2)
            cv2.putText(vis_frame, f"Toy 2: {int(toy_2_time // 60):02}:{int(toy_2_time % 60):02}.{int((toy_2_time % 1) * 100):02}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[2], 2)
            cv2.putText(vis_frame, f"Sum Time: {int((toy_1_time + toy_2_time) // 60):02}:{int((toy_1_time + toy_2_time) % 60):02}.{int(((toy_1_time + toy_2_time) % 1) * 100):02}", (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[3], 2)
            cv2.putText(vis_frame, f"Total Time: {int(total_time // 60):02}:{int(total_time % 60):02}.{int((total_time % 1) * 100):02}", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[4], 2)

            vis_frame = cv2.addWeighted(vis_frame, 1, layer, 0.5, 0)
            vis_frames.append(vis_frame)

            if (toy_1_time + toy_2_time) >= 20:
                break

            qbar.update(1)

        # Copy the last frame for 3 second
        for _ in range(int(3 * cap.get(cv2.CAP_PROP_FPS))):
            vis_frames.append(vis_frames[-1].copy())
        qbar.close()
        cap.release()
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Export to Video"), gr.update(visible=False), gr.update(visible=False)

    elif step == 4:
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
            output_path = tmp.name
        h, w, _ = vis_frames[0].shape
        writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h))
        for f in vis_frames:
            writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
        writer.release()
        step = 5
        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=output_path, visible=True)

def on_video_upload(video_file):
    global step
    step = 0
    total_frames, _ = get_video_info(video_file)
    frame = load_frame(video_file, 0)
    return gr.update(minimum=0, maximum=total_frames - 1), frame

with gr.Blocks() as demo:
    video_input = gr.File(label="Upload Video", file_types=[".mp4", ".avi", ".mov"])
    frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, step=1)
    annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Tracking Area"], single_box=True)
    cropped_image = gr.Image(label="Cropped Area", type="numpy", interactive=True, visible=False)
    toy_annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Toy 1", "Toy 2"], image_type="numpy", visible=False)
    result_image = gr.Image(label="Final Output", visible=False)
    downloadable_output = gr.DownloadButton(label="Download Final Video", visible=False)
    action_button = gr.Button(value="Crop", interactive=False)

    video_input.change(on_video_upload, inputs=video_input, outputs=[frame_slider, annotator])
    frame_slider.change(load_frame, inputs=[video_input, frame_slider], outputs=annotator)
    annotator.change(handle_box_select, inputs=annotator, outputs=action_button)
    action_button.click(unified_button_handler, inputs=[cropped_image, toy_annotator, video_input], outputs=[cropped_image, annotator, frame_slider, toy_annotator, action_button, result_image, downloadable_output])

    demo.launch()