DSatishchandra commited on
Commit
ac0bf9a
·
verified ·
1 Parent(s): 1055087

Update modules/ai_model.py

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +41 -13
modules/ai_model.py CHANGED
@@ -1,19 +1,47 @@
1
- from transformers import DetrForObjectDetection, AutoFeatureExtractor
2
  import torch
3
  from PIL import Image
 
4
 
5
- # Load the pre-trained DETR model and feature extractor
6
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
7
- feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
8
 
9
- def run_inference(image):
10
- # Preprocess the image using AutoFeatureExtractor
11
- inputs = feature_extractor(images=image, return_tensors="pt")
12
-
13
- # Run inference
14
- outputs = model(**inputs)
 
 
 
 
 
 
 
15
 
16
- # Post-process the output (get bounding boxes)
17
- results = feature_extractor.post_process_object_detection(outputs, target_sizes=[image.size[::-1]], threshold=0.9)[0]
18
 
19
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
2
  import torch
3
  from PIL import Image
4
+ import logging
5
 
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
 
10
+ try:
11
+ # Load the pre-trained DETR model and image processor
12
+ logger.info("Loading facebook/detr-resnet-50 model and processor...")
13
+ processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
14
+ model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
15
+ logger.info("Model and processor loaded successfully.")
16
+ except Exception as e:
17
+ logger.error(f"Failed to load model or processor: {str(e)}")
18
+ raise
19
+
20
+ def run_inference(image: Image.Image) -> dict:
21
+ """
22
+ Run object detection inference on the input image.
23
 
24
+ Args:
25
+ image (PIL.Image.Image): Input image for object detection.
26
 
27
+ Returns:
28
+ dict: Processed results containing bounding boxes, scores, and labels.
29
+ """
30
+ try:
31
+ # Preprocess the image using AutoImageProcessor
32
+ inputs = processor(images=image, return_tensors="pt")
33
+
34
+ # Run inference
35
+ with torch.no_grad(): # Disable gradient calculation for inference
36
+ outputs = model(**inputs)
37
+
38
+ # Post-process the output (get bounding boxes)
39
+ target_sizes = torch.tensor([image.size[::-1]]) # Format: [height, width]
40
+ results = processor.post_process_object_detection(
41
+ outputs, target_sizes=target_sizes, threshold=0.9
42
+ )[0]
43
+
44
+ return results
45
+ except Exception as e:
46
+ logger.error(f"Error during inference: {str(e)}")
47
+ raise