Update ink_detection_pipeline.py
Browse filesfixing bfloat16 handling on certain devices
ink_detection_pipeline.py
CHANGED
|
@@ -72,7 +72,7 @@ class InkDetectionPipeline(Pipeline):
|
|
| 72 |
sub_y_preds = torch.sigmoid(sub_y_preds)
|
| 73 |
|
| 74 |
# Move to CPU and numpy
|
| 75 |
-
sub_y_preds = sub_y_preds.detach().cpu().numpy()
|
| 76 |
# shape (subB, 1, tile_size, tile_size)
|
| 77 |
|
| 78 |
all_preds.append(sub_y_preds)
|
|
|
|
| 72 |
sub_y_preds = torch.sigmoid(sub_y_preds)
|
| 73 |
|
| 74 |
# Move to CPU and numpy
|
| 75 |
+
sub_y_preds = sub_y_preds.detach().cpu().float().numpy()
|
| 76 |
# shape (subB, 1, tile_size, tile_size)
|
| 77 |
|
| 78 |
all_preds.append(sub_y_preds)
|