import os import sys import json # pip install ott-jax==0.2.0 import jax import numpy as np import tensorflow as tf from scipy.special import expit as sigmoid import skimage from skimage import io as skimage_io from skimage import transform as skimage_transform import matplotlib as mpl from matplotlib import pyplot as plt sys.path.append('/home/netzone22/bohanliu_2025/VisionModels/Scenic_OWLv2/big_vision') tf.config.experimental.set_visible_devices([], 'GPU') from scenic.projects.owl_vit import configs from scenic.projects.owl_vit import models from owlv2_helper_functions import plot_boxes_on_image, rescale_detection_box from owlv2_helper_functions import format_string # IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28' # OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection' # JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json' # JSON_OUT = '/home/netzone22/bohanliu_2025/HALO/case28/case_28_detection.json' # IMAGE_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image' # OUTPUT_DIR = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection' # JOSN_IN = '/home/netzone22/bohanliu_2025/structured_prompt.json' # JSON_OUT = '/home/netzone22/bohanliu_2025/HALO_test/single_image/detection.json' INSTANCE = 'machine' IMAGE_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic/{INSTANCE}' OUTPUT_DIR = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/detection' JOSN_IN = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/structured_prompt.json' JSON_OUT = f'/home/netzone22/bohanliu_2025/HALO_test/semantic_output/{INSTANCE}/metadata.json' THRESHOLD = 0.12 TEXT = [] with open(JOSN_IN, 'r') as file: prompts = json.load(file) for target_obj in prompts['target_obj']: TEXT.append(format_string(target_obj)) if prompts['spacial_info'] == True: for referred_obj in prompts['referred_obj'].keys(): TEXT.append(format_string(referred_obj)) print(f"\nQueries: {TEXT}\n") ### Choose config # config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint') config = configs.owl_v2_clip_l14.get_config(init_mode='canonical_checkpoint') ### Load the model and variables module = models.TextZeroShotDetectionModule( body_configs=config.model.body, objectness_head_configs=config.model.objectness_head, normalize=config.model.normalize, box_bias=config.model.box_bias) variables = module.load_variables(config.init_from.checkpoint_path) ### Prepare text queries text_queries = TEXT # ['machine', 'human'] tokenized_queries = np.array([ module.tokenize(q, config.dataset_configs.max_query_length) for q in text_queries ]) # Pad tokenized queries to avoid recompilation if number of queries changes: tokenized_queries = np.pad( tokenized_queries, pad_width=((0, 100 - len(text_queries)), (0, 0)), constant_values=0) ### Prepare image jitted = jax.jit(module.apply, static_argnames=('train',)) digital_twins = {} # filenames = sorted(tf.io.gfile.listdir(IMAGE_DIR)) extensions = {".jpg", ".jpeg", ".png"} filenames = sorted([ file for file in tf.io.gfile.listdir(IMAGE_DIR) if any(file.lower().endswith(ext) for ext in extensions) ]) for i, filename in enumerate(filenames): file_path = os.path.join(IMAGE_DIR, filename) image_uint8 = skimage_io.imread(file_path) image = image_uint8.astype(np.float32) / 255.0 # Pad to square with gray pixels on bottom and right: h, w, _ = image.shape # print(f"original img: {h} x {w}") size = max(h, w) image_padded = np.pad(image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) # Resize to model input size: input_image = skimage.transform.resize( image_padded, (config.dataset_configs.input_size, config.dataset_configs.input_size), anti_aliasing=True ) ### Get predictions # This will take a minute on the first execution due to model compilation. # Subsequent executions will be faster. # jitted = jax.jit(module.apply, static_argnames=('train',)) # Note: The model expects a batch dimension. predictions = jitted( variables, input_image[None, ...], tokenized_queries[None, ...], train=False) # Remove batch dimension and convert to numpy: predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions) # print(predictions.keys()) ### Plot prediction score_threshold = THRESHOLD # 0.1 logits = predictions['pred_logits'][..., :len(text_queries)] # Remove padding. scores = sigmoid(np.max(logits, axis=-1)) labels = np.argmax(predictions['pred_logits'], axis=-1) raw_boxes = predictions['pred_boxes'] boxes = rescale_detection_box(raw_boxes, image) ### Write results into JSON file. digital_twins[filename] = {} count = {} for label in labels: count[text_queries[label]] = 0 for score, raw_box, box, label in zip(scores, raw_boxes, boxes, labels): if score < score_threshold: continue; # x1, y1, x2, y2 = box x1, y1, x2, y2 = map(float, box) cx, cy, box_w, box_h = map(float, raw_box) x = round(score, 2) digital_twins[filename][f'{text_queries[label]}_{count[text_queries[label]]}'] = { 'detection_label': text_queries[label], 'detection_box': [x1, y1, x2, y2], 'detection_centroid': [cx*size, cy*size], 'detection_score': round(float(score), 2), } count[text_queries[label]]+=1 if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) plot_boxes_on_image(image, text_queries, scores, boxes, labels, filename, score_threshold, OUTPUT_DIR) with open(JSON_OUT, "w", encoding="utf-8") as json_f: json.dump(digital_twins, json_f, indent=4)