NeerajCodz commited on
Commit
4a26583
Β·
verified Β·
1 Parent(s): 38bc350

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -87
app.py CHANGED
@@ -4,10 +4,10 @@ import logging
4
  import os
5
  import sys
6
  import threading
 
7
  from collections import Counter
8
  from io import BytesIO
9
  from typing import Dict, List, Optional, Tuple, Union
10
-
11
  import gradio as gr
12
  import pandas as pd
13
  import requests
@@ -22,17 +22,19 @@ from transformers import (
22
  DetrImageProcessor,
23
  YolosForObjectDetection,
24
  YolosImageProcessor,
 
 
25
  )
26
  import nest_asyncio
27
-
28
  # ------------------------------
29
  # Configuration
30
  # ------------------------------
31
-
32
  # Configure logging for debugging and monitoring
33
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
34
  logger = logging.getLogger(__name__)
35
-
 
 
36
  # Define constants for model and server configuration
37
  CONFIDENCE_THRESHOLD: float = 0.5 # Default threshold for object detection confidence
38
  VALID_MODELS: List[str] = [
@@ -41,7 +43,9 @@ VALID_MODELS: List[str] = [
41
  "facebook/detr-resnet-50-panoptic",
42
  "facebook/detr-resnet-101-panoptic",
43
  "hustvl/yolos-tiny",
 
44
  "hustvl/yolos-base",
 
45
  ]
46
  MODEL_DESCRIPTIONS: Dict[str, str] = {
47
  "facebook/detr-resnet-50": "DETR with ResNet-50 for object detection. Fast and accurate.",
@@ -49,22 +53,21 @@ MODEL_DESCRIPTIONS: Dict[str, str] = {
49
  "facebook/detr-resnet-50-panoptic": "DETR with ResNet-50 for panoptic segmentation.",
50
  "facebook/detr-resnet-101-panoptic": "DETR with ResNet-101 for panoptic segmentation.",
51
  "hustvl/yolos-tiny": "YOLOS Tiny. Lightweight and fast.",
52
- "hustvl/yolos-base": "YOLOS Base. Balances speed and accuracy."
 
 
53
  }
54
  DEFAULT_GRADIO_PORT: int = 7860 # Default port for Gradio UI
55
  DEFAULT_FASTAPI_PORT: int = 8000 # Default port for FastAPI server
56
  PORT_RANGE: range = range(7860, 7870) # Range of ports to try for Gradio
57
  MAX_PORT_ATTEMPTS: int = 10 # Maximum attempts to find an available port
58
-
59
  # Thread-safe storage for lazy-loaded models and processors
60
  models: Dict[str, any] = {}
61
  processors: Dict[str, any] = {}
62
  model_lock = threading.Lock()
63
-
64
  # ------------------------------
65
  # Image Processing
66
  # ------------------------------
67
-
68
  def process_image(
69
  image: Optional[Image.Image],
70
  url: Optional[str],
@@ -74,14 +77,12 @@ def process_image(
74
  ) -> Union[Dict, Tuple[Optional[Image.Image], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], str]]:
75
  """
76
  Process an image for object detection or panoptic segmentation, handling Gradio and FastAPI inputs.
77
-
78
  Args:
79
  image: PIL Image object from file upload (optional).
80
  url: URL of the image to process (optional).
81
  model_name: Name of the model to use (must be in VALID_MODELS).
82
  for_json: If True, return JSON dict for API/JSON tab; else, return tuple for Gradio Home tab.
83
  confidence_threshold: Minimum confidence score for detection (default: 0.5).
84
-
85
  Returns:
86
  For JSON: Dict with base64-encoded image, detected objects, and confidence scores.
87
  For Gradio: Tuple of (annotated image, objects DataFrame, unique objects DataFrame, properties DataFrame, error message).
@@ -95,16 +96,17 @@ def process_image(
95
  if model_name not in VALID_MODELS:
96
  error_msg = f"Invalid model: {model_name}. Choose from: {VALID_MODELS}"
97
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
98
-
 
 
 
99
  # Calculate margin threshold: (1 - confidence_threshold) / 2 + confidence_threshold
100
  margin_threshold = (1 - confidence_threshold) / 2 + confidence_threshold
101
-
102
  # Load image from URL if provided
103
  if url:
104
  response = requests.get(url, timeout=10)
105
  response.raise_for_status()
106
  image = Image.open(BytesIO(response.content)).convert("RGB")
107
-
108
  # Load model and processor thread-safely
109
  with model_lock:
110
  if model_name not in models:
@@ -112,32 +114,35 @@ def process_image(
112
  try:
113
  # Select appropriate model and processor based on model name
114
  if "yolos" in model_name:
115
- models[model_name] = YolosForObjectDetection.from_pretrained(model_name)
116
  processors[model_name] = YolosImageProcessor.from_pretrained(model_name)
117
  elif "panoptic" in model_name:
118
- models[model_name] = DetrForSegmentation.from_pretrained(model_name)
119
  processors[model_name] = DetrImageProcessor.from_pretrained(model_name)
 
 
 
120
  else:
121
- models[model_name] = DetrForObjectDetection.from_pretrained(model_name)
122
  processors[model_name] = DetrImageProcessor.from_pretrained(model_name)
123
  except Exception as e:
124
  error_msg = f"Failed to load model: {str(e)}"
125
  logger.error(error_msg)
126
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
127
  model, processor = models[model_name], processors[model_name]
128
-
129
  # Prepare image for model processing
130
  inputs = processor(images=image, return_tensors="pt")
 
 
 
131
  with torch.no_grad():
132
  outputs = model(**inputs)
133
-
134
  # Initialize drawing context for annotations
135
  draw = ImageDraw.Draw(image)
136
  object_names: List[str] = []
137
  confidence_scores: List[float] = []
138
  object_counter = Counter()
139
  target_sizes = torch.tensor([image.size[::-1]])
140
-
141
  # Process results based on model type (panoptic or object detection)
142
  if "panoptic" in model_name:
143
  # Handle panoptic segmentation
@@ -161,29 +166,26 @@ def process_image(
161
  if score > confidence_threshold:
162
  object_names.append(label_name)
163
  confidence_scores.append(float(score))
164
- object_counter[label_name] = float(score)
165
  else:
166
  # Handle object detection
167
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
168
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
169
- if score > confidence_threshold:
170
- x, y, x2, y2 = box.tolist()
171
- label_name = model.config.id2label.get(label.item(), "Unknown")
172
- text = f"{label_name}: {score:.2f}"
173
- text_bbox = draw.textbbox((0, 0), text)
174
- text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
175
- # Use yellow for confidence_threshold <= score < margin_threshold, green for >= margin_threshold
176
- color = "#FFFF00" if score < margin_threshold else "#32CD32"
177
- draw.rectangle([x, y, x2, y2], outline=color, width=2)
178
- draw.text((x2 - text_width - 2, y - text_height - 2), text, fill=color)
179
- object_names.append(label_name)
180
- confidence_scores.append(float(score))
181
- object_counter[label_name] = float(score)
182
-
183
  # Compile unique objects and their highest confidence scores
184
  unique_objects = list(object_counter.keys())
185
  unique_confidences = [object_counter[obj] for obj in unique_objects]
186
-
187
  # Calculate image properties (metadata)
188
  properties: Dict[str, str] = {
189
  "Format": image.format if hasattr(image, "format") and image.format else "Unknown",
@@ -207,7 +209,6 @@ def process_image(
207
  properties["StdDev (R,G,B)"] = ", ".join(f"{s:.2f}" for s in stat.stddev)
208
  except Exception as e:
209
  logger.error(f"Error calculating image stats: {str(e)}")
210
-
211
  # Prepare output based on request type
212
  if for_json:
213
  # Return JSON with base64-encoded image
@@ -231,9 +232,8 @@ def process_image(
231
  pd.DataFrame({"Unique Object": unique_objects, "Confidence Score": [f"{score:.2f}" for score in unique_confidences]})
232
  if unique_objects else pd.DataFrame(columns=["Unique Object", "Confidence Score"])
233
  )
234
- properties_df = pd.DataFrame([properties]) if properties else pd.DataFrame(columns=properties.keys())
235
  return image, objects_df, unique_objects_df, properties_df, ""
236
-
237
  except requests.RequestException as e:
238
  # Handle URL fetch errors
239
  error_msg = f"Error fetching image from URL: {str(e)}"
@@ -244,13 +244,10 @@ def process_image(
244
  error_msg = f"Error processing image: {str(e)}"
245
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
246
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
247
-
248
  # ------------------------------
249
  # FastAPI Setup
250
  # ------------------------------
251
-
252
  app = FastAPI(title="Object Detection API")
253
-
254
  @app.post("/detect")
255
  async def detect_objects_endpoint(
256
  file: Optional[UploadFile] = File(None),
@@ -260,16 +257,13 @@ async def detect_objects_endpoint(
260
  ) -> JSONResponse:
261
  """
262
  FastAPI endpoint to detect objects in an image from file upload or URL.
263
-
264
  Args:
265
  file: Uploaded image file (optional).
266
  image_url: URL of the image (optional).
267
  model_name: Model to use for detection (default: first VALID_MODELS entry).
268
  confidence_threshold: Confidence threshold for detection (default: 0.5).
269
-
270
  Returns:
271
  JSONResponse with base64-encoded image, detected objects, and confidence scores.
272
-
273
  Raises:
274
  HTTPException: For invalid inputs or processing errors.
275
  """
@@ -277,9 +271,6 @@ async def detect_objects_endpoint(
277
  # Validate input: ensure exactly one of file or URL
278
  if (file is None and not image_url) or (file is not None and image_url):
279
  raise HTTPException(status_code=400, detail="Provide either an image file or an image URL, not both.")
280
- # Validate confidence threshold
281
- if not 0 <= confidence_threshold <= 1:
282
- raise HTTPException(status_code=400, detail="Confidence threshold must be between 0 and 1.")
283
  # Load image from file if provided
284
  image = None
285
  if file:
@@ -297,18 +288,14 @@ async def detect_objects_endpoint(
297
  except Exception as e:
298
  logger.error(f"Error in FastAPI endpoint: {str(e)}\n{traceback.format_exc()}")
299
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
300
-
301
  # ------------------------------
302
  # Gradio UI Setup
303
  # ------------------------------
304
-
305
  def create_gradio_ui() -> gr.Blocks:
306
  """
307
  Create and configure the Gradio UI for object detection with Home, JSON, and Help tabs.
308
-
309
  Returns:
310
  Gradio Blocks object representing the UI.
311
-
312
  Raises:
313
  RuntimeError: If UI creation fails.
314
  """
@@ -319,11 +306,11 @@ def create_gradio_ui() -> gr.Blocks:
319
  gr.Markdown(
320
  f"""
321
  # πŸš€ Object Detection App
322
- Upload an image or provide a URL to detect objects using transformer models (DETR, YOLOS).
323
  Running on port: {os.getenv('GRADIO_SERVER_PORT', 'auto-selected')}
 
324
  """
325
  )
326
-
327
  # Create tabbed interface
328
  with gr.Tabs():
329
  # Home tab (formerly Image Upload)
@@ -335,6 +322,8 @@ def create_gradio_ui() -> gr.Blocks:
335
  # Model selection dropdown
336
  model_choice = gr.Dropdown(choices=VALID_MODELS, value=VALID_MODELS[0], label="πŸ”Ž Select Model")
337
  model_info = gr.Markdown(f"**Model Info**: {MODEL_DESCRIPTIONS[VALID_MODELS[0]]}")
 
 
338
  # Image upload input
339
  image_input = gr.Image(type="pil", label="πŸ“· Upload Image")
340
  # Image URL input
@@ -343,14 +332,12 @@ def create_gradio_ui() -> gr.Blocks:
343
  with gr.Row():
344
  submit_btn = gr.Button("✨ Detect", variant="primary")
345
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
346
-
347
  # Update model info when model changes
348
  model_choice.change(
349
  fn=lambda model_name: f"**Model Info**: {MODEL_DESCRIPTIONS.get(model_name, 'No description available.')}",
350
  inputs=model_choice,
351
  outputs=model_info,
352
  )
353
-
354
  # Right column for results
355
  with gr.Column(scale=2):
356
  gr.Markdown("### Results")
@@ -364,21 +351,18 @@ def create_gradio_ui() -> gr.Blocks:
364
  unique_objects_output = gr.DataFrame(label="πŸ” Unique Objects", interactive=False)
365
  # Image properties table
366
  properties_output = gr.DataFrame(label="πŸ“„ Image Properties", interactive=False)
367
-
368
  # Process image when Detect button is clicked
369
  submit_btn.click(
370
  fn=process_image,
371
- inputs=[image_input, image_url_input, model_choice],
372
  outputs=[output_image, objects_output, unique_objects_output, properties_output, error_output],
373
  )
374
-
375
  # Clear all inputs and outputs
376
  clear_btn.click(
377
- fn=lambda: [None, "", None, None, None, None],
378
  inputs=None,
379
- outputs=[image_input, image_url_input, output_image, objects_output, unique_objects_output, properties_output, error_output],
380
  )
381
-
382
  # JSON tab for API-like output
383
  with gr.Tab("πŸ”— JSON"):
384
  with gr.Row():
@@ -388,64 +372,64 @@ def create_gradio_ui() -> gr.Blocks:
388
  # Model selection dropdown
389
  url_model_choice = gr.Dropdown(choices=VALID_MODELS, value=VALID_MODELS[0], label="πŸ”Ž Select Model")
390
  url_model_info = gr.Markdown(f"**Model Info**: {MODEL_DESCRIPTIONS[VALID_MODELS[0]]}")
 
 
391
  # Image upload input
392
  image_input_json = gr.Image(type="pil", label="πŸ“· Upload Image")
393
  # Image URL input
394
  image_url_input_json = gr.Textbox(label="πŸ”— Image URL", placeholder="https://example.com/image.jpg")
395
  # Process button
396
  url_submit_btn = gr.Button("πŸ”„ Process", variant="primary")
397
-
398
  # Update model info when model changes
399
  url_model_choice.change(
400
  fn=lambda model_name: f"**Model Info**: {MODEL_DESCRIPTIONS.get(model_name, 'No description available.')}",
401
  inputs=url_model_choice,
402
  outputs=url_model_info,
403
  )
404
-
405
  # Right column for JSON output
406
  with gr.Column(scale=1):
407
  # JSON output display
408
  url_output = gr.JSON(label="API Response")
409
-
410
  # Process image and return JSON when Process button is clicked
411
  url_submit_btn.click(
412
- fn=lambda img, url, model: process_image(img, url, model, for_json=True),
413
- inputs=[image_input_json, image_url_input_json, url_model_choice],
414
  outputs=[url_output],
415
  )
416
-
 
 
 
 
 
417
  # Help tab with usage instructions
418
  with gr.Tab("ℹ️ Help"):
419
  gr.Markdown(
420
  """
421
  ## How to Use
422
- - **Home**: Select a model, upload an image or provide a URL, click "Detect" to see results.
423
- - **JSON**: Select a model, upload an image or enter a URL, click "Process" for JSON output.
424
- - **Models**: Choose DETR (detection or panoptic) or YOLOS (lightweight detection).
425
- - **Clear**: Reset inputs/outputs in Home tab.
426
  - **Errors**: Check error box (Home) or JSON response (JSON) for issues.
427
-
428
  ## Tips
429
  - Use high-quality images for better results.
430
  - Panoptic models provide segmentation masks for complex scenes.
431
  - YOLOS-Tiny is faster for resource-constrained devices.
 
432
  """
433
  )
434
-
435
  return demo
436
-
437
  except Exception as e:
438
  logger.error(f"Error creating Gradio UI: {str(e)}\n{traceback.format_exc()}")
439
  raise RuntimeError(f"Failed to create Gradio UI: {str(e)}")
440
-
441
  # ------------------------------
442
  # Launcher
443
  # ------------------------------
444
-
445
  def parse_args() -> argparse.Namespace:
446
  """
447
  Parse command-line arguments for configuring the application.
448
-
449
  Returns:
450
  Parsed arguments as a Namespace object.
451
  """
@@ -464,16 +448,13 @@ def parse_args() -> argparse.Namespace:
464
  if not 0 <= args.confidence_threshold <= 1:
465
  parser.error("Confidence threshold must be between 0 and 1.")
466
  return args
467
-
468
  def find_available_port(start_port: int, port_range: range, max_attempts: int) -> Optional[int]:
469
  """
470
  Find an available port within the specified range.
471
-
472
  Args:
473
  start_port: Initial port to try.
474
  port_range: Range of ports to attempt.
475
  max_attempts: Maximum number of ports to try.
476
-
477
  Returns:
478
  Available port number, or None if no port is found.
479
  """
@@ -497,11 +478,9 @@ def find_available_port(start_port: int, port_range: range, max_attempts: int) -
497
  raise
498
  logger.error(f"No available port in range {min(port_range)}-{max(port_range)}")
499
  return None
500
-
501
  def main() -> None:
502
  """
503
  Launch the Gradio UI and optional FastAPI server.
504
-
505
  Raises:
506
  SystemExit: On interruption or critical errors.
507
  """
@@ -516,7 +495,6 @@ def main() -> None:
516
  if gradio_port is None:
517
  logger.error("Failed to find an available port for Gradio UI")
518
  sys.exit(1)
519
-
520
  # Start FastAPI server in a thread if enabled
521
  if args.enable_fastapi:
522
  logger.info(f"Starting FastAPI on port {args.fastapi_port}")
@@ -525,18 +503,15 @@ def main() -> None:
525
  daemon=True
526
  )
527
  fastapi_thread.start()
528
-
529
  # Launch Gradio UI
530
  logger.info(f"Starting Gradio UI on port {gradio_port}")
531
  demo = create_gradio_ui()
532
  demo.launch(server_port=gradio_port, server_name="0.0.0.0")
533
-
534
  except KeyboardInterrupt:
535
  logger.info("Application terminated by user.")
536
  sys.exit(0)
537
  except Exception as e:
538
  logger.error(f"Error: {str(e)}\n{traceback.format_exc()}")
539
  sys.exit(1)
540
-
541
  if __name__ == "__main__":
542
  main()
 
4
  import os
5
  import sys
6
  import threading
7
+ import traceback
8
  from collections import Counter
9
  from io import BytesIO
10
  from typing import Dict, List, Optional, Tuple, Union
 
11
  import gradio as gr
12
  import pandas as pd
13
  import requests
 
22
  DetrImageProcessor,
23
  YolosForObjectDetection,
24
  YolosImageProcessor,
25
+ DeformableDetrForObjectDetection,
26
+ DeformableDetrImageProcessor,
27
  )
28
  import nest_asyncio
 
29
  # ------------------------------
30
  # Configuration
31
  # ------------------------------
 
32
  # Configure logging for debugging and monitoring
33
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
34
  logger = logging.getLogger(__name__)
35
+ # Define device
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ logger.info(f"Using device: {device}")
38
  # Define constants for model and server configuration
39
  CONFIDENCE_THRESHOLD: float = 0.5 # Default threshold for object detection confidence
40
  VALID_MODELS: List[str] = [
 
43
  "facebook/detr-resnet-50-panoptic",
44
  "facebook/detr-resnet-101-panoptic",
45
  "hustvl/yolos-tiny",
46
+ "hustvl/yolos-small",
47
  "hustvl/yolos-base",
48
+ "SenseTime/deformable-detr",
49
  ]
50
  MODEL_DESCRIPTIONS: Dict[str, str] = {
51
  "facebook/detr-resnet-50": "DETR with ResNet-50 for object detection. Fast and accurate.",
 
53
  "facebook/detr-resnet-50-panoptic": "DETR with ResNet-50 for panoptic segmentation.",
54
  "facebook/detr-resnet-101-panoptic": "DETR with ResNet-101 for panoptic segmentation.",
55
  "hustvl/yolos-tiny": "YOLOS Tiny. Lightweight and fast.",
56
+ "hustvl/yolos-small": "YOLOS Small. Medium speed and accuracy.",
57
+ "hustvl/yolos-base": "YOLOS Base. Balances speed and accuracy.",
58
+ "SenseTime/deformable-detr": "Deformable DETR. Improved efficiency with deformable attention.",
59
  }
60
  DEFAULT_GRADIO_PORT: int = 7860 # Default port for Gradio UI
61
  DEFAULT_FASTAPI_PORT: int = 8000 # Default port for FastAPI server
62
  PORT_RANGE: range = range(7860, 7870) # Range of ports to try for Gradio
63
  MAX_PORT_ATTEMPTS: int = 10 # Maximum attempts to find an available port
 
64
  # Thread-safe storage for lazy-loaded models and processors
65
  models: Dict[str, any] = {}
66
  processors: Dict[str, any] = {}
67
  model_lock = threading.Lock()
 
68
  # ------------------------------
69
  # Image Processing
70
  # ------------------------------
 
71
  def process_image(
72
  image: Optional[Image.Image],
73
  url: Optional[str],
 
77
  ) -> Union[Dict, Tuple[Optional[Image.Image], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], str]]:
78
  """
79
  Process an image for object detection or panoptic segmentation, handling Gradio and FastAPI inputs.
 
80
  Args:
81
  image: PIL Image object from file upload (optional).
82
  url: URL of the image to process (optional).
83
  model_name: Name of the model to use (must be in VALID_MODELS).
84
  for_json: If True, return JSON dict for API/JSON tab; else, return tuple for Gradio Home tab.
85
  confidence_threshold: Minimum confidence score for detection (default: 0.5).
 
86
  Returns:
87
  For JSON: Dict with base64-encoded image, detected objects, and confidence scores.
88
  For Gradio: Tuple of (annotated image, objects DataFrame, unique objects DataFrame, properties DataFrame, error message).
 
96
  if model_name not in VALID_MODELS:
97
  error_msg = f"Invalid model: {model_name}. Choose from: {VALID_MODELS}"
98
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
99
+ # Validate confidence threshold
100
+ if not 0 <= confidence_threshold <= 1:
101
+ error_msg = "Confidence threshold must be between 0 and 1."
102
+ return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
103
  # Calculate margin threshold: (1 - confidence_threshold) / 2 + confidence_threshold
104
  margin_threshold = (1 - confidence_threshold) / 2 + confidence_threshold
 
105
  # Load image from URL if provided
106
  if url:
107
  response = requests.get(url, timeout=10)
108
  response.raise_for_status()
109
  image = Image.open(BytesIO(response.content)).convert("RGB")
 
110
  # Load model and processor thread-safely
111
  with model_lock:
112
  if model_name not in models:
 
114
  try:
115
  # Select appropriate model and processor based on model name
116
  if "yolos" in model_name:
117
+ models[model_name] = YolosForObjectDetection.from_pretrained(model_name).eval().to(device)
118
  processors[model_name] = YolosImageProcessor.from_pretrained(model_name)
119
  elif "panoptic" in model_name:
120
+ models[model_name] = DetrForSegmentation.from_pretrained(model_name).eval().to(device)
121
  processors[model_name] = DetrImageProcessor.from_pretrained(model_name)
122
+ elif "deformable" in model_name:
123
+ models[model_name] = DeformableDetrForObjectDetection.from_pretrained(model_name).eval().to(device)
124
+ processors[model_name] = DeformableDetrImageProcessor.from_pretrained(model_name)
125
  else:
126
+ models[model_name] = DetrForObjectDetection.from_pretrained(model_name).eval().to(device)
127
  processors[model_name] = DetrImageProcessor.from_pretrained(model_name)
128
  except Exception as e:
129
  error_msg = f"Failed to load model: {str(e)}"
130
  logger.error(error_msg)
131
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
132
  model, processor = models[model_name], processors[model_name]
 
133
  # Prepare image for model processing
134
  inputs = processor(images=image, return_tensors="pt")
135
+ # Move inputs to device if using GPU
136
+ if device == "cuda":
137
+ inputs = {k: v.to(device) for k, v in inputs.items()}
138
  with torch.no_grad():
139
  outputs = model(**inputs)
 
140
  # Initialize drawing context for annotations
141
  draw = ImageDraw.Draw(image)
142
  object_names: List[str] = []
143
  confidence_scores: List[float] = []
144
  object_counter = Counter()
145
  target_sizes = torch.tensor([image.size[::-1]])
 
146
  # Process results based on model type (panoptic or object detection)
147
  if "panoptic" in model_name:
148
  # Handle panoptic segmentation
 
166
  if score > confidence_threshold:
167
  object_names.append(label_name)
168
  confidence_scores.append(float(score))
169
+ object_counter[label_name] = max(object_counter.get(label_name, 0.0), float(score))
170
  else:
171
  # Handle object detection
172
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
173
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
174
+ x, y, x2, y2 = box.tolist()
175
+ label_name = model.config.id2label.get(label.item(), "Unknown")
176
+ text = f"{label_name}: {score:.2f}"
177
+ text_bbox = draw.textbbox((0, 0), text)
178
+ text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
179
+ # Use yellow for confidence_threshold <= score < margin_threshold, green for >= margin_threshold
180
+ color = "#FFFF00" if score < margin_threshold else "#32CD32"
181
+ draw.rectangle([x, y, x2, y2], outline=color, width=2)
182
+ draw.text((x2 - text_width - 2, y - text_height - 2), text, fill=color)
183
+ object_names.append(label_name)
184
+ confidence_scores.append(float(score))
185
+ object_counter[label_name] = max(object_counter.get(label_name, 0.0), float(score))
 
 
186
  # Compile unique objects and their highest confidence scores
187
  unique_objects = list(object_counter.keys())
188
  unique_confidences = [object_counter[obj] for obj in unique_objects]
 
189
  # Calculate image properties (metadata)
190
  properties: Dict[str, str] = {
191
  "Format": image.format if hasattr(image, "format") and image.format else "Unknown",
 
209
  properties["StdDev (R,G,B)"] = ", ".join(f"{s:.2f}" for s in stat.stddev)
210
  except Exception as e:
211
  logger.error(f"Error calculating image stats: {str(e)}")
 
212
  # Prepare output based on request type
213
  if for_json:
214
  # Return JSON with base64-encoded image
 
232
  pd.DataFrame({"Unique Object": unique_objects, "Confidence Score": [f"{score:.2f}" for score in unique_confidences]})
233
  if unique_objects else pd.DataFrame(columns=["Unique Object", "Confidence Score"])
234
  )
235
+ properties_df = pd.DataFrame([properties]) if properties else pd.DataFrame(columns=list(properties.keys()))
236
  return image, objects_df, unique_objects_df, properties_df, ""
 
237
  except requests.RequestException as e:
238
  # Handle URL fetch errors
239
  error_msg = f"Error fetching image from URL: {str(e)}"
 
244
  error_msg = f"Error processing image: {str(e)}"
245
  logger.error(f"{error_msg}\n{traceback.format_exc()}")
246
  return {"error": error_msg} if for_json else (None, None, None, None, error_msg)
 
247
  # ------------------------------
248
  # FastAPI Setup
249
  # ------------------------------
 
250
  app = FastAPI(title="Object Detection API")
 
251
  @app.post("/detect")
252
  async def detect_objects_endpoint(
253
  file: Optional[UploadFile] = File(None),
 
257
  ) -> JSONResponse:
258
  """
259
  FastAPI endpoint to detect objects in an image from file upload or URL.
 
260
  Args:
261
  file: Uploaded image file (optional).
262
  image_url: URL of the image (optional).
263
  model_name: Model to use for detection (default: first VALID_MODELS entry).
264
  confidence_threshold: Confidence threshold for detection (default: 0.5).
 
265
  Returns:
266
  JSONResponse with base64-encoded image, detected objects, and confidence scores.
 
267
  Raises:
268
  HTTPException: For invalid inputs or processing errors.
269
  """
 
271
  # Validate input: ensure exactly one of file or URL
272
  if (file is None and not image_url) or (file is not None and image_url):
273
  raise HTTPException(status_code=400, detail="Provide either an image file or an image URL, not both.")
 
 
 
274
  # Load image from file if provided
275
  image = None
276
  if file:
 
288
  except Exception as e:
289
  logger.error(f"Error in FastAPI endpoint: {str(e)}\n{traceback.format_exc()}")
290
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
 
291
  # ------------------------------
292
  # Gradio UI Setup
293
  # ------------------------------
 
294
  def create_gradio_ui() -> gr.Blocks:
295
  """
296
  Create and configure the Gradio UI for object detection with Home, JSON, and Help tabs.
 
297
  Returns:
298
  Gradio Blocks object representing the UI.
 
299
  Raises:
300
  RuntimeError: If UI creation fails.
301
  """
 
306
  gr.Markdown(
307
  f"""
308
  # πŸš€ Object Detection App
309
+ Upload an image or provide a URL to detect objects using transformer models (DETR, YOLOS, Deformable DETR).
310
  Running on port: {os.getenv('GRADIO_SERVER_PORT', 'auto-selected')}
311
+ Device: {device.upper()}
312
  """
313
  )
 
314
  # Create tabbed interface
315
  with gr.Tabs():
316
  # Home tab (formerly Image Upload)
 
322
  # Model selection dropdown
323
  model_choice = gr.Dropdown(choices=VALID_MODELS, value=VALID_MODELS[0], label="πŸ”Ž Select Model")
324
  model_info = gr.Markdown(f"**Model Info**: {MODEL_DESCRIPTIONS[VALID_MODELS[0]]}")
325
+ # Confidence threshold slider
326
+ confidence_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Confidence Threshold")
327
  # Image upload input
328
  image_input = gr.Image(type="pil", label="πŸ“· Upload Image")
329
  # Image URL input
 
332
  with gr.Row():
333
  submit_btn = gr.Button("✨ Detect", variant="primary")
334
  clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
 
335
  # Update model info when model changes
336
  model_choice.change(
337
  fn=lambda model_name: f"**Model Info**: {MODEL_DESCRIPTIONS.get(model_name, 'No description available.')}",
338
  inputs=model_choice,
339
  outputs=model_info,
340
  )
 
341
  # Right column for results
342
  with gr.Column(scale=2):
343
  gr.Markdown("### Results")
 
351
  unique_objects_output = gr.DataFrame(label="πŸ” Unique Objects", interactive=False)
352
  # Image properties table
353
  properties_output = gr.DataFrame(label="πŸ“„ Image Properties", interactive=False)
 
354
  # Process image when Detect button is clicked
355
  submit_btn.click(
356
  fn=process_image,
357
+ inputs=[image_input, image_url_input, model_choice, confidence_slider],
358
  outputs=[output_image, objects_output, unique_objects_output, properties_output, error_output],
359
  )
 
360
  # Clear all inputs and outputs
361
  clear_btn.click(
362
+ fn=lambda: [None, "", VALID_MODELS[0], 0.5, None, pd.DataFrame(columns=["Object", "Confidence Score"]), pd.DataFrame(columns=["Unique Object", "Confidence Score"]), pd.DataFrame(), None],
363
  inputs=None,
364
+ outputs=[image_input, image_url_input, model_choice, confidence_slider, output_image, objects_output, unique_objects_output, properties_output, error_output],
365
  )
 
366
  # JSON tab for API-like output
367
  with gr.Tab("πŸ”— JSON"):
368
  with gr.Row():
 
372
  # Model selection dropdown
373
  url_model_choice = gr.Dropdown(choices=VALID_MODELS, value=VALID_MODELS[0], label="πŸ”Ž Select Model")
374
  url_model_info = gr.Markdown(f"**Model Info**: {MODEL_DESCRIPTIONS[VALID_MODELS[0]]}")
375
+ # Confidence threshold slider
376
+ confidence_slider_json = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Confidence Threshold")
377
  # Image upload input
378
  image_input_json = gr.Image(type="pil", label="πŸ“· Upload Image")
379
  # Image URL input
380
  image_url_input_json = gr.Textbox(label="πŸ”— Image URL", placeholder="https://example.com/image.jpg")
381
  # Process button
382
  url_submit_btn = gr.Button("πŸ”„ Process", variant="primary")
383
+ url_clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
384
  # Update model info when model changes
385
  url_model_choice.change(
386
  fn=lambda model_name: f"**Model Info**: {MODEL_DESCRIPTIONS.get(model_name, 'No description available.')}",
387
  inputs=url_model_choice,
388
  outputs=url_model_info,
389
  )
 
390
  # Right column for JSON output
391
  with gr.Column(scale=1):
392
  # JSON output display
393
  url_output = gr.JSON(label="API Response")
 
394
  # Process image and return JSON when Process button is clicked
395
  url_submit_btn.click(
396
+ fn=lambda img, url, model, conf: process_image(img, url, model, for_json=True, confidence_threshold=conf),
397
+ inputs=[image_input_json, image_url_input_json, url_model_choice, confidence_slider_json],
398
  outputs=[url_output],
399
  )
400
+ # Clear inputs and output
401
+ url_clear_btn.click(
402
+ fn=lambda: [None, "", VALID_MODELS[0], 0.5, {}],
403
+ inputs=None,
404
+ outputs=[image_input_json, image_url_input_json, url_model_choice, confidence_slider_json, url_output],
405
+ )
406
  # Help tab with usage instructions
407
  with gr.Tab("ℹ️ Help"):
408
  gr.Markdown(
409
  """
410
  ## How to Use
411
+ - **Home**: Select a model and confidence threshold, upload an image or provide a URL, click "Detect" to see results.
412
+ - **JSON**: Select a model and confidence threshold, upload an image or enter a URL, click "Process" for JSON output.
413
+ - **Models**: Choose DETR (detection or panoptic), YOLOS (lightweight detection), or Deformable DETR (efficient attention).
414
+ - **Clear**: Reset inputs/outputs in Home or JSON tab.
415
  - **Errors**: Check error box (Home) or JSON response (JSON) for issues.
 
416
  ## Tips
417
  - Use high-quality images for better results.
418
  - Panoptic models provide segmentation masks for complex scenes.
419
  - YOLOS-Tiny is faster for resource-constrained devices.
420
+ - Adjust confidence threshold to filter detections (higher = fewer but more confident).
421
  """
422
  )
 
423
  return demo
 
424
  except Exception as e:
425
  logger.error(f"Error creating Gradio UI: {str(e)}\n{traceback.format_exc()}")
426
  raise RuntimeError(f"Failed to create Gradio UI: {str(e)}")
 
427
  # ------------------------------
428
  # Launcher
429
  # ------------------------------
 
430
  def parse_args() -> argparse.Namespace:
431
  """
432
  Parse command-line arguments for configuring the application.
 
433
  Returns:
434
  Parsed arguments as a Namespace object.
435
  """
 
448
  if not 0 <= args.confidence_threshold <= 1:
449
  parser.error("Confidence threshold must be between 0 and 1.")
450
  return args
 
451
  def find_available_port(start_port: int, port_range: range, max_attempts: int) -> Optional[int]:
452
  """
453
  Find an available port within the specified range.
 
454
  Args:
455
  start_port: Initial port to try.
456
  port_range: Range of ports to attempt.
457
  max_attempts: Maximum number of ports to try.
 
458
  Returns:
459
  Available port number, or None if no port is found.
460
  """
 
478
  raise
479
  logger.error(f"No available port in range {min(port_range)}-{max(port_range)}")
480
  return None
 
481
  def main() -> None:
482
  """
483
  Launch the Gradio UI and optional FastAPI server.
 
484
  Raises:
485
  SystemExit: On interruption or critical errors.
486
  """
 
495
  if gradio_port is None:
496
  logger.error("Failed to find an available port for Gradio UI")
497
  sys.exit(1)
 
498
  # Start FastAPI server in a thread if enabled
499
  if args.enable_fastapi:
500
  logger.info(f"Starting FastAPI on port {args.fastapi_port}")
 
503
  daemon=True
504
  )
505
  fastapi_thread.start()
 
506
  # Launch Gradio UI
507
  logger.info(f"Starting Gradio UI on port {gradio_port}")
508
  demo = create_gradio_ui()
509
  demo.launch(server_port=gradio_port, server_name="0.0.0.0")
 
510
  except KeyboardInterrupt:
511
  logger.info("Application terminated by user.")
512
  sys.exit(0)
513
  except Exception as e:
514
  logger.error(f"Error: {str(e)}\n{traceback.format_exc()}")
515
  sys.exit(1)
 
516
  if __name__ == "__main__":
517
  main()