DSatishchandra commited on
Commit
48a3570
·
verified ·
1 Parent(s): 0195fcb

Update modules/ai_model.py

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +11 -7
modules/ai_model.py CHANGED
@@ -1,15 +1,19 @@
1
- from transformers import DetrImageProcessor, DetrForObjectDetection
2
  import torch
3
  from PIL import Image
4
 
5
- # Load pre-trained DETR model
6
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
7
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
 
8
 
9
- def run_inference(image_path):
10
- image = Image.open(image_path) # Replace with actual image path
11
- inputs = processor(images=image, return_tensors="pt")
12
 
13
  # Run inference
14
  outputs = model(**inputs)
15
- return outputs
 
 
 
 
 
1
+ from transformers import DetrForObjectDetection, AutoFeatureExtractor
2
  import torch
3
  from PIL import Image
4
 
5
+ # Load the pre-trained DETR model and feature extractor
 
6
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
7
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
8
 
9
+ def run_inference(image):
10
+ # Preprocess the image using AutoFeatureExtractor
11
+ inputs = feature_extractor(images=image, return_tensors="pt")
12
 
13
  # Run inference
14
  outputs = model(**inputs)
15
+
16
+ # Post-process the output (get bounding boxes)
17
+ results = feature_extractor.post_process_object_detection(outputs, target_sizes=[image.size[::-1]], threshold=0.9)[0]
18
+
19
+ return results