LAViz / app.py
drankush-ai's picture
Upload app.py with huggingface_hub
edfa726 verified
import gradio as gr
import torch
import cv2
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from fastMONAI.vision_all import *
from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume
import sys
import os
import requests
from groq import Groq
from dotenv import load_dotenv
import math
import plotly.graph_objects as go
from skimage import measure
# Load environment variables (local .env or HuggingFace Secrets)
load_dotenv(dotenv_path=Path.cwd().parent / '.env')
GROQ_API_KEY = os.environ.get('GROQ_API_KEY')
groq_client = Groq(api_key=GROQ_API_KEY)
# Debug: List all symbols imported from fastMONAI.vision_all
print("[DEBUG] fastMONAI.vision_all symbols:", dir())
from git import Repo
import os
#Additional support for local execution:-
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
#pathlib.PosixPath = temp
# Local execution setup
clone_dir = Path.cwd()
# URI = os.getenv('PAT_Token_URI')
# if os.path.exists(clone_dir):
# pass
# else:
# Repo.clone_from(URI, clone_dir)
def extract_slices_from_mask(img, mask_data, view):
"""Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
slices = []
target_size = (320, 320)
for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
if view == "Sagittal":
slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
elif view == "Axial":
slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
elif view == "Coronal":
slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
slice_img = np.fliplr(np.rot90(slice_img, -1))
slice_mask = np.fliplr(np.rot90(slice_mask, -1))
slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
slices.append((slice_img_resized, slice_mask_resized))
return slices
def resize_and_pad(slice_img, slice_mask, target_size):
"""Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
h, w = slice_img.shape
scale = min(target_size[0] / w, target_size[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
pad_w = (target_size[0] - new_w) // 2
pad_h = (target_size[1] - new_h) // 2
padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
return padded_img, padded_mask
def normalize_image(slice_img):
"""Normalize the image to the range [0, 255] safely."""
slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
if slice_img_min == slice_img_max: # Avoid division by zero
return np.zeros_like(slice_img, dtype=np.uint8)
normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
return normalized_img.astype(np.uint8)
def get_fused_image(img, pred_mask, view, alpha=0.8):
"""Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
mask_color = np.array([255, 0, 0])
colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
# Flip the fused image vertically and horizontally
fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
if view=='Sagittal':
return fused_flipped
elif view=='Coronal' or 'Axial':
rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
return rotated
def get_bsa(height, weight):
"""Calculate Body Surface Area using the Mosteller formula."""
return math.sqrt((height * weight) / 3600)
def create_3d_mesh_file(mask_data, spacing, save_dir):
"""Create a 3D mesh file from the segmentation mask using marching cubes."""
import trimesh
try:
# Convert to numpy if tensor
if hasattr(mask_data, 'numpy'):
mask_np = mask_data.numpy().astype(np.float32)
else:
mask_np = np.array(mask_data).astype(np.float32)
# Squeeze to 3D if needed
if mask_np.ndim == 4:
mask_np = mask_np[0]
print(f"[DEBUG] Mask shape: {mask_np.shape}, spacing: {spacing}, sum: {np.sum(mask_np)}")
# Check if mask has valid data
if np.sum(mask_np) < 100:
print("[DEBUG] Mask has too few positive voxels")
return None
# Apply marching cubes to extract surface mesh
verts, faces, normals, values = measure.marching_cubes(
mask_np, level=0.5, spacing=spacing
)
print(f"[DEBUG] Marching cubes: {len(verts)} vertices, {len(faces)} faces")
# Create trimesh object
mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
# Apply a crimson color to the mesh
mesh.visual.vertex_colors = [220, 20, 60, 255] # Crimson RGBA
# Export to GLB format
mesh_path = save_dir / "la_mesh.glb"
mesh.export(str(mesh_path), file_type='glb')
print(f"[DEBUG] Mesh exported to: {mesh_path}")
return str(mesh_path)
except Exception as e:
print(f"[DEBUG] Error creating 3D mesh: {e}")
import traceback
traceback.print_exc()
return None
def fetch_miracle_ref(gender, bsa_indexed=False):
"""Fetch reference values from MIRACLE-API."""
param = "MXLAVi" if bsa_indexed else "MXLAV"
url = f"https://ref.miracle-api.workers.dev/exec?domain=LA_VF&parameter={param}&gender={gender.lower()}&method=SM_AI"
try:
response = requests.get(url)
if response.status_code == 200:
return response.json().get('results', {})
except Exception as e:
print(f"Error fetching MIRACLE-API: {e}")
return {}
def get_interpretation(volume, height, weight, gender, voxel_info):
"""Generate interpretation using Groq LLM."""
bsa = get_bsa(height, weight)
lavi = volume / bsa
ref_lav = fetch_miracle_ref(gender, bsa_indexed=False)
ref_lavi = fetch_miracle_ref(gender, bsa_indexed=True)
system_prompt = f"""
You are a medical imaging assistant. You will be provided with patient data and cardiac segmentation results (specifically Left Atrium Volume - LAV).
Your task is to interpret these results using reference data from MIRACLE-API.
Input Data:
- LAV: {volume} mL
- Height: {height} cm, Weight: {weight} kg, Gender: {gender}
- Calculated BSA: {bsa:.2f}
- Calculated LAVi: {lavi:.2f} mL/m²
- Voxel Info: {voxel_info}
- Reference LAV (MIRACLE-API): {ref_lav}
- Reference LAVi (MIRACLE-API): {ref_lavi}
Instructions:
1. Acknowledge the calculation method using the voxel info.
2. Compare the volume and LAVi against the reference mean and ranges (ll: lower limit, ul: upper limit).
3. State if the volume is enlarged or normal based on the Z-score/percentile (if you can estimate) or simply by comparing against the upper limit (ul).
4. Format the response strictly as requested by the user, starting with 'MIRACLE-API'.
"""
try:
completion = groq_client.chat.completions.create(
model="openai/gpt-oss-120b",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Interpret the results."}
],
temperature=0.1
)
return completion.choices[0].message.content
except Exception as e:
return f"Error generating interpretation: {e}"
def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
"""Predict function using the learner and other resources."""
if view == None:
view = 'Sagittal'
img_path = Path(fileobj.name)
save_fn = 'pred_' + img_path.stem
save_path = save_dir / save_fn
org_img, input_img, org_size = med_img_reader(img_path,
reorder=reorder,
resample=resample,
only_tensor=False)
mask_data = inference(learn, reorder=reorder, resample=resample,
org_img=org_img, input_img=input_img,
org_size=org_size).data
if "".join(org_img.orientation) == "LSA":
mask_data = mask_data.permute(0,1,3,2)
mask_data = torch.flip(mask_data[0], dims=[1])
mask_data = torch.Tensor(mask_data)[None]
img = org_img.data
org_img.set_data(mask_data)
org_img.save(save_path)
slices = extract_slices_from_mask(img[0], mask_data[0], view)
fused_images = [(get_fused_image(
normalize_image(slice_img), # Normalize safely
slice_mask, view))
for slice_img, slice_mask in slices]
volume = compute_binary_tumor_volume(org_img)
# Voxel info for the notes
dx, dy, dz = org_img.spacing
voxel_vol = dx * dy * dz / 1000
total_voxels = int(np.sum(mask_data.numpy()))
voxel_info = f"{total_voxels:,} voxels with each voxel volume of {voxel_vol:.4f} mL"
# Create 3D mesh file
mesh_path = create_3d_mesh_file(mask_data, spacing=(dx, dy, dz), save_dir=save_dir)
return fused_images, round(float(volume), 2), voxel_info, mesh_path
def wrapped_segmentation(fileobj, height, weight, gender, view, display_mode):
fused_images, volume, voxel_info, mesh_path = gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
notes = get_interpretation(volume, height, weight, gender, voxel_info)
# Return Model3D with the selected display_mode
model3d = gr.Model3D(value=mesh_path, height=420, zoom_speed=0.5, pan_speed=0.5, display_mode=display_mode)
return fused_images, volume, notes, model3d
# Initialize the system
models_path = Path.cwd()
save_dir = Path.cwd() / 'hs_pred'
save_dir.mkdir(parents=True, exist_ok=True)
# Download model files from private HuggingFace repo
from huggingface_hub import hf_hub_download
HF_TOKEN = os.environ.get('HF_TOKEN')
MODEL_REPO = 'drankush-ai/laviz-model'
# Download model files if not already present
model_path = models_path / 'heart_model.pkl'
vars_path = models_path / 'vars.pkl'
if not model_path.exists():
print(f"[DEBUG] Downloading heart_model.pkl from {MODEL_REPO}...")
downloaded_model = hf_hub_download(repo_id=MODEL_REPO, filename='heart_model.pkl', token=HF_TOKEN)
import shutil
shutil.copy(downloaded_model, model_path)
if not vars_path.exists():
print(f"[DEBUG] Downloading vars.pkl from {MODEL_REPO}...")
downloaded_vars = hf_hub_download(repo_id=MODEL_REPO, filename='vars.pkl', token=HF_TOKEN)
import shutil
shutil.copy(downloaded_vars, vars_path)
# Debug: Check if load_system_resources is defined
learn, reorder, resample = load_system_resources(models_path=models_path,
learner_fn='heart_model.pkl',
variables_fn='vars.pkl')
# Gradio interface setup with light theme
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# LAViz - Left Atrium Visualization & Analysis")
with gr.Row():
# Left Column - Inputs
with gr.Column():
input_file = gr.File(label="Upload MRI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"])
view_selector = gr.Radio(
choices=["Axial", "Coronal", "Sagittal"],
value='Sagittal',
label="Select View (Sagittal by default)"
)
with gr.Row():
height_in = gr.Number(label="Height (cm)", value=None)
weight_in = gr.Number(label="Weight (kg)", value=None)
gender_in = gr.Radio(choices=["Male", "Female"], value=None, label="Gender")
# 3D Display Mode selector (before Submit)
display_mode_selector = gr.Radio(
choices=["solid", "point_cloud", "wireframe"],
value="solid",
label="3D Display Mode",
info="Select display mode before clicking Submit. To change mode, click Clear and re-submit."
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
submit_btn = gr.Button("Submit", variant="primary")
# 3D Visualization below buttons
mesh_out = gr.Model3D(label="3D Left Atrium Model", height=420, zoom_speed=0.5, pan_speed=0.5)
# Right Column - Outputs
with gr.Column():
gallery_out = gr.Gallery(
label="Click an Image, and use Arrow Keys to scroll slices",
columns=3,
height=450
)
vol_out = gr.Textbox(label="Volume of the Left Atrium (mL):")
notes_out = gr.Markdown(label="Notes")
# Example handling - clicking fills all fields
gr.Examples(
examples=[[str(Path.cwd() / "sample.nii.gz"), "Sagittal", 172, 80, "Male"]],
inputs=[input_file, view_selector, height_in, weight_in, gender_in],
label="Examples"
)
# Clear action - clears all inputs AND outputs
def clear_all():
return (
None, # input_file
"Sagittal", # view_selector (reset to default)
None, # height_in
None, # weight_in
None, # gender_in
"solid", # display_mode_selector (reset to default)
None, # gallery_out
"", # vol_out
"", # notes_out
None, # mesh_out
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[input_file, view_selector, height_in, weight_in, gender_in, display_mode_selector, gallery_out, vol_out, notes_out, mesh_out]
)
# Submit action
submit_btn.click(
fn=wrapped_segmentation,
inputs=[input_file, height_in, weight_in, gender_in, view_selector, display_mode_selector],
outputs=[gallery_out, vol_out, notes_out, mesh_out]
)
# Launch the Gradio interface
demo.launch()