sandbox338 commited on
Commit
e803eba
·
verified ·
1 Parent(s): 8f8c75e

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +48 -0
inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Simple inference script for wildlife ensemble detector
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from detectron2 import model_zoo
7
+ from detectron2.config import get_cfg
8
+ from detectron2.engine import DefaultPredictor
9
+
10
+ def load_ensemble_models(inception_path, resnet_path, class_names):
11
+ """Load both ensemble models for inference"""
12
+
13
+ # You'll need to copy the model_architecture.py content here
14
+ # and register the backbones before loading
15
+
16
+ # Setup configs (simplified)
17
+ cfg_inception = get_cfg()
18
+ cfg_inception.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
19
+ cfg_inception.MODEL.BACKBONE.NAME = "InceptionBackboneWrapper"
20
+ cfg_inception.MODEL.ROI_HEADS.NUM_CLASSES = len(class_names)
21
+ cfg_inception.MODEL.WEIGHTS = inception_path
22
+ cfg_inception.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
23
+
24
+ cfg_resnet = get_cfg()
25
+ cfg_resnet.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
26
+ cfg_resnet.MODEL.BACKBONE.NAME = "ResNetBackboneWrapper"
27
+ cfg_resnet.MODEL.ROI_HEADS.NUM_CLASSES = len(class_names)
28
+ cfg_resnet.MODEL.WEIGHTS = resnet_path
29
+ cfg_resnet.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
30
+
31
+ predictor_inception = DefaultPredictor(cfg_inception)
32
+ predictor_resnet = DefaultPredictor(cfg_resnet)
33
+
34
+ return predictor_inception, predictor_resnet
35
+
36
+ def predict_ensemble(image_path, predictor_inception, predictor_resnet, class_names):
37
+ """Run ensemble inference on an image"""
38
+
39
+ img = cv2.imread(image_path)
40
+
41
+ # Get predictions from both models
42
+ outputs_inc = predictor_inception(img)
43
+ outputs_res = predictor_resnet(img)
44
+
45
+ # Combine predictions (simplified)
46
+ # Add your ensemble logic here
47
+
48
+ return outputs_inc, outputs_res