Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| from pathlib import Path | |
| from dds_cloudapi_sdk import Config, Client, TextPrompt | |
| from dds_cloudapi_sdk.tasks.dinox import DinoxTask | |
| from dds_cloudapi_sdk.tasks.detection import DetectionTask | |
| from dds_cloudapi_sdk.tasks.types import DetectionTarget | |
| # Constants | |
| API_TOKEN = "361d32fa5ce22649133660c65cfcaf22" | |
| TEXT_PROMPT = "wheel . eye . helmet . mouse . mouth . vehicle . steering wheel . ear . nose" | |
| VID_PROMPT = "wheel . mouse . pot . acquariam . box" | |
| TEMP_DIR = "./temp" | |
| OUTPUT_DIR = "./outputs" | |
| # Ensure directories exist | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| def initialize_dino_client(): | |
| """Initialize the DINO-X client""" | |
| config = Config(API_TOKEN) | |
| return Client(config) | |
| def get_class_mappings(text_prompt): | |
| """Create class name to ID mappings""" | |
| classes = [x.strip().lower() for x in text_prompt.split('.') if x] | |
| class_name_to_id = {name: id for id, name in enumerate(classes)} | |
| return classes, class_name_to_id | |
| def process_predictions(predictions, class_name_to_id): | |
| """Process DINO-X predictions into detection format""" | |
| boxes = [] | |
| masks = [] | |
| confidences = [] | |
| class_names = [] | |
| class_ids = [] | |
| for obj in predictions: | |
| boxes.append(obj.bbox) | |
| if hasattr(obj, 'mask') and obj.mask: | |
| masks.append(DetectionTask.rle2mask( | |
| DetectionTask.string2rle(obj.mask.counts), | |
| obj.mask.size | |
| )) | |
| cls_name = obj.category.lower().strip() | |
| class_names.append(cls_name) | |
| class_ids.append(class_name_to_id[cls_name]) | |
| confidences.append(obj.score) | |
| return { | |
| 'boxes': np.array(boxes), | |
| 'masks': np.array(masks) if masks else None, | |
| 'class_ids': np.array(class_ids), | |
| 'class_names': class_names, | |
| 'confidences': confidences | |
| } | |
| def process_image(image_path, prompt=TEXT_PROMPT): | |
| """Process a single image with DINO-X""" | |
| try: | |
| client = initialize_dino_client() | |
| _, class_name_to_id = get_class_mappings(prompt) | |
| # Upload and process image | |
| image_url = client.upload_file(image_path) | |
| task = DinoxTask( | |
| image_url=image_url, | |
| prompts=[TextPrompt(text=prompt)], | |
| bbox_threshold=0.25, | |
| targets=[DetectionTarget.BBox, DetectionTarget.Mask] | |
| ) | |
| client.run_task(task) | |
| # Process predictions | |
| results = process_predictions(task.result.objects, class_name_to_id) | |
| # Annotate image | |
| img = cv2.imread(image_path) | |
| detections = sv.Detections( | |
| xyxy=results['boxes'], | |
| mask=results['masks'].astype(bool) if results['masks'] is not None else None, | |
| class_id=results['class_ids'] | |
| ) | |
| labels = [ | |
| f"{name} {conf:.2f}" | |
| for name, conf in zip(results['class_names'], results['confidences']) | |
| ] | |
| # Apply annotations | |
| annotator = sv.BoxAnnotator() | |
| annotated_frame = annotator.annotate(scene=img.copy(), detections=detections) | |
| label_annotator = sv.LabelAnnotator() | |
| annotated_frame = label_annotator.annotate( | |
| scene=annotated_frame, | |
| detections=detections, | |
| labels=labels | |
| ) | |
| if results['masks'] is not None: | |
| mask_annotator = sv.MaskAnnotator() | |
| annotated_frame = mask_annotator.annotate( | |
| scene=annotated_frame, | |
| detections=detections | |
| ) | |
| output_path = os.path.join(OUTPUT_DIR, "result.jpg") | |
| cv2.imwrite(output_path, annotated_frame) | |
| return output_path | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| def process_video(video_path, prompt=VID_PROMPT): | |
| """Process a video with DINO-X""" | |
| try: | |
| client = initialize_dino_client() | |
| _, class_name_to_id = get_class_mappings(prompt) | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| output_path = os.path.join(OUTPUT_DIR, "result.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| temp_frame_path = os.path.join(TEMP_DIR, "temp_frame.jpg") | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| if frame_count % 3 != 0: # Process every 3rd frame for speed | |
| continue | |
| cv2.imwrite(temp_frame_path, frame) | |
| image_url = client.upload_file(temp_frame_path) | |
| task = DinoxTask( | |
| image_url=image_url, | |
| prompts=[TextPrompt(text=prompt)], | |
| bbox_threshold=0.25 | |
| ) | |
| client.run_task(task) | |
| results = process_predictions(task.result.objects, class_name_to_id) | |
| detections = sv.Detections( | |
| xyxy=results['boxes'], | |
| class_id=results['class_ids'] | |
| ) | |
| labels = [ | |
| f"{name} {conf:.2f}" | |
| for name, conf in zip(results['class_names'], results['confidences']) | |
| ] | |
| annotator = sv.BoxAnnotator() | |
| annotated_frame = annotator.annotate(scene=frame.copy(), detections=detections) | |
| label_annotator = sv.LabelAnnotator() | |
| annotated_frame = label_annotator.annotate( | |
| scene=annotated_frame, | |
| detections=detections, | |
| labels=labels | |
| ) | |
| out.write(annotated_frame) | |
| cap.release() | |
| out.release() | |
| if os.path.exists(temp_frame_path): | |
| os.remove(temp_frame_path) | |
| return output_path | |
| except Exception as e: | |
| return f"Error processing video: {str(e)}" | |
| def process_input(input_file, prompt=TEXT_PROMPT): | |
| """Process either image or video input""" | |
| if input_file is None: | |
| return "Please provide an input file" | |
| file_path = input_file.name | |
| extension = os.path.splitext(file_path)[1].lower() | |
| if extension in ['.jpg', '.jpeg', '.png']: | |
| return process_image(file_path, prompt) | |
| elif extension in ['.mp4', '.avi', '.mov']: | |
| return process_video(file_path, prompt) | |
| else: | |
| return "Unsupported file format. Please use jpg/jpeg/png for images or mp4/avi/mov for videos." | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_input, | |
| inputs=[ | |
| gr.File( | |
| label="Upload Image/Video", | |
| file_types=["image", "video"] | |
| ), | |
| gr.Textbox( | |
| label="Detection Prompt", | |
| value=TEXT_PROMPT, | |
| lines=2 | |
| ) | |
| ], | |
| outputs=gr.Image(label="Detection Result"), | |
| title="DINO-X Object Detection", | |
| description="Upload an image or video to detect objects using DINO-X. You can modify the detection prompt to specify what objects to look for.", | |
| examples=[ | |
| ["assets/demo.png", TEXT_PROMPT], | |
| ["assets/demo.mp4", VID_PROMPT] | |
| ], | |
| cache_examples=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |