drankush-ai commited on
Commit
8b33a25
·
verified ·
1 Parent(s): b5e31c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -113,16 +113,15 @@ def inference(learn, reorder, resample, org_img, input_img, org_size):
113
 
114
  # Ensure we have exactly 5 dimensions [1,1,320,320,140]
115
  if input_img.dim() == 6:
116
- print("Removing extra batch dimension")
117
  input_img = input_img.squeeze(0) # Remove extra batch dim
118
  elif input_img.dim() == 4:
119
- print("Adding missing batch dimension")
120
  input_img = input_img.unsqueeze(0) # Add batch dim
121
- elif input_img.dim() != 5:
122
- raise ValueError(f"Unexpected input dimensions: {input_img.dim()}")
123
 
124
- # Final reshape to exact required dimensions
125
- input_img = input_img.view(1, 1, 320, 320, 140)
 
 
 
126
  print(f"Final input shape to model: {input_img.shape}")
127
  pred = learn.predict(input_img)
128
  print(f"Prediction output shape: {pred[0].shape if isinstance(pred, (list, tuple)) else pred.shape}")
@@ -310,6 +309,11 @@ try:
310
  print(f"Input shape: {tensor_data.shape}")
311
  print(f"Input dtype: {tensor_data.dtype}")
312
 
 
 
 
 
 
313
  # Calculate mean and std
314
  mean = tensor_data.mean()
315
  std = tensor_data.std()
@@ -324,8 +328,8 @@ try:
324
  normalized = (tensor_data - mean) / std
325
  print(f"Normalized tensor shape: {normalized.shape}")
326
 
327
- # Return as MedImage
328
- return MedImage.create(normalized)
329
 
330
  except Exception as e:
331
  print(f"Error in custom normalization: {str(e)}")
 
113
 
114
  # Ensure we have exactly 5 dimensions [1,1,320,320,140]
115
  if input_img.dim() == 6:
 
116
  input_img = input_img.squeeze(0) # Remove extra batch dim
117
  elif input_img.dim() == 4:
 
118
  input_img = input_img.unsqueeze(0) # Add batch dim
 
 
119
 
120
+ # Verify final shape
121
+ if input_img.shape != torch.Size([1,1,320,320,140]):
122
+ input_img = input_img.view(1, 1, 320, 320, 140)
123
+ print(f"Reshaped to required dimensions: {input_img.shape}")
124
+
125
  print(f"Final input shape to model: {input_img.shape}")
126
  pred = learn.predict(input_img)
127
  print(f"Prediction output shape: {pred[0].shape if isinstance(pred, (list, tuple)) else pred.shape}")
 
309
  print(f"Input shape: {tensor_data.shape}")
310
  print(f"Input dtype: {tensor_data.dtype}")
311
 
312
+ # Ensure we have the right dimensions
313
+ if tensor_data.dim() == 6:
314
+ tensor_data = tensor_data.squeeze(0)
315
+ print(f"Removed extra dim - new shape: {tensor_data.shape}")
316
+
317
  # Calculate mean and std
318
  mean = tensor_data.mean()
319
  std = tensor_data.std()
 
328
  normalized = (tensor_data - mean) / std
329
  print(f"Normalized tensor shape: {normalized.shape}")
330
 
331
+ # Return as plain tensor to avoid MedImage wrapping issues
332
+ return normalized
333
 
334
  except Exception as e:
335
  print(f"Error in custom normalization: {str(e)}")