Amodit commited on
Commit
f91c0f5
·
verified ·
1 Parent(s): b042b0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  from monai.networks.nets import SwinUNETR
6
  from monai.inferers import sliding_window_inference
7
  from monai.transforms import (
8
- Compose, LoadImage, EnsureChannelFirstd, Orientationd,
9
  Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped
10
  )
11
  import matplotlib.pyplot as plt
@@ -15,7 +15,7 @@ import os
15
  import tempfile
16
 
17
  print("Loading model...")
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  print(f"Device: {device}")
20
 
21
  # Model config
@@ -27,33 +27,38 @@ model = SwinUNETR(
27
  spatial_dims=3,
28
  ).to(device)
29
 
30
- try:
31
- model.load_state_dict(torch.load("best_metric_model.pth", map_location=device, weights_only=False))
32
- print("Model loaded!")
33
- except Exception as e:
34
- print(f"Error loading model: {e}")
 
 
 
 
35
 
36
  model.eval()
37
 
38
- # CORRECTED TRANSFORMS
39
- # Note: Using Resized instead of Resize
40
  test_transforms = Compose([
41
- LoadImage(image_only=True),
42
  EnsureChannelFirstd(keys=["image"]),
43
  Orientationd(keys=["image"], axcodes="RAS"),
44
  Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
45
- ScaleIntensityRanged(keys="image", a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
46
- CropForegroundd(keys=["image"], source_key="image", allow_smaller=False), # Fix deprecated warning
47
- Resized(keys=["image"], spatial_size=(128, 128, 128)), # Fix: Resized instead of Resize
48
  EnsureTyped(keys=["image"], dtype=torch.float32),
49
  ])
50
 
51
- def segment_liver(file_path, slice_num=64):
52
- """Simplified segmentation function"""
53
  try:
54
- if file_path is None:
55
  return None, None
56
 
 
 
57
  print(f"Processing: {file_path}")
58
 
59
  # Load and preprocess
@@ -66,7 +71,7 @@ def segment_liver(file_path, slice_num=64):
66
  # Inference
67
  with torch.no_grad():
68
  outputs = sliding_window_inference(
69
- volume, (96, 96, 96), 1, model, overlap=0.5
70
  )
71
  pred = torch.argmax(outputs, dim=1).float()
72
 
@@ -77,7 +82,10 @@ def segment_liver(file_path, slice_num=64):
77
  # Normalize volume
78
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
79
 
80
- slice_idx = min(int(slice_num), vol_np.shape[2] - 1)
 
 
 
81
 
82
  # Plot
83
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
@@ -119,23 +127,13 @@ def segment_liver(file_path, slice_num=64):
119
  print(f"Error: {e}")
120
  import traceback
121
  traceback.print_exc()
122
-
123
- fig, ax = plt.subplots(figsize=(10, 6))
124
- ax.text(0.5, 0.5, f'Error: {str(e)}',
125
- ha='center', va='center', fontsize=12, color='red')
126
- ax.axis('off')
127
- buf = BytesIO()
128
- plt.savefig(buf, format='png')
129
- buf.seek(0)
130
- img = Image.open(buf)
131
- plt.close()
132
- return img, None
133
 
134
- # Simple Gradio interface
135
  iface = gr.Interface(
136
  fn=segment_liver,
137
  inputs=[
138
- gr.File(label="Upload NIfTI"), # Removed file_types constraint
139
  gr.Slider(0, 127, value=64, label="Slice")
140
  ],
141
  outputs=[
@@ -143,9 +141,8 @@ iface = gr.Interface(
143
  gr.File(label="Download Mask")
144
  ],
145
  title="Liver Segmentation (95.5% Dice)",
146
- description="Upload file from TestVolumes folder (like liver_104.nii.gz)"
147
  )
148
 
149
-
150
  if __name__ == "__main__":
151
- iface.launch()
 
5
  from monai.networks.nets import SwinUNETR
6
  from monai.inferers import sliding_window_inference
7
  from monai.transforms import (
8
+ Compose, LoadImaged, EnsureChannelFirstd, Orientationd,
9
  Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped
10
  )
11
  import matplotlib.pyplot as plt
 
15
  import tempfile
16
 
17
  print("Loading model...")
18
+ device = torch.device("cpu") # Keep CPU for stability
19
  print(f"Device: {device}")
20
 
21
  # Model config
 
27
  spatial_dims=3,
28
  ).to(device)
29
 
30
+ model_path = "best_metric_model.pth"
31
+ if os.path.exists(model_path):
32
+ try:
33
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
34
+ print("Model loaded!")
35
+ except Exception as e:
36
+ print(f"Error loading model: {e}")
37
+ else:
38
+ print("WARNING: Model file not found")
39
 
40
  model.eval()
41
 
42
+ # --- CORRECTED TRANSFORMS ---
43
+ # Using LoadImaged (dictionary version) to match the rest of the pipeline
44
  test_transforms = Compose([
45
+ LoadImaged(keys=["image"]), # FIXED: Dictionary transform
46
  EnsureChannelFirstd(keys=["image"]),
47
  Orientationd(keys=["image"], axcodes="RAS"),
48
  Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
49
+ ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
50
+ CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
51
+ Resized(keys=["image"], spatial_size=(128, 128, 128)),
52
  EnsureTyped(keys=["image"], dtype=torch.float32),
53
  ])
54
 
55
+ def segment_liver(file_obj, slice_num=64):
 
56
  try:
57
+ if file_obj is None:
58
  return None, None
59
 
60
+ # Handle Gradio file object
61
+ file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
62
  print(f"Processing: {file_path}")
63
 
64
  # Load and preprocess
 
71
  # Inference
72
  with torch.no_grad():
73
  outputs = sliding_window_inference(
74
+ volume, (96, 96, 96), 1, model, overlap=0.25
75
  )
76
  pred = torch.argmax(outputs, dim=1).float()
77
 
 
82
  # Normalize volume
83
  vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
84
 
85
+ # Slice selection
86
+ z_dim = vol_np.shape[2]
87
+ slice_idx = int(slice_num)
88
+ if slice_idx >= z_dim: slice_idx = z_dim // 2
89
 
90
  # Plot
91
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
127
  print(f"Error: {e}")
128
  import traceback
129
  traceback.print_exc()
130
+ return None, None
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # Interface
133
  iface = gr.Interface(
134
  fn=segment_liver,
135
  inputs=[
136
+ gr.File(label="Upload NIfTI"),
137
  gr.Slider(0, 127, value=64, label="Slice")
138
  ],
139
  outputs=[
 
141
  gr.File(label="Download Mask")
142
  ],
143
  title="Liver Segmentation (95.5% Dice)",
144
+ description="Upload file from TestVolumes folder"
145
  )
146
 
 
147
  if __name__ == "__main__":
148
+ iface.launch(server_name="0.0.0.0", server_port=7860)