Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from fastMONAI.vision_all import * | |
| from git import Repo | |
| import os | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.svm import SVC | |
| #import pathlib | |
| #temp = pathlib.PosixPath | |
| #pathlib.PosixPath = pathlib.WindowsPath | |
| #pathlib.PosixPath = temp | |
| clone_dir = Path.cwd() / 'clone_dir' | |
| URI = os.getenv('PAT_Token_URI') | |
| if os.path.exists(clone_dir): | |
| pass | |
| else: | |
| Repo.clone_from(URI, clone_dir) | |
| def extract_slices_from_mask(img, mask_data, view): | |
| """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view.""" | |
| slices = [] | |
| target_size = (320, 320) | |
| for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]): | |
| if view == "Sagittal": | |
| slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] | |
| elif view == "Axial": | |
| slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :] | |
| elif view == "Coronal": | |
| slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :] | |
| slice_img = np.fliplr(np.rot90(slice_img, -1)) | |
| slice_mask = np.fliplr(np.rot90(slice_mask, -1)) | |
| slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size) | |
| slices.append((slice_img_resized, slice_mask_resized)) | |
| return slices | |
| def resize_and_pad(slice_img, slice_mask, target_size): | |
| """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio.""" | |
| h, w = slice_img.shape | |
| scale = min(target_size[0] / w, target_size[1] / h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
| resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
| pad_w = (target_size[0] - new_w) // 2 | |
| pad_h = (target_size[1] - new_h) // 2 | |
| padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) | |
| padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) | |
| return padded_img, padded_mask | |
| def normalize_image(slice_img): | |
| """Normalize the image to the range [0, 255] safely.""" | |
| slice_img_min, slice_img_max = slice_img.min(), slice_img.max() | |
| if slice_img_min == slice_img_max: # Avoid division by zero | |
| return np.zeros_like(slice_img, dtype=np.uint8) | |
| normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255 | |
| return normalized_img.astype(np.uint8) | |
| def get_fused_image(img, pred_mask, view, alpha=0.8): | |
| """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" | |
| gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| mask_color = np.array([255, 0, 0]) | |
| colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) | |
| fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) | |
| # Flip the fused image vertically and horizontally | |
| fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally | |
| if view=='Sagittal': | |
| return fused_flipped | |
| elif view=='Coronal' or 'Axial': | |
| rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1) | |
| return rotated | |
| #SVM Model for Classification of LAE Volume | |
| df = pd.read_csv(str(Path.cwd() /"dataset.csv")) | |
| X = df.iloc[:, 1].values.reshape(-1, 1) # Feature | |
| y = df.iloc[:, 2].values # Target | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| model = SVC(kernel='poly', random_state=0) | |
| model.fit(X_train, y_train) | |
| def gradio_volume_classification(volume): | |
| global model | |
| y_pred = model.predict([[volume]]) | |
| if int(y_pred[0]) == 0: | |
| return "Cardiomegaly: Negative (Left Atrium Enlargement Not Detected)" | |
| else: | |
| return "Cardiomegaly: Positive (Left Atrium Enlargement Detected)" | |
| def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view): | |
| """Predict function using the learner and other resources.""" | |
| if view == None: | |
| view = 'Sagittal' | |
| img_path = Path(fileobj.name) | |
| save_fn = 'pred_' + img_path.stem | |
| save_path = save_dir / save_fn | |
| org_img, input_img, org_size = med_img_reader(img_path, | |
| reorder=reorder, | |
| resample=resample, | |
| only_tensor=False) | |
| mask_data = inference(learn, reorder=reorder, resample=resample, | |
| org_img=org_img, input_img=input_img, | |
| org_size=org_size).data | |
| if "".join(org_img.orientation) == "LSA": | |
| mask_data = mask_data.permute(0,1,3,2) | |
| mask_data = torch.flip(mask_data[0], dims=[1]) | |
| mask_data = torch.Tensor(mask_data)[None] | |
| img = org_img.data | |
| org_img.set_data(mask_data) | |
| org_img.save(save_path) | |
| slices = extract_slices_from_mask(img[0], mask_data[0], view) | |
| fused_images = [(get_fused_image( | |
| normalize_image(slice_img), # Normalize safely | |
| slice_mask, view)) | |
| for slice_img, slice_mask in slices] | |
| volume = compute_binary_tumor_volume(org_img) | |
| classification_result = gradio_volume_classification(volume) | |
| return fused_images, round(volume, 2), classification_result | |
| # Initialize the system | |
| models_path = Path.cwd() / 'clone_dir' | |
| save_dir = Path.cwd() / 'hs_pred' | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # Load the model and other required resources | |
| learn, reorder, resample = load_system_resources(models_path=models_path, | |
| learner_fn='heart_model.pkl', | |
| variables_fn='vars.pkl') | |
| # Gradio interface setup | |
| output_text = gr.Textbox(label="Volume of the Left Atrium (mL):") | |
| output_classification = gr.Textbox(label="Left Atrium Cardiomegaly Classification:") | |
| view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)") | |
| demo = gr.Interface( | |
| fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view), | |
| inputs=["file", view_selector], | |
| outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text, output_classification], | |
| examples=[[str(Path.cwd() /"sample.nii.gz")]], | |
| allow_flagging='never') | |
| # Launch the Gradio interface | |
| demo.launch() | |