owlv2 / owlv2_img_embeding.py
fcxfcx's picture
Upload 549 files
742a3d1 verified
import os
import sys
import json
# pip install ott-jax==0.2.0
import jax
import numpy as np
import tensorflow as tf
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import matplotlib as mpl
from matplotlib import pyplot as plt
sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision')
tf.config.experimental.set_visible_devices([], 'GPU')
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
# from owlv2_helper_functions import prepare_images
from owlv2_helper_functions import read_images, preprocess_images
from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image
from owlv2_helper_functions import top_object_index
from owlv2_helper_functions import rescale_detection_box
from owlv2_helper_functions import get_iou, boxes_filter
"""
Prepare OWLv2 pretrained model
"""
config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
body_configs=config.model.body,
objectness_head_configs=config.model.objectness_head,
normalize=config.model.normalize,
box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
"""
Wrapped model components
"""
import functools
image_embedder = jax.jit(
functools.partial(module.apply, variables, train=False, method=module.image_embedder))
objectness_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.objectness_predictor))
box_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.box_predictor))
class_predictor = jax.jit(
functools.partial(module.apply, variables, method=module.class_predictor))
"""
Detect the main object on instances' images
"""
INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances'
INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections'
model_input_size = config.dataset_configs.input_size
images, source_images_names = read_images(INSTANCE_DIR)
source_images = preprocess_images(images, model_input_size)
instances, query_embeddings, indexes = [], [], []
for source_image, source_image_name in zip(source_images, source_images_names):
feature_map = image_embedder(source_image[None, ...])
b, h, w, d = feature_map.shape
image_features = feature_map.reshape(b, h * w, d)
objectnesses = objectness_predictor(image_features)['objectness_logits']
bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes']
all_class_embeddings = class_predictor(image_features=image_features)['class_embeddings']
# Remove batch dimension
objectnesses = np.array(objectnesses[0])
bboxes = np.array(bboxes[0])
all_class_embeddings = np.array(all_class_embeddings[0])
top_k = 1
objectnesses = sigmoid(objectnesses)
objectness_threshold = np.partition(objectnesses, -top_k)[-top_k]
index = top_object_index(objectnesses, objectness_threshold)
query_embedding = all_class_embeddings[index]
indexes.append(index)
instances.append(source_image_name.split('_')[0])
query_embeddings.append(query_embedding)
# Plot instance detection
output_file = os.path.join(INSTANCE_DETECTION, source_image_name)
plot_bbox_on_image(source_image, bboxes, objectnesses, objectness_threshold, output_file)
# IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/data_sample'
# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/bliu75_output/test_output/batch_results'
IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08'
OUTPUT_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08_detect'
digital_twins = {}
images, target_images_names = read_images(IMAGE_DIR)
target_images = preprocess_images(images, model_input_size)
h, w, d = images[0].shape
size = max(h, w)
for target_image, target_image_name, image in zip(target_images, target_images_names, images):
feature_map = image_embedder(target_image[None, ...])
b, h, w, d = feature_map.shape
image_features = feature_map.reshape(b, h * w, d)
all_bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes']
pred_scores, pred_bboxes = [], []
for i, query_embedding in enumerate(query_embeddings):
target_class_predictions = class_predictor(
image_features=feature_map.reshape(b, h * w, d),
query_embeddings=query_embedding[None, None, ...], # [batch, queries, d]
)
# Remove batch dimension and convert to numpy:
logits = np.array(target_class_predictions['pred_logits'][0])
bboxes = np.array(all_bboxes[0])
top_ind = np.argmax(logits[:, 0], axis=0)
score = logits[top_ind, 0]
bbox = bboxes[top_ind]
pred_bboxes.append(bbox)
pred_scores.append(score)
instances_dup = instances[:]
pred_scores = sigmoid(pred_scores)
rescaled_bboxes = rescale_detection_box(pred_bboxes, image)
rescaled_bboxes, pred_bboxes, pred_scores, instances_dup = boxes_filter(rescaled_bboxes, pred_bboxes, pred_scores, instances_dup)
count = {}
for instance_name in instances_dup:
count[instance_name] = 0
digital_twins[target_image_name] = {}
for instance_i, (instance_name, instance_box, instance_raw_box, instance_score) in enumerate(zip(instances_dup, rescaled_bboxes, pred_bboxes, pred_scores)):
x1, y1, x2, y2 = map(float, instance_box)
cx, cy, box_w, box_h = map(float, instance_raw_box)
x = round(instance_score, 2)
digital_twins[target_image_name][f'{instance_name}_{count[instance_name]}'] = {
'detection_label': instance_name,
'detection_box': [x1, y1, x2, y2],
'detection_centroid': [cx*size, cy*size],
'detection_score': round(float(instance_score), 2),
}
count[instance_name] += 1
image_based_plot_boxes_on_image(image, instances_dup, pred_scores, rescaled_bboxes, target_image_name, OUTPUT_DIR)
JSON_OUT_PATH = "/home/netzone22/bohanliu_2025/DT_SPR_video08_detection.json"
if not os.path.exists(JSON_OUT_PATH):
os.makedirs(JSON_OUT_PATH)
with open(JSON_OUT_PATH, "w", encoding="utf-8") as json_f:
json.dump(digital_twins, json_f, indent=4)