owlv2 / owlv2_inference.py
fcxfcx's picture
Upload 549 files
742a3d1 verified
import os
import sys
import json
# pip install ott-jax==0.2.0
import jax
import numpy as np
import tensorflow as tf
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import matplotlib as mpl
from matplotlib import pyplot as plt
sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision')
tf.config.experimental.set_visible_devices([], 'GPU')
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
from owlv2_helper_functions import plot_boxes_on_image, rescale_detection_box
from owlv2_helper_functions import format_string
# IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28'
# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection'
# JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json'
# JSON_OUT = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection.json'
# IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image'
# OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection'
# JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json'
# JSON_OUT = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection.json'
INSTANCE = 'machine'
IMAGE_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic/{INSTANCE}'
OUTPUT_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/detection'
JOSN_IN = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/structured_prompt.json'
JSON_OUT = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/metadata.json'
THRESHOLD = 0.12
TEXT = []
with open(JOSN_IN, 'r') as file:
prompts = json.load(file)
for target_obj in prompts['target_obj']:
TEXT.append(format_string(target_obj))
if prompts['spacial_info'] == True:
for referred_obj in prompts['referred_obj'].keys():
TEXT.append(format_string(referred_obj))
print(f"\nQueries: {TEXT}\n")
### Choose config
# config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint')
config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint')
### Load the model and variables
module = models.TextZeroShotDetectionModule(
body_configs=config.model.body,
objectness_head_configs=config.model.objectness_head,
normalize=config.model.normalize,
box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
### Prepare text queries
text_queries = TEXT # ['machine', 'human']
tokenized_queries = np.array([
module.tokenize(q, config.dataset_configs.max_query_length)
for q in text_queries
])
# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
tokenized_queries,
pad_width=((0, 100 - len(text_queries)), (0, 0)),
constant_values=0)
### Prepare image
jitted = jax.jit(module.apply, static_argnames=('train',))
digital_twins = {}
# filenames = sorted(tf.io.gfile.listdir(IMAGE_DIR))
extensions = {".jpg", ".jpeg", ".png"}
filenames = sorted([
file
for file in tf.io.gfile.listdir(IMAGE_DIR)
if any(file.lower().endswith(ext) for ext in extensions)
])
for i, filename in enumerate(filenames):
file_path = os.path.join(IMAGE_DIR, filename)
image_uint8 = skimage_io.imread(file_path)
image = image_uint8.astype(np.float32) / 255.0
# Pad to square with gray pixels on bottom and right:
h, w, _ = image.shape
# print(f"original img: {h} x {w}")
size = max(h, w)
image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5)
# Resize to model input size:
input_image = skimage.transform.resize(
image_padded,
(config.dataset_configs.input_size, config.dataset_configs.input_size),
anti_aliasing=True
)
### Get predictions
# This will take a minute on the first execution due to model compilation.
# Subsequent executions will be faster.
# jitted = jax.jit(module.apply, static_argnames=('train',))
# Note: The model expects a batch dimension.
predictions = jitted(
variables,
input_image[None, ...],
tokenized_queries[None, ...],
train=False)
# Remove batch dimension and convert to numpy:
predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions)
# print(predictions.keys())
### Plot prediction
score_threshold = THRESHOLD # 0.1
logits = predictions['pred_logits'][..., :len(text_queries)] # Remove padding.
scores = sigmoid(np.max(logits, axis=-1))
labels = np.argmax(predictions['pred_logits'], axis=-1)
raw_boxes = predictions['pred_boxes']
boxes = rescale_detection_box(raw_boxes, image)
### Write results into JSON file.
digital_twins[filename] = {}
count = {}
for label in labels:
count[text_queries[label]] = 0
for score, raw_box, box, label in zip(scores, raw_boxes, boxes, labels):
if score < score_threshold:
continue;
# x1, y1, x2, y2 = box
x1, y1, x2, y2 = map(float, box)
cx, cy, box_w, box_h = map(float, raw_box)
x = round(score, 2)
digital_twins[filename][f'{text_queries[label]}_{count[text_queries[label]]}'] = {
'detection_label': text_queries[label],
'detection_box': [x1, y1, x2, y2],
'detection_centroid': [cx*size, cy*size],
'detection_score': round(float(score), 2),
}
count[text_queries[label]]+=1
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
plot_boxes_on_image(image, text_queries, scores, boxes, labels, filename, score_threshold, OUTPUT_DIR)
with open(JSON_OUT, "w", encoding="utf-8") as json_f:
json.dump(digital_twins, json_f, indent=4)