Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from fastMONAI.vision_all import * | |
| from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume | |
| import sys | |
| import os | |
| import requests | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| import math | |
| import plotly.graph_objects as go | |
| from skimage import measure | |
| # Load environment variables (local .env or HuggingFace Secrets) | |
| load_dotenv(dotenv_path=Path.cwd().parent / '.env') | |
| GROQ_API_KEY = os.environ.get('GROQ_API_KEY') | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| # Debug: List all symbols imported from fastMONAI.vision_all | |
| print("[DEBUG] fastMONAI.vision_all symbols:", dir()) | |
| from git import Repo | |
| import os | |
| #Additional support for local execution:- | |
| #import pathlib | |
| #temp = pathlib.PosixPath | |
| #pathlib.PosixPath = pathlib.WindowsPath | |
| #pathlib.PosixPath = temp | |
| # Local execution setup | |
| clone_dir = Path.cwd() | |
| # URI = os.getenv('PAT_Token_URI') | |
| # if os.path.exists(clone_dir): | |
| # pass | |
| # else: | |
| # Repo.clone_from(URI, clone_dir) | |
| def extract_slices_from_mask(img, mask_data, view): | |
| """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view.""" | |
| slices = [] | |
| target_size = (320, 320) | |
| for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]): | |
| if view == "Sagittal": | |
| slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] | |
| elif view == "Axial": | |
| slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :] | |
| elif view == "Coronal": | |
| slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :] | |
| slice_img = np.fliplr(np.rot90(slice_img, -1)) | |
| slice_mask = np.fliplr(np.rot90(slice_mask, -1)) | |
| slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size) | |
| slices.append((slice_img_resized, slice_mask_resized)) | |
| return slices | |
| def resize_and_pad(slice_img, slice_mask, target_size): | |
| """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio.""" | |
| h, w = slice_img.shape | |
| scale = min(target_size[0] / w, target_size[1] / h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
| resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
| pad_w = (target_size[0] - new_w) // 2 | |
| pad_h = (target_size[1] - new_h) // 2 | |
| padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) | |
| padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) | |
| return padded_img, padded_mask | |
| def normalize_image(slice_img): | |
| """Normalize the image to the range [0, 255] safely.""" | |
| slice_img_min, slice_img_max = slice_img.min(), slice_img.max() | |
| if slice_img_min == slice_img_max: # Avoid division by zero | |
| return np.zeros_like(slice_img, dtype=np.uint8) | |
| normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255 | |
| return normalized_img.astype(np.uint8) | |
| def get_fused_image(img, pred_mask, view, alpha=0.8): | |
| """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" | |
| gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| mask_color = np.array([255, 0, 0]) | |
| colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) | |
| fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) | |
| # Flip the fused image vertically and horizontally | |
| fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally | |
| if view=='Sagittal': | |
| return fused_flipped | |
| elif view=='Coronal' or 'Axial': | |
| rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1) | |
| return rotated | |
| def get_bsa(height, weight): | |
| """Calculate Body Surface Area using the Mosteller formula.""" | |
| return math.sqrt((height * weight) / 3600) | |
| def create_3d_mesh_file(mask_data, spacing, save_dir): | |
| """Create a 3D mesh file from the segmentation mask using marching cubes.""" | |
| import trimesh | |
| try: | |
| # Convert to numpy if tensor | |
| if hasattr(mask_data, 'numpy'): | |
| mask_np = mask_data.numpy().astype(np.float32) | |
| else: | |
| mask_np = np.array(mask_data).astype(np.float32) | |
| # Squeeze to 3D if needed | |
| if mask_np.ndim == 4: | |
| mask_np = mask_np[0] | |
| print(f"[DEBUG] Mask shape: {mask_np.shape}, spacing: {spacing}, sum: {np.sum(mask_np)}") | |
| # Check if mask has valid data | |
| if np.sum(mask_np) < 100: | |
| print("[DEBUG] Mask has too few positive voxels") | |
| return None | |
| # Apply marching cubes to extract surface mesh | |
| verts, faces, normals, values = measure.marching_cubes( | |
| mask_np, level=0.5, spacing=spacing | |
| ) | |
| print(f"[DEBUG] Marching cubes: {len(verts)} vertices, {len(faces)} faces") | |
| # Create trimesh object | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) | |
| # Apply a crimson color to the mesh | |
| mesh.visual.vertex_colors = [220, 20, 60, 255] # Crimson RGBA | |
| # Export to GLB format | |
| mesh_path = save_dir / "la_mesh.glb" | |
| mesh.export(str(mesh_path), file_type='glb') | |
| print(f"[DEBUG] Mesh exported to: {mesh_path}") | |
| return str(mesh_path) | |
| except Exception as e: | |
| print(f"[DEBUG] Error creating 3D mesh: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def fetch_miracle_ref(gender, bsa_indexed=False): | |
| """Fetch reference values from MIRACLE-API.""" | |
| param = "MXLAVi" if bsa_indexed else "MXLAV" | |
| url = f"https://ref.miracle-api.workers.dev/exec?domain=LA_VF¶meter={param}&gender={gender.lower()}&method=SM_AI" | |
| try: | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return response.json().get('results', {}) | |
| except Exception as e: | |
| print(f"Error fetching MIRACLE-API: {e}") | |
| return {} | |
| def get_interpretation(volume, height, weight, gender, voxel_info): | |
| """Generate interpretation using Groq LLM.""" | |
| bsa = get_bsa(height, weight) | |
| lavi = volume / bsa | |
| ref_lav = fetch_miracle_ref(gender, bsa_indexed=False) | |
| ref_lavi = fetch_miracle_ref(gender, bsa_indexed=True) | |
| system_prompt = f""" | |
| You are a medical imaging assistant. You will be provided with patient data and cardiac segmentation results (specifically Left Atrium Volume - LAV). | |
| Your task is to interpret these results using reference data from MIRACLE-API. | |
| Input Data: | |
| - LAV: {volume} mL | |
| - Height: {height} cm, Weight: {weight} kg, Gender: {gender} | |
| - Calculated BSA: {bsa:.2f} m² | |
| - Calculated LAVi: {lavi:.2f} mL/m² | |
| - Voxel Info: {voxel_info} | |
| - Reference LAV (MIRACLE-API): {ref_lav} | |
| - Reference LAVi (MIRACLE-API): {ref_lavi} | |
| Instructions: | |
| 1. Acknowledge the calculation method using the voxel info. | |
| 2. Compare the volume and LAVi against the reference mean and ranges (ll: lower limit, ul: upper limit). | |
| 3. State if the volume is enlarged or normal based on the Z-score/percentile (if you can estimate) or simply by comparing against the upper limit (ul). | |
| 4. Format the response strictly as requested by the user, starting with 'MIRACLE-API'. | |
| """ | |
| try: | |
| completion = groq_client.chat.completions.create( | |
| model="openai/gpt-oss-120b", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": "Interpret the results."} | |
| ], | |
| temperature=0.1 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error generating interpretation: {e}" | |
| def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view): | |
| """Predict function using the learner and other resources.""" | |
| if view == None: | |
| view = 'Sagittal' | |
| img_path = Path(fileobj.name) | |
| save_fn = 'pred_' + img_path.stem | |
| save_path = save_dir / save_fn | |
| org_img, input_img, org_size = med_img_reader(img_path, | |
| reorder=reorder, | |
| resample=resample, | |
| only_tensor=False) | |
| mask_data = inference(learn, reorder=reorder, resample=resample, | |
| org_img=org_img, input_img=input_img, | |
| org_size=org_size).data | |
| if "".join(org_img.orientation) == "LSA": | |
| mask_data = mask_data.permute(0,1,3,2) | |
| mask_data = torch.flip(mask_data[0], dims=[1]) | |
| mask_data = torch.Tensor(mask_data)[None] | |
| img = org_img.data | |
| org_img.set_data(mask_data) | |
| org_img.save(save_path) | |
| slices = extract_slices_from_mask(img[0], mask_data[0], view) | |
| fused_images = [(get_fused_image( | |
| normalize_image(slice_img), # Normalize safely | |
| slice_mask, view)) | |
| for slice_img, slice_mask in slices] | |
| volume = compute_binary_tumor_volume(org_img) | |
| # Voxel info for the notes | |
| dx, dy, dz = org_img.spacing | |
| voxel_vol = dx * dy * dz / 1000 | |
| total_voxels = int(np.sum(mask_data.numpy())) | |
| voxel_info = f"{total_voxels:,} voxels with each voxel volume of {voxel_vol:.4f} mL" | |
| # Create 3D mesh file | |
| mesh_path = create_3d_mesh_file(mask_data, spacing=(dx, dy, dz), save_dir=save_dir) | |
| return fused_images, round(float(volume), 2), voxel_info, mesh_path | |
| def wrapped_segmentation(fileobj, height, weight, gender, view, display_mode): | |
| fused_images, volume, voxel_info, mesh_path = gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view) | |
| notes = get_interpretation(volume, height, weight, gender, voxel_info) | |
| # Return Model3D with the selected display_mode | |
| model3d = gr.Model3D(value=mesh_path, height=420, zoom_speed=0.5, pan_speed=0.5, display_mode=display_mode) | |
| return fused_images, volume, notes, model3d | |
| # Initialize the system | |
| models_path = Path.cwd() | |
| save_dir = Path.cwd() / 'hs_pred' | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # Download model files from private HuggingFace repo | |
| from huggingface_hub import hf_hub_download | |
| HF_TOKEN = os.environ.get('HF_TOKEN') | |
| MODEL_REPO = 'drankush-ai/laviz-model' | |
| # Download model files if not already present | |
| model_path = models_path / 'heart_model.pkl' | |
| vars_path = models_path / 'vars.pkl' | |
| if not model_path.exists(): | |
| print(f"[DEBUG] Downloading heart_model.pkl from {MODEL_REPO}...") | |
| downloaded_model = hf_hub_download(repo_id=MODEL_REPO, filename='heart_model.pkl', token=HF_TOKEN) | |
| import shutil | |
| shutil.copy(downloaded_model, model_path) | |
| if not vars_path.exists(): | |
| print(f"[DEBUG] Downloading vars.pkl from {MODEL_REPO}...") | |
| downloaded_vars = hf_hub_download(repo_id=MODEL_REPO, filename='vars.pkl', token=HF_TOKEN) | |
| import shutil | |
| shutil.copy(downloaded_vars, vars_path) | |
| # Debug: Check if load_system_resources is defined | |
| learn, reorder, resample = load_system_resources(models_path=models_path, | |
| learner_fn='heart_model.pkl', | |
| variables_fn='vars.pkl') | |
| # Gradio interface setup with light theme | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# LAViz - Left Atrium Visualization & Analysis") | |
| with gr.Row(): | |
| # Left Column - Inputs | |
| with gr.Column(): | |
| input_file = gr.File(label="Upload MRI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"]) | |
| view_selector = gr.Radio( | |
| choices=["Axial", "Coronal", "Sagittal"], | |
| value='Sagittal', | |
| label="Select View (Sagittal by default)" | |
| ) | |
| with gr.Row(): | |
| height_in = gr.Number(label="Height (cm)", value=None) | |
| weight_in = gr.Number(label="Weight (kg)", value=None) | |
| gender_in = gr.Radio(choices=["Male", "Female"], value=None, label="Gender") | |
| # 3D Display Mode selector (before Submit) | |
| display_mode_selector = gr.Radio( | |
| choices=["solid", "point_cloud", "wireframe"], | |
| value="solid", | |
| label="3D Display Mode", | |
| info="Select display mode before clicking Submit. To change mode, click Clear and re-submit." | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # 3D Visualization below buttons | |
| mesh_out = gr.Model3D(label="3D Left Atrium Model", height=420, zoom_speed=0.5, pan_speed=0.5) | |
| # Right Column - Outputs | |
| with gr.Column(): | |
| gallery_out = gr.Gallery( | |
| label="Click an Image, and use Arrow Keys to scroll slices", | |
| columns=3, | |
| height=450 | |
| ) | |
| vol_out = gr.Textbox(label="Volume of the Left Atrium (mL):") | |
| notes_out = gr.Markdown(label="Notes") | |
| # Example handling - clicking fills all fields | |
| gr.Examples( | |
| examples=[[str(Path.cwd() / "sample.nii.gz"), "Sagittal", 172, 80, "Male"]], | |
| inputs=[input_file, view_selector, height_in, weight_in, gender_in], | |
| label="Examples" | |
| ) | |
| # Clear action - clears all inputs AND outputs | |
| def clear_all(): | |
| return ( | |
| None, # input_file | |
| "Sagittal", # view_selector (reset to default) | |
| None, # height_in | |
| None, # weight_in | |
| None, # gender_in | |
| "solid", # display_mode_selector (reset to default) | |
| None, # gallery_out | |
| "", # vol_out | |
| "", # notes_out | |
| None, # mesh_out | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[input_file, view_selector, height_in, weight_in, gender_in, display_mode_selector, gallery_out, vol_out, notes_out, mesh_out] | |
| ) | |
| # Submit action | |
| submit_btn.click( | |
| fn=wrapped_segmentation, | |
| inputs=[input_file, height_in, weight_in, gender_in, view_selector, display_mode_selector], | |
| outputs=[gallery_out, vol_out, notes_out, mesh_out] | |
| ) | |
| # Launch the Gradio interface | |
| demo.launch() |