datasciencesage commited on
Commit
b4f34db
·
verified ·
1 Parent(s): fcf400f

Update preprocess_test.py

Browse files
Files changed (1) hide show
  1. preprocess_test.py +7 -3
preprocess_test.py CHANGED
@@ -136,7 +136,8 @@ class Preprocess_Test:
136
  BATCH_SIZE = 256
137
  correct = 0
138
  total = 0
139
-
 
140
  with torch.no_grad():
141
  for i in range(0, len(X_test_t), BATCH_SIZE):
142
  batch_x = X_test_t[i:i + BATCH_SIZE].to(self.device)
@@ -146,12 +147,15 @@ class Preprocess_Test:
146
  predicted = torch.argmax(outputs, dim=1)
147
  total += batch_y.size(0)
148
  correct += (predicted == batch_y).sum().item()
149
-
 
150
  if i == 0:
151
  print(f"Test batch - Predicted: {predicted.cpu().numpy()[:10]}")
152
  print(f"Test batch - Actual: {batch_y.cpu().numpy()[:10]}")
153
 
154
 
155
-
 
 
156
 
157
 
 
136
  BATCH_SIZE = 256
137
  correct = 0
138
  total = 0
139
+ all_predictions = []
140
+
141
  with torch.no_grad():
142
  for i in range(0, len(X_test_t), BATCH_SIZE):
143
  batch_x = X_test_t[i:i + BATCH_SIZE].to(self.device)
 
147
  predicted = torch.argmax(outputs, dim=1)
148
  total += batch_y.size(0)
149
  correct += (predicted == batch_y).sum().item()
150
+ all_predictions.extend(predicted.cpu().numpy().tolist())
151
+
152
  if i == 0:
153
  print(f"Test batch - Predicted: {predicted.cpu().numpy()[:10]}")
154
  print(f"Test batch - Actual: {batch_y.cpu().numpy()[:10]}")
155
 
156
 
157
+ return {
158
+ "predictions": all_predictions}
159
+
160
 
161