Daniel-F's picture
debug
b43eb74
import subprocess
import os
import sys
import subprocess
def run(cmd, cwd=None):
print(f"▶ {cmd}")
subprocess.check_call(cmd, shell=True, cwd=cwd)
def setup_deps():
print(f"mouse_yolo.pt size: {os.path.getsize('checkpoints/mouse_yolo.pt')} bytes")
# Use a flag to prevent infinite restarts
if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1":
return
# Try importing something to check if it's already set up
try:
import torch
import sam2
print("🔧 Dependencies already installed.")
return # all good, don't reinstall
except ImportError:
pass
print("🔧 Installing dependencies...")
run("pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu")
run("pip install -e .", cwd="segment-anything-2")
run("pip install --no-deps -r requirements_manual.txt")
run("git lfs install && git lfs pull", cwd=".") # assuming checkpoints/mouse_yolo.pt is tracked
# Relaunch the script with an env flag to avoid looping
print("♻️ Restarting app to apply changes...")
os.environ["HF_SPACE_BOOTSTRAPPED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
setup_deps()
import gradio as gr
import cv2
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from gradio_image_annotation import image_annotator
from train.model import ResNetNoseRegressor
import sam_utils
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO
import torch
import os
import tqdm
import tempfile
from pathlib import Path
# Global states
step = 0 # 0: crop, 1: nose, 2: toys, 3: track, 4: export
selected_box = None
last_frame = None
cropped_image_original = None
nose_point = None
toy_1_area = None
toy_2_area = None
sam_model = None
yolo_model = None
nose_model = None
vis_frames = []
meta = {
"start": None,
"crop": None,
"nose": None,
"toys": None,
}
def get_nose_point(cropped_image):
global meta, yolo_model
if yolo_model is None:
yolo_model = YOLO("checkpoints/mouse_yolo.pt")
results = yolo_model.predict(source=cropped_image, conf=0.5, verbose=False)
box_mouse = None
for r in results:
if len(r.boxes) == 0:
continue
# Get highest confidence box
best_idx = r.boxes.conf.argmax().item()
best_box = r.boxes[best_idx]
x1, y1, x2, y2 = map(int, best_box.xyxy[0].tolist())
box_mouse = np.array([x1, y1, x2, y2])
conf = float(best_box.conf)
if box_mouse is None or len(box_mouse) == 0:
meta["nose"] = (0,0)
return meta["nose"]
box_mouse = box_mouse.astype(np.int32)
mouse = cropped_image.copy()
mouse = mouse[box_mouse[1]:box_mouse[3], box_mouse[0]:box_mouse[2]]
device = "cuda" if torch.cuda.is_available() else "cpu"
global nose_model
if nose_model is None:
nose_model = ResNetNoseRegressor(pretrained=False)
nose_model.load_state_dict(torch.load("checkpoints/nose_detector.pth", map_location="cpu"))
nose_model.eval().to(device)
image = Image.fromarray(cv2.cvtColor(mouse, cv2.COLOR_BGR2RGB))
orig_w, orig_h = image.size
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(device)
with torch.no_grad():
pred = nose_model(image_tensor)[0].cpu().numpy() # shape: (2,) normalized
x_pred = int(pred[0] * orig_w)
y_pred = int(pred[1] * orig_h)
x_pred += box_mouse[0]
y_pred += box_mouse[1]
meta["nose"] = (x_pred, y_pred)
return meta["nose"]
def get_video_info(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): return 0, 0
fps = cap.get(cv2.CAP_PROP_FPS)
frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
return int(frames), fps
def get_frame(video_path, frame_index):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
cap.release()
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None
def load_frame(video_file, frame_idx):
global last_frame, meta
meta["start"] = frame_idx
last_frame = get_frame(video_file, frame_idx)
return {"image": last_frame, "boxes": []} if last_frame is not None else None
def handle_box_select(data):
global selected_box
boxes = data.get("boxes", [])
if boxes:
selected_box = boxes[0]
return gr.update(interactive=True)
return gr.update(interactive=False)
def unified_button_handler(cropped_img, toy_ann, video_file, progress=gr.Progress(track_tqdm=True)):
global step, selected_box, last_frame, cropped_image_original, nose_point
global toy_1_area, toy_2_area, sam_model, vis_frames
global meta, download_path
if step == 0:
if selected_box is None or last_frame is None:
return gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(value="Crop"), gr.update(visible=False), gr.update(visible=False)
x1, y1 = int(selected_box["xmin"]), int(selected_box["ymin"])
x2, y2 = int(selected_box["xmax"]), int(selected_box["ymax"])
meta["crop"] = {"x1": x1, "y1": y1, "x2": x2, "y2": y2}
cropped = last_frame[y1:y2, x1:x2]
cropped_image_original = cropped.copy()
step = 1
# Automatically determine nose point here using heuristic
# Replace the following line with your heuristic
nose_point = get_nose_point(cropped_image_original)
meta["nose"] = nose_point
# Optional visualization (can be removed)
img_copy = cropped_image_original.copy()
cv2.circle(img_copy, (nose_point[0], nose_point[1]), 4, (255, 0, 0), -1)
step = 2
return gr.update(value=cropped_image_original, visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value={"image": img_copy, "boxes": []}, visible=True), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False)
elif step == 2:
boxes = toy_ann.get("boxes", [])
if len(boxes) < 2:
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value="Confirm Toys"), gr.update(visible=False), gr.update(visible=False)
step = 3
toy_1_area = np.array([boxes[0]["xmin"], boxes[0]["ymin"], boxes[0]["xmax"], boxes[0]["ymax"]])
toy_2_area = np.array([boxes[1]["xmin"], boxes[1]["ymin"], boxes[1]["xmax"], boxes[1]["ymax"]])
if sam_model is None:
model = sam_utils.load_SAM2("checkpoints/sam2_hiera_small.pt", "checkpoints/sam2_hiera_s.yaml")
if torch.cuda.is_available():
model = model.to("cuda")
sam_model = SAM2ImagePredictor(model)
sam_model.set_image(cropped_img)
mask_1, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_1_area[None, :], multimask_output=False)
mask_2, _, _ = sam_model.predict(point_coords=None, point_labels=None, box=toy_2_area[None, :], multimask_output=False)
mask_1 = mask_1[0].astype(bool)
mask_2 = mask_2[0].astype(bool)
meta["toys"] = {"toy_1": mask_1.tolist(), "toy_2": mask_2.tolist()}
result = cropped_image_original.copy()
result[mask_1] = [0, 0, 255]
result[mask_2] = [0, 255, 0]
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Process Whole Video"), gr.update(value=result, visible=True), gr.update(visible=False)
elif step == 3:
step = 4
cap = cv2.VideoCapture(video_file)
ms_per_frame = 1000 / cap.get(cv2.CAP_PROP_FPS)
toy_1_time = 0.0
toy_2_time = 0.0
total_time = 0.0
cap.set(cv2.CAP_PROP_POS_FRAMES, meta["start"])
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - meta["start"]
vis_frames = []
qbar = tqdm.tqdm(total=total_frames, desc="Processing Video", unit="frame", leave=False)
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)[meta["crop"]["y1"]:meta["crop"]["y2"], meta["crop"]["x1"]:meta["crop"]["x2"]]
nose_point = get_nose_point(frame_rgb)
vis_frame = frame_rgb.copy()
overlay = vis_frame.copy()
cv2.circle(overlay, tuple(nose_point), 10, (255, 0, 0), -1) # draw on overlay
alpha = 0.4 # transparency factor
# Blend the overlay with the original frame
vis_frame = cv2.addWeighted(overlay, alpha, vis_frame, 1 - alpha, 0)
# Toy overlays
layer = np.zeros_like(vis_frame)
layer[np.array(meta["toys"]["toy_1"])] = [0, 255, 0]
layer[np.array(meta["toys"]["toy_2"])] = [0, 0, 255]
toy_1_mask = np.array(meta["toys"]["toy_1"]).copy().astype(np.uint8)
toy_2_mask = np.array(meta["toys"]["toy_2"]).copy().astype(np.uint8)
toy_1_mask = cv2.dilate(toy_1_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool)
toy_2_mask = cv2.dilate(toy_2_mask, np.ones((20, 20), np.uint8), iterations=1).astype(bool)
x, y = nose_point
increment = ms_per_frame / 1000.0
# text colors of all 5 lines, default is white
colors = [(255, 255, 255)] * 5
if toy_1_mask[y, x] or toy_2_mask[y, x]:
colors[3] = (255, 0, 0)
if toy_1_time + toy_2_time + increment >= 20:
increment = 20 - (toy_1_time + toy_2_time)
if toy_1_mask[y, x]:
toy_1_time += increment
colors[1] = (255, 0, 0) # Red for Toy 1
if toy_2_mask[y, x]:
toy_2_time += increment
colors[2] = (255, 0, 0) # Red for Toy 2
total_time += increment
# Time text
cv2.putText(vis_frame, f"Time spent on toys:", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[0], 2)
cv2.putText(vis_frame, f"Toy 1: {int(toy_1_time // 60):02}:{int(toy_1_time % 60):02}.{int((toy_1_time % 1) * 100):02}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[1], 2)
cv2.putText(vis_frame, f"Toy 2: {int(toy_2_time // 60):02}:{int(toy_2_time % 60):02}.{int((toy_2_time % 1) * 100):02}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[2], 2)
cv2.putText(vis_frame, f"Sum Time: {int((toy_1_time + toy_2_time) // 60):02}:{int((toy_1_time + toy_2_time) % 60):02}.{int(((toy_1_time + toy_2_time) % 1) * 100):02}", (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[3], 2)
cv2.putText(vis_frame, f"Total Time: {int(total_time // 60):02}:{int(total_time % 60):02}.{int((total_time % 1) * 100):02}", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, colors[4], 2)
vis_frame = cv2.addWeighted(vis_frame, 1, layer, 0.5, 0)
vis_frames.append(vis_frame)
if (toy_1_time + toy_2_time) >= 20:
break
qbar.update(1)
# Copy the last frame for 3 second
for _ in range(int(3 * cap.get(cv2.CAP_PROP_FPS))):
vis_frames.append(vis_frames[-1].copy())
qbar.close()
cap.release()
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="Export to Video"), gr.update(visible=False), gr.update(visible=False)
elif step == 4:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_path = tmp.name
h, w, _ = vis_frames[0].shape
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (w, h))
for f in vis_frames:
writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
writer.release()
step = 5
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=output_path, visible=True)
def on_video_upload(video_file):
global step
step = 0
total_frames, _ = get_video_info(video_file)
frame = load_frame(video_file, 0)
return gr.update(minimum=0, maximum=total_frames - 1), frame
with gr.Blocks() as demo:
video_input = gr.File(label="Upload Video", file_types=[".mp4", ".avi", ".mov"])
frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, step=1)
annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Tracking Area"], single_box=True)
cropped_image = gr.Image(label="Cropped Area", type="numpy", interactive=True, visible=False)
toy_annotator = image_annotator(value={"image": np.zeros((100,100,3)), "boxes": []}, label_list=["Toy 1", "Toy 2"], image_type="numpy", visible=False)
result_image = gr.Image(label="Final Output", visible=False)
downloadable_output = gr.DownloadButton(label="Download Final Video", visible=False)
action_button = gr.Button(value="Crop", interactive=False)
video_input.change(on_video_upload, inputs=video_input, outputs=[frame_slider, annotator])
frame_slider.change(load_frame, inputs=[video_input, frame_slider], outputs=annotator)
annotator.change(handle_box_select, inputs=annotator, outputs=action_button)
action_button.click(unified_button_handler, inputs=[cropped_image, toy_annotator, video_input], outputs=[cropped_image, annotator, frame_slider, toy_annotator, action_button, result_image, downloadable_output])
demo.launch()