File size: 1,583 Bytes
ac0bf9a 02ce8ec ac0bf9a 02ce8ec ac0bf9a 02ce8ec ac0bf9a 48a3570 ac0bf9a 48a3570 ac0bf9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
logger.info("Loading facebook/detr-resnet-50 model and processor...")
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
logger.info("Model and processor loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model or processor: {str(e)}")
raise
def run_inference(image: Image.Image) -> dict:
"""
Run object detection inference on the input image.
Args:
image (PIL.Image.Image): Input image for object detection.
Returns:
dict: Processed results containing bounding boxes, scores, and labels.
"""
try:
# Preprocess the image using AutoImageProcessor
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
# Post-process the output (get bounding boxes)
target_sizes = torch.tensor([image.size[::-1]]) # Format: [height, width]
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
return results
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
raise |