sharvari0b26 commited on
Commit
bfef6f3
·
verified ·
1 Parent(s): a7f2ea8

Upload 3 files

Browse files
Files changed (3) hide show
  1. checkpoint_best_total.pth +3 -0
  2. rf-detr-base.pth +3 -0
  3. script.py +72 -0
checkpoint_best_total.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98f3d25a38ae7f9e598d02bead7903179ed99e8114ab99a31dfe9a81b69f532e
3
+ size 127634110
rf-detr-base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8f70210e425a4a4234d547737f57500bcc4ac24a333b99e33d9d5a371e0b80f
3
+ size 372578043
script.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from rfdetr import RFDETRBase
5
+
6
+
7
+ def run_inference(model, image_path, conf_threshold, save_path):
8
+
9
+ test_images = sorted([
10
+ f for f in os.listdir(image_path)
11
+ if f.lower().endswith((".jpg", ".jpeg", ".png"))
12
+ ])
13
+
14
+ bboxes = []
15
+ category_ids = []
16
+ test_images_names = []
17
+
18
+ for image_name in test_images:
19
+ test_images_names.append(image_name)
20
+
21
+ image_file = os.path.join(image_path, image_name)
22
+
23
+ bbox = []
24
+ category_id = []
25
+
26
+ preds = model.predict(image_file)
27
+
28
+ if preds is not None and preds.xyxy is not None and len(preds.xyxy) > 0:
29
+ for box, score, label in zip(
30
+ preds.xyxy,
31
+ preds.confidence,
32
+ preds.class_id
33
+ ):
34
+ score = float(score)
35
+ if score >= conf_threshold:
36
+ xmin, ymin, xmax, ymax = map(float, box)
37
+
38
+ width = xmax - xmin
39
+ height = ymax - ymin
40
+
41
+ bbox.append([xmin, ymin, width, height])
42
+ category_id.append(int(label))
43
+
44
+ bboxes.append(bbox)
45
+ category_ids.append(category_id)
46
+
47
+ df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
48
+
49
+ for i in range(len(test_images_names)):
50
+ new_row = pd.DataFrame({
51
+ "file_name": test_images_names[i],
52
+ "bbox": str(bboxes[i]),
53
+ "category_id": str(category_ids[i])
54
+ }, index=[0])
55
+
56
+ df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
57
+
58
+ df_predictions.to_csv(save_path, index=False)
59
+
60
+
61
+ if __name__ == "__main__":
62
+
63
+ TEST_IMAGE_PATH = r"rf-detr\dataset\test"
64
+ SUBMISSION_SAVE_PATH = "submission.csv"
65
+ CONF_THRESHOLD = 0.30
66
+
67
+ model = RFDETRBase(
68
+ checkpoint_path="checkpoint_best_total.pth",
69
+ device="cuda" if torch.cuda.is_available() else "cpu"
70
+ )
71
+
72
+ run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)