salso's picture
Upload 28 files
545e508 verified
raw
history blame
15.3 kB
import os
from typing import Tuple, Optional
import cv2
import gradio as gr
import numpy as np
#import spaces
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
from utils.video import generate_unique_name, create_directory, delete_directory
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_DETAILED_CAPTION_TASK, \
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.modes import IMAGE_INFERENCE_MODES, IMAGE_OPEN_VOCABULARY_DETECTION_MODE, \
IMAGE_CAPTION_GROUNDING_MASKS_MODE, VIDEO_INFERENCE_MODES
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
MARKDOWN = """
# Florence2 + SAM2 🔥
<div>
<a href="https://github.com/facebookresearch/segment-anything-2">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/what-is-segment-anything-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
</a>
</div>
This demo integrates Florence2 and SAM2 by creating a two-stage inference pipeline. In
the first stage, Florence2 performs tasks such as object detection, open-vocabulary
object detection, image captioning, or phrase grounding. In the second stage, SAM2
performs object segmentation on the image.
"""
IMAGE_PROCESSING_EXAMPLES = [
[IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 'straw, white napkin, black napkin, hair'],
[IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 'tail'],
[IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
[IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
]
VIDEO_PROCESSING_EXAMPLES = [
["videos/clip-07-camera-1.mp4", "player in white outfit, player in black outfit, ball, rim"],
["videos/clip-07-camera-2.mp4", "player in white outfit, player in black outfit, ball, rim"],
["videos/clip-07-camera-3.mp4", "player in white outfit, player in black outfit, ball, rim"]
]
VIDEO_SCALE_FACTOR = 0.5
VIDEO_TARGET_DIRECTORY = "tmp"
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000000"),
border_radius=5
)
MASK_ANNOTATOR = sv.MaskAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX
)
def annotate_image(image, detections):
output_image = image.copy()
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
output_image = BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
return output_image
def on_mode_dropdown_change(text):
return [
gr.Textbox(visible=text == IMAGE_OPEN_VOCABULARY_DETECTION_MODE),
gr.Textbox(visible=text == IMAGE_CAPTION_GROUNDING_MASKS_MODE),
]
#@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(
mode_dropdown, image_input, text_input
) -> Tuple[Optional[Image.Image], Optional[str]]:
if not image_input:
gr.Info("Please upload an image.")
return None, None
if mode_dropdown == IMAGE_OPEN_VOCABULARY_DETECTION_MODE:
if not text_input:
gr.Info("Please enter a text prompt.")
return None, None
texts = [prompt.strip() for prompt in text_input.split(",")]
detections_list = []
for text in texts:
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
detections_list.append(detections)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
return annotate_image(image_input, detections), None
if mode_dropdown == IMAGE_CAPTION_GROUNDING_MASKS_MODE:
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_DETAILED_CAPTION_TASK
)
caption = result[FLORENCE_DETAILED_CAPTION_TASK]
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK,
text=caption
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
return annotate_image(image_input, detections), caption
#@spaces.GPU(duration=300)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_video(
video_input, text_input, progress=gr.Progress(track_tqdm=True)
) -> Optional[str]:
if not video_input:
gr.Info("Please upload a video.")
return None
if not text_input:
gr.Info("Please enter a text prompt.")
return None
frame_generator = sv.get_video_frames_generator(video_input)
frame = next(frame_generator)
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
texts = [prompt.strip() for prompt in text_input.split(",")]
detections_list = []
for text in texts:
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=frame,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=frame.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
detections_list.append(detections)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
if len(detections.mask) == 0:
gr.Info(
"No objects of class {text_input} found in the first frame of the video. "
"Trim the video to make the object appear in the first frame or try a "
"different text prompt."
)
return None
name = generate_unique_name()
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
frames_sink = sv.ImageSink(
target_dir_path=frame_directory_path,
image_name_pattern="{:05d}.jpeg"
)
video_info = sv.VideoInfo.from_video_path(video_input)
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
frames_generator = sv.get_video_frames_generator(video_input)
with frames_sink:
for frame in tqdm(
frames_generator,
total=video_info.total_frames,
desc="splitting video into frames"
):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
frames_sink.save_image(frame)
inference_state = SAM_VIDEO_MODEL.init_state(
video_path=frame_directory_path,
device=DEVICE
)
for mask_index, mask in enumerate(detections.mask):
_, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=mask_index,
mask=mask
)
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
frames_generator = sv.get_video_frames_generator(video_input)
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
with sv.VideoSink(video_path, video_info=video_info) as sink:
for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
if len(masks.shape) == 4:
masks = np.squeeze(masks, axis=1)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks,
class_id=np.array(tracker_ids)
)
annotated_frame = frame.copy()
annotated_frame = MASK_ANNOTATOR.annotate(
scene=annotated_frame, detections=detections)
annotated_frame = BOX_ANNOTATOR.annotate(
scene=annotated_frame, detections=detections)
sink.write_frame(annotated_frame)
delete_directory(frame_directory_path)
return video_path
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Tab("Image"):
image_processing_mode_dropdown_component = gr.Dropdown(
choices=IMAGE_INFERENCE_MODES,
value=IMAGE_INFERENCE_MODES[0],
label="Mode",
info="Select a mode to use.",
interactive=True
)
with gr.Row():
with gr.Column():
image_processing_image_input_component = gr.Image(
type='pil', label='Upload image')
image_processing_text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter comma separated text prompts')
image_processing_submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_processing_image_output_component = gr.Image(
type='pil', label='Image output')
image_processing_text_output_component = gr.Textbox(
label='Caption output', visible=False)
with gr.Row():
gr.Examples(
fn=process_image,
examples=IMAGE_PROCESSING_EXAMPLES,
inputs=[
image_processing_mode_dropdown_component,
image_processing_image_input_component,
image_processing_text_input_component
],
outputs=[
image_processing_image_output_component,
image_processing_text_output_component
],
run_on_click=True
)
with gr.Tab("Video"):
video_processing_mode_dropdown_component = gr.Dropdown(
choices=VIDEO_INFERENCE_MODES,
value=VIDEO_INFERENCE_MODES[0],
label="Mode",
info="Select a mode to use.",
interactive=True
)
with gr.Row():
with gr.Column():
video_processing_video_input_component = gr.Video(
label='Upload video')
video_processing_text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter comma separated text prompts')
video_processing_submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
video_processing_video_output_component = gr.Video(
label='Video output')
with gr.Row():
gr.Examples(
fn=process_video,
examples=VIDEO_PROCESSING_EXAMPLES,
inputs=[
video_processing_video_input_component,
video_processing_text_input_component
],
outputs=video_processing_video_output_component,
run_on_click=True
)
image_processing_submit_button_component.click(
fn=process_image,
inputs=[
image_processing_mode_dropdown_component,
image_processing_image_input_component,
image_processing_text_input_component
],
outputs=[
image_processing_image_output_component,
image_processing_text_output_component
]
)
image_processing_text_input_component.submit(
fn=process_image,
inputs=[
image_processing_mode_dropdown_component,
image_processing_image_input_component,
image_processing_text_input_component
],
outputs=[
image_processing_image_output_component,
image_processing_text_output_component
]
)
image_processing_mode_dropdown_component.change(
on_mode_dropdown_change,
inputs=[image_processing_mode_dropdown_component],
outputs=[
image_processing_text_input_component,
image_processing_text_output_component
]
)
video_processing_submit_button_component.click(
fn=process_video,
inputs=[
video_processing_video_input_component,
video_processing_text_input_component
],
outputs=video_processing_video_output_component
)
video_processing_text_input_component.submit(
fn=process_video,
inputs=[
video_processing_video_input_component,
video_processing_text_input_component
],
outputs=video_processing_video_output_component
)
demo.launch(debug=False, show_error=True, share=True)