sharvari0b26 commited on
Commit
dbe218e
·
1 Parent(s): f6526ef

Change script.py

Browse files
Files changed (1) hide show
  1. script.py +13 -4
script.py CHANGED
@@ -22,12 +22,21 @@ def run_inference(TEST_IMAGE_PATH, pipeline_model, SUBMISSION_CSV_SAVE_PATH):
22
 
23
  multiclass_predictions = pipeline_model.predict(features_multiclass)
24
 
25
- df_predictions = pd.DataFrame({
26
- "file_name": test_images,
27
- "category_id": multiclass_predictions
28
- })
 
 
 
 
 
 
29
 
30
  df_predictions.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
 
 
 
31
 
32
 
33
 
 
22
 
23
  multiclass_predictions = pipeline_model.predict(features_multiclass)
24
 
25
+ df_predictions = pd.DataFrame(columns=["file_name", "category_id"])
26
+
27
+ for i in range(len(test_images)):
28
+ file_name = test_images[i]
29
+ new_row = pd.DataFrame({
30
+ "file_name": file_name,
31
+ "category_id": multiclass_predictions[i]
32
+ }, index=[0])
33
+
34
+ df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
35
 
36
  df_predictions.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
37
+
38
+ print(f"Saved predictions to: {SUBMISSION_CSV_SAVE_PATH}")
39
+
40
 
41
 
42