|
|
import os |
|
|
import sys |
|
|
import json |
|
|
|
|
|
|
|
|
import jax |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from scipy.special import expit as sigmoid |
|
|
|
|
|
import skimage |
|
|
from skimage import io as skimage_io |
|
|
from skimage import transform as skimage_transform |
|
|
import matplotlib as mpl |
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') |
|
|
tf.config.experimental.set_visible_devices([], 'GPU') |
|
|
|
|
|
from scenic.projects.owl_vit import configs |
|
|
from scenic.projects.owl_vit import models |
|
|
|
|
|
from owlv2_helper_functions import plot_boxes_on_image, rescale_detection_box |
|
|
from owlv2_helper_functions import format_string |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INSTANCE = 'machine' |
|
|
IMAGE_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic/{INSTANCE}' |
|
|
OUTPUT_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/detection' |
|
|
JOSN_IN = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/structured_prompt.json' |
|
|
JSON_OUT = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/metadata.json' |
|
|
|
|
|
|
|
|
THRESHOLD = 0.12 |
|
|
|
|
|
TEXT = [] |
|
|
with open(JOSN_IN, 'r') as file: |
|
|
prompts = json.load(file) |
|
|
|
|
|
for target_obj in prompts['target_obj']: |
|
|
TEXT.append(format_string(target_obj)) |
|
|
if prompts['spacial_info'] == True: |
|
|
for referred_obj in prompts['referred_obj'].keys(): |
|
|
TEXT.append(format_string(referred_obj)) |
|
|
print(f"\nQueries: {TEXT}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') |
|
|
|
|
|
|
|
|
|
|
|
module = models.TextZeroShotDetectionModule( |
|
|
body_configs=config.model.body, |
|
|
objectness_head_configs=config.model.objectness_head, |
|
|
normalize=config.model.normalize, |
|
|
box_bias=config.model.box_bias) |
|
|
|
|
|
variables = module.load_variables(config.init_from.checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_queries = TEXT |
|
|
tokenized_queries = np.array([ |
|
|
module.tokenize(q, config.dataset_configs.max_query_length) |
|
|
for q in text_queries |
|
|
]) |
|
|
|
|
|
tokenized_queries = np.pad( |
|
|
tokenized_queries, |
|
|
pad_width=((0, 100 - len(text_queries)), (0, 0)), |
|
|
constant_values=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
jitted = jax.jit(module.apply, static_argnames=('train',)) |
|
|
digital_twins = {} |
|
|
|
|
|
|
|
|
extensions = {".jpg", ".jpeg", ".png"} |
|
|
filenames = sorted([ |
|
|
file |
|
|
for file in tf.io.gfile.listdir(IMAGE_DIR) |
|
|
if any(file.lower().endswith(ext) for ext in extensions) |
|
|
]) |
|
|
|
|
|
for i, filename in enumerate(filenames): |
|
|
file_path = os.path.join(IMAGE_DIR, filename) |
|
|
image_uint8 = skimage_io.imread(file_path) |
|
|
image = image_uint8.astype(np.float32) / 255.0 |
|
|
|
|
|
h, w, _ = image.shape |
|
|
|
|
|
size = max(h, w) |
|
|
image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) |
|
|
|
|
|
input_image = skimage.transform.resize( |
|
|
image_padded, |
|
|
(config.dataset_configs.input_size, config.dataset_configs.input_size), |
|
|
anti_aliasing=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions = jitted( |
|
|
variables, |
|
|
input_image[None, ...], |
|
|
tokenized_queries[None, ...], |
|
|
train=False) |
|
|
|
|
|
predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions) |
|
|
|
|
|
|
|
|
|
|
|
score_threshold = THRESHOLD |
|
|
|
|
|
logits = predictions['pred_logits'][..., :len(text_queries)] |
|
|
scores = sigmoid(np.max(logits, axis=-1)) |
|
|
labels = np.argmax(predictions['pred_logits'], axis=-1) |
|
|
raw_boxes = predictions['pred_boxes'] |
|
|
boxes = rescale_detection_box(raw_boxes, image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
digital_twins[filename] = {} |
|
|
|
|
|
count = {} |
|
|
for label in labels: |
|
|
count[text_queries[label]] = 0 |
|
|
|
|
|
for score, raw_box, box, label in zip(scores, raw_boxes, boxes, labels): |
|
|
if score < score_threshold: |
|
|
continue; |
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = map(float, box) |
|
|
cx, cy, box_w, box_h = map(float, raw_box) |
|
|
x = round(score, 2) |
|
|
|
|
|
digital_twins[filename][f'{text_queries[label]}_{count[text_queries[label]]}'] = { |
|
|
'detection_label': text_queries[label], |
|
|
'detection_box': [x1, y1, x2, y2], |
|
|
'detection_centroid': [cx*size, cy*size], |
|
|
'detection_score': round(float(score), 2), |
|
|
} |
|
|
count[text_queries[label]]+=1 |
|
|
|
|
|
if not os.path.exists(OUTPUT_DIR): |
|
|
os.makedirs(OUTPUT_DIR) |
|
|
plot_boxes_on_image(image, text_queries, scores, boxes, labels, filename, score_threshold, OUTPUT_DIR) |
|
|
|
|
|
|
|
|
with open(JSON_OUT, "w", encoding="utf-8") as json_f: |
|
|
json.dump(digital_twins, json_f, indent=4) |
|
|
|