John Ho
commited on
Commit
·
e7334c8
1
Parent(s):
df7d2e0
init comit
Browse files- app.py +60 -0
- ffmpeg_extractor.py +242 -0
- requirements.txt +7 -0
- samv2_handler.py +214 -0
- toolbox/mask_encoding.py +43 -0
- toolbox/vid_utils.py +351 -0
app.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces, torch
|
| 3 |
+
from samv2_handler import load_sam_image_model, run_sam_im_inference
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@spaces.GPU
|
| 9 |
+
def load_im_model(variant, auto_mask_gen: bool = False):
|
| 10 |
+
return load_sam_image_model(
|
| 11 |
+
variant=variant, device="cuda", auto_mask_gen=auto_mask_gen
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@spaces.GPU
|
| 16 |
+
def detect_image(
|
| 17 |
+
im: Image.Image,
|
| 18 |
+
variant: str,
|
| 19 |
+
bboxes: Union[list, str] = None,
|
| 20 |
+
points: Union[list, str] = None,
|
| 21 |
+
point_labels: Union[list, str] = None,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
SAM2 Image Segmentation
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
im: Pillow Image
|
| 28 |
+
object_name: the object you would like to detect
|
| 29 |
+
mode: point or object_detection
|
| 30 |
+
Returns:
|
| 31 |
+
list: a list of masks
|
| 32 |
+
"""
|
| 33 |
+
bboxes = json.loads(bboxes) if isinstance(bboxes, str) else bboxes
|
| 34 |
+
model = load_im_model(variant=variant)
|
| 35 |
+
return run_sam_im_inference(
|
| 36 |
+
model, image=im, bboxes=bboxes, get_pil_mask=False, b64_encode_mask=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
with gr.Blocks() as demo:
|
| 41 |
+
with gr.Tab("Images"):
|
| 42 |
+
gr.Interface(
|
| 43 |
+
fn=detect_image,
|
| 44 |
+
inputs=[
|
| 45 |
+
gr.Image(label="Input Image", type="pil"),
|
| 46 |
+
gr.Dropdown(
|
| 47 |
+
label="Model Variant",
|
| 48 |
+
choices=["tiny", "small", "base_plus", "large"],
|
| 49 |
+
),
|
| 50 |
+
gr.JSON(
|
| 51 |
+
label='Bounding Boxes (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
|
| 52 |
+
optional=True,
|
| 53 |
+
),
|
| 54 |
+
],
|
| 55 |
+
outputs=gr.JSON(label="Output JSON"),
|
| 56 |
+
title="SAM2 for Images",
|
| 57 |
+
)
|
| 58 |
+
demo.launch(
|
| 59 |
+
mcp_server=True, app_kwargs={"docs_url": "/docs"} # add FastAPI Swagger API Docs
|
| 60 |
+
)
|
ffmpeg_extractor.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ffmpeg, typer, os, sys, json, shutil
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
logger.remove()
|
| 5 |
+
logger.add(
|
| 6 |
+
sys.stderr,
|
| 7 |
+
format="<d>{time:YYYY-MM-DD ddd HH:mm:ss}</d> | <lvl>{level}</lvl> | <lvl>{message}</lvl>",
|
| 8 |
+
)
|
| 9 |
+
app = typer.Typer(pretty_exceptions_show_locals=False)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_frame_name(fname: str):
|
| 13 |
+
"""return a tuple of frame_type and frame_index"""
|
| 14 |
+
fn, fext = os.path.splitext(os.path.basename(fname))
|
| 15 |
+
frame_type, frame_index = fn.split("_")
|
| 16 |
+
return frame_type, int(frame_index)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_fps_ffmpeg(video_path: str):
|
| 20 |
+
probe = ffmpeg.probe(video_path)
|
| 21 |
+
# Find the first video stream
|
| 22 |
+
video_stream = next(
|
| 23 |
+
(stream for stream in probe["streams"] if stream["codec_type"] == "video"), None
|
| 24 |
+
)
|
| 25 |
+
if video_stream is None:
|
| 26 |
+
raise ValueError("No video stream found")
|
| 27 |
+
# Frame rate is given as a string fraction, e.g., '30000/1001'
|
| 28 |
+
r_frame_rate = video_stream["r_frame_rate"]
|
| 29 |
+
num, denom = map(int, r_frame_rate.split("/"))
|
| 30 |
+
return num / denom
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.command()
|
| 34 |
+
def extract_keyframes_greedy(
|
| 35 |
+
video_path: str,
|
| 36 |
+
output_dir: str = None,
|
| 37 |
+
threshold: float = 0.2,
|
| 38 |
+
overwrite: bool = False,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
run i-frames extractions and keyframes extraction and return a list of keyframe's paths
|
| 42 |
+
"""
|
| 43 |
+
assert (
|
| 44 |
+
threshold > 0
|
| 45 |
+
), f"threshold must be no negative, for i-frame extraction use extract-keyframes instead"
|
| 46 |
+
|
| 47 |
+
iframes = extract_keyframes(
|
| 48 |
+
video_path,
|
| 49 |
+
output_dir=output_dir,
|
| 50 |
+
threshold=0,
|
| 51 |
+
overwrite=overwrite,
|
| 52 |
+
append=False,
|
| 53 |
+
)
|
| 54 |
+
assert type(iframes) != type(None), f"i-frames extraction failed"
|
| 55 |
+
kframes = extract_keyframes(
|
| 56 |
+
video_path,
|
| 57 |
+
output_dir=output_dir,
|
| 58 |
+
threshold=threshold,
|
| 59 |
+
overwrite=False,
|
| 60 |
+
append=True,
|
| 61 |
+
)
|
| 62 |
+
assert type(kframes) != type(None), f"keyframes extraction failed"
|
| 63 |
+
|
| 64 |
+
# remove kframes that are also iframes
|
| 65 |
+
removed_kframes = []
|
| 66 |
+
for fn in kframes:
|
| 67 |
+
fname = os.path.basename(fn)
|
| 68 |
+
if os.path.isfile(
|
| 69 |
+
os.path.join(os.path.dirname(fn), fname.replace("kframe_", "iframe_"))
|
| 70 |
+
):
|
| 71 |
+
os.remove(fn)
|
| 72 |
+
removed_kframes.append(fn)
|
| 73 |
+
if len(removed_kframes) > 0:
|
| 74 |
+
logger.warning(f"removed {len(removed_kframes)} redundant kframes")
|
| 75 |
+
kframes = [kf for kf in kframes if kf not in removed_kframes]
|
| 76 |
+
|
| 77 |
+
frames = iframes + kframes
|
| 78 |
+
logger.success(f"extracted {len(frames)} total frames")
|
| 79 |
+
return frames
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.command()
|
| 83 |
+
def extract_keyframes(
|
| 84 |
+
video_path: str,
|
| 85 |
+
output_dir: str = None,
|
| 86 |
+
threshold: float = 0.3,
|
| 87 |
+
overwrite: bool = False,
|
| 88 |
+
append: bool = False,
|
| 89 |
+
):
|
| 90 |
+
"""extract keyframes as images into output_dir and return a list of keyframe's paths
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
output_dir: if not provided, will be in video_name/keyframes/
|
| 94 |
+
"""
|
| 95 |
+
# Create output directory if it doesn't exist
|
| 96 |
+
output_dir = output_dir if output_dir else os.path.dirname(video_path)
|
| 97 |
+
vname, vext = os.path.splitext(os.path.basename(video_path))
|
| 98 |
+
output_dir = os.path.join(output_dir, vname, "keyframes")
|
| 99 |
+
if os.path.isdir(output_dir):
|
| 100 |
+
if overwrite:
|
| 101 |
+
shutil.rmtree(output_dir)
|
| 102 |
+
logger.warning(f"removed existing data: {output_dir}")
|
| 103 |
+
elif not append:
|
| 104 |
+
logger.error(f"overwrite is false and data already exists!")
|
| 105 |
+
return None
|
| 106 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 107 |
+
|
| 108 |
+
# Construct the ffmpeg-python pipeline
|
| 109 |
+
stream = ffmpeg.input(video_path)
|
| 110 |
+
config_dict = {
|
| 111 |
+
"vsync": "0",
|
| 112 |
+
"frame_pts": "true",
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
if threshold:
|
| 116 |
+
# always add in the first frame by default
|
| 117 |
+
filter_value = f"eq(n,0)+gt(scene,{threshold})"
|
| 118 |
+
frame_name = "kframe"
|
| 119 |
+
logger.info(f"Extracting Scene-changing frames with {filter_value}")
|
| 120 |
+
else:
|
| 121 |
+
filter_value = f"eq(pict_type,I)"
|
| 122 |
+
# config_dict["skip_frame"] = "nokey"
|
| 123 |
+
frame_name = "iframe"
|
| 124 |
+
logger.info(f"Extracting I-Frames since no threshold provided: {filter_value}")
|
| 125 |
+
|
| 126 |
+
stream = ffmpeg.filter(stream, "select", filter_value)
|
| 127 |
+
stream = ffmpeg.output(stream, f"{output_dir}/{frame_name}_%d.jpg", **config_dict)
|
| 128 |
+
|
| 129 |
+
# Execute the ffmpeg command
|
| 130 |
+
try:
|
| 131 |
+
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
| 132 |
+
frames = [
|
| 133 |
+
os.path.join(output_dir, f)
|
| 134 |
+
for f in os.listdir(output_dir)
|
| 135 |
+
if f.endswith(".jpg") and frame_name in f
|
| 136 |
+
]
|
| 137 |
+
logger.success(f"{len(frames)} {frame_name} extracted to {output_dir}")
|
| 138 |
+
return frames
|
| 139 |
+
except ffmpeg.Error as e:
|
| 140 |
+
logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@app.command()
|
| 145 |
+
def extract_audio(video_path: str, output_dir: str = None, overwrite: bool = False):
|
| 146 |
+
"""extracting audio of a video file into m4a without re-encoding
|
| 147 |
+
ref: https://www.baeldung.com/linux/ffmpeg-audio-from-video#1-extracting-audio-without-re-encoding
|
| 148 |
+
"""
|
| 149 |
+
# Create output directory if it doesn't exist
|
| 150 |
+
output_dir = output_dir if output_dir else os.path.dirname(video_path)
|
| 151 |
+
vname, vext = os.path.splitext(os.path.basename(video_path))
|
| 152 |
+
output_dir = os.path.join(output_dir, vname)
|
| 153 |
+
output_fname = os.path.join(output_dir, vname + ".m4a")
|
| 154 |
+
if os.path.isfile(output_fname):
|
| 155 |
+
if overwrite:
|
| 156 |
+
os.remove(output_fname)
|
| 157 |
+
logger.warning(f"removed existing data: {output_fname}")
|
| 158 |
+
else:
|
| 159 |
+
logger.error(f"overwrite is false and data already exists!")
|
| 160 |
+
return None
|
| 161 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
# Construct the ffmpeg-python pipeline
|
| 164 |
+
stream = ffmpeg.input(video_path)
|
| 165 |
+
config_dict = {"map": "0:a", "acodec": "copy"}
|
| 166 |
+
stream = ffmpeg.output(stream, output_fname, **config_dict)
|
| 167 |
+
|
| 168 |
+
# Execute the ffmpeg command
|
| 169 |
+
try:
|
| 170 |
+
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
| 171 |
+
logger.success(f"audio extracted to {output_fname}")
|
| 172 |
+
return output_fname
|
| 173 |
+
except ffmpeg.Error as e:
|
| 174 |
+
logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@app.command()
|
| 179 |
+
def extract_frames(
|
| 180 |
+
video_path: str,
|
| 181 |
+
output_dir: str = None,
|
| 182 |
+
fps: int = None,
|
| 183 |
+
every_x: int = None,
|
| 184 |
+
overwrite: bool = False,
|
| 185 |
+
append: bool = False,
|
| 186 |
+
im_name_pattern: str = "frame_%05d.jpg",
|
| 187 |
+
):
|
| 188 |
+
"""extract frames as images into output_dir and return the list of frames' paths
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
output_dir: if not provided, will be in video_name/keyframes/
|
| 192 |
+
"""
|
| 193 |
+
# Create output directory if it doesn't exist
|
| 194 |
+
vname, vext = os.path.splitext(os.path.basename(video_path))
|
| 195 |
+
output_dir = output_dir if output_dir else os.path.dirname(video_path)
|
| 196 |
+
output_dir = os.path.join(output_dir, vname, "keyframes")
|
| 197 |
+
if os.path.isdir(output_dir):
|
| 198 |
+
if overwrite:
|
| 199 |
+
shutil.rmtree(output_dir)
|
| 200 |
+
logger.warning(f"removed existing data: {output_dir}")
|
| 201 |
+
elif not append:
|
| 202 |
+
logger.error(f"overwrite is false and data already exists in {output_dir}!")
|
| 203 |
+
return None
|
| 204 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 205 |
+
|
| 206 |
+
# Construct the ffmpeg-python pipeline
|
| 207 |
+
stream = ffmpeg.input(video_path)
|
| 208 |
+
config_dict = {
|
| 209 |
+
"vsync": 0, # preserves the original timestamps
|
| 210 |
+
"frame_pts": 1, # set output file's %d to the frame's PTS
|
| 211 |
+
}
|
| 212 |
+
if fps:
|
| 213 |
+
# check FPS
|
| 214 |
+
vid_fps = get_fps_ffmpeg(video_path)
|
| 215 |
+
fps = min(vid_fps, fps)
|
| 216 |
+
logger.info(f"{vname}{vext} FPS: {vid_fps}, extraction FPS: {fps}")
|
| 217 |
+
config_dict["vf"] = f"fps={fps}"
|
| 218 |
+
elif every_x:
|
| 219 |
+
config_dict["vf"] = f"select=not(mod(n\,{every_x}))"
|
| 220 |
+
|
| 221 |
+
logger.info(
|
| 222 |
+
f"Extracting Frames into {output_dir} with these configs: \n{config_dict}"
|
| 223 |
+
)
|
| 224 |
+
stream = ffmpeg.output(stream, f"{output_dir}/{im_name_pattern}", **config_dict)
|
| 225 |
+
|
| 226 |
+
# Execute the ffmpeg command
|
| 227 |
+
try:
|
| 228 |
+
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True)
|
| 229 |
+
frames = [
|
| 230 |
+
os.path.join(output_dir, f)
|
| 231 |
+
for f in os.listdir(output_dir)
|
| 232 |
+
if f.endswith(".jpg")
|
| 233 |
+
]
|
| 234 |
+
logger.success(f"{len(frames)} frames extracted to {output_dir}")
|
| 235 |
+
return frames
|
| 236 |
+
except ffmpeg.Error as e:
|
| 237 |
+
logger.error(f"Error executing FFmpeg command: {e.stderr.decode()}")
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
app()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg-python>=0.2.0
|
| 2 |
+
imageio[ffmpeg]>=2.37.0
|
| 3 |
+
loguru>=0.7.3
|
| 4 |
+
pydantic
|
| 5 |
+
retrying>=1.3.4
|
| 6 |
+
samv2==0.0.4
|
| 7 |
+
validators>=0.35.0
|
samv2_handler.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, shutil
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import Literal, Any, Union, Generic, List
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
| 7 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 8 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 9 |
+
from sam2.utils.misc import variant_to_config_mapping
|
| 10 |
+
from sam2.utils.visualization import show_masks
|
| 11 |
+
from ffmpeg_extractor import extract_frames, logger
|
| 12 |
+
from toolbox.vid_utils import VidInfo
|
| 13 |
+
from toolbox.mask_encoding import b64_mask_encode
|
| 14 |
+
|
| 15 |
+
variant_checkpoints_mapping = {
|
| 16 |
+
"tiny": "checkpoints/sam2_hiera_tiny.pt",
|
| 17 |
+
"small": "checkpoints/sam2_hiera_small.pt",
|
| 18 |
+
"base_plus": "checkpoints/sam2_hiera_base_plus.pt",
|
| 19 |
+
"large": "checkpoints/sam2_hiera_large.pt",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class bbox_xyxy(BaseModel):
|
| 24 |
+
x0: Union[int, float]
|
| 25 |
+
y0: Union[int, float]
|
| 26 |
+
x1: Union[int, float]
|
| 27 |
+
y1: Union[int, float]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class point_xy(BaseModel):
|
| 31 |
+
x: Union[int, float]
|
| 32 |
+
y: Union[int, float]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def mask_to_xyxy(mask: np.ndarray) -> tuple:
|
| 36 |
+
"""Convert a binary mask of shape (h, w) to
|
| 37 |
+
xyxy bounding box format (top-left and bottom-right coordinates).
|
| 38 |
+
"""
|
| 39 |
+
ys, xs = np.where(mask)
|
| 40 |
+
if len(xs) == 0 or len(ys) == 0:
|
| 41 |
+
logger.warning("mask_to_xyxy: No object found in the mask")
|
| 42 |
+
return None
|
| 43 |
+
x_min = np.min(xs)
|
| 44 |
+
y_min = np.min(ys)
|
| 45 |
+
x_max = np.max(xs)
|
| 46 |
+
y_max = np.max(ys)
|
| 47 |
+
xyxy = (x_min, y_min, x_max, y_max)
|
| 48 |
+
xyxy = tuple([int(i) for i in xyxy])
|
| 49 |
+
return xyxy
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_sam_image_model(
|
| 53 |
+
# variant: Literal[*variant_checkpoints_mapping.keys()],
|
| 54 |
+
variant: Literal["tiny", "small", "base_plus", "large"],
|
| 55 |
+
device: str = "cpu",
|
| 56 |
+
auto_mask_gen: bool = False,
|
| 57 |
+
) -> SAM2ImagePredictor:
|
| 58 |
+
model = build_sam2(
|
| 59 |
+
config_file=variant_to_config_mapping[variant],
|
| 60 |
+
ckpt_path=variant_checkpoints_mapping[variant],
|
| 61 |
+
device=device,
|
| 62 |
+
)
|
| 63 |
+
return (
|
| 64 |
+
SAM2AutomaticMaskGenerator(model)
|
| 65 |
+
if auto_mask_gen
|
| 66 |
+
else SAM2ImagePredictor(sam_model=model)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_sam_video_model(
|
| 71 |
+
variant: Literal["tiny", "small", "base_plus", "large"] = "small",
|
| 72 |
+
device: str = "cpu",
|
| 73 |
+
) -> Any:
|
| 74 |
+
return build_sam2_video_predictor(
|
| 75 |
+
config_file=variant_to_config_mapping[variant],
|
| 76 |
+
ckpt_path=variant_checkpoints_mapping[variant],
|
| 77 |
+
device=device,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def run_sam_im_inference(
|
| 82 |
+
model: Any,
|
| 83 |
+
image: Image.Image,
|
| 84 |
+
points: Union[List[point_xy], List[dict]] = [],
|
| 85 |
+
point_labels: List[int] = [],
|
| 86 |
+
bboxes: Union[List[bbox_xyxy], List[dict]] = [],
|
| 87 |
+
get_pil_mask: bool = False,
|
| 88 |
+
b64_encode_mask: bool = False,
|
| 89 |
+
):
|
| 90 |
+
"""returns a list of np masks, each with the shape (h,w) and dtype uint8"""
|
| 91 |
+
assert (
|
| 92 |
+
points or bboxes
|
| 93 |
+
), f"SAM2 Image Inference must have either bounding boxes or points. Neither were provided."
|
| 94 |
+
if points:
|
| 95 |
+
assert len(points) == len(
|
| 96 |
+
point_labels
|
| 97 |
+
), f"{len(points)} points provided but {len(point_labels)} labels given."
|
| 98 |
+
|
| 99 |
+
# determine multimask_output
|
| 100 |
+
has_multi = False
|
| 101 |
+
if points and bboxes:
|
| 102 |
+
has_multi = True
|
| 103 |
+
elif points and len(list(set(point_labels))) > 1:
|
| 104 |
+
has_multi = True
|
| 105 |
+
elif bboxes and len(bboxes) > 1:
|
| 106 |
+
has_multi = True
|
| 107 |
+
|
| 108 |
+
# parse provided bboxes
|
| 109 |
+
bboxes = (
|
| 110 |
+
[bbox_xyxy(**bbox) if isinstance(bbox, dict) else bbox for bbox in bboxes]
|
| 111 |
+
if bboxes
|
| 112 |
+
else []
|
| 113 |
+
)
|
| 114 |
+
points = (
|
| 115 |
+
[point_xy(**p) if isinstance(p, dict) else p for p in points] if points else []
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# setup inference
|
| 119 |
+
image = np.array(image.convert("RGB"))
|
| 120 |
+
model.set_image(image)
|
| 121 |
+
|
| 122 |
+
box_coords = (
|
| 123 |
+
np.array([[b.x0, b.y0, b.x1, b.y1] for b in bboxes]) if bboxes else None
|
| 124 |
+
)
|
| 125 |
+
point_coords = np.array([[p.x, p.y] for p in points]) if points else None
|
| 126 |
+
point_labels = np.array(point_labels) if point_labels else None
|
| 127 |
+
|
| 128 |
+
masks, scores, _ = model.predict(
|
| 129 |
+
box=box_coords,
|
| 130 |
+
point_coords=point_coords,
|
| 131 |
+
point_labels=point_labels,
|
| 132 |
+
multimask_output=has_multi,
|
| 133 |
+
)
|
| 134 |
+
# mask here is of shape (X, h, w) of np array, X = number of masks
|
| 135 |
+
|
| 136 |
+
if get_pil_mask:
|
| 137 |
+
return show_masks(image, masks, scores=None, display_image=False)
|
| 138 |
+
else:
|
| 139 |
+
output_masks = []
|
| 140 |
+
for i, mask in enumerate(masks):
|
| 141 |
+
if mask.ndim > 2: # shape (3, h, w)
|
| 142 |
+
mask = np.transpose(mask, (1, 2, 0)) # shape (h,w,3)
|
| 143 |
+
mask = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
|
| 144 |
+
output_masks.append(np.array(mask))
|
| 145 |
+
else:
|
| 146 |
+
output_masks.append(mask.squeeze().astype(np.uint8))
|
| 147 |
+
return (
|
| 148 |
+
[b64_mask_encode(m) for m in output_masks]
|
| 149 |
+
if b64_encode_mask
|
| 150 |
+
else output_masks
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def run_sam_video_inference(
|
| 155 |
+
model: Any,
|
| 156 |
+
video_path: str,
|
| 157 |
+
masks: np.ndarray,
|
| 158 |
+
device: str = "cpu",
|
| 159 |
+
sample_fps: int = None,
|
| 160 |
+
every_x: int = None,
|
| 161 |
+
do_tidy_up: bool = False,
|
| 162 |
+
drop_mask: bool = True,
|
| 163 |
+
):
|
| 164 |
+
# put video frames into directory
|
| 165 |
+
# TODO:
|
| 166 |
+
# change frame size
|
| 167 |
+
# async frame load
|
| 168 |
+
l_frames_fp = extract_frames(
|
| 169 |
+
video_path,
|
| 170 |
+
fps=sample_fps,
|
| 171 |
+
every_x=every_x,
|
| 172 |
+
overwrite=True,
|
| 173 |
+
im_name_pattern="%05d.jpg",
|
| 174 |
+
)
|
| 175 |
+
vframes_dir = os.path.dirname(l_frames_fp[0])
|
| 176 |
+
vinfo = VidInfo(video_path)
|
| 177 |
+
w = vinfo["frame_width"]
|
| 178 |
+
h = vinfo["frame_height"]
|
| 179 |
+
|
| 180 |
+
inference_state = model.init_state(video_path=vframes_dir, device=device)
|
| 181 |
+
for i, mask in enumerate(masks):
|
| 182 |
+
model.add_new_mask(
|
| 183 |
+
inference_state=inference_state, frame_idx=0, obj_id=i, mask=mask
|
| 184 |
+
)
|
| 185 |
+
masks_generator = model.propagate_in_video(inference_state)
|
| 186 |
+
|
| 187 |
+
detections = []
|
| 188 |
+
for i, tracker_ids, mask_logits in masks_generator:
|
| 189 |
+
masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
|
| 190 |
+
for id, mask in zip(tracker_ids, masks):
|
| 191 |
+
mask = mask.squeeze().astype(np.uint8)
|
| 192 |
+
xyxy = mask_to_xyxy(mask)
|
| 193 |
+
if not xyxy: # mask is empty
|
| 194 |
+
logger.debug(f"track_id {id} is missing mask at frame {i}")
|
| 195 |
+
continue
|
| 196 |
+
x0, y0, x1, y1 = xyxy
|
| 197 |
+
det = { # miro's detections format for videos
|
| 198 |
+
"frame": i,
|
| 199 |
+
"track_id": id,
|
| 200 |
+
"x": x0 / w,
|
| 201 |
+
"y": y0 / h,
|
| 202 |
+
"w": (x1 - x0) / w,
|
| 203 |
+
"h": (y1 - y0) / h,
|
| 204 |
+
"conf": 1,
|
| 205 |
+
}
|
| 206 |
+
if not drop_mask:
|
| 207 |
+
det["mask_b64"] = b64_mask_encode(mask)
|
| 208 |
+
detections.append(det)
|
| 209 |
+
|
| 210 |
+
if do_tidy_up:
|
| 211 |
+
# remove vframes_dir
|
| 212 |
+
shutil.rmtree(vframes_dir)
|
| 213 |
+
|
| 214 |
+
return detections
|
toolbox/mask_encoding.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64, os, io, random, time
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def b64_mask_encode(mask_np_arr, tmp_dir = '/tmp/miro/mask_encoding/'):
|
| 6 |
+
'''
|
| 7 |
+
turn a binary mask in numpy into a base64 string
|
| 8 |
+
'''
|
| 9 |
+
mask_im = Image.fromarray(np.array(mask_np_arr).astype(np.uint8)*255)
|
| 10 |
+
mask_im = mask_im.convert(mode = '1') # convert to 1bit image
|
| 11 |
+
|
| 12 |
+
if not os.path.isdir(tmp_dir):
|
| 13 |
+
print(f'b64_mask_encode: making tmp dir for mask encoding...')
|
| 14 |
+
os.makedirs(tmp_dir)
|
| 15 |
+
|
| 16 |
+
timestr = time.strftime("%Y%m%d-%H%M%S")
|
| 17 |
+
hash_str = random.getrandbits(128)
|
| 18 |
+
tmp_fname = tmp_dir + f'{timestr}_{hash_str}_mask.png'
|
| 19 |
+
mask_im.save(tmp_fname)
|
| 20 |
+
return base64.b64encode(open(tmp_fname, 'rb').read())
|
| 21 |
+
|
| 22 |
+
def b64_mask_decode(b64_string):
|
| 23 |
+
'''
|
| 24 |
+
decode a base64 string back to a binary mask numpy array
|
| 25 |
+
'''
|
| 26 |
+
im_bytes = base64.b64decode(b64_string)
|
| 27 |
+
im_decode = Image.open(io.BytesIO(im_bytes))
|
| 28 |
+
return np.array(im_decode)
|
| 29 |
+
|
| 30 |
+
def get_true_mask(mask_arr, im_w_h:tuple, x0, y0, x1, y1):
|
| 31 |
+
'''
|
| 32 |
+
decode the mask of CM output to get a mask that's the same size as source im
|
| 33 |
+
'''
|
| 34 |
+
if x0 > im_w_h[0] or x1 > im_w_h[0] or y0 > im_w_h[1] or y1 > im_w_h[1]:
|
| 35 |
+
raise ValueError(f'get_true_mask: Xs and Ys exceeded im_w_h bound: {im_w_h}')
|
| 36 |
+
|
| 37 |
+
if mask_arr.shape != (y1 - y0, x1 - x0):
|
| 38 |
+
raise ValueError(f'get_true_mask: Bounding Box h: {y1-y0} w: {x1-x0} does not match mask shape: {mask_arr.shape}')
|
| 39 |
+
|
| 40 |
+
w, h = im_w_h
|
| 41 |
+
mask = np.zeros((h,w), dtype = np.uint8)
|
| 42 |
+
mask[y0:y1, x0:x1] = mask_arr
|
| 43 |
+
return mask
|
toolbox/vid_utils.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import cv2, imageio, ffmpeg, os, time, shutil
|
| 4 |
+
|
| 5 |
+
def VidInfo(vid_path):
|
| 6 |
+
'''
|
| 7 |
+
returns a dictonary of 'duration', 'fps', 'frame_count', 'frame_height', 'frame_width',
|
| 8 |
+
'format', 'fourcc'
|
| 9 |
+
'''
|
| 10 |
+
vcap = cv2.VideoCapture(vid_path)
|
| 11 |
+
if not vcap.isOpened():
|
| 12 |
+
# cannot read video
|
| 13 |
+
if vid_path.startswith('https://'):
|
| 14 |
+
# likely a ffmpeg without open-ssl support issue
|
| 15 |
+
# https://github.com/opencv/opencv-python/issues/204
|
| 16 |
+
return VidInfo(vid_path.replace('https://','http://'))
|
| 17 |
+
else:
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
info_dict = {
|
| 21 |
+
'fps' : round(vcap.get(cv2.CAP_PROP_FPS),2), #int(vcap.get(cv2.CAP_PROP_FPS)),
|
| 22 |
+
'frame_count': int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)), # number of frames should integars
|
| 23 |
+
'duration': round(
|
| 24 |
+
int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)) / vcap.get(cv2.CAP_PROP_FPS),
|
| 25 |
+
2), # round number of seconds to 2 decimals
|
| 26 |
+
'frame_height': vcap.get(cv2.CAP_PROP_FRAME_HEIGHT),
|
| 27 |
+
'frame_width': vcap.get(cv2.CAP_PROP_FRAME_WIDTH),
|
| 28 |
+
'format': vcap.get(cv2.CAP_PROP_FORMAT),
|
| 29 |
+
'fourcc': vcap.get(cv2.CAP_PROP_FOURCC)
|
| 30 |
+
}
|
| 31 |
+
vcap.release()
|
| 32 |
+
return info_dict
|
| 33 |
+
|
| 34 |
+
def VidReader(vid_path, verbose = False, use_imageio = True):
|
| 35 |
+
'''
|
| 36 |
+
given a video file path, returns a list of images
|
| 37 |
+
Args:
|
| 38 |
+
vid_path: a MP4 file path
|
| 39 |
+
use_imageio: if true, function returns a ImageIO reader object (RGB);
|
| 40 |
+
otherwise, a list of CV2 array will be returned
|
| 41 |
+
'''
|
| 42 |
+
|
| 43 |
+
if use_imageio:
|
| 44 |
+
vid = imageio.get_reader(vid_path, 'ffmpeg')
|
| 45 |
+
return vid
|
| 46 |
+
|
| 47 |
+
vcap = cv2.VideoCapture(vid_path)
|
| 48 |
+
s_time = time.time()
|
| 49 |
+
|
| 50 |
+
# try to determine the total number of frames in Vid
|
| 51 |
+
frame_count = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 52 |
+
frame_rate = int(vcap.get(cv2.CAP_PROP_FPS))
|
| 53 |
+
if verbose:
|
| 54 |
+
print(f'\t{frame_count} total frames in video {vid_path}')
|
| 55 |
+
print(f'\t\t FPS: {frame_rate}')
|
| 56 |
+
print(f'\t\t Video Duration: {frame_count/ frame_rate}s')
|
| 57 |
+
|
| 58 |
+
# loop over frames
|
| 59 |
+
results = []
|
| 60 |
+
for i in tqdm(range(frame_count)):
|
| 61 |
+
grabbed, frame = vcap.read()
|
| 62 |
+
if grabbed:
|
| 63 |
+
results.append(frame)
|
| 64 |
+
|
| 65 |
+
# Output
|
| 66 |
+
r_time = "{:.2f}".format(time.time() - s_time)
|
| 67 |
+
if verbose:
|
| 68 |
+
print(f'\t{vid_path} loaded in {r_time} ({frame_count/float(r_time)} fps)')
|
| 69 |
+
vcap.release()
|
| 70 |
+
return results
|
| 71 |
+
|
| 72 |
+
def get_vid_frame(n, vid_path):
|
| 73 |
+
'''
|
| 74 |
+
return frame(s) in np.array specified by i
|
| 75 |
+
Args:
|
| 76 |
+
n: list of int
|
| 77 |
+
'''
|
| 78 |
+
vreader = VidReader(vid_path, verbose = False, use_imageio = True)
|
| 79 |
+
fcount = VidInfo(vid_path)['frame_count']
|
| 80 |
+
|
| 81 |
+
if type(n) == list:
|
| 82 |
+
return [vreader.get_data(i) if i in range(fcount) else None for i in n]
|
| 83 |
+
elif type(n) == int:
|
| 84 |
+
return vreader.get_data(n) if n in range(fcount) else None
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f'n must be either int or list, {type(n)} detected.')
|
| 87 |
+
|
| 88 |
+
def vid_slicer(vid_path, output_path, start_frame, end_frame, keep_audio = False, overwrite = False):
|
| 89 |
+
'''
|
| 90 |
+
ref https://github.com/kkroening/ffmpeg-python/issues/184#issuecomment-493847192
|
| 91 |
+
'''
|
| 92 |
+
if not( os.path.isdir(os.path.dirname(output_path))):
|
| 93 |
+
raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
|
| 94 |
+
|
| 95 |
+
if os.path.isfile(output_path) and not overwrite:
|
| 96 |
+
warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
input_vid = ffmpeg.input(vid_path)
|
| 100 |
+
vid_info = VidInfo(vid_path)
|
| 101 |
+
end_frame += 1
|
| 102 |
+
|
| 103 |
+
if keep_audio:
|
| 104 |
+
vid = (
|
| 105 |
+
input_vid
|
| 106 |
+
.trim(start_frame = start_frame, end_frame = end_frame)
|
| 107 |
+
.setpts('PTS-STARTPTS')
|
| 108 |
+
)
|
| 109 |
+
aud = (
|
| 110 |
+
input_vid
|
| 111 |
+
.filter_('atrim', start = start_frame / vid_info['fps'], end = end_frame / vid_info['fps'])
|
| 112 |
+
.filter_('asetpts', 'PTS-STARTPTS')
|
| 113 |
+
)
|
| 114 |
+
joined = ffmpeg.concat(vid, aud, v = 1, a =1).node
|
| 115 |
+
output = ffmpeg.output(joined[0], joined[1], f'{output_path}').overwrite_output()
|
| 116 |
+
output.run()
|
| 117 |
+
else:
|
| 118 |
+
(
|
| 119 |
+
input_vid
|
| 120 |
+
.trim (start_frame = start_frame, end_frame = end_frame )
|
| 121 |
+
.setpts ('PTS-STARTPTS')
|
| 122 |
+
.output (f'{output_path}')
|
| 123 |
+
.overwrite_output()
|
| 124 |
+
.run()
|
| 125 |
+
)
|
| 126 |
+
return output_path
|
| 127 |
+
|
| 128 |
+
def vid_resize(vid_path, output_path, width, overwrite = False):
|
| 129 |
+
'''
|
| 130 |
+
use ffmpeg to resize the input video to the width given, keeping aspect ratio
|
| 131 |
+
'''
|
| 132 |
+
if not( os.path.isdir(os.path.dirname(output_path))):
|
| 133 |
+
raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
|
| 134 |
+
|
| 135 |
+
if os.path.isfile(output_path) and not overwrite:
|
| 136 |
+
warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
input_vid = ffmpeg.input(vid_path)
|
| 140 |
+
vid = (
|
| 141 |
+
input_vid
|
| 142 |
+
.filter('scale', width, -1)
|
| 143 |
+
.output(output_path)
|
| 144 |
+
.overwrite_output()
|
| 145 |
+
.run()
|
| 146 |
+
)
|
| 147 |
+
return output_path
|
| 148 |
+
|
| 149 |
+
def vid_reduce_framerate(vid_path, output_path, new_fps, overwrite = False):
|
| 150 |
+
'''
|
| 151 |
+
use ffmpeg to resize the input video to the width given, keeping aspect ratio
|
| 152 |
+
'''
|
| 153 |
+
if not( os.path.isdir(os.path.dirname(output_path))):
|
| 154 |
+
raise ValueError(f'output_path directory does not exists: {os.path.dirname(output_path)}')
|
| 155 |
+
|
| 156 |
+
if os.path.isfile(output_path) and not overwrite:
|
| 157 |
+
warnings.warn(f'{output_path} already exists but overwrite switch is False, nothing done.')
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
input_vid = ffmpeg.input(vid_path)
|
| 161 |
+
vid = (
|
| 162 |
+
input_vid
|
| 163 |
+
.filter('fps', fps = new_fps, round = 'up')
|
| 164 |
+
.output(output_path)
|
| 165 |
+
.overwrite_output()
|
| 166 |
+
.run()
|
| 167 |
+
)
|
| 168 |
+
return output_path
|
| 169 |
+
|
| 170 |
+
def seek_frame_count(VidReader, cv2_frame_count, guess_within = 0.1,
|
| 171 |
+
seek_rate = 1, bDebug = False):
|
| 172 |
+
'''
|
| 173 |
+
imageio/ffmpeg frame count could be different than cv2. this function
|
| 174 |
+
returns the true frame count in the given vid reader. Returns None if frame
|
| 175 |
+
count can't be determined
|
| 176 |
+
Args:
|
| 177 |
+
VidReader: ImageIO video reader object with method .get_data()
|
| 178 |
+
cv2_frame_count: frame count from cv2
|
| 179 |
+
guess_within: look for actual frame count within X% of cv2_frame_count
|
| 180 |
+
'''
|
| 181 |
+
max_guess = int(cv2_frame_count * (1-guess_within))
|
| 182 |
+
seek_rate = max(seek_rate, 1)
|
| 183 |
+
pbar = reversed(range(max_guess, cv2_frame_count, seek_rate))
|
| 184 |
+
if bDebug:
|
| 185 |
+
pbar = tqdm(pbar, desc = f'seeking frame')
|
| 186 |
+
print(f'seeking from {max_guess} to {cv2_frame_count} with seek_rate of {seek_rate}')
|
| 187 |
+
|
| 188 |
+
for i in pbar:
|
| 189 |
+
try:
|
| 190 |
+
im = VidReader.get_data(i)
|
| 191 |
+
except IndexError:
|
| 192 |
+
if bDebug:
|
| 193 |
+
print(f'{i} not found.')
|
| 194 |
+
continue
|
| 195 |
+
# Frame Found
|
| 196 |
+
if i+1 == cv2_frame_count:
|
| 197 |
+
print(f'seek_frame_count: found frame count at {i+1}')
|
| 198 |
+
return i + 1
|
| 199 |
+
else:
|
| 200 |
+
return seek_frame_count(VidReader, cv2_frame_count = i + seek_rate,
|
| 201 |
+
guess_within= seek_rate / (i + seek_rate),
|
| 202 |
+
seek_rate= int(seek_rate/2),
|
| 203 |
+
bDebug = bDebug)
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
def VidWriter(lFrames, output_path, strFourcc = 'MP4V', verbose = False, intFPS = 20, crf = None,
|
| 207 |
+
use_imageio = False):
|
| 208 |
+
'''
|
| 209 |
+
Given a list of images in numpy array format, it outputs a MP4 file
|
| 210 |
+
Args:
|
| 211 |
+
lFrames: list of numpy arrays or filename
|
| 212 |
+
output_path: a MP4 file path
|
| 213 |
+
strFourcc: four letter video codec; XVID is more preferable. MJPG results in high size video. X264 gives very small size video; see https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_gui/py_video_display/py_video_display.html
|
| 214 |
+
crf: Constant Rate Factor for ffmpeg video compression
|
| 215 |
+
'''
|
| 216 |
+
s_time = time.time()
|
| 217 |
+
|
| 218 |
+
if not output_path.endswith('.mp4'):
|
| 219 |
+
raise ValueError(f'VidWriter: only mp4 video output supported.')
|
| 220 |
+
|
| 221 |
+
if crf:
|
| 222 |
+
crf = int(crf)
|
| 223 |
+
if crf > 24 or crf < 18:
|
| 224 |
+
raise ValueError(f'VidWriter: crf must be between 18 and 24')
|
| 225 |
+
|
| 226 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
| 227 |
+
output_dir = os.path.dirname(output_path)
|
| 228 |
+
print(f'\t{output_dir} does not exist.\n\tCreating video file output directory: {output_dir}')
|
| 229 |
+
os.makedirs(output_dir)
|
| 230 |
+
|
| 231 |
+
if use_imageio:
|
| 232 |
+
writer = imageio.get_writer(output_path, fps = intFPS)
|
| 233 |
+
for frame in tqdm(lFrames, desc = "Writing video using ImageIO"):
|
| 234 |
+
if not type(frame) == np.ndarray:
|
| 235 |
+
# read from filename
|
| 236 |
+
if not os.path.isfile(frame):
|
| 237 |
+
raise ValueError(f'VidWriter: lFrames must be list of images (np.array) or filenames')
|
| 238 |
+
frame = imageio.imread(frame)
|
| 239 |
+
|
| 240 |
+
writer.append_data(frame)
|
| 241 |
+
writer.close()
|
| 242 |
+
else:
|
| 243 |
+
#init OpenCV Vid Writer:
|
| 244 |
+
H , W = lFrames[0].shape[:2]
|
| 245 |
+
#fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
| 246 |
+
fourcc = cv2.VideoWriter_fourcc(*strFourcc)
|
| 247 |
+
if verbose:
|
| 248 |
+
print(f'\tEncoding using fourcc: {strFourcc}')
|
| 249 |
+
writer = cv2.VideoWriter(output_path, fourcc, fps = intFPS, frameSize = (W, H), isColor = True)
|
| 250 |
+
|
| 251 |
+
for frame in tqdm(lFrames, desc = "Writing video using OpenCV"):
|
| 252 |
+
writer.write(frame)
|
| 253 |
+
writer.release()
|
| 254 |
+
|
| 255 |
+
# Output
|
| 256 |
+
r_time = "{:.2f}".format( max(time.time() - s_time, 0.01))
|
| 257 |
+
if verbose:
|
| 258 |
+
print(f'\t{output_path} written in {r_time} ({len(lFrames)/float(r_time)} fps)')
|
| 259 |
+
|
| 260 |
+
if crf:
|
| 261 |
+
if verbose:
|
| 262 |
+
print(f'\tCompressing {output_path} with FFmpeg using crf: {crf}')
|
| 263 |
+
|
| 264 |
+
isCompressed = VidCompress(output_path, crf = crf, use_ffmpy = False)
|
| 265 |
+
|
| 266 |
+
if verbose:
|
| 267 |
+
print(f'\tCompressed: {isCompressed}')
|
| 268 |
+
|
| 269 |
+
return output_path
|
| 270 |
+
|
| 271 |
+
def im_dir_to_video(im_dir, output_path, fps, tup_im_extension = ('.jpg'),
|
| 272 |
+
max_long_edge = 600, filename_len = 6, pixel_format = 'yuv420p',
|
| 273 |
+
tqdm_func = tqdm):
|
| 274 |
+
'''turn a directory of images into video using ffmpeg
|
| 275 |
+
ref: https://github.com/kkroening/ffmpeg-python/issues/95#issuecomment-401428324
|
| 276 |
+
Args:
|
| 277 |
+
pixel_format: for list of supported formats see https://en.wikipedia.org/wiki/FFmpeg#Pixel_formats
|
| 278 |
+
filename_len: ensure frame number are zero padded; 0 will skip this step
|
| 279 |
+
'''
|
| 280 |
+
if filename_len:
|
| 281 |
+
# Ensure Filenames are Zero padded
|
| 282 |
+
l_im_fp = [f for f in os.listdir(im_dir) if f.endswith(tup_im_extension)]
|
| 283 |
+
l_im_fp = sorted(l_im_fp, key = lambda f: int(f.split('.')[0]))
|
| 284 |
+
for f in tqdm_func(l_im_fp, desc = 'ensuring image filenames are zero padded'):
|
| 285 |
+
fname, fext = os.path.splitext(f)
|
| 286 |
+
padded_f = fname.zfill(filename_len) + fext
|
| 287 |
+
if not os.path.isfile(os.path.join(im_dir,padded_f)):
|
| 288 |
+
shutil.move(os.path.join(im_dir, f), os.path.join(im_dir, padded_f))
|
| 289 |
+
# removed symlink to f as it will duplicate the frames in video generation
|
| 290 |
+
# os.symlink(src = os.path.join(im_dir, padded_f), dst = os.path.join(im_dir, f))
|
| 291 |
+
#TODO: ensure image size are divisible by 2
|
| 292 |
+
|
| 293 |
+
im_dir += '' if im_dir.endswith('/') else '/'
|
| 294 |
+
im_stream_string = f'{im_dir}*.jpg'
|
| 295 |
+
# we need to escape special characters
|
| 296 |
+
im_stream_string = im_stream_string.translate(
|
| 297 |
+
str.maketrans(
|
| 298 |
+
{'[': r'\[',
|
| 299 |
+
']': r'\]'})
|
| 300 |
+
)
|
| 301 |
+
r = (
|
| 302 |
+
ffmpeg
|
| 303 |
+
.input(im_stream_string, pattern_type = 'glob', framerate=fps)
|
| 304 |
+
.filter('format', pixel_format)
|
| 305 |
+
# .filter('pad', 'ceil(iw/2)*2:ceil(ih/2)*2')
|
| 306 |
+
.output(output_path)
|
| 307 |
+
.run()
|
| 308 |
+
)
|
| 309 |
+
return output_path
|
| 310 |
+
#
|
| 311 |
+
# def VidCompress(input_path, output_path = None, crf = 24, use_ffmpy = False):
|
| 312 |
+
# '''
|
| 313 |
+
# Compress input_path video (mp4 only) using ffmpy
|
| 314 |
+
# crf: Constant Rate Factor for ffmpeg video compression, must be between 18 and 24
|
| 315 |
+
# use_ffmpy: use ffmpy instead of commandline call to ffmpeg
|
| 316 |
+
# '''
|
| 317 |
+
# if not input_path.endswith('.mp4'):
|
| 318 |
+
# print(f'\tFATAL: only mp4 videos supported.')
|
| 319 |
+
# return None
|
| 320 |
+
#
|
| 321 |
+
# output_fname = output_path if output_path else input_path
|
| 322 |
+
# tmp_fname = input_path.replace(".mp4","_tmp.mp4")
|
| 323 |
+
# os.rename(input_path, tmp_fname)
|
| 324 |
+
#
|
| 325 |
+
# try:
|
| 326 |
+
# if not use_ffmpy:
|
| 327 |
+
# #os.popen(f'ffmpeg -i {tmp_fname} -vcodec libx264 -crf {crf} {output_fname}')
|
| 328 |
+
#
|
| 329 |
+
# cmdOut = subprocess.Popen(['ffmpeg', '-i', tmp_fname, '-vcodec', 'libx264', '-crf', str(crf), output_fname],
|
| 330 |
+
# stdout = subprocess.PIPE,
|
| 331 |
+
# stderr = subprocess.STDOUT)
|
| 332 |
+
# stdout, stderr = cmdOut.communicate()
|
| 333 |
+
# if not stderr:
|
| 334 |
+
# os.remove(tmp_fname)
|
| 335 |
+
# return True
|
| 336 |
+
# else:
|
| 337 |
+
# return False
|
| 338 |
+
# else:
|
| 339 |
+
# ff = FFmpeg(
|
| 340 |
+
# inputs = {tmp_fname : None},
|
| 341 |
+
# outputs = {output_fname : f'-vcodec libx264 -crf {crf}'}
|
| 342 |
+
# )
|
| 343 |
+
# ff.run()
|
| 344 |
+
#
|
| 345 |
+
# os.remove(tmp_fname)
|
| 346 |
+
# return True
|
| 347 |
+
#
|
| 348 |
+
# except OSError as e:
|
| 349 |
+
# print(f'\tWARNING: Compression Failed; OSError\n\tLikely out of RAM\n\tError Msg: {e}')
|
| 350 |
+
# os.rename(tmp_fname, output_fname)
|
| 351 |
+
# return False
|