DSatishchandra's picture
Update modules/ai_model.py
0394517 verified
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
logger.info("Loading facebook/detr-resnet-50 model and processor...")
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
logger.info("Model and processor loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model or processor: {str(e)}")
raise
def run_inference(image: Image.Image) -> dict:
"""
Run object detection inference on the input image.
Args:
image (PIL.Image.Image): Input image for object detection.
Returns:
dict: Processed results containing bounding boxes, scores, and labels.
"""
try:
# Preprocess the image using AutoImageProcessor
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
# Post-process the output (get bounding boxes)
target_sizes = torch.tensor([image.size[::-1]]) # Format: [height, width]
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
return results
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
raise