|
|
import os |
|
|
import sys |
|
|
import json |
|
|
|
|
|
|
|
|
import jax |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from scipy.special import expit as sigmoid |
|
|
|
|
|
import skimage |
|
|
from skimage import io as skimage_io |
|
|
from skimage import transform as skimage_transform |
|
|
import matplotlib as mpl |
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') |
|
|
tf.config.experimental.set_visible_devices([], 'GPU') |
|
|
|
|
|
from scenic.projects.owl_vit import configs |
|
|
from scenic.projects.owl_vit import models |
|
|
|
|
|
|
|
|
from owlv2_helper_functions import read_images, preprocess_images |
|
|
from owlv2_helper_functions import plot_bbox_on_image, image_based_plot_boxes_on_image |
|
|
from owlv2_helper_functions import top_object_index |
|
|
from owlv2_helper_functions import rescale_detection_box |
|
|
from owlv2_helper_functions import get_iou, boxes_filter |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Prepare OWLv2 pretrained model |
|
|
""" |
|
|
config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') |
|
|
module = models.TextZeroShotDetectionModule( |
|
|
body_configs=config.model.body, |
|
|
objectness_head_configs=config.model.objectness_head, |
|
|
normalize=config.model.normalize, |
|
|
box_bias=config.model.box_bias) |
|
|
variables = module.load_variables(config.init_from.checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Wrapped model components |
|
|
""" |
|
|
import functools |
|
|
|
|
|
image_embedder = jax.jit( |
|
|
functools.partial(module.apply, variables, train=False, method=module.image_embedder)) |
|
|
|
|
|
objectness_predictor = jax.jit( |
|
|
functools.partial(module.apply, variables, method=module.objectness_predictor)) |
|
|
|
|
|
box_predictor = jax.jit( |
|
|
functools.partial(module.apply, variables, method=module.box_predictor)) |
|
|
|
|
|
class_predictor = jax.jit( |
|
|
functools.partial(module.apply, variables, method=module.class_predictor)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Detect the main object on instances' images |
|
|
""" |
|
|
INSTANCE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances' |
|
|
INSTANCE_DETECTION = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_instances_detections' |
|
|
|
|
|
model_input_size = config.dataset_configs.input_size |
|
|
images, source_images_names = read_images(INSTANCE_DIR) |
|
|
source_images = preprocess_images(images, model_input_size) |
|
|
|
|
|
|
|
|
instances, query_embeddings, indexes = [], [], [] |
|
|
for source_image, source_image_name in zip(source_images, source_images_names): |
|
|
feature_map = image_embedder(source_image[None, ...]) |
|
|
b, h, w, d = feature_map.shape |
|
|
image_features = feature_map.reshape(b, h * w, d) |
|
|
|
|
|
objectnesses = objectness_predictor(image_features)['objectness_logits'] |
|
|
bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] |
|
|
all_class_embeddings = class_predictor(image_features=image_features)['class_embeddings'] |
|
|
|
|
|
|
|
|
objectnesses = np.array(objectnesses[0]) |
|
|
bboxes = np.array(bboxes[0]) |
|
|
all_class_embeddings = np.array(all_class_embeddings[0]) |
|
|
|
|
|
top_k = 1 |
|
|
objectnesses = sigmoid(objectnesses) |
|
|
objectness_threshold = np.partition(objectnesses, -top_k)[-top_k] |
|
|
|
|
|
index = top_object_index(objectnesses, objectness_threshold) |
|
|
query_embedding = all_class_embeddings[index] |
|
|
|
|
|
indexes.append(index) |
|
|
instances.append(source_image_name.split('_')[0]) |
|
|
query_embeddings.append(query_embedding) |
|
|
|
|
|
|
|
|
output_file = os.path.join(INSTANCE_DETECTION, source_image_name) |
|
|
plot_bbox_on_image(source_image, bboxes, objectnesses, objectness_threshold, output_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGE_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08' |
|
|
OUTPUT_DIR = '/home/netzone22/bohanliu_2025/DT_SPR/DT_SPR_video08_detect' |
|
|
|
|
|
digital_twins = {} |
|
|
|
|
|
images, target_images_names = read_images(IMAGE_DIR) |
|
|
target_images = preprocess_images(images, model_input_size) |
|
|
h, w, d = images[0].shape |
|
|
size = max(h, w) |
|
|
|
|
|
for target_image, target_image_name, image in zip(target_images, target_images_names, images): |
|
|
|
|
|
feature_map = image_embedder(target_image[None, ...]) |
|
|
b, h, w, d = feature_map.shape |
|
|
image_features = feature_map.reshape(b, h * w, d) |
|
|
all_bboxes = box_predictor(image_features=image_features, feature_map=feature_map)['pred_boxes'] |
|
|
|
|
|
pred_scores, pred_bboxes = [], [] |
|
|
for i, query_embedding in enumerate(query_embeddings): |
|
|
target_class_predictions = class_predictor( |
|
|
image_features=feature_map.reshape(b, h * w, d), |
|
|
query_embeddings=query_embedding[None, None, ...], |
|
|
) |
|
|
|
|
|
|
|
|
logits = np.array(target_class_predictions['pred_logits'][0]) |
|
|
bboxes = np.array(all_bboxes[0]) |
|
|
|
|
|
top_ind = np.argmax(logits[:, 0], axis=0) |
|
|
score = logits[top_ind, 0] |
|
|
bbox = bboxes[top_ind] |
|
|
|
|
|
pred_bboxes.append(bbox) |
|
|
pred_scores.append(score) |
|
|
|
|
|
instances_dup = instances[:] |
|
|
pred_scores = sigmoid(pred_scores) |
|
|
rescaled_bboxes = rescale_detection_box(pred_bboxes, image) |
|
|
|
|
|
rescaled_bboxes, pred_bboxes, pred_scores, instances_dup = boxes_filter(rescaled_bboxes, pred_bboxes, pred_scores, instances_dup) |
|
|
|
|
|
|
|
|
count = {} |
|
|
for instance_name in instances_dup: |
|
|
count[instance_name] = 0 |
|
|
|
|
|
digital_twins[target_image_name] = {} |
|
|
for instance_i, (instance_name, instance_box, instance_raw_box, instance_score) in enumerate(zip(instances_dup, rescaled_bboxes, pred_bboxes, pred_scores)): |
|
|
x1, y1, x2, y2 = map(float, instance_box) |
|
|
cx, cy, box_w, box_h = map(float, instance_raw_box) |
|
|
x = round(instance_score, 2) |
|
|
|
|
|
digital_twins[target_image_name][f'{instance_name}_{count[instance_name]}'] = { |
|
|
'detection_label': instance_name, |
|
|
'detection_box': [x1, y1, x2, y2], |
|
|
'detection_centroid': [cx*size, cy*size], |
|
|
'detection_score': round(float(instance_score), 2), |
|
|
} |
|
|
count[instance_name] += 1 |
|
|
|
|
|
|
|
|
image_based_plot_boxes_on_image(image, instances_dup, pred_scores, rescaled_bboxes, target_image_name, OUTPUT_DIR) |
|
|
|
|
|
|
|
|
JSON_OUT_PATH = "/home/netzone22/bohanliu_2025/DT_SPR_video08_detection.json" |
|
|
if not os.path.exists(JSON_OUT_PATH): |
|
|
os.makedirs(JSON_OUT_PATH) |
|
|
with open(JSON_OUT_PATH, "w", encoding="utf-8") as json_f: |
|
|
json.dump(digital_twins, json_f, indent=4) |