Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| import gradio as gr | |
| # Set environment variables | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| # Load Model and Processor | |
| def load_model(): | |
| """Load OwlViT model and processor from local directory or Hugging Face Hub.""" | |
| global model, processor | |
| if model is not None and processor is not None: | |
| return model, processor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_name = "google/owlvit-base-patch32" | |
| # Check if local model directory exists | |
| local_model_path = "./owlvit-base-patch32" | |
| try: | |
| if os.path.exists(local_model_path) and os.path.isdir(local_model_path): | |
| print(f"Loading model from local directory: {local_model_path}") | |
| processor = OwlViTProcessor.from_pretrained(local_model_path) | |
| model = OwlViTForObjectDetection.from_pretrained(local_model_path) | |
| else: | |
| print(f"Loading model from Hugging Face Hub: {model_name}") | |
| processor = OwlViTProcessor.from_pretrained(model_name) | |
| model = OwlViTForObjectDetection.from_pretrained(model_name) | |
| model.eval() | |
| model.to(device) | |
| print("Model loaded successfully!") | |
| return model, processor | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {str(e)}") | |
| # Draw Bounding Boxes Function | |
| def draw_boxes(image, results, queries): | |
| """Draw bounding boxes on the image.""" | |
| draw = ImageDraw.Draw(image) | |
| boxes = results[0]["boxes"] | |
| scores = results[0]["scores"] | |
| labels = results[0]["labels"] | |
| for box, score, label in zip(boxes, scores, labels): | |
| x1, y1, x2, y2 = box.tolist() | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
| # Draw label | |
| text = f"{queries[label]}: {score:.2f}" | |
| draw.text((x1, y1 - 15), text, fill="red") | |
| return image | |
| # Prediction Function | |
| def detect_objects(image, text_query, threshold): | |
| global model, processor | |
| if image is None: | |
| return None | |
| try: | |
| # Load model if not already loaded | |
| if model is None or processor is None: | |
| model, processor = load_model() | |
| # Convert to PIL Image if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| # Parse text queries (split by comma) | |
| text_queries = [q.strip() for q in text_query.split(",") if q.strip()] | |
| if not text_queries: | |
| return image | |
| # Process inputs | |
| inputs = processor(text=text_queries, images=image, return_tensors="pt") | |
| # Move inputs to device | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process results | |
| target_sizes = torch.Tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection( | |
| outputs=outputs, | |
| threshold=threshold, | |
| target_sizes=target_sizes | |
| ) | |
| # Draw bounding boxes | |
| output_image = draw_boxes(image.copy(), results, text_queries) | |
| return output_image | |
| except Exception as e: | |
| print(f"Error during detection: {str(e)}") | |
| return image | |
| # Gradio Interface | |
| with gr.Blocks(title="Query based object detection") as demo: | |
| gr.Markdown( | |
| """ | |
| Upload an image and describe what you want to detect. You can specify multiple objects separated by commas. | |
| **Example queries:** | |
| - `a dog on couch sofa` | |
| - `person, car, bicycle` | |
| - `red apple, green apple` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| text_input = gr.Textbox( | |
| label="Text Query", | |
| placeholder="e.g., a dog on couch sofa", | |
| value="a dog on couch sofa" | |
| ) | |
| threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.05, | |
| info="Lower values detect more objects but may include false positives" | |
| ) | |
| detect_btn = gr.Button("Detect Objects", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Detected Objects", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Example queries | |
| gr.Markdown("### Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["a dog on couch sofa", 0.1], | |
| ["person, car", 0.1], | |
| ["cat, dog", 0.1], | |
| ], | |
| inputs=[text_input, threshold], | |
| label="Try these queries" | |
| ) | |
| # Set up the function call | |
| detect_btn.click( | |
| fn=detect_objects, | |
| inputs=[image_input, text_input, threshold], | |
| outputs=output_image | |
| ) | |
| # Also allow Enter key to trigger detection | |
| text_input.submit( | |
| fn=detect_objects, | |
| inputs=[image_input, text_input, threshold], | |
| outputs=output_image | |
| ) | |
| demo.launch() | |