import subprocess import os import sys import subprocess def run(cmd, cwd=None): print(f"▶ {cmd}") subprocess.check_call(cmd, shell=True, cwd=cwd) def setup_deps(): print(f"mouse_yolo.pt size: {os.path.getsize('checkpoints/mouse_yolo.pt')} bytes") # Use a flag to prevent infinite restarts if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1": return # Try importing something to check if it's already set up try: import torch import sam2 print("🔧 Dependencies already installed.") return # all good, don't reinstall except ImportError: pass print("🔧 Installing dependencies...") run("pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu") run("pip install -e .", cwd="segment-anything-2") run("pip install --no-deps -r requirements_manual.txt") run("git lfs install && git lfs pull", cwd=".") # assuming checkpoints/mouse_yolo.pt is tracked # Relaunch the script with an env flag to avoid looping print("♻️ Restarting app to apply changes...") os.environ["HF_SPACE_BOOTSTRAPPED"] = "1" os.execv(sys.executable, [sys.executable] + sys.argv) setup_deps() import gradio as gr import cv2 import numpy as np from sam2.sam2_image_predictor import SAM2ImagePredictor from gradio_image_annotation import image_annotator from train.model import ResNetNoseRegressor import sam_utils from PIL import Image from torchvision import transforms from ultralytics import YOLO import torch import os import tqdm import tempfile from pathlib import Path # Global states step = 0 # 0: crop, 1: nose, 2: toys, 3: track, 4: export selected_box = None last_frame = None cropped_image_original = None nose_point = None toy_1_area = None toy_2_area = None sam_model = None yolo_model = None nose_model = None vis_frames = [] meta = { "start": None, "crop": None, "nose": None, "toys": None, } def get_nose_point(cropped_image): global meta, yolo_model if yolo_model is None: yolo_model = YOLO("checkpoints/mouse_yolo.pt") results = yolo_model.predict(source=cropped_image, conf=0.5, verbose=False) box_mouse = None for r in results: if len(r.boxes) == 0: continue # Get highest confidence box best_idx = r.boxes.conf.argmax().item() best_box = r.boxes[best_idx] x1, y1, x2, y2 = map(int, best_box.xyxy[0].tolist()) box_mouse = np.array([x1, y1, x2, y2]) conf = float(best_box.conf) if box_mouse is None or len(box_mouse) == 0: meta["nose"] = (0,0) return meta["nose"] box_mouse = box_mouse.astype(np.int32) mouse = cropped_image.copy() mouse = mouse[box_mouse[1]:box_mouse[3], box_mouse[0]:box_mouse[2]] device = "cuda" if torch.cuda.is_available() else "cpu" global nose_model if nose_model is None: nose_model = ResNetNoseRegressor(pretrained=False) nose_model.load_state_dict(torch.load("checkpoints/nose_detector.pth", map_location="cpu")) nose_model.eval().to(device) image = Image.fromarray(cv2.cvtColor(mouse, cv2.COLOR_BGR2RGB)) orig_w, orig_h = image.size transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), ]) image_tensor = transform(image).unsqueeze(0) # Add batch dimension image_tensor = image_tensor.to(device) with torch.no_grad(): pred = nose_model(image_tensor)[0].cpu().numpy() # shape: (2,) normalized x_pred = int(pred[0] * orig_w) y_pred = int(pred[1] * orig_h) x_pred += box_mouse[0] y_pred += box_mouse[1] meta["nose"] = (x_pred, y_pred) return meta["nose"] def get_video_info(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return 0, 0 fps = cap.get(cv2.CAP_PROP_FPS) frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) cap.release() return int(frames), fps def get_frame(video_path, frame_index): cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() cap.release() return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None def load_frame(video_file, frame_idx): global last_frame, meta meta["start"] = frame_idx last_frame = get_frame(video_file, frame_idx) return {"image": last_frame, "boxes": []} if last_frame is not None else None def handle_box_select(data): global selected_box boxes = data.get("boxes", []) if boxes: selected_box = boxes[0] return gr.update(interactive=True) return gr.update(interactive=False) def unified_button_handler(cropped_img, toy_ann, video_file, progress=gr.Progress(track_tqdm=True)): global step, selected_box, last_frame, cropped_image_original, nose_point global toy_1_area, toy_2_area, sam_model, vis_frames global meta, download_path if step == 0: if selected_box is None or last_frame is None: return gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(value="Crop"), gr.update(visible=False), gr.update(visible=False) x1, y1 = int(selected_box["xmin"]), int(selected_box["ymin"]) x2, y2 = int(selected_box["xmax"]), int(selected_box["ymax"]) meta["crop"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2} cropped = last_frame[y1:y2, x1:x2] cropped_image_original = cropped.copy() step = 1 # Automatically determine nose point here using heuristic # Replace the following line with your heuristic nose_point = get_nose_point(cropped_image_original) meta["nose"] = nose_point # Optional visualization (can be removed) img_copy = cropped_image_original.copy() cv2.circle(img_copy, (nose_point[0], nose_point[1]), 4, (255, 0, 0), -1) step = 2 return gr.update(value=cropped_image_original, visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value={"image": img_copy, "boxes": []}, visible=True), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False) elif step == 2: boxes = toy_ann.get("boxes", []) if len(boxes) < 2: return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False) step = 3 toy_1_area = np.array([boxes[0]["xmin"], boxes[0]["ymin"], boxes[0]["xmax"], boxes[0]["ymax"]]) toy_2_area = np.array([boxes[1]["xmin"], boxes[1]["ymin"], boxes[1]["xmax"], boxes[1]["ymax"]]) if sam_model is None: model = sam_utils.load_SAM2("checkpoints/sam2_hiera_small.pt", "checkpoints/sam2_hiera_s.yaml") if torch.cuda.is_available(): model = model.to("cuda") sam_model = SAM2ImagePredictor(model) sam_model.set_image(cropped_img) mask_1, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_1_area[None, :], multimask_output=False) mask_2, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_2_area[None, :], multimask_output=False) mask_1 = mask_1[0].astype(bool) mask_2 = mask_2[0].astype(bool) meta["toys"] = {"toy_1": mask_1.tolist(), "toy_2": mask_2.tolist()} result = cropped_image_original.copy() result[mask_1] = [0, 0, 255] result[mask_2] = [0, 255, 0] return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Process Whole Video"), gr.update(value=result, visible=True), gr.update(visible=False) elif step == 3: step = 4 cap = cv2.VideoCapture(video_file) ms_per_frame = 1000 / cap.get(cv2.CAP_PROP_FPS) toy_1_time = 0.0 toy_2_time = 0.0 total_time = 0.0 cap.set(cv2.CAP_PROP_POS_FRAMES, meta["start"]) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - meta["start"] vis_frames = [] qbar = tqdm.tqdm(total=total_frames, desc="Processing Video", unit="frame", leave=False) while True: ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)[meta["crop"]["y1"]:meta["crop"]["y2"], meta["crop"]["x1"]:meta["crop"]["x2"]] nose_point = get_nose_point(frame_rgb) vis_frame = frame_rgb.copy() overlay = vis_frame.copy() cv2.circle(overlay, tuple(nose_point), 10, (255, 0, 0), -1) # draw on overlay alpha = 0.4 # transparency factor # Blend the overlay with the original frame vis_frame = cv2.addWeighted(overlay, alpha, vis_frame, 1 - alpha, 0) # Toy overlays layer = np.zeros_like(vis_frame) layer[np.array(meta["toys"]["toy_1"])] = [0, 255, 0] layer[np.array(meta["toys"]["toy_2"])] = [0, 0, 255] toy_1_mask = np.array(meta["toys"]["toy_1"]).copy().astype(np.uint8) toy_2_mask = np.array(meta["toys"]["toy_2"]).copy().astype(np.uint8) toy_1_mask = cv2.dilate(toy_1_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool) toy_2_mask = cv2.dilate(toy_2_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool) x, y = nose_point increment = ms_per_frame / 1000.0 # text colors of all 5 lines, default is white colors = [(255, 255, 255)] * 5 if toy_1_mask[y, x] or toy_2_mask[y, x]: colors[3] = (255, 0, 0) if toy_1_time + toy_2_time + increment >= 20: increment = 20 - (toy_1_time + toy_2_time) if toy_1_mask[y, x]: toy_1_time += increment colors[1] = (255, 0, 0) # Red for Toy 1 if toy_2_mask[y, x]: toy_2_time += increment colors[2] = (255, 0, 0) # Red for Toy 2 total_time += increment # Time text cv2.putText(vis_frame, f"Time spent on toys:", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[0], 2) cv2.putText(vis_frame, f"Toy 1: {int(toy_1_time // 60):02}:{int(toy_1_time % 60):02}.{int((toy_1_time % 1) * 100):02}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[1], 2) cv2.putText(vis_frame, f"Toy 2: {int(toy_2_time // 60):02}:{int(toy_2_time % 60):02}.{int((toy_2_time % 1) * 100):02}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[2], 2) cv2.putText(vis_frame, f"Sum Time: {int((toy_1_time + toy_2_time) // 60):02}:{int((toy_1_time + toy_2_time) % 60):02}.{int(((toy_1_time + toy_2_time) % 1) * 100):02}", (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[3], 2) cv2.putText(vis_frame, f"Total Time: {int(total_time // 60):02}:{int(total_time % 60):02}.{int((total_time % 1) * 100):02}", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[4], 2) vis_frame = cv2.addWeighted(vis_frame, 1, layer, 0.5, 0) vis_frames.append(vis_frame) if (toy_1_time + toy_2_time) >= 20: break qbar.update(1) # Copy the last frame for 3 second for _ in range(int(3 * cap.get(cv2.CAP_PROP_FPS))): vis_frames.append(vis_frames[-1].copy()) qbar.close() cap.release() return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Export to Video"), gr.update(visible=False), gr.update(visible=False) elif step == 4: with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: output_path = tmp.name h, w, _ = vis_frames[0].shape writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h)) for f in vis_frames: writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR)) writer.release() step = 5 return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=output_path, visible=True) def on_video_upload(video_file): global step step = 0 total_frames, _ = get_video_info(video_file) frame = load_frame(video_file, 0) return gr.update(minimum=0, maximum=total_frames - 1), frame with gr.Blocks() as demo: video_input = gr.File(label="Upload Video", file_types=[".mp4", ".avi", ".mov"]) frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, step=1) annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Tracking Area"], single_box=True) cropped_image = gr.Image(label="Cropped Area", type="numpy", interactive=True, visible=False) toy_annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Toy 1", "Toy 2"], image_type="numpy", visible=False) result_image = gr.Image(label="Final Output", visible=False) downloadable_output = gr.DownloadButton(label="Download Final Video", visible=False) action_button = gr.Button(value="Crop", interactive=False) video_input.change(on_video_upload, inputs=video_input, outputs=[frame_slider, annotator]) frame_slider.change(load_frame, inputs=[video_input, frame_slider], outputs=annotator) annotator.change(handle_box_select, inputs=annotator, outputs=action_button) action_button.click(unified_button_handler, inputs=[cropped_image, toy_annotator, video_input], outputs=[cropped_image, annotator, frame_slider, toy_annotator, action_button, result_image, downloadable_output]) demo.launch()