Update app.py
Browse files
app.py
CHANGED
|
@@ -113,16 +113,15 @@ def inference(learn, reorder, resample, org_img, input_img, org_size):
|
|
| 113 |
|
| 114 |
# Ensure we have exactly 5 dimensions [1,1,320,320,140]
|
| 115 |
if input_img.dim() == 6:
|
| 116 |
-
print("Removing extra batch dimension")
|
| 117 |
input_img = input_img.squeeze(0) # Remove extra batch dim
|
| 118 |
elif input_img.dim() == 4:
|
| 119 |
-
print("Adding missing batch dimension")
|
| 120 |
input_img = input_img.unsqueeze(0) # Add batch dim
|
| 121 |
-
elif input_img.dim() != 5:
|
| 122 |
-
raise ValueError(f"Unexpected input dimensions: {input_img.dim()}")
|
| 123 |
|
| 124 |
-
#
|
| 125 |
-
input_img =
|
|
|
|
|
|
|
|
|
|
| 126 |
print(f"Final input shape to model: {input_img.shape}")
|
| 127 |
pred = learn.predict(input_img)
|
| 128 |
print(f"Prediction output shape: {pred[0].shape if isinstance(pred, (list, tuple)) else pred.shape}")
|
|
@@ -310,6 +309,11 @@ try:
|
|
| 310 |
print(f"Input shape: {tensor_data.shape}")
|
| 311 |
print(f"Input dtype: {tensor_data.dtype}")
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
# Calculate mean and std
|
| 314 |
mean = tensor_data.mean()
|
| 315 |
std = tensor_data.std()
|
|
@@ -324,8 +328,8 @@ try:
|
|
| 324 |
normalized = (tensor_data - mean) / std
|
| 325 |
print(f"Normalized tensor shape: {normalized.shape}")
|
| 326 |
|
| 327 |
-
# Return as MedImage
|
| 328 |
-
return
|
| 329 |
|
| 330 |
except Exception as e:
|
| 331 |
print(f"Error in custom normalization: {str(e)}")
|
|
|
|
| 113 |
|
| 114 |
# Ensure we have exactly 5 dimensions [1,1,320,320,140]
|
| 115 |
if input_img.dim() == 6:
|
|
|
|
| 116 |
input_img = input_img.squeeze(0) # Remove extra batch dim
|
| 117 |
elif input_img.dim() == 4:
|
|
|
|
| 118 |
input_img = input_img.unsqueeze(0) # Add batch dim
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
# Verify final shape
|
| 121 |
+
if input_img.shape != torch.Size([1,1,320,320,140]):
|
| 122 |
+
input_img = input_img.view(1, 1, 320, 320, 140)
|
| 123 |
+
print(f"Reshaped to required dimensions: {input_img.shape}")
|
| 124 |
+
|
| 125 |
print(f"Final input shape to model: {input_img.shape}")
|
| 126 |
pred = learn.predict(input_img)
|
| 127 |
print(f"Prediction output shape: {pred[0].shape if isinstance(pred, (list, tuple)) else pred.shape}")
|
|
|
|
| 309 |
print(f"Input shape: {tensor_data.shape}")
|
| 310 |
print(f"Input dtype: {tensor_data.dtype}")
|
| 311 |
|
| 312 |
+
# Ensure we have the right dimensions
|
| 313 |
+
if tensor_data.dim() == 6:
|
| 314 |
+
tensor_data = tensor_data.squeeze(0)
|
| 315 |
+
print(f"Removed extra dim - new shape: {tensor_data.shape}")
|
| 316 |
+
|
| 317 |
# Calculate mean and std
|
| 318 |
mean = tensor_data.mean()
|
| 319 |
std = tensor_data.std()
|
|
|
|
| 328 |
normalized = (tensor_data - mean) / std
|
| 329 |
print(f"Normalized tensor shape: {normalized.shape}")
|
| 330 |
|
| 331 |
+
# Return as plain tensor to avoid MedImage wrapping issues
|
| 332 |
+
return normalized
|
| 333 |
|
| 334 |
except Exception as e:
|
| 335 |
print(f"Error in custom normalization: {str(e)}")
|