sharvari0b26 commited on
Commit
d154083
·
verified ·
1 Parent(s): 4a8250e

Upload script.py

Browse files
Files changed (1) hide show
  1. script.py +78 -0
script.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from rfdetr import RFDETRBase
5
+
6
+ def run_inference(model, image_path, conf_threshold, save_path):
7
+
8
+ test_images = sorted([
9
+ f for f in os.listdir(image_path)
10
+ if f.lower().endswith((".jpg", ".jpeg", ".png"))
11
+ ])
12
+
13
+ print(f"Found {len(test_images)} images for inference.")
14
+
15
+ bboxes = []
16
+ category_ids = []
17
+ test_images_names = []
18
+
19
+ for idx, image_name in enumerate(test_images):
20
+ print(f"\nProcessing image {idx+1}/{len(test_images)}: {image_name}")
21
+ test_images_names.append(image_name)
22
+
23
+ image_file = os.path.join(image_path, image_name)
24
+
25
+ # Run prediction
26
+ preds = model.predict(image_file)
27
+
28
+ # Debug print
29
+ print(f"Raw prediction output: {preds}")
30
+
31
+ # Handle empty predictions
32
+ if preds is None or preds.xyxy is None or len(preds.xyxy) == 0:
33
+ print("No predictions returned for this image.")
34
+ image_bboxes = []
35
+ image_categories = []
36
+ else:
37
+ image_bboxes = []
38
+ image_categories = []
39
+ for box, score, label in zip(preds.xyxy, preds.confidence, preds.class_id):
40
+ if score >= conf_threshold:
41
+ xmin, ymin, xmax, ymax = box
42
+ width = xmax - xmin
43
+ height = ymax - ymin
44
+ image_bboxes.append([float(xmin), float(ymin), float(width), float(height)])
45
+ image_categories.append(int(label))
46
+
47
+ print(f"Detected {len(image_bboxes)} objects above threshold {conf_threshold}.")
48
+
49
+ bboxes.append(image_bboxes)
50
+ category_ids.append(image_categories)
51
+
52
+ # Prepare DataFrame
53
+ df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
54
+ for i in range(len(test_images_names)):
55
+ df_predictions.loc[i] = [
56
+ test_images_names[i],
57
+ str(bboxes[i]),
58
+ str(category_ids[i]),
59
+ ]
60
+
61
+ df_predictions.to_csv(save_path, index=False)
62
+ print(f"\nInference complete. Predictions saved to {save_path}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+
67
+ TEST_IMAGE_PATH = r"rf-detr\dataset\test"
68
+ SUBMISSION_SAVE_PATH = "submission.csv"
69
+ CONF_THRESHOLD = 0.30
70
+
71
+ print("Loading RF-DETR model...")
72
+ model = RFDETRBase(
73
+ checkpoint_path="checkpoint_best_ema.pth",
74
+ device="cuda" if torch.cuda.is_available() else "cpu"
75
+ )
76
+
77
+ print("Starting inference...")
78
+ run_inference(model, TEST_IMAGE_PATH, CONF_THRESHOLD, SUBMISSION_SAVE_PATH)