titi
now supporting mcp server
bdd3d9c
import numpy as np
import tifffile
import os
import tempfile
import urllib.request
from PIL import Image
from pathlib import Path
import time, uuid, atexit
from unet_lungs_segmentation import LungsPredict
import gradio as gr
model = LungsPredict()
APP_TMP_DIR = Path(tempfile.gettempdir()) / "lungs_seg_tmp"
APP_TMP_DIR.mkdir(parents=True, exist_ok=True)
# ---------- Example file ----------
def get_example_file():
url = "https://zenodo.org/record/8099852/files/lungs_ct.tif?download=1"
tmp_path = APP_TMP_DIR / "example_lungs.tif"
if not tmp_path.exists():
urllib.request.urlretrieve(url, tmp_path)
return str(tmp_path)
example_file_path = get_example_file()
PROTECTED_PATHS = {Path(example_file_path).resolve()}
def new_tmp_path(basename: str = "tmp.tif") -> str:
"""Return a unique path inside the app temp dir."""
uid = uuid.uuid4().hex[:8]
return str(APP_TMP_DIR / f"{uid}_{basename}")
def clean_temp(max_age_hours: float = 6.0) -> None:
cutoff = time.time() - max_age_hours * 3600 if max_age_hours > 0 else float("inf")
protected = PROTECTED_PATHS
for p in APP_TMP_DIR.glob("*"):
try:
rp = p.resolve()
if rp in protected:
continue
if max_age_hours == 0 or p.stat().st_mtime < cutoff:
p.unlink(missing_ok=True)
except Exception as e:
print(f"[cleanup] could not remove {p}: {e}")
atexit.register(lambda: clean_temp(0)) # purge on shutdown
def write_mask_tif(mask: np.ndarray) -> str:
"""Write a mask volume to a compressed TIFF in app temp and return the path."""
out_path = new_tmp_path("mask.tif")
tifffile.imwrite(out_path, mask.astype(np.uint8), compression="zlib")
return out_path
# ---------- Reading helpers ----------
def _read_tif_from_path(path: str):
"""Read a tif from a local filesystem path; only auto-delete files in APP_TMP_DIR (not protected)."""
arr = tifffile.imread(path)
try:
if path and os.path.exists(path):
rp = Path(path).resolve()
if (rp not in PROTECTED_PATHS) and (APP_TMP_DIR in rp.parents):
os.remove(rp)
except Exception as e:
print(f"[load_volume] couldn't remove temp file {path}: {e}")
return arr
def load_volume(file_obj):
"""
Backward-compatible wrapper used by older code that passes in a path-like object.
Prefer _load_volume_from_any() in new code.
"""
if not file_obj:
return None
path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or file_obj
if isinstance(path, (str, os.PathLike)):
return _read_tif_from_path(str(path))
# If a dict/FileData slipped through, delegate to the robust path:
return _load_volume_from_any(file_obj)
def _load_volume_from_any(file_obj):
"""
Normalize different inputs to a real filesystem path and read via _read_tif_from_path.
Accepts:
- dict with 'path' or 'url' (Gradio FileData / programmatic)
- str local path or URL
- bytes / bytearray
- file-like object with .read()
"""
try:
# Gradio FileData-like dict
if isinstance(file_obj, dict):
path = file_obj.get("path") or file_obj.get("url")
if not path:
raise gr.Error("Invalid file object (missing 'path' or 'url').")
if isinstance(path, str) and (path.startswith("http://") or path.startswith("https://")):
fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
os.close(fd)
urllib.request.urlretrieve(path, tmp_path)
return _read_tif_from_path(tmp_path)
return _read_tif_from_path(path)
# String path or URL
if isinstance(file_obj, (str, os.PathLike)):
s = str(file_obj)
if s.startswith("http://") or s.startswith("https://"):
fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
os.close(fd)
urllib.request.urlretrieve(s, tmp_path)
return _read_tif_from_path(tmp_path)
return _read_tif_from_path(s)
# Raw bytes
if isinstance(file_obj, (bytes, bytearray)):
fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
os.close(fd)
with open(tmp_path, "wb") as w:
w.write(file_obj)
return _read_tif_from_path(tmp_path)
# File-like object
if hasattr(file_obj, "read"):
data = file_obj.read()
fd, tmp_path = tempfile.mkstemp(suffix=".tif", dir=str(APP_TMP_DIR))
os.close(fd)
with open(tmp_path, "wb") as w:
w.write(data)
return _read_tif_from_path(tmp_path)
raise gr.Error(f"Unsupported input type for file_obj: {type(file_obj)}")
except Exception as e:
raise gr.Error(f"Failed to read input file: {e}")
# ---------- Model + viz ----------
def segment_volume(volume):
"""Run segmentation on the loaded volume (return shape (Z, Y, X))."""
if volume is None:
return None
return model.segment_lungs(volume)
def volume_stats(volume):
"""Return (min, max) as floats for global 8-bit scaling."""
if volume is None:
return (0.0, 1.0)
return float(volume.min()), float(volume.max())
def _to_8bit_stats(arr, mn, mx):
rng = max(mx - mn, 1e-8)
return np.clip((arr - mn) / rng * 255.0, 0, 255).astype(np.uint8)
def browse_axis_fast(axis, idx, volume, stats):
"""Same as browse_axis but uses precomputed global stats."""
if volume is None:
return None
mn, mx = stats
if axis == "Z":
slice_ = volume[idx]
elif axis == "Y":
slice_ = volume[:, idx, :]
elif axis == "X":
slice_ = volume[:, :, idx]
else:
return None
return Image.fromarray(_to_8bit_stats(slice_, mn, mx))
def browse_overlay_axis_fast(axis, idx, volume, seg, stats, alpha=0.35):
"""Overlay using global stats (fewer allocations, faster)."""
if volume is None or seg is None:
return None
mn, mx = stats
if axis == "Z":
raw = volume[idx]; mask = seg[idx]
elif axis == "Y":
raw = volume[:, idx, :]; mask = seg[:, idx, :]
elif axis == "X":
raw = volume[:, :, idx]; mask = seg[:, :, idx]
else:
return None
raw8 = _to_8bit_stats(raw, mn, mx)
rgb = np.repeat(raw8[..., None], 3, axis=-1)
mask_rgb = np.zeros_like(rgb)
mask_rgb[..., 0] = (mask.astype(np.uint8) * 255)
blended = rgb.astype(np.float32) * (1 - alpha) + mask_rgb.astype(np.float32) * alpha
return Image.fromarray(blended.astype(np.uint8))