| | import typer |
| | from groundingdino.util.inference import load_model, load_image, predict |
| | from tqdm import tqdm |
| | import torchvision |
| | import torch |
| | import fiftyone as fo |
| |
|
| |
|
| | def main( |
| | image_directory: str = 'test_grounding_dino', |
| | text_prompt: str = 'bus, car', |
| | box_threshold: float = 0.15, |
| | text_threshold: float = 0.10, |
| | export_dataset: bool = False, |
| | view_dataset: bool = False, |
| | export_annotated_images: bool = True, |
| | weights_path : str = "groundingdino_swint_ogc.pth", |
| | config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", |
| | subsample: int = None, |
| | ): |
| |
|
| | model = load_model(config_path, weights_path) |
| | |
| | dataset = fo.Dataset.from_images_dir(image_directory) |
| |
|
| | samples = [] |
| |
|
| | if subsample is not None: |
| | |
| | if subsample < len(dataset): |
| | dataset = dataset.take(subsample).clone() |
| | |
| | for sample in tqdm(dataset): |
| |
|
| | image_source, image = load_image(sample.filepath) |
| |
|
| | boxes, logits, phrases = predict( |
| | model=model, |
| | image=image, |
| | caption=text_prompt, |
| | box_threshold=box_threshold, |
| | text_threshold=text_threshold, |
| | ) |
| |
|
| | detections = [] |
| |
|
| | for box, logit, phrase in zip(boxes, logits, phrases): |
| |
|
| | rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh') |
| |
|
| | detections.append( |
| | fo.Detection( |
| | label=phrase, |
| | bounding_box=rel_box, |
| | confidence=logit, |
| | )) |
| |
|
| | |
| | sample["detections"] = fo.Detections(detections=detections) |
| | sample.save() |
| |
|
| | |
| | if view_dataset: |
| | session = fo.launch_app(dataset) |
| | session.wait() |
| | |
| | |
| | if export_dataset: |
| | dataset.export( |
| | 'coco_dataset', |
| | dataset_type=fo.types.COCODetectionDataset, |
| | ) |
| | |
| | |
| | if export_annotated_images: |
| | dataset.draw_labels( |
| | 'images_with_bounding_boxes', |
| | label_fields=['detections'] |
| | ) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | typer.run(main) |
| |
|