Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
| 5 |
from monai.networks.nets import SwinUNETR
|
| 6 |
from monai.inferers import sliding_window_inference
|
| 7 |
from monai.transforms import (
|
| 8 |
-
Compose,
|
| 9 |
Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped
|
| 10 |
)
|
| 11 |
import matplotlib.pyplot as plt
|
|
@@ -15,7 +15,7 @@ import os
|
|
| 15 |
import tempfile
|
| 16 |
|
| 17 |
print("Loading model...")
|
| 18 |
-
device = torch.device("
|
| 19 |
print(f"Device: {device}")
|
| 20 |
|
| 21 |
# Model config
|
|
@@ -27,33 +27,38 @@ model = SwinUNETR(
|
|
| 27 |
spatial_dims=3,
|
| 28 |
).to(device)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
model.eval()
|
| 37 |
|
| 38 |
-
# CORRECTED TRANSFORMS
|
| 39 |
-
#
|
| 40 |
test_transforms = Compose([
|
| 41 |
-
|
| 42 |
EnsureChannelFirstd(keys=["image"]),
|
| 43 |
Orientationd(keys=["image"], axcodes="RAS"),
|
| 44 |
Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
|
| 45 |
-
ScaleIntensityRanged(keys="image", a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
|
| 46 |
-
CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
|
| 47 |
-
Resized(keys=["image"], spatial_size=(128, 128, 128)),
|
| 48 |
EnsureTyped(keys=["image"], dtype=torch.float32),
|
| 49 |
])
|
| 50 |
|
| 51 |
-
def segment_liver(
|
| 52 |
-
"""Simplified segmentation function"""
|
| 53 |
try:
|
| 54 |
-
if
|
| 55 |
return None, None
|
| 56 |
|
|
|
|
|
|
|
| 57 |
print(f"Processing: {file_path}")
|
| 58 |
|
| 59 |
# Load and preprocess
|
|
@@ -66,7 +71,7 @@ def segment_liver(file_path, slice_num=64):
|
|
| 66 |
# Inference
|
| 67 |
with torch.no_grad():
|
| 68 |
outputs = sliding_window_inference(
|
| 69 |
-
volume, (96, 96, 96), 1, model, overlap=0.
|
| 70 |
)
|
| 71 |
pred = torch.argmax(outputs, dim=1).float()
|
| 72 |
|
|
@@ -77,7 +82,10 @@ def segment_liver(file_path, slice_num=64):
|
|
| 77 |
# Normalize volume
|
| 78 |
vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# Plot
|
| 83 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
@@ -119,23 +127,13 @@ def segment_liver(file_path, slice_num=64):
|
|
| 119 |
print(f"Error: {e}")
|
| 120 |
import traceback
|
| 121 |
traceback.print_exc()
|
| 122 |
-
|
| 123 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 124 |
-
ax.text(0.5, 0.5, f'Error: {str(e)}',
|
| 125 |
-
ha='center', va='center', fontsize=12, color='red')
|
| 126 |
-
ax.axis('off')
|
| 127 |
-
buf = BytesIO()
|
| 128 |
-
plt.savefig(buf, format='png')
|
| 129 |
-
buf.seek(0)
|
| 130 |
-
img = Image.open(buf)
|
| 131 |
-
plt.close()
|
| 132 |
-
return img, None
|
| 133 |
|
| 134 |
-
#
|
| 135 |
iface = gr.Interface(
|
| 136 |
fn=segment_liver,
|
| 137 |
inputs=[
|
| 138 |
-
gr.File(label="Upload NIfTI"),
|
| 139 |
gr.Slider(0, 127, value=64, label="Slice")
|
| 140 |
],
|
| 141 |
outputs=[
|
|
@@ -143,9 +141,8 @@ iface = gr.Interface(
|
|
| 143 |
gr.File(label="Download Mask")
|
| 144 |
],
|
| 145 |
title="Liver Segmentation (95.5% Dice)",
|
| 146 |
-
description="Upload file from TestVolumes folder
|
| 147 |
)
|
| 148 |
|
| 149 |
-
|
| 150 |
if __name__ == "__main__":
|
| 151 |
-
iface.launch()
|
|
|
|
| 5 |
from monai.networks.nets import SwinUNETR
|
| 6 |
from monai.inferers import sliding_window_inference
|
| 7 |
from monai.transforms import (
|
| 8 |
+
Compose, LoadImaged, EnsureChannelFirstd, Orientationd,
|
| 9 |
Spacingd, ScaleIntensityRanged, CropForegroundd, Resized, EnsureTyped
|
| 10 |
)
|
| 11 |
import matplotlib.pyplot as plt
|
|
|
|
| 15 |
import tempfile
|
| 16 |
|
| 17 |
print("Loading model...")
|
| 18 |
+
device = torch.device("cpu") # Keep CPU for stability
|
| 19 |
print(f"Device: {device}")
|
| 20 |
|
| 21 |
# Model config
|
|
|
|
| 27 |
spatial_dims=3,
|
| 28 |
).to(device)
|
| 29 |
|
| 30 |
+
model_path = "best_metric_model.pth"
|
| 31 |
+
if os.path.exists(model_path):
|
| 32 |
+
try:
|
| 33 |
+
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
|
| 34 |
+
print("Model loaded!")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Error loading model: {e}")
|
| 37 |
+
else:
|
| 38 |
+
print("WARNING: Model file not found")
|
| 39 |
|
| 40 |
model.eval()
|
| 41 |
|
| 42 |
+
# --- CORRECTED TRANSFORMS ---
|
| 43 |
+
# Using LoadImaged (dictionary version) to match the rest of the pipeline
|
| 44 |
test_transforms = Compose([
|
| 45 |
+
LoadImaged(keys=["image"]), # FIXED: Dictionary transform
|
| 46 |
EnsureChannelFirstd(keys=["image"]),
|
| 47 |
Orientationd(keys=["image"], axcodes="RAS"),
|
| 48 |
Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
|
| 49 |
+
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
|
| 50 |
+
CropForegroundd(keys=["image"], source_key="image", allow_smaller=False),
|
| 51 |
+
Resized(keys=["image"], spatial_size=(128, 128, 128)),
|
| 52 |
EnsureTyped(keys=["image"], dtype=torch.float32),
|
| 53 |
])
|
| 54 |
|
| 55 |
+
def segment_liver(file_obj, slice_num=64):
|
|
|
|
| 56 |
try:
|
| 57 |
+
if file_obj is None:
|
| 58 |
return None, None
|
| 59 |
|
| 60 |
+
# Handle Gradio file object
|
| 61 |
+
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
|
| 62 |
print(f"Processing: {file_path}")
|
| 63 |
|
| 64 |
# Load and preprocess
|
|
|
|
| 71 |
# Inference
|
| 72 |
with torch.no_grad():
|
| 73 |
outputs = sliding_window_inference(
|
| 74 |
+
volume, (96, 96, 96), 1, model, overlap=0.25
|
| 75 |
)
|
| 76 |
pred = torch.argmax(outputs, dim=1).float()
|
| 77 |
|
|
|
|
| 82 |
# Normalize volume
|
| 83 |
vol_display = (vol_np - vol_np.min()) / (vol_np.max() - vol_np.min() + 1e-8)
|
| 84 |
|
| 85 |
+
# Slice selection
|
| 86 |
+
z_dim = vol_np.shape[2]
|
| 87 |
+
slice_idx = int(slice_num)
|
| 88 |
+
if slice_idx >= z_dim: slice_idx = z_dim // 2
|
| 89 |
|
| 90 |
# Plot
|
| 91 |
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
|
|
|
| 127 |
print(f"Error: {e}")
|
| 128 |
import traceback
|
| 129 |
traceback.print_exc()
|
| 130 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
# Interface
|
| 133 |
iface = gr.Interface(
|
| 134 |
fn=segment_liver,
|
| 135 |
inputs=[
|
| 136 |
+
gr.File(label="Upload NIfTI"),
|
| 137 |
gr.Slider(0, 127, value=64, label="Slice")
|
| 138 |
],
|
| 139 |
outputs=[
|
|
|
|
| 141 |
gr.File(label="Download Mask")
|
| 142 |
],
|
| 143 |
title="Liver Segmentation (95.5% Dice)",
|
| 144 |
+
description="Upload file from TestVolumes folder"
|
| 145 |
)
|
| 146 |
|
|
|
|
| 147 |
if __name__ == "__main__":
|
| 148 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|