import gradio as gr import torch import cv2 import numpy as np from pathlib import Path from huggingface_hub import snapshot_download from fastMONAI.vision_all import * from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume import sys import os import requests from groq import Groq from dotenv import load_dotenv import math import plotly.graph_objects as go from skimage import measure # Load environment variables (local .env or HuggingFace Secrets) load_dotenv(dotenv_path=Path.cwd().parent / '.env') GROQ_API_KEY = os.environ.get('GROQ_API_KEY') groq_client = Groq(api_key=GROQ_API_KEY) # Debug: List all symbols imported from fastMONAI.vision_all print("[DEBUG] fastMONAI.vision_all symbols:", dir()) from git import Repo import os #Additional support for local execution:- #import pathlib #temp = pathlib.PosixPath #pathlib.PosixPath = pathlib.WindowsPath #pathlib.PosixPath = temp # Local execution setup clone_dir = Path.cwd() # URI = os.getenv('PAT_Token_URI') # if os.path.exists(clone_dir): # pass # else: # Repo.clone_from(URI, clone_dir) def extract_slices_from_mask(img, mask_data, view): """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view.""" slices = [] target_size = (320, 320) for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]): if view == "Sagittal": slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] elif view == "Axial": slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :] elif view == "Coronal": slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :] slice_img = np.fliplr(np.rot90(slice_img, -1)) slice_mask = np.fliplr(np.rot90(slice_mask, -1)) slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size) slices.append((slice_img_resized, slice_mask_resized)) return slices def resize_and_pad(slice_img, slice_mask, target_size): """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio.""" h, w = slice_img.shape scale = min(target_size[0] / w, target_size[1] / h) new_w, new_h = int(w * scale), int(h * scale) resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) pad_w = (target_size[0] - new_w) // 2 pad_h = (target_size[1] - new_h) // 2 padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) return padded_img, padded_mask def normalize_image(slice_img): """Normalize the image to the range [0, 255] safely.""" slice_img_min, slice_img_max = slice_img.min(), slice_img.max() if slice_img_min == slice_img_max: # Avoid division by zero return np.zeros_like(slice_img, dtype=np.uint8) normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255 return normalized_img.astype(np.uint8) def get_fused_image(img, pred_mask, view, alpha=0.8): """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) mask_color = np.array([255, 0, 0]) colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) # Flip the fused image vertically and horizontally fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally if view=='Sagittal': return fused_flipped elif view=='Coronal' or 'Axial': rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1) return rotated def get_bsa(height, weight): """Calculate Body Surface Area using the Mosteller formula.""" return math.sqrt((height * weight) / 3600) def create_3d_mesh_file(mask_data, spacing, save_dir): """Create a 3D mesh file from the segmentation mask using marching cubes.""" import trimesh try: # Convert to numpy if tensor if hasattr(mask_data, 'numpy'): mask_np = mask_data.numpy().astype(np.float32) else: mask_np = np.array(mask_data).astype(np.float32) # Squeeze to 3D if needed if mask_np.ndim == 4: mask_np = mask_np[0] print(f"[DEBUG] Mask shape: {mask_np.shape}, spacing: {spacing}, sum: {np.sum(mask_np)}") # Check if mask has valid data if np.sum(mask_np) < 100: print("[DEBUG] Mask has too few positive voxels") return None # Apply marching cubes to extract surface mesh verts, faces, normals, values = measure.marching_cubes( mask_np, level=0.5, spacing=spacing ) print(f"[DEBUG] Marching cubes: {len(verts)} vertices, {len(faces)} faces") # Create trimesh object mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) # Apply a crimson color to the mesh mesh.visual.vertex_colors = [220, 20, 60, 255] # Crimson RGBA # Export to GLB format mesh_path = save_dir / "la_mesh.glb" mesh.export(str(mesh_path), file_type='glb') print(f"[DEBUG] Mesh exported to: {mesh_path}") return str(mesh_path) except Exception as e: print(f"[DEBUG] Error creating 3D mesh: {e}") import traceback traceback.print_exc() return None def fetch_miracle_ref(gender, bsa_indexed=False): """Fetch reference values from MIRACLE-API.""" param = "MXLAVi" if bsa_indexed else "MXLAV" url = f"https://ref.miracle-api.workers.dev/exec?domain=LA_VF¶meter={param}&gender={gender.lower()}&method=SM_AI" try: response = requests.get(url) if response.status_code == 200: return response.json().get('results', {}) except Exception as e: print(f"Error fetching MIRACLE-API: {e}") return {} def get_interpretation(volume, height, weight, gender, voxel_info): """Generate interpretation using Groq LLM.""" bsa = get_bsa(height, weight) lavi = volume / bsa ref_lav = fetch_miracle_ref(gender, bsa_indexed=False) ref_lavi = fetch_miracle_ref(gender, bsa_indexed=True) system_prompt = f""" You are a medical imaging assistant. You will be provided with patient data and cardiac segmentation results (specifically Left Atrium Volume - LAV). Your task is to interpret these results using reference data from MIRACLE-API. Input Data: - LAV: {volume} mL - Height: {height} cm, Weight: {weight} kg, Gender: {gender} - Calculated BSA: {bsa:.2f} m² - Calculated LAVi: {lavi:.2f} mL/m² - Voxel Info: {voxel_info} - Reference LAV (MIRACLE-API): {ref_lav} - Reference LAVi (MIRACLE-API): {ref_lavi} Instructions: 1. Acknowledge the calculation method using the voxel info. 2. Compare the volume and LAVi against the reference mean and ranges (ll: lower limit, ul: upper limit). 3. State if the volume is enlarged or normal based on the Z-score/percentile (if you can estimate) or simply by comparing against the upper limit (ul). 4. Format the response strictly as requested by the user, starting with 'MIRACLE-API'. """ try: completion = groq_client.chat.completions.create( model="openai/gpt-oss-120b", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": "Interpret the results."} ], temperature=0.1 ) return completion.choices[0].message.content except Exception as e: return f"Error generating interpretation: {e}" def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view): """Predict function using the learner and other resources.""" if view == None: view = 'Sagittal' img_path = Path(fileobj.name) save_fn = 'pred_' + img_path.stem save_path = save_dir / save_fn org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False) mask_data = inference(learn, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data if "".join(org_img.orientation) == "LSA": mask_data = mask_data.permute(0,1,3,2) mask_data = torch.flip(mask_data[0], dims=[1]) mask_data = torch.Tensor(mask_data)[None] img = org_img.data org_img.set_data(mask_data) org_img.save(save_path) slices = extract_slices_from_mask(img[0], mask_data[0], view) fused_images = [(get_fused_image( normalize_image(slice_img), # Normalize safely slice_mask, view)) for slice_img, slice_mask in slices] volume = compute_binary_tumor_volume(org_img) # Voxel info for the notes dx, dy, dz = org_img.spacing voxel_vol = dx * dy * dz / 1000 total_voxels = int(np.sum(mask_data.numpy())) voxel_info = f"{total_voxels:,} voxels with each voxel volume of {voxel_vol:.4f} mL" # Create 3D mesh file mesh_path = create_3d_mesh_file(mask_data, spacing=(dx, dy, dz), save_dir=save_dir) return fused_images, round(float(volume), 2), voxel_info, mesh_path def wrapped_segmentation(fileobj, height, weight, gender, view, display_mode): fused_images, volume, voxel_info, mesh_path = gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view) notes = get_interpretation(volume, height, weight, gender, voxel_info) # Return Model3D with the selected display_mode model3d = gr.Model3D(value=mesh_path, height=420, zoom_speed=0.5, pan_speed=0.5, display_mode=display_mode) return fused_images, volume, notes, model3d # Initialize the system models_path = Path.cwd() save_dir = Path.cwd() / 'hs_pred' save_dir.mkdir(parents=True, exist_ok=True) # Download model files from private HuggingFace repo from huggingface_hub import hf_hub_download HF_TOKEN = os.environ.get('HF_TOKEN') MODEL_REPO = 'drankush-ai/laviz-model' # Download model files if not already present model_path = models_path / 'heart_model.pkl' vars_path = models_path / 'vars.pkl' if not model_path.exists(): print(f"[DEBUG] Downloading heart_model.pkl from {MODEL_REPO}...") downloaded_model = hf_hub_download(repo_id=MODEL_REPO, filename='heart_model.pkl', token=HF_TOKEN) import shutil shutil.copy(downloaded_model, model_path) if not vars_path.exists(): print(f"[DEBUG] Downloading vars.pkl from {MODEL_REPO}...") downloaded_vars = hf_hub_download(repo_id=MODEL_REPO, filename='vars.pkl', token=HF_TOKEN) import shutil shutil.copy(downloaded_vars, vars_path) # Debug: Check if load_system_resources is defined learn, reorder, resample = load_system_resources(models_path=models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl') # Gradio interface setup with light theme with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# LAViz - Left Atrium Visualization & Analysis") with gr.Row(): # Left Column - Inputs with gr.Column(): input_file = gr.File(label="Upload MRI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"]) view_selector = gr.Radio( choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)" ) with gr.Row(): height_in = gr.Number(label="Height (cm)", value=None) weight_in = gr.Number(label="Weight (kg)", value=None) gender_in = gr.Radio(choices=["Male", "Female"], value=None, label="Gender") # 3D Display Mode selector (before Submit) display_mode_selector = gr.Radio( choices=["solid", "point_cloud", "wireframe"], value="solid", label="3D Display Mode", info="Select display mode before clicking Submit. To change mode, click Clear and re-submit." ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") submit_btn = gr.Button("Submit", variant="primary") # 3D Visualization below buttons mesh_out = gr.Model3D(label="3D Left Atrium Model", height=420, zoom_speed=0.5, pan_speed=0.5) # Right Column - Outputs with gr.Column(): gallery_out = gr.Gallery( label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450 ) vol_out = gr.Textbox(label="Volume of the Left Atrium (mL):") notes_out = gr.Markdown(label="Notes") # Example handling - clicking fills all fields gr.Examples( examples=[[str(Path.cwd() / "sample.nii.gz"), "Sagittal", 172, 80, "Male"]], inputs=[input_file, view_selector, height_in, weight_in, gender_in], label="Examples" ) # Clear action - clears all inputs AND outputs def clear_all(): return ( None, # input_file "Sagittal", # view_selector (reset to default) None, # height_in None, # weight_in None, # gender_in "solid", # display_mode_selector (reset to default) None, # gallery_out "", # vol_out "", # notes_out None, # mesh_out ) clear_btn.click( fn=clear_all, inputs=[], outputs=[input_file, view_selector, height_in, weight_in, gender_in, display_mode_selector, gallery_out, vol_out, notes_out, mesh_out] ) # Submit action submit_btn.click( fn=wrapped_segmentation, inputs=[input_file, height_in, weight_in, gender_in, view_selector, display_mode_selector], outputs=[gallery_out, vol_out, notes_out, mesh_out] ) # Launch the Gradio interface demo.launch()