mmrech commited on
Commit
69aa7a5
·
1 Parent(s): 208761f

Add comprehensive enhancements: Point/Box prompts, ROI statistics, NIFTI export, annotations

Browse files

NEW FEATURES:
- 🎯 Point/Box Prompts Tab: Interactive point and bounding box-based segmentation
- 📊 ROI Statistics & Export Tab:
- Detailed statistics (area, intensity, centroid, bounding box)
- NIFTI format export for medical imaging software
- Annotation save/load functionality (ZIP format)
- 🎭 Multi-Mask Output Tab: Generate multiple mask candidates with confidence scores
- ▶️ Auto-play button now functional in Interactive Slice Viewer

TECHNICAL IMPROVEMENTS:
- Added nibabel and scipy dependencies for NIFTI export and ROI calculations
- Added JSON-based annotation storage with mask compression
- Enhanced image processing with point/box region filtering
- Added progress tracking for auto-play functionality

Files changed (2) hide show
  1. app.py +1024 -1
  2. requirements.txt +2 -0
app.py CHANGED
@@ -7,14 +7,26 @@ import os
7
  import tempfile
8
  import zipfile
9
  import io
 
 
10
  from datetime import datetime
11
  import gradio as gr
12
  import torch
13
  import pydicom
14
  import numpy as np
15
- from PIL import Image, ImageEnhance
16
  from transformers import AutoImageProcessor, AutoModel
17
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Hugging Face Token (must be set as HF_TOKEN environment variable in Space settings)
20
  hf_token = os.getenv("HF_TOKEN")
@@ -718,6 +730,732 @@ def process_batch_enhanced(image_files, prompt_text, modality, window_type,
718
  status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download."
719
  return results, zip_path.name, status
720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  # Create Gradio Interface
722
  demo_file_path = demo_dicom_path if demo_file_available and os.path.exists(demo_dicom_path) else None
723
 
@@ -1532,6 +2270,205 @@ with gr.Blocks(css="""
1532
  interactive=False,
1533
  lines=4
1534
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1535
 
1536
  # Single image processing
1537
  load_demo_btn.click(
@@ -1639,6 +2576,92 @@ with gr.Blocks(css="""
1639
  ],
1640
  outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text]
1641
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1642
 
1643
  if __name__ == "__main__":
1644
  demo.launch()
 
7
  import tempfile
8
  import zipfile
9
  import io
10
+ import json
11
+ import time
12
  from datetime import datetime
13
  import gradio as gr
14
  import torch
15
  import pydicom
16
  import numpy as np
17
+ from PIL import Image, ImageEnhance, ImageDraw
18
  from transformers import AutoImageProcessor, AutoModel
19
  import matplotlib.pyplot as plt
20
+ from matplotlib.patches import Rectangle
21
+ from scipy import ndimage
22
+
23
+ # Try to import nibabel for NIFTI support (optional)
24
+ try:
25
+ import nibabel as nib
26
+ NIBABEL_AVAILABLE = True
27
+ except ImportError:
28
+ NIBABEL_AVAILABLE = False
29
+ print("⚠️ nibabel not available - NIFTI export disabled")
30
 
31
  # Hugging Face Token (must be set as HF_TOKEN environment variable in Space settings)
32
  hf_token = os.getenv("HF_TOKEN")
 
730
  status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download."
731
  return results, zip_path.name, status
732
 
733
+ # ============================================================================
734
+ # ENHANCED FEATURES - Auto-play, Point/Box Prompts, ROI Stats, NIFTI Export
735
+ # ============================================================================
736
+
737
+ # Global state for auto-play
738
+ auto_play_state = {"running": False, "current_idx": 0}
739
+
740
+ def calculate_roi_statistics(image_file, mask, modality):
741
+ """Calculate ROI statistics from the segmented region.
742
+
743
+ Returns:
744
+ dict: Statistics including area, mean intensity, std, min, max, centroid
745
+ """
746
+ if mask is None or not isinstance(mask, np.ndarray):
747
+ return {
748
+ "error": "No valid mask available",
749
+ "area_pixels": 0,
750
+ "area_percentage": 0,
751
+ "mean_intensity": 0,
752
+ "std_intensity": 0,
753
+ "min_intensity": 0,
754
+ "max_intensity": 0,
755
+ "centroid": (0, 0),
756
+ "bounding_box": (0, 0, 0, 0)
757
+ }
758
+
759
+ try:
760
+ # Load original image for intensity statistics
761
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
762
+ file_ext = os.path.splitext(file_path)[1].lower()
763
+
764
+ if file_ext == '.dcm':
765
+ ds = pydicom.dcmread(file_path)
766
+ img_array = ds.pixel_array.astype(np.float32)
767
+ slope = getattr(ds, 'RescaleSlope', 1)
768
+ intercept = getattr(ds, 'RescaleIntercept', 0)
769
+ img_array = img_array * slope + intercept
770
+ else:
771
+ img = Image.open(file_path)
772
+ if img.mode == 'RGB':
773
+ img = img.convert('L') # Convert to grayscale for intensity stats
774
+ img_array = np.array(img).astype(np.float32)
775
+
776
+ # Resize mask if needed
777
+ if mask.shape != img_array.shape:
778
+ from scipy.ndimage import zoom
779
+ zoom_factors = (img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1])
780
+ mask = zoom(mask.astype(float), zoom_factors, order=0) > 0.5
781
+
782
+ # Calculate statistics
783
+ mask_bool = mask.astype(bool)
784
+ total_pixels = mask.size
785
+ roi_pixels = np.sum(mask_bool)
786
+
787
+ if roi_pixels == 0:
788
+ return {
789
+ "error": "No pixels in ROI",
790
+ "area_pixels": 0,
791
+ "area_percentage": 0,
792
+ "mean_intensity": 0,
793
+ "std_intensity": 0,
794
+ "min_intensity": 0,
795
+ "max_intensity": 0,
796
+ "centroid": (0, 0),
797
+ "bounding_box": (0, 0, 0, 0)
798
+ }
799
+
800
+ roi_intensities = img_array[mask_bool]
801
+
802
+ # Calculate centroid
803
+ labeled_mask, num_features = ndimage.label(mask_bool)
804
+ centroid = ndimage.center_of_mass(mask_bool)
805
+
806
+ # Calculate bounding box
807
+ rows = np.any(mask_bool, axis=1)
808
+ cols = np.any(mask_bool, axis=0)
809
+ rmin, rmax = np.where(rows)[0][[0, -1]]
810
+ cmin, cmax = np.where(cols)[0][[0, -1]]
811
+
812
+ stats = {
813
+ "area_pixels": int(roi_pixels),
814
+ "area_percentage": float(roi_pixels / total_pixels * 100),
815
+ "mean_intensity": float(np.mean(roi_intensities)),
816
+ "std_intensity": float(np.std(roi_intensities)),
817
+ "min_intensity": float(np.min(roi_intensities)),
818
+ "max_intensity": float(np.max(roi_intensities)),
819
+ "centroid": (float(centroid[1]), float(centroid[0])), # (x, y)
820
+ "bounding_box": (int(cmin), int(rmin), int(cmax), int(rmax)), # (x1, y1, x2, y2)
821
+ "num_components": num_features
822
+ }
823
+
824
+ # Add HU statistics for CT
825
+ if modality == "CT":
826
+ stats["mean_hu"] = stats["mean_intensity"]
827
+ stats["std_hu"] = stats["std_intensity"]
828
+
829
+ return stats
830
+
831
+ except Exception as e:
832
+ print(f"Error calculating ROI statistics: {e}")
833
+ return {"error": str(e)}
834
+
835
+ def format_roi_statistics(stats):
836
+ """Format ROI statistics as a readable string."""
837
+ if "error" in stats and stats.get("area_pixels", 0) == 0:
838
+ return f"⚠️ {stats.get('error', 'No statistics available')}"
839
+
840
+ text = "📊 **ROI Statistics**\n\n"
841
+ text += f"**Area:** {stats['area_pixels']:,} pixels ({stats['area_percentage']:.2f}%)\n"
842
+ text += f"**Intensity:** {stats['mean_intensity']:.2f} ± {stats['std_intensity']:.2f}\n"
843
+ text += f"**Range:** [{stats['min_intensity']:.2f}, {stats['max_intensity']:.2f}]\n"
844
+ text += f"**Centroid:** ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f})\n"
845
+ text += f"**Bounding Box:** {stats['bounding_box']}\n"
846
+ text += f"**Components:** {stats.get('num_components', 1)}"
847
+
848
+ if "mean_hu" in stats:
849
+ text += f"\n\n**CT (Hounsfield Units):**\n"
850
+ text += f"Mean HU: {stats['mean_hu']:.1f} ± {stats['std_hu']:.1f}"
851
+
852
+ return text
853
+
854
+ def process_with_roi_stats(image_file, prompt_text, modality, window_type):
855
+ """Process image and return both segmentation and ROI statistics."""
856
+ if model is None or processor is None:
857
+ return None, "❌ Error: Model not loaded.", ""
858
+
859
+ if image_file is None:
860
+ return None, "⚠️ Please upload a medical image file.", ""
861
+
862
+ result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
863
+
864
+ if result is None:
865
+ return None, "❌ Processing failed.", ""
866
+
867
+ # Calculate ROI statistics
868
+ stats = calculate_roi_statistics(image_file, mask, modality)
869
+ stats_text = format_roi_statistics(stats)
870
+
871
+ return result, "✅ Segmentation complete!", stats_text
872
+
873
+ def process_with_point_prompt(image_file, point_x, point_y, modality, window_type, colormap='spring', transparency=0.5):
874
+ """Process image with a point prompt for segmentation.
875
+
876
+ Note: This simulates point-based prompting by using the point location
877
+ as a seed for region-based segmentation.
878
+ """
879
+ if model is None or processor is None:
880
+ return None, "❌ Error: Model not loaded."
881
+
882
+ if image_file is None:
883
+ return None, "⚠️ Please upload a medical image file."
884
+
885
+ try:
886
+ # Load image
887
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
888
+ file_ext = os.path.splitext(file_path)[1].lower()
889
+
890
+ if file_ext == '.dcm':
891
+ ds = pydicom.dcmread(file_path)
892
+ img_array = ds.pixel_array.astype(np.float32)
893
+ slope = getattr(ds, 'RescaleSlope', 1)
894
+ intercept = getattr(ds, 'RescaleIntercept', 0)
895
+ img_array = img_array * slope + intercept
896
+
897
+ # Normalize
898
+ img_min = np.percentile(img_array, 1)
899
+ img_max = np.percentile(img_array, 99)
900
+ img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
901
+ img_uint8 = (img_norm * 255).astype(np.uint8)
902
+ pil_image = Image.fromarray(img_uint8).convert('RGB')
903
+ else:
904
+ pil_image = Image.open(file_path).convert('RGB')
905
+
906
+ img_array = np.array(pil_image)
907
+ h, w = img_array.shape[:2]
908
+
909
+ # Clamp point coordinates
910
+ point_x = max(0, min(int(point_x), w - 1))
911
+ point_y = max(0, min(int(point_y), h - 1))
912
+
913
+ # Create a prompt based on the point location
914
+ # Use the point's neighborhood intensity as a hint for segmentation
915
+ prompt_text = f"segment region at point"
916
+
917
+ # Process with SAM
918
+ inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
919
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
920
+
921
+ with torch.no_grad():
922
+ outputs = model(**inputs)
923
+
924
+ # Extract masks
925
+ masks = None
926
+ if hasattr(outputs, 'pred_masks'):
927
+ masks = outputs.pred_masks
928
+ elif isinstance(outputs, dict):
929
+ masks = outputs.get('pred_masks') or outputs.get('masks')
930
+
931
+ if masks is not None:
932
+ if isinstance(masks, torch.Tensor):
933
+ masks = masks.cpu().numpy()
934
+
935
+ if len(masks.shape) == 4:
936
+ masks = masks[0]
937
+
938
+ if masks.dtype != bool:
939
+ masks = masks > 0.5
940
+
941
+ if len(masks.shape) == 3:
942
+ # Select mask containing the point
943
+ best_mask = None
944
+ for i in range(masks.shape[0]):
945
+ mask_resized = np.array(Image.fromarray(masks[i].astype(np.uint8) * 255).resize((w, h))) > 127
946
+ if mask_resized[point_y, point_x]:
947
+ best_mask = mask_resized
948
+ break
949
+
950
+ if best_mask is None:
951
+ best_mask = np.any(masks, axis=0)
952
+ best_mask = np.array(Image.fromarray(best_mask.astype(np.uint8) * 255).resize((w, h))) > 127
953
+
954
+ final_mask = best_mask
955
+ else:
956
+ final_mask = np.array(Image.fromarray(masks.astype(np.uint8) * 255).resize((w, h))) > 127
957
+ else:
958
+ final_mask = None
959
+
960
+ # Draw result with point marker
961
+ plt.figure(figsize=(10, 10))
962
+ plt.imshow(pil_image)
963
+
964
+ if final_mask is not None:
965
+ plt.imshow(final_mask, alpha=transparency, cmap=colormap)
966
+
967
+ # Draw point marker
968
+ plt.scatter([point_x], [point_y], c='red', s=200, marker='+', linewidths=3)
969
+ plt.scatter([point_x], [point_y], c='red', s=100, marker='o', facecolors='none', linewidths=2)
970
+
971
+ plt.axis('off')
972
+ plt.title(f"Point Prompt Segmentation at ({point_x}, {point_y})", fontsize=12)
973
+
974
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
975
+ output_path = output_file.name
976
+ output_file.close()
977
+
978
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
979
+ plt.close()
980
+
981
+ return output_path, f"✅ Point-based segmentation at ({point_x}, {point_y})"
982
+
983
+ except Exception as e:
984
+ print(f"Error in point prompt processing: {e}")
985
+ import traceback
986
+ traceback.print_exc()
987
+ return None, f"❌ Error: {str(e)}"
988
+
989
+ def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, colormap='spring', transparency=0.5):
990
+ """Process image with a bounding box prompt for segmentation."""
991
+ if model is None or processor is None:
992
+ return None, "❌ Error: Model not loaded."
993
+
994
+ if image_file is None:
995
+ return None, "⚠️ Please upload a medical image file."
996
+
997
+ try:
998
+ # Load image
999
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
1000
+ file_ext = os.path.splitext(file_path)[1].lower()
1001
+
1002
+ if file_ext == '.dcm':
1003
+ ds = pydicom.dcmread(file_path)
1004
+ img_array = ds.pixel_array.astype(np.float32)
1005
+ slope = getattr(ds, 'RescaleSlope', 1)
1006
+ intercept = getattr(ds, 'RescaleIntercept', 0)
1007
+ img_array = img_array * slope + intercept
1008
+
1009
+ img_min = np.percentile(img_array, 1)
1010
+ img_max = np.percentile(img_array, 99)
1011
+ img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
1012
+ img_uint8 = (img_norm * 255).astype(np.uint8)
1013
+ pil_image = Image.fromarray(img_uint8).convert('RGB')
1014
+ else:
1015
+ pil_image = Image.open(file_path).convert('RGB')
1016
+
1017
+ img_array = np.array(pil_image)
1018
+ h, w = img_array.shape[:2]
1019
+
1020
+ # Ensure box coordinates are valid
1021
+ x1, x2 = min(x1, x2), max(x1, x2)
1022
+ y1, y2 = min(y1, y2), max(y1, y2)
1023
+ x1, y1 = max(0, int(x1)), max(0, int(y1))
1024
+ x2, y2 = min(w, int(x2)), min(h, int(y2))
1025
+
1026
+ prompt_text = "segment region in bounding box"
1027
+
1028
+ # Process with SAM
1029
+ inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
1030
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1031
+
1032
+ with torch.no_grad():
1033
+ outputs = model(**inputs)
1034
+
1035
+ # Extract and filter masks by box region
1036
+ masks = None
1037
+ if hasattr(outputs, 'pred_masks'):
1038
+ masks = outputs.pred_masks
1039
+ elif isinstance(outputs, dict):
1040
+ masks = outputs.get('pred_masks') or outputs.get('masks')
1041
+
1042
+ final_mask = None
1043
+ if masks is not None:
1044
+ if isinstance(masks, torch.Tensor):
1045
+ masks = masks.cpu().numpy()
1046
+
1047
+ if len(masks.shape) == 4:
1048
+ masks = masks[0]
1049
+
1050
+ if masks.dtype != bool:
1051
+ masks = masks > 0.5
1052
+
1053
+ if len(masks.shape) == 3:
1054
+ combined = np.any(masks, axis=0)
1055
+ else:
1056
+ combined = masks
1057
+
1058
+ # Resize to image size
1059
+ combined_resized = np.array(Image.fromarray(combined.astype(np.uint8) * 255).resize((w, h))) > 127
1060
+
1061
+ # Create box mask and intersect
1062
+ box_mask = np.zeros((h, w), dtype=bool)
1063
+ box_mask[y1:y2, x1:x2] = True
1064
+ final_mask = combined_resized & box_mask
1065
+
1066
+ # Draw result with box
1067
+ plt.figure(figsize=(10, 10))
1068
+ plt.imshow(pil_image)
1069
+
1070
+ if final_mask is not None:
1071
+ plt.imshow(final_mask, alpha=transparency, cmap=colormap)
1072
+
1073
+ # Draw bounding box
1074
+ rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='red', facecolor='none')
1075
+ plt.gca().add_patch(rect)
1076
+
1077
+ plt.axis('off')
1078
+ plt.title(f"Box Prompt Segmentation [{x1}, {y1}, {x2}, {y2}]", fontsize=12)
1079
+
1080
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1081
+ output_path = output_file.name
1082
+ output_file.close()
1083
+
1084
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
1085
+ plt.close()
1086
+
1087
+ return output_path, f"✅ Box-based segmentation at [{x1}, {y1}, {x2}, {y2}]"
1088
+
1089
+ except Exception as e:
1090
+ print(f"Error in box prompt processing: {e}")
1091
+ import traceback
1092
+ traceback.print_exc()
1093
+ return None, f"❌ Error: {str(e)}"
1094
+
1095
+ def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks=3):
1096
+ """Process image and return multiple mask candidates with confidence scores."""
1097
+ if model is None or processor is None:
1098
+ return [], "❌ Error: Model not loaded.", ""
1099
+
1100
+ if image_file is None:
1101
+ return [], "⚠️ Please upload a medical image file.", ""
1102
+
1103
+ try:
1104
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
1105
+ file_ext = os.path.splitext(file_path)[1].lower()
1106
+
1107
+ if file_ext == '.dcm':
1108
+ ds = pydicom.dcmread(file_path)
1109
+ img_array = ds.pixel_array.astype(np.float32)
1110
+ slope = getattr(ds, 'RescaleSlope', 1)
1111
+ intercept = getattr(ds, 'RescaleIntercept', 0)
1112
+ img_array = img_array * slope + intercept
1113
+
1114
+ img_min = np.percentile(img_array, 1)
1115
+ img_max = np.percentile(img_array, 99)
1116
+ img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
1117
+ img_uint8 = (img_norm * 255).astype(np.uint8)
1118
+ pil_image = Image.fromarray(img_uint8).convert('RGB')
1119
+ else:
1120
+ pil_image = Image.open(file_path).convert('RGB')
1121
+
1122
+ if not prompt_text or not prompt_text.strip():
1123
+ prompt_text = "brain"
1124
+
1125
+ # Process with SAM
1126
+ inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
1127
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1128
+
1129
+ with torch.no_grad():
1130
+ outputs = model(**inputs)
1131
+
1132
+ # Extract masks
1133
+ masks = None
1134
+ scores = None
1135
+
1136
+ if hasattr(outputs, 'pred_masks'):
1137
+ masks = outputs.pred_masks
1138
+ elif isinstance(outputs, dict):
1139
+ masks = outputs.get('pred_masks') or outputs.get('masks')
1140
+ scores = outputs.get('iou_scores') or outputs.get('scores')
1141
+
1142
+ results = []
1143
+ mask_info = []
1144
+
1145
+ if masks is not None:
1146
+ if isinstance(masks, torch.Tensor):
1147
+ masks = masks.cpu().numpy()
1148
+ if scores is not None and isinstance(scores, torch.Tensor):
1149
+ scores = scores.cpu().numpy().flatten()
1150
+
1151
+ if len(masks.shape) == 4:
1152
+ masks = masks[0]
1153
+
1154
+ if len(masks.shape) == 3:
1155
+ num_available = masks.shape[0]
1156
+ num_to_show = min(num_masks, num_available)
1157
+
1158
+ # Generate confidence scores if not available
1159
+ if scores is None:
1160
+ scores = [1.0 / (i + 1) for i in range(num_available)] # Simulated scores
1161
+
1162
+ colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma']
1163
+
1164
+ for i in range(num_to_show):
1165
+ mask = masks[i]
1166
+ if mask.dtype != bool:
1167
+ mask = mask > 0.5
1168
+
1169
+ score = scores[i] if i < len(scores) else 0.5
1170
+
1171
+ # Create visualization
1172
+ plt.figure(figsize=(8, 8))
1173
+ plt.imshow(pil_image)
1174
+ plt.imshow(mask, alpha=0.5, cmap=colormaps[i % len(colormaps)])
1175
+ plt.axis('off')
1176
+ plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12)
1177
+
1178
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1179
+ output_path = output_file.name
1180
+ output_file.close()
1181
+
1182
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
1183
+ plt.close()
1184
+
1185
+ results.append(output_path)
1186
+ mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask):,} pixels")
1187
+ else:
1188
+ # Single mask case
1189
+ mask = masks
1190
+ if mask.dtype != bool:
1191
+ mask = mask > 0.5
1192
+
1193
+ plt.figure(figsize=(8, 8))
1194
+ plt.imshow(pil_image)
1195
+ plt.imshow(mask, alpha=0.5, cmap='spring')
1196
+ plt.axis('off')
1197
+ plt.title(f"Single Mask Output", fontsize=12)
1198
+
1199
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1200
+ output_path = output_file.name
1201
+ output_file.close()
1202
+
1203
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
1204
+ plt.close()
1205
+
1206
+ results.append(output_path)
1207
+ mask_info.append(f"Single mask: {np.sum(mask):,} pixels")
1208
+
1209
+ status = f"✅ Generated {len(results)} mask candidate(s)"
1210
+ info = "\n".join(mask_info) if mask_info else "No mask information available"
1211
+
1212
+ return results, status, info
1213
+
1214
+ except Exception as e:
1215
+ print(f"Error in multi-mask processing: {e}")
1216
+ import traceback
1217
+ traceback.print_exc()
1218
+ return [], f"❌ Error: {str(e)}", ""
1219
+
1220
+ def export_to_nifti(image_file, mask, output_name="segmentation"):
1221
+ """Export segmentation mask to NIFTI format.
1222
+
1223
+ Returns:
1224
+ str: Path to the exported NIFTI file, or None if export failed
1225
+ """
1226
+ if not NIBABEL_AVAILABLE:
1227
+ return None, "⚠️ NIFTI export not available - nibabel not installed"
1228
+
1229
+ if mask is None or not isinstance(mask, np.ndarray):
1230
+ return None, "⚠️ No valid mask to export"
1231
+
1232
+ try:
1233
+ # Convert mask to appropriate format
1234
+ mask_data = mask.astype(np.float32)
1235
+
1236
+ # Create NIFTI image
1237
+ # Use identity affine (1mm isotropic)
1238
+ affine = np.eye(4)
1239
+
1240
+ # Try to get spacing from DICOM if available
1241
+ if image_file:
1242
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
1243
+ if file_path.lower().endswith('.dcm'):
1244
+ try:
1245
+ ds = pydicom.dcmread(file_path, stop_before_pixels=True)
1246
+ pixel_spacing = getattr(ds, 'PixelSpacing', [1.0, 1.0])
1247
+ slice_thickness = getattr(ds, 'SliceThickness', 1.0)
1248
+ affine[0, 0] = float(pixel_spacing[0])
1249
+ affine[1, 1] = float(pixel_spacing[1])
1250
+ affine[2, 2] = float(slice_thickness)
1251
+ except:
1252
+ pass
1253
+
1254
+ nifti_img = nib.Nifti1Image(mask_data, affine)
1255
+
1256
+ # Save to temp file
1257
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz')
1258
+ output_path = output_file.name
1259
+ output_file.close()
1260
+
1261
+ nib.save(nifti_img, output_path)
1262
+
1263
+ return output_path, f"✅ Exported to NIFTI: {output_path}"
1264
+
1265
+ except Exception as e:
1266
+ print(f"Error exporting to NIFTI: {e}")
1267
+ return None, f"❌ Export failed: {str(e)}"
1268
+
1269
+ def save_annotation(image_file, mask, prompt_text, modality, stats=None):
1270
+ """Save annotation to a JSON file for later loading."""
1271
+ if mask is None:
1272
+ return None, "⚠️ No annotation to save"
1273
+
1274
+ try:
1275
+ annotation = {
1276
+ "timestamp": datetime.now().isoformat(),
1277
+ "image_file": os.path.basename(image_file) if image_file else "unknown",
1278
+ "prompt": prompt_text,
1279
+ "modality": modality,
1280
+ "mask_shape": list(mask.shape),
1281
+ "mask_sum": int(np.sum(mask)),
1282
+ "mask_base64": None, # We'll store as binary in a separate file
1283
+ "statistics": stats if stats else {}
1284
+ }
1285
+
1286
+ # Save mask as numpy file
1287
+ mask_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npz')
1288
+ mask_path = mask_file.name
1289
+ mask_file.close()
1290
+ np.savez_compressed(mask_path, mask=mask)
1291
+
1292
+ # Save annotation JSON
1293
+ json_file = tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w')
1294
+ json_path = json_file.name
1295
+ annotation["mask_file"] = mask_path
1296
+ json.dump(annotation, json_file, indent=2)
1297
+ json_file.close()
1298
+
1299
+ # Create ZIP with both files
1300
+ zip_buffer = io.BytesIO()
1301
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
1302
+ zf.write(json_path, 'annotation.json')
1303
+ zf.write(mask_path, 'mask.npz')
1304
+
1305
+ zip_buffer.seek(0)
1306
+ zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
1307
+ zip_path = zip_file.name
1308
+ zip_file.write(zip_buffer.read())
1309
+ zip_file.close()
1310
+
1311
+ return zip_path, f"✅ Annotation saved: {os.path.basename(zip_path)}"
1312
+
1313
+ except Exception as e:
1314
+ print(f"Error saving annotation: {e}")
1315
+ return None, f"❌ Save failed: {str(e)}"
1316
+
1317
+ def load_annotation(annotation_file):
1318
+ """Load a previously saved annotation."""
1319
+ if annotation_file is None:
1320
+ return None, None, "⚠️ No file selected"
1321
+
1322
+ try:
1323
+ file_path = annotation_file if isinstance(annotation_file, str) else str(annotation_file)
1324
+
1325
+ if file_path.endswith('.zip'):
1326
+ # Extract ZIP
1327
+ with zipfile.ZipFile(file_path, 'r') as zf:
1328
+ # Read annotation JSON
1329
+ with zf.open('annotation.json') as f:
1330
+ annotation = json.load(f)
1331
+
1332
+ # Extract mask file
1333
+ mask_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.npz')
1334
+ mask_temp.write(zf.read('mask.npz'))
1335
+ mask_temp.close()
1336
+
1337
+ mask_data = np.load(mask_temp.name)
1338
+ mask = mask_data['mask']
1339
+
1340
+ info = f"📋 **Loaded Annotation**\n"
1341
+ info += f"Image: {annotation.get('image_file', 'unknown')}\n"
1342
+ info += f"Prompt: {annotation.get('prompt', 'N/A')}\n"
1343
+ info += f"Modality: {annotation.get('modality', 'N/A')}\n"
1344
+ info += f"Saved: {annotation.get('timestamp', 'N/A')}\n"
1345
+ info += f"Mask size: {annotation.get('mask_sum', 0):,} pixels"
1346
+
1347
+ return mask, annotation, info
1348
+ else:
1349
+ return None, None, "⚠️ Invalid file format. Please upload a .zip annotation file."
1350
+
1351
+ except Exception as e:
1352
+ print(f"Error loading annotation: {e}")
1353
+ return None, None, f"❌ Load failed: {str(e)}"
1354
+
1355
+ def visualize_loaded_annotation(image_file, annotation_file, colormap='spring', transparency=0.5):
1356
+ """Visualize a loaded annotation on the original image."""
1357
+ mask, annotation, info = load_annotation(annotation_file)
1358
+
1359
+ if mask is None:
1360
+ return None, info
1361
+
1362
+ if image_file is None:
1363
+ return None, "⚠️ Please upload the original image to visualize"
1364
+
1365
+ try:
1366
+ file_path = image_file if isinstance(image_file, str) else str(image_file)
1367
+ file_ext = os.path.splitext(file_path)[1].lower()
1368
+
1369
+ if file_ext == '.dcm':
1370
+ ds = pydicom.dcmread(file_path)
1371
+ img_array = ds.pixel_array.astype(np.float32)
1372
+ slope = getattr(ds, 'RescaleSlope', 1)
1373
+ intercept = getattr(ds, 'RescaleIntercept', 0)
1374
+ img_array = img_array * slope + intercept
1375
+
1376
+ img_min = np.percentile(img_array, 1)
1377
+ img_max = np.percentile(img_array, 99)
1378
+ img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
1379
+ img_uint8 = (img_norm * 255).astype(np.uint8)
1380
+ pil_image = Image.fromarray(img_uint8).convert('RGB')
1381
+ else:
1382
+ pil_image = Image.open(file_path).convert('RGB')
1383
+
1384
+ # Resize mask if needed
1385
+ w, h = pil_image.size
1386
+ if mask.shape != (h, w):
1387
+ mask = np.array(Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))) > 127
1388
+
1389
+ # Visualize
1390
+ plt.figure(figsize=(10, 10))
1391
+ plt.imshow(pil_image)
1392
+ plt.imshow(mask, alpha=transparency, cmap=colormap)
1393
+ plt.axis('off')
1394
+ plt.title(f"Loaded Annotation: {annotation.get('prompt', 'N/A')}", fontsize=12)
1395
+
1396
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1397
+ output_path = output_file.name
1398
+ output_file.close()
1399
+
1400
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
1401
+ plt.close()
1402
+
1403
+ return output_path, info
1404
+
1405
+ except Exception as e:
1406
+ print(f"Error visualizing annotation: {e}")
1407
+ return None, f"❌ Visualization failed: {str(e)}"
1408
+
1409
+ # Store last mask for export/save operations
1410
+ last_processed_mask = {"mask": None, "image_file": None, "prompt": None, "modality": None}
1411
+
1412
+ def process_and_store_mask(image_file, prompt_text, modality, window_type):
1413
+ """Process image and store mask for export/save operations."""
1414
+ result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
1415
+
1416
+ if result and mask is not None:
1417
+ last_processed_mask["mask"] = mask
1418
+ last_processed_mask["image_file"] = image_file
1419
+ last_processed_mask["prompt"] = prompt_text
1420
+ last_processed_mask["modality"] = modality
1421
+
1422
+ # Calculate stats
1423
+ stats = calculate_roi_statistics(image_file, mask, modality)
1424
+ stats_text = format_roi_statistics(stats)
1425
+
1426
+ return result, "✅ Segmentation complete! Ready for export.", stats_text
1427
+ else:
1428
+ return result, "❌ Processing failed.", ""
1429
+
1430
+ def export_last_mask_nifti():
1431
+ """Export the last processed mask to NIFTI."""
1432
+ if last_processed_mask["mask"] is None:
1433
+ return None, "⚠️ No mask to export. Process an image first."
1434
+
1435
+ return export_to_nifti(
1436
+ last_processed_mask["image_file"],
1437
+ last_processed_mask["mask"]
1438
+ )
1439
+
1440
+ def save_last_annotation():
1441
+ """Save the last processed annotation."""
1442
+ if last_processed_mask["mask"] is None:
1443
+ return None, "⚠️ No annotation to save. Process an image first."
1444
+
1445
+ stats = calculate_roi_statistics(
1446
+ last_processed_mask["image_file"],
1447
+ last_processed_mask["mask"],
1448
+ last_processed_mask["modality"]
1449
+ )
1450
+
1451
+ return save_annotation(
1452
+ last_processed_mask["image_file"],
1453
+ last_processed_mask["mask"],
1454
+ last_processed_mask["prompt"],
1455
+ last_processed_mask["modality"],
1456
+ stats
1457
+ )
1458
+
1459
  # Create Gradio Interface
1460
  demo_file_path = demo_dicom_path if demo_file_available and os.path.exists(demo_dicom_path) else None
1461
 
 
2270
  interactive=False,
2271
  lines=4
2272
  )
2273
+
2274
+ # NEW: Point/Box Prompts Tab
2275
+ with gr.Tab("🎯 Point/Box Prompts"):
2276
+ gr.Markdown("""
2277
+ **Interactive Point and Box-based Segmentation**
2278
+
2279
+ Use precise point clicks or bounding boxes to guide the segmentation.
2280
+ - **Point Prompt**: Click on the region you want to segment
2281
+ - **Box Prompt**: Define a bounding box around the region of interest
2282
+ """)
2283
+
2284
+ with gr.Tabs():
2285
+ with gr.Tab("Point Prompt"):
2286
+ with gr.Row():
2287
+ with gr.Column():
2288
+ file_input_point = gr.File(
2289
+ label="Upload Medical Image",
2290
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
2291
+ type="filepath"
2292
+ )
2293
+
2294
+ gr.Markdown("### Point Coordinates")
2295
+ with gr.Row():
2296
+ point_x = gr.Number(label="X coordinate", value=128, precision=0)
2297
+ point_y = gr.Number(label="Y coordinate", value=128, precision=0)
2298
+
2299
+ with gr.Row():
2300
+ modality_point = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
2301
+ window_point = gr.Dropdown(
2302
+ ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
2303
+ label="Windowing", value="Brain (Grey Matter)"
2304
+ )
2305
+
2306
+ with gr.Row():
2307
+ colormap_point = gr.Dropdown(
2308
+ ["spring", "cool", "hot", "viridis", "plasma"],
2309
+ label="Colormap", value="spring"
2310
+ )
2311
+ transparency_point = gr.Slider(0.0, 1.0, value=0.5, label="Transparency")
2312
+
2313
+ submit_point_btn = gr.Button("Segment at Point", variant="primary")
2314
+
2315
+ with gr.Column():
2316
+ output_point = gr.Image(label="Point Segmentation Result", type="filepath")
2317
+ status_point = gr.Textbox(label="Status", interactive=False)
2318
+
2319
+ with gr.Tab("Box Prompt"):
2320
+ with gr.Row():
2321
+ with gr.Column():
2322
+ file_input_box = gr.File(
2323
+ label="Upload Medical Image",
2324
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
2325
+ type="filepath"
2326
+ )
2327
+
2328
+ gr.Markdown("### Bounding Box Coordinates")
2329
+ with gr.Row():
2330
+ box_x1 = gr.Number(label="X1 (left)", value=50, precision=0)
2331
+ box_y1 = gr.Number(label="Y1 (top)", value=50, precision=0)
2332
+ with gr.Row():
2333
+ box_x2 = gr.Number(label="X2 (right)", value=200, precision=0)
2334
+ box_y2 = gr.Number(label="Y2 (bottom)", value=200, precision=0)
2335
+
2336
+ with gr.Row():
2337
+ modality_box = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
2338
+ window_box = gr.Dropdown(
2339
+ ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
2340
+ label="Windowing", value="Brain (Grey Matter)"
2341
+ )
2342
+
2343
+ with gr.Row():
2344
+ colormap_box = gr.Dropdown(
2345
+ ["spring", "cool", "hot", "viridis", "plasma"],
2346
+ label="Colormap", value="spring"
2347
+ )
2348
+ transparency_box = gr.Slider(0.0, 1.0, value=0.5, label="Transparency")
2349
+
2350
+ submit_box_btn = gr.Button("Segment in Box", variant="primary")
2351
+
2352
+ with gr.Column():
2353
+ output_box = gr.Image(label="Box Segmentation Result", type="filepath")
2354
+ status_box = gr.Textbox(label="Status", interactive=False)
2355
+
2356
+ # NEW: ROI Statistics & Export Tab
2357
+ with gr.Tab("📊 ROI Statistics & Export"):
2358
+ gr.Markdown("""
2359
+ **ROI Statistics and Export Options**
2360
+
2361
+ Process an image and get detailed statistics about the segmented region:
2362
+ - Area (pixels and percentage)
2363
+ - Intensity statistics (mean, std, min, max)
2364
+ - Centroid and bounding box
2365
+ - Export to NIFTI format for medical imaging software
2366
+ - Save/Load annotations for later use
2367
+ """)
2368
+
2369
+ with gr.Row():
2370
+ with gr.Column():
2371
+ file_input_stats = gr.File(
2372
+ label="Upload Medical Image",
2373
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
2374
+ type="filepath"
2375
+ )
2376
+
2377
+ text_input_stats = gr.Textbox(
2378
+ label="Text Prompt", value="brain",
2379
+ placeholder="e.g. brain, tumor, skull"
2380
+ )
2381
+
2382
+ with gr.Row():
2383
+ modality_stats = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
2384
+ window_stats = gr.Dropdown(
2385
+ ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
2386
+ label="Windowing", value="Brain (Grey Matter)"
2387
+ )
2388
+
2389
+ submit_stats_btn = gr.Button("Process & Calculate Statistics", variant="primary")
2390
+
2391
+ gr.Markdown("### Export Options")
2392
+ with gr.Row():
2393
+ export_nifti_btn = gr.Button("📥 Export to NIFTI", size="sm")
2394
+ save_annotation_btn = gr.Button("💾 Save Annotation", size="sm")
2395
+
2396
+ with gr.Column():
2397
+ output_stats = gr.Image(label="Segmentation Result", type="filepath")
2398
+ status_stats = gr.Textbox(label="Status", interactive=False)
2399
+
2400
+ gr.Markdown("### 📊 ROI Statistics")
2401
+ roi_stats_text = gr.Markdown(value="*Process an image to see statistics*")
2402
+
2403
+ nifti_download = gr.File(label="Download NIFTI", visible=True)
2404
+ annotation_download = gr.File(label="Download Annotation", visible=True)
2405
+
2406
+ gr.Markdown("---")
2407
+ gr.Markdown("### Load Saved Annotation")
2408
+ with gr.Row():
2409
+ with gr.Column():
2410
+ annotation_upload = gr.File(
2411
+ label="Upload Annotation (.zip)",
2412
+ file_types=[".zip"],
2413
+ type="filepath"
2414
+ )
2415
+
2416
+ original_image_upload = gr.File(
2417
+ label="Upload Original Image (for visualization)",
2418
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
2419
+ type="filepath"
2420
+ )
2421
+
2422
+ load_annotation_btn = gr.Button("Load & Visualize Annotation", variant="secondary")
2423
+
2424
+ with gr.Column():
2425
+ loaded_annotation_output = gr.Image(label="Loaded Annotation", type="filepath")
2426
+ loaded_annotation_info = gr.Markdown(value="*Upload an annotation file to load*")
2427
+
2428
+ # NEW: Multi-Mask Output Tab
2429
+ with gr.Tab("🎭 Multi-Mask Output"):
2430
+ gr.Markdown("""
2431
+ **Generate Multiple Mask Candidates**
2432
+
2433
+ SAM can generate multiple segmentation hypotheses with confidence scores.
2434
+ This is useful when the segmentation is ambiguous or you want to compare alternatives.
2435
+ """)
2436
+
2437
+ with gr.Row():
2438
+ with gr.Column():
2439
+ file_input_multi = gr.File(
2440
+ label="Upload Medical Image",
2441
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
2442
+ type="filepath"
2443
+ )
2444
+
2445
+ text_input_multi = gr.Textbox(
2446
+ label="Text Prompt", value="brain",
2447
+ placeholder="e.g. brain, tumor, skull"
2448
+ )
2449
+
2450
+ with gr.Row():
2451
+ modality_multi = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
2452
+ window_multi = gr.Dropdown(
2453
+ ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
2454
+ label="Windowing", value="Brain (Grey Matter)"
2455
+ )
2456
+
2457
+ num_masks_slider = gr.Slider(1, 5, value=3, step=1, label="Number of Masks")
2458
+
2459
+ submit_multi_btn = gr.Button("Generate Multiple Masks", variant="primary")
2460
+
2461
+ with gr.Column():
2462
+ gallery_multi = gr.Gallery(
2463
+ label="Mask Candidates",
2464
+ show_label=True,
2465
+ columns=2,
2466
+ rows=2,
2467
+ height="auto"
2468
+ )
2469
+
2470
+ status_multi = gr.Textbox(label="Status", interactive=False)
2471
+ mask_info_multi = gr.Textbox(label="Mask Information", lines=5, interactive=False)
2472
 
2473
  # Single image processing
2474
  load_demo_btn.click(
 
2576
  ],
2577
  outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text]
2578
  )
2579
+
2580
+ # Point prompt processing
2581
+ submit_point_btn.click(
2582
+ fn=process_with_point_prompt,
2583
+ inputs=[file_input_point, point_x, point_y, modality_point, window_point, colormap_point, transparency_point],
2584
+ outputs=[output_point, status_point]
2585
+ )
2586
+
2587
+ # Box prompt processing
2588
+ submit_box_btn.click(
2589
+ fn=process_with_box_prompt,
2590
+ inputs=[file_input_box, box_x1, box_y1, box_x2, box_y2, modality_box, window_box, colormap_box, transparency_box],
2591
+ outputs=[output_box, status_box]
2592
+ )
2593
+
2594
+ # ROI Statistics processing
2595
+ submit_stats_btn.click(
2596
+ fn=process_and_store_mask,
2597
+ inputs=[file_input_stats, text_input_stats, modality_stats, window_stats],
2598
+ outputs=[output_stats, status_stats, roi_stats_text]
2599
+ )
2600
+
2601
+ # NIFTI Export
2602
+ export_nifti_btn.click(
2603
+ fn=export_last_mask_nifti,
2604
+ inputs=[],
2605
+ outputs=[nifti_download, status_stats]
2606
+ )
2607
+
2608
+ # Save Annotation
2609
+ save_annotation_btn.click(
2610
+ fn=save_last_annotation,
2611
+ inputs=[],
2612
+ outputs=[annotation_download, status_stats]
2613
+ )
2614
+
2615
+ # Load Annotation
2616
+ load_annotation_btn.click(
2617
+ fn=visualize_loaded_annotation,
2618
+ inputs=[original_image_upload, annotation_upload],
2619
+ outputs=[loaded_annotation_output, loaded_annotation_info]
2620
+ )
2621
+
2622
+ # Multi-Mask processing
2623
+ submit_multi_btn.click(
2624
+ fn=process_multi_mask,
2625
+ inputs=[file_input_multi, text_input_multi, modality_multi, window_multi, num_masks_slider],
2626
+ outputs=[gallery_multi, status_multi, mask_info_multi]
2627
+ )
2628
+
2629
+ # Auto-play functionality for slice viewer
2630
+ def auto_play_slices(files, selected_subject, prompt, mod, window):
2631
+ """Auto-play through slices with a short delay."""
2632
+ if not files:
2633
+ yield None, "No slices loaded", 0
2634
+ return
2635
+
2636
+ subject_groups = group_images_by_subject(files)
2637
+ if selected_subject:
2638
+ subject_id = selected_subject.split(" (")[0]
2639
+ else:
2640
+ subject_id = list(subject_groups.keys())[0] if subject_groups else None
2641
+
2642
+ if not subject_id or subject_id not in subject_groups:
2643
+ yield None, "No slices loaded", 0
2644
+ return
2645
+
2646
+ subject_files = subject_groups[subject_id]['files']
2647
+ cache_key = f"{subject_id}_{len(subject_files)}_{prompt}_{mod}"
2648
+
2649
+ if cache_key not in processed_results_cache:
2650
+ yield None, "Please process slices first", 0
2651
+ return
2652
+
2653
+ results = processed_results_cache[cache_key]
2654
+
2655
+ for idx in range(len(results)):
2656
+ slice_info = f"Slice {idx + 1}/{len(results)} ({subject_id}) - Auto-playing..."
2657
+ yield results[idx], slice_info, idx
2658
+ time.sleep(0.5) # 500ms delay between slices
2659
+
2660
+ auto_play_btn.click(
2661
+ fn=auto_play_slices,
2662
+ inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
2663
+ outputs=[current_slice_output, slice_info_text, slice_slider]
2664
+ )
2665
 
2666
  if __name__ == "__main__":
2667
  demo.launch()
requirements.txt CHANGED
@@ -7,4 +7,6 @@ torch>=2.0.0
7
  torchvision>=0.15.0
8
  transformers>=4.45.0
9
  huggingface-hub>=0.20.0
 
 
10
 
 
7
  torchvision>=0.15.0
8
  transformers>=4.45.0
9
  huggingface-hub>=0.20.0
10
+ nibabel>=5.0.0
11
+ scipy>=1.10.0
12