sharvari0b26 commited on
Commit
68a0691
·
verified ·
1 Parent(s): 6014d5d

Upload 2 files

Browse files
Files changed (2) hide show
  1. checkpoint_best_total.pth +3 -0
  2. script.py +65 -0
checkpoint_best_total.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98f3d25a38ae7f9e598d02bead7903179ed99e8114ab99a31dfe9a81b69f532e
3
+ size 127634110
script.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from PIL import Image
5
+ import numpy as np
6
+ from rfdetr import RFDETRBase
7
+
8
+
9
+ def run_inference(model, image_path, conf_threshold, save_path):
10
+
11
+ test_images = sorted(os.listdir(image_path))
12
+
13
+ bboxes = []
14
+ category_ids = []
15
+ test_images_names = []
16
+
17
+ for image_name in test_images:
18
+ test_images_names.append(image_name)
19
+
20
+ image_file = os.path.join(image_path, image_name)
21
+ image = Image.open(image_file).convert("RGB")
22
+
23
+ preds = model.predict(image)
24
+
25
+ image_bboxes = []
26
+ image_categories = []
27
+
28
+ for box, score, label in zip(
29
+ preds["boxes"], preds["scores"], preds["labels"]
30
+ ):
31
+ if score >= conf_threshold:
32
+ xmin, ymin, xmax, ymax = box.tolist()
33
+ width = xmax - xmin
34
+ height = ymax - ymin
35
+
36
+ image_bboxes.append([xmin, ymin, width, height])
37
+ image_categories.append(int(label))
38
+
39
+ bboxes.append(image_bboxes)
40
+ category_ids.append(image_categories)
41
+
42
+ df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
43
+
44
+ for i in range(len(test_images_names)):
45
+ df_predictions.loc[i] = [
46
+ test_images_names[i],
47
+ str(bboxes[i]),
48
+ str(category_ids[i]),
49
+ ]
50
+
51
+ df_predictions.to_csv(save_path, index=False)
52
+
53
+
54
+ if __name__ == "__main__":
55
+
56
+ TEST_IMAGE_PATH = "/tmp/data/test_images"
57
+ SUBMISSION_SAVE_PATH = "submission.csv"
58
+ CONF_THRESHOLD = 0.30
59
+
60
+ model = RFDETRBase(
61
+ checkpoint_path="checkpoint_best_ema.pth",
62
+ device="cuda" if torch.cuda.is_available() else "cpu"
63
+ )
64
+
65
+ run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)