Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import sys | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| from eval.grounded_sam.grounded_sam2_florence2_autolabel_pipeline import FlorenceSAM | |
| class ObjectDetector: | |
| def __init__(self, device): | |
| self.device = torch.device(device) | |
| self.detector = FlorenceSAM(device) | |
| def get_instances(self, gen_image, label, min_size=64): | |
| _, instance_result_dict = \ | |
| self.detector.od_grounding_and_segmentation( | |
| image=gen_image, text_input=label, | |
| ) | |
| instances = instance_result_dict["instance_images"] | |
| filtered_instances = [] | |
| for img in instances: | |
| width, height = img.shape[:2] | |
| if width * height < min_size * min_size or min(width, height) < min_size // 4: | |
| continue | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(img) | |
| filtered_instances.append(img) | |
| return filtered_instances | |
| def get_multiple_instances(self, gen_image, label, min_size=64): | |
| # self.detector.phrase_grounding_and_segmentation( | |
| _, instance_result_dict = \ | |
| self.detector.od_grounding_and_segmentation( | |
| image=gen_image, text_input=label, | |
| ) | |
| return instance_result_dict | |
| if __name__ == "__main__": | |
| # online demo: https://dun.163.com/trial/face/compare | |
| from glob import glob | |
| from tqdm import tqdm | |
| from src.train.data.data_utils import split_grid, pad_to_square | |
| from eval.idip.dino import DINOScore | |
| detector = ObjectDetector("cuda") | |
| dino_model = DINOScore("cuda") | |
| gen_image = Image.open("assets/tests/20250320-151038.jpeg").convert("RGB") | |
| label = "two people" | |
| save_dir = f"tmp" | |
| os.makedirs(save_dir, exist_ok=True) | |
| # for i, img in enumerate(split_grid(gen_image)): | |
| for i, img in enumerate([gen_image]): | |
| found_ips = detector.get_instances(img, label, min_size=img.size[0]//20)[:3] | |
| found_ips = [pad_to_square(x) for x in found_ips] | |
| for j, ip in enumerate(found_ips): | |
| # score = dino_model(real_image, ip) | |
| score = 1 | |
| pad_to_square(ip).save(f"{save_dir}/{label}_{i}_{j}_{score}.png") | |