Spaces:
Sleeping
Sleeping
| import subprocess | |
| import os | |
| import sys | |
| import subprocess | |
| def run(cmd, cwd=None): | |
| print(f"▶ {cmd}") | |
| subprocess.check_call(cmd, shell=True, cwd=cwd) | |
| def setup_deps(): | |
| print(f"mouse_yolo.pt size: {os.path.getsize('checkpoints/mouse_yolo.pt')} bytes") | |
| # Use a flag to prevent infinite restarts | |
| if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1": | |
| return | |
| # Try importing something to check if it's already set up | |
| try: | |
| import torch | |
| import sam2 | |
| print("🔧 Dependencies already installed.") | |
| return # all good, don't reinstall | |
| except ImportError: | |
| pass | |
| print("🔧 Installing dependencies...") | |
| run("pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu") | |
| run("pip install -e .", cwd="segment-anything-2") | |
| run("pip install --no-deps -r requirements_manual.txt") | |
| run("git lfs install && git lfs pull", cwd=".") # assuming checkpoints/mouse_yolo.pt is tracked | |
| # Relaunch the script with an env flag to avoid looping | |
| print("♻️ Restarting app to apply changes...") | |
| os.environ["HF_SPACE_BOOTSTRAPPED"] = "1" | |
| os.execv(sys.executable, [sys.executable] + sys.argv) | |
| setup_deps() | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from gradio_image_annotation import image_annotator | |
| from train.model import ResNetNoseRegressor | |
| import sam_utils | |
| from PIL import Image | |
| from torchvision import transforms | |
| from ultralytics import YOLO | |
| import torch | |
| import os | |
| import tqdm | |
| import tempfile | |
| from pathlib import Path | |
| # Global states | |
| step = 0 # 0: crop, 1: nose, 2: toys, 3: track, 4: export | |
| selected_box = None | |
| last_frame = None | |
| cropped_image_original = None | |
| nose_point = None | |
| toy_1_area = None | |
| toy_2_area = None | |
| sam_model = None | |
| yolo_model = None | |
| nose_model = None | |
| vis_frames = [] | |
| meta = { | |
| "start": None, | |
| "crop": None, | |
| "nose": None, | |
| "toys": None, | |
| } | |
| def get_nose_point(cropped_image): | |
| global meta, yolo_model | |
| if yolo_model is None: | |
| yolo_model = YOLO("checkpoints/mouse_yolo.pt") | |
| results = yolo_model.predict(source=cropped_image, conf=0.5, verbose=False) | |
| box_mouse = None | |
| for r in results: | |
| if len(r.boxes) == 0: | |
| continue | |
| # Get highest confidence box | |
| best_idx = r.boxes.conf.argmax().item() | |
| best_box = r.boxes[best_idx] | |
| x1, y1, x2, y2 = map(int, best_box.xyxy[0].tolist()) | |
| box_mouse = np.array([x1, y1, x2, y2]) | |
| conf = float(best_box.conf) | |
| if box_mouse is None or len(box_mouse) == 0: | |
| meta["nose"] = (0,0) | |
| return meta["nose"] | |
| box_mouse = box_mouse.astype(np.int32) | |
| mouse = cropped_image.copy() | |
| mouse = mouse[box_mouse[1]:box_mouse[3], box_mouse[0]:box_mouse[2]] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| global nose_model | |
| if nose_model is None: | |
| nose_model = ResNetNoseRegressor(pretrained=False) | |
| nose_model.load_state_dict(torch.load("checkpoints/nose_detector.pth", map_location="cpu")) | |
| nose_model.eval().to(device) | |
| image = Image.fromarray(cv2.cvtColor(mouse, cv2.COLOR_BGR2RGB)) | |
| orig_w, orig_h = image.size | |
| transform = transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.ToTensor(), | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| image_tensor = image_tensor.to(device) | |
| with torch.no_grad(): | |
| pred = nose_model(image_tensor)[0].cpu().numpy() # shape: (2,) normalized | |
| x_pred = int(pred[0] * orig_w) | |
| y_pred = int(pred[1] * orig_h) | |
| x_pred += box_mouse[0] | |
| y_pred += box_mouse[1] | |
| meta["nose"] = (x_pred, y_pred) | |
| return meta["nose"] | |
| def get_video_info(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): return 0, 0 | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| cap.release() | |
| return int(frames), fps | |
| def get_frame(video_path, frame_index): | |
| cap = cv2.VideoCapture(video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ret, frame = cap.read() | |
| cap.release() | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None | |
| def load_frame(video_file, frame_idx): | |
| global last_frame, meta | |
| meta["start"] = frame_idx | |
| last_frame = get_frame(video_file, frame_idx) | |
| return {"image": last_frame, "boxes": []} if last_frame is not None else None | |
| def handle_box_select(data): | |
| global selected_box | |
| boxes = data.get("boxes", []) | |
| if boxes: | |
| selected_box = boxes[0] | |
| return gr.update(interactive=True) | |
| return gr.update(interactive=False) | |
| def unified_button_handler(cropped_img, toy_ann, video_file, progress=gr.Progress(track_tqdm=True)): | |
| global step, selected_box, last_frame, cropped_image_original, nose_point | |
| global toy_1_area, toy_2_area, sam_model, vis_frames | |
| global meta, download_path | |
| if step == 0: | |
| if selected_box is None or last_frame is None: | |
| return gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(value="Crop"), gr.update(visible=False), gr.update(visible=False) | |
| x1, y1 = int(selected_box["xmin"]), int(selected_box["ymin"]) | |
| x2, y2 = int(selected_box["xmax"]), int(selected_box["ymax"]) | |
| meta["crop"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2} | |
| cropped = last_frame[y1:y2, x1:x2] | |
| cropped_image_original = cropped.copy() | |
| step = 1 | |
| # Automatically determine nose point here using heuristic | |
| # Replace the following line with your heuristic | |
| nose_point = get_nose_point(cropped_image_original) | |
| meta["nose"] = nose_point | |
| # Optional visualization (can be removed) | |
| img_copy = cropped_image_original.copy() | |
| cv2.circle(img_copy, (nose_point[0], nose_point[1]), 4, (255, 0, 0), -1) | |
| step = 2 | |
| return gr.update(value=cropped_image_original, visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value={"image": img_copy, "boxes": []}, visible=True), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False) | |
| elif step == 2: | |
| boxes = toy_ann.get("boxes", []) | |
| if len(boxes) < 2: | |
| return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False) | |
| step = 3 | |
| toy_1_area = np.array([boxes[0]["xmin"], boxes[0]["ymin"], boxes[0]["xmax"], boxes[0]["ymax"]]) | |
| toy_2_area = np.array([boxes[1]["xmin"], boxes[1]["ymin"], boxes[1]["xmax"], boxes[1]["ymax"]]) | |
| if sam_model is None: | |
| model = sam_utils.load_SAM2("checkpoints/sam2_hiera_small.pt", "checkpoints/sam2_hiera_s.yaml") | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| sam_model = SAM2ImagePredictor(model) | |
| sam_model.set_image(cropped_img) | |
| mask_1, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_1_area[None, :], multimask_output=False) | |
| mask_2, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_2_area[None, :], multimask_output=False) | |
| mask_1 = mask_1[0].astype(bool) | |
| mask_2 = mask_2[0].astype(bool) | |
| meta["toys"] = {"toy_1": mask_1.tolist(), "toy_2": mask_2.tolist()} | |
| result = cropped_image_original.copy() | |
| result[mask_1] = [0, 0, 255] | |
| result[mask_2] = [0, 255, 0] | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Process Whole Video"), gr.update(value=result, visible=True), gr.update(visible=False) | |
| elif step == 3: | |
| step = 4 | |
| cap = cv2.VideoCapture(video_file) | |
| ms_per_frame = 1000 / cap.get(cv2.CAP_PROP_FPS) | |
| toy_1_time = 0.0 | |
| toy_2_time = 0.0 | |
| total_time = 0.0 | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, meta["start"]) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - meta["start"] | |
| vis_frames = [] | |
| qbar = tqdm.tqdm(total=total_frames, desc="Processing Video", unit="frame", leave=False) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)[meta["crop"]["y1"]:meta["crop"]["y2"], meta["crop"]["x1"]:meta["crop"]["x2"]] | |
| nose_point = get_nose_point(frame_rgb) | |
| vis_frame = frame_rgb.copy() | |
| overlay = vis_frame.copy() | |
| cv2.circle(overlay, tuple(nose_point), 10, (255, 0, 0), -1) # draw on overlay | |
| alpha = 0.4 # transparency factor | |
| # Blend the overlay with the original frame | |
| vis_frame = cv2.addWeighted(overlay, alpha, vis_frame, 1 - alpha, 0) | |
| # Toy overlays | |
| layer = np.zeros_like(vis_frame) | |
| layer[np.array(meta["toys"]["toy_1"])] = [0, 255, 0] | |
| layer[np.array(meta["toys"]["toy_2"])] = [0, 0, 255] | |
| toy_1_mask = np.array(meta["toys"]["toy_1"]).copy().astype(np.uint8) | |
| toy_2_mask = np.array(meta["toys"]["toy_2"]).copy().astype(np.uint8) | |
| toy_1_mask = cv2.dilate(toy_1_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool) | |
| toy_2_mask = cv2.dilate(toy_2_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool) | |
| x, y = nose_point | |
| increment = ms_per_frame / 1000.0 | |
| # text colors of all 5 lines, default is white | |
| colors = [(255, 255, 255)] * 5 | |
| if toy_1_mask[y, x] or toy_2_mask[y, x]: | |
| colors[3] = (255, 0, 0) | |
| if toy_1_time + toy_2_time + increment >= 20: | |
| increment = 20 - (toy_1_time + toy_2_time) | |
| if toy_1_mask[y, x]: | |
| toy_1_time += increment | |
| colors[1] = (255, 0, 0) # Red for Toy 1 | |
| if toy_2_mask[y, x]: | |
| toy_2_time += increment | |
| colors[2] = (255, 0, 0) # Red for Toy 2 | |
| total_time += increment | |
| # Time text | |
| cv2.putText(vis_frame, f"Time spent on toys:", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[0], 2) | |
| cv2.putText(vis_frame, f"Toy 1: {int(toy_1_time // 60):02}:{int(toy_1_time % 60):02}.{int((toy_1_time % 1) * 100):02}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[1], 2) | |
| cv2.putText(vis_frame, f"Toy 2: {int(toy_2_time // 60):02}:{int(toy_2_time % 60):02}.{int((toy_2_time % 1) * 100):02}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[2], 2) | |
| cv2.putText(vis_frame, f"Sum Time: {int((toy_1_time + toy_2_time) // 60):02}:{int((toy_1_time + toy_2_time) % 60):02}.{int(((toy_1_time + toy_2_time) % 1) * 100):02}", (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[3], 2) | |
| cv2.putText(vis_frame, f"Total Time: {int(total_time // 60):02}:{int(total_time % 60):02}.{int((total_time % 1) * 100):02}", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[4], 2) | |
| vis_frame = cv2.addWeighted(vis_frame, 1, layer, 0.5, 0) | |
| vis_frames.append(vis_frame) | |
| if (toy_1_time + toy_2_time) >= 20: | |
| break | |
| qbar.update(1) | |
| # Copy the last frame for 3 second | |
| for _ in range(int(3 * cap.get(cv2.CAP_PROP_FPS))): | |
| vis_frames.append(vis_frames[-1].copy()) | |
| qbar.close() | |
| cap.release() | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Export to Video"), gr.update(visible=False), gr.update(visible=False) | |
| elif step == 4: | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| output_path = tmp.name | |
| h, w, _ = vis_frames[0].shape | |
| writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h)) | |
| for f in vis_frames: | |
| writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR)) | |
| writer.release() | |
| step = 5 | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=output_path, visible=True) | |
| def on_video_upload(video_file): | |
| global step | |
| step = 0 | |
| total_frames, _ = get_video_info(video_file) | |
| frame = load_frame(video_file, 0) | |
| return gr.update(minimum=0, maximum=total_frames - 1), frame | |
| with gr.Blocks() as demo: | |
| video_input = gr.File(label="Upload Video", file_types=[".mp4", ".avi", ".mov"]) | |
| frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, step=1) | |
| annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Tracking Area"], single_box=True) | |
| cropped_image = gr.Image(label="Cropped Area", type="numpy", interactive=True, visible=False) | |
| toy_annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Toy 1", "Toy 2"], image_type="numpy", visible=False) | |
| result_image = gr.Image(label="Final Output", visible=False) | |
| downloadable_output = gr.DownloadButton(label="Download Final Video", visible=False) | |
| action_button = gr.Button(value="Crop", interactive=False) | |
| video_input.change(on_video_upload, inputs=video_input, outputs=[frame_slider, annotator]) | |
| frame_slider.change(load_frame, inputs=[video_input, frame_slider], outputs=annotator) | |
| annotator.change(handle_box_select, inputs=annotator, outputs=action_button) | |
| action_button.click(unified_button_handler, inputs=[cropped_image, toy_annotator, video_input], outputs=[cropped_image, annotator, frame_slider, toy_annotator, action_button, result_image, downloadable_output]) | |
| demo.launch() |