supunnadeera commited on
Commit
faf2829
·
1 Parent(s): 2f81e5d

Option B: dual-venv subprocess dispatch (venv_v3=Cellpose3.1, venv_v4=CellposeSAM)

Browse files
Files changed (7) hide show
  1. app.py +3 -0
  2. requirements.txt +1 -3
  3. requirements_v3.txt +5 -0
  4. requirements_v4.txt +6 -0
  5. segmentation.py +61 -95
  6. setup_venvs.py +54 -0
  7. worker.py +115 -0
app.py CHANGED
@@ -6,6 +6,9 @@ import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
 
 
 
 
9
  from segmentation import MODELS, run_segmentation
10
  from mask_utils import (
11
  create_colored_overlay,
 
6
  import numpy as np
7
  from PIL import Image
8
 
9
+ from setup_venvs import setup_venvs
10
+ setup_venvs() # create venv_v3 / venv_v4 on first launch
11
+
12
  from segmentation import MODELS, run_segmentation
13
  from mask_utils import (
14
  create_colored_overlay,
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
- cellpose>=3.1.1
2
  gradio>=5.0.0
3
  numpy<2
4
- opencv-python-headless
5
  Pillow
6
  scikit-image
7
  pandas
8
- timm
 
 
1
  gradio>=5.0.0
2
  numpy<2
 
3
  Pillow
4
  scikit-image
5
  pandas
6
+ opencv-python-headless
requirements_v3.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cellpose==3.1.1.2
2
+ numpy<2
3
+ opencv-python-headless
4
+ Pillow
5
+ scikit-image
requirements_v4.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ cellpose
2
+ numpy<2
3
+ opencv-python-headless
4
+ Pillow
5
+ scikit-image
6
+ timm
segmentation.py CHANGED
@@ -1,40 +1,21 @@
 
 
 
 
 
1
  import numpy as np
2
- import cv2
3
  from PIL import Image
4
- from cellpose import models
5
 
 
 
 
6
 
7
  MODELS = {
8
  "CellposeSAM": ["cpsam"],
9
- "Cellpose3.1": ["cyto3"],
10
  }
11
 
12
 
13
- def load_image(image_input) -> np.ndarray:
14
- """Accept a numpy array (from Gradio) or file path, return H×W×3 uint8 RGB."""
15
- if isinstance(image_input, np.ndarray):
16
- img = image_input.copy()
17
- if img.ndim == 2:
18
- img = np.stack([img, img, img], axis=-1)
19
- elif img.shape[2] == 4:
20
- img = img[:, :, :3]
21
- return img.astype(np.uint8)
22
-
23
- # file path
24
- img = cv2.imread(str(image_input), cv2.IMREAD_UNCHANGED)
25
- if img is None:
26
- img = np.array(Image.open(str(image_input)).convert("RGB"))
27
- return img.astype(np.uint8)
28
-
29
- if img.ndim == 2:
30
- img = np.stack([img, img, img], axis=-1)
31
- elif img.shape[2] == 4:
32
- img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
33
- else:
34
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
35
- return img.astype(np.uint8)
36
-
37
-
38
  def run_segmentation(
39
  image_input,
40
  model_type: str,
@@ -46,73 +27,58 @@ def run_segmentation(
46
  use_gpu: bool,
47
  ) -> tuple[np.ndarray, int]:
48
  """
49
- Run Cellpose segmentation.
50
-
51
- Returns:
52
- masks: 2-D integer array (0 = background, 1..N = cell labels)
53
- num_cells: number of detected cells
54
  """
55
- img = load_image(image_input)
56
-
57
- diam = diameter if diameter > 0 else None
58
-
59
- normalize_param = {"percentile": [1.0, 99.0], "tile_norm_blocksize": 0}
60
 
61
- if model_type == "CellposeSAM":
62
- if model_name in ("cpsam",):
63
- model = models.CellposeModel(gpu=use_gpu, model_type=model_name)
 
 
64
  else:
65
- model = models.CellposeModel(gpu=use_gpu, pretrained_model=model_name)
66
-
67
- # Pack channels into front of array (same logic as worker.py CellposeSAM path)
68
- chan_indices = _parse_channels(channels_text)
69
- valid = [c for c in chan_indices if c < img.shape[2]]
70
- img_input = np.zeros_like(img)
71
- for i, c in enumerate(valid):
72
- img_input[:, :, i] = img[:, :, c]
73
-
74
- masks, _, _ = model.eval(
75
- img_input,
76
- diameter=diam,
77
- batch_size=8,
78
- resample=True,
79
- normalize=normalize_param,
80
- flow_threshold=flow_threshold,
81
- cellprob_threshold=cellprob_threshold,
82
- )[:3]
83
-
84
- else: # Cellpose3.1
85
- if model_name in ("cyto3",):
86
- model = models.Cellpose(gpu=use_gpu, model_type=model_name)
87
- else:
88
- model = models.CellposeModel(gpu=use_gpu, pretrained_model=model_name)
89
-
90
- chan_indices = _parse_channels(channels_text)
91
- chan_arg = (chan_indices + [0, 0])[:2]
92
-
93
- masks, _, _ = model.eval(
94
- img,
95
- diameter=diam,
96
- channels=chan_arg,
97
- batch_size=8,
98
- resample=True,
99
- normalize=normalize_param,
100
- flow_threshold=flow_threshold,
101
- cellprob_threshold=cellprob_threshold,
102
- )[:3]
103
-
104
- num_cells = int(masks.max())
105
- return masks.astype(np.int32), num_cells
106
-
107
-
108
- def _parse_channels(channels_text: str) -> list[int]:
109
- """Parse '0,0' style channel string into a list of ints."""
110
- try:
111
- parts = [p.strip() for p in str(channels_text).split(",")]
112
- result = []
113
- for p in parts:
114
- val = int(p)
115
- result.append(max(0, min(3, val)))
116
- return result
117
- except Exception:
118
- return [0, 0]
 
1
+ import json
2
+ import subprocess
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
  import numpy as np
 
7
  from PIL import Image
 
8
 
9
+ from setup_venvs import get_python_executable
10
+
11
+ WORKER = str(Path(__file__).parent / "worker.py")
12
 
13
  MODELS = {
14
  "CellposeSAM": ["cpsam"],
15
+ "Cellpose3.1": ["cyto3", "nuclei"],
16
  }
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def run_segmentation(
20
  image_input,
21
  model_type: str,
 
27
  use_gpu: bool,
28
  ) -> tuple[np.ndarray, int]:
29
  """
30
+ Dispatch segmentation to the correct venv via subprocess.
31
+ Returns (masks, num_cells) where masks is a 2-D integer array.
 
 
 
32
  """
33
+ python_exe = get_python_executable(model_type)
 
 
 
 
34
 
35
+ # Write input image to a temp file so worker.py can read it
36
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_tmp:
37
+ img_path = img_tmp.name
38
+ if isinstance(image_input, np.ndarray):
39
+ Image.fromarray(image_input.astype(np.uint8)).save(img_path)
40
  else:
41
+ Image.open(str(image_input)).convert("RGB").save(img_path)
42
+
43
+ with tempfile.NamedTemporaryFile(suffix="_mask.png", delete=False) as mask_tmp:
44
+ mask_path = mask_tmp.name
45
+
46
+ command = [
47
+ python_exe, WORKER,
48
+ "--image_path", img_path,
49
+ "--model_type", model_type,
50
+ "--model_name", model_name,
51
+ "--diameter", str(diameter),
52
+ "--flow_threshold", str(flow_threshold),
53
+ "--cellprob_threshold", str(cellprob_threshold),
54
+ "--channels", channels_text,
55
+ "--output_path", mask_path,
56
+ ]
57
+
58
+ result = subprocess.run(
59
+ command,
60
+ capture_output=True,
61
+ text=True,
62
+ )
63
+
64
+ # Print worker stderr so HF Spaces logs show Cellpose progress
65
+ if result.stderr:
66
+ for line in result.stderr.strip().splitlines():
67
+ print(f"[worker] {line}")
68
+
69
+ if result.returncode != 0:
70
+ raise RuntimeError(
71
+ f"Worker failed (exit {result.returncode}):\n{result.stderr}"
72
+ )
73
+
74
+ output = json.loads(result.stdout.strip())
75
+ if output.get("status") == "error":
76
+ raise RuntimeError(output.get("message", "Unknown worker error"))
77
+
78
+ num_cells = output.get("num_cells", 0)
79
+ if num_cells == 0:
80
+ return np.zeros((1, 1), dtype=np.int32), 0
81
+
82
+ # Read the 16-bit PNG mask back as integer array
83
+ masks = np.array(Image.open(mask_path)).astype(np.int32)
84
+ return masks, num_cells
 
 
 
 
 
 
 
 
 
 
setup_venvs.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates venv_v3 (cellpose==3.1.1.2) and venv_v4 (cellpose latest) at Space startup.
3
+ Called once from app.py before the Gradio interface launches.
4
+ """
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ BASE = Path(__file__).parent
10
+
11
+ VENVS = [
12
+ ("venv_v3", "requirements_v3.txt"),
13
+ ("venv_v4", "requirements_v4.txt"),
14
+ ]
15
+
16
+
17
+ def _python(venv_name: str) -> Path:
18
+ return BASE / venv_name / "bin" / "python"
19
+
20
+
21
+ def setup_venvs():
22
+ for venv_name, req_file in VENVS:
23
+ python_path = _python(venv_name)
24
+ if python_path.exists():
25
+ print(f"[Setup] {venv_name} already exists — skipping.")
26
+ continue
27
+
28
+ print(f"[Setup] Creating {venv_name} ...")
29
+ subprocess.run([sys.executable, "-m", "venv", str(BASE / venv_name)], check=True)
30
+
31
+ print(f"[Setup] Upgrading pip in {venv_name} ...")
32
+ subprocess.run([str(python_path), "-m", "pip", "install", "--upgrade", "pip"],
33
+ check=True)
34
+
35
+ print(f"[Setup] Installing {req_file} into {venv_name} ...")
36
+ subprocess.run(
37
+ [str(python_path), "-m", "pip", "install", "-r", str(BASE / req_file)],
38
+ check=True,
39
+ )
40
+ print(f"[Setup] {venv_name} ready.")
41
+
42
+
43
+ def get_python_executable(model_type: str) -> str:
44
+ venv_name = "venv_v3" if model_type == "Cellpose3.1" else "venv_v4"
45
+ python_path = _python(venv_name)
46
+ if not python_path.exists():
47
+ raise RuntimeError(
48
+ f"venv {venv_name} not found. Call setup_venvs() first."
49
+ )
50
+ return str(python_path)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ setup_venvs()
worker.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone inference script.
3
+ Runs inside either venv_v3 (Cellpose3.1) or venv_v4 (CellposeSAM).
4
+ Writes the integer-label mask as a 16-bit PNG and prints a JSON result to stdout.
5
+
6
+ Usage:
7
+ python worker.py --image_path <path> --model_type <CellposeSAM|Cellpose3.1>
8
+ --model_name <cpsam|cyto3> --diameter 30
9
+ --flow_threshold 0.4 --cellprob_threshold 0.0
10
+ --channels 0,0 --output_path <path>
11
+ """
12
+ import argparse
13
+ import json
14
+ import sys
15
+
16
+ import numpy as np
17
+ import cv2
18
+ from PIL import Image
19
+ from cellpose import models
20
+
21
+
22
+ def load_image(path: str) -> np.ndarray:
23
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
24
+ if img is None:
25
+ img = np.array(Image.open(path).convert("RGB"))
26
+ return img.astype(np.uint8)
27
+ if img.ndim == 2:
28
+ img = np.stack([img, img, img], axis=-1)
29
+ elif img.shape[2] == 4:
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
31
+ else:
32
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
+ return img.astype(np.uint8)
34
+
35
+
36
+ def parse_channels(text: str) -> list:
37
+ try:
38
+ return [max(0, min(3, int(p.strip()))) for p in text.split(",")]
39
+ except Exception:
40
+ return [0, 0]
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--image_path", required=True)
46
+ parser.add_argument("--model_type", required=True)
47
+ parser.add_argument("--model_name", required=True)
48
+ parser.add_argument("--diameter", type=float, default=30.0)
49
+ parser.add_argument("--flow_threshold", type=float, default=0.4)
50
+ parser.add_argument("--cellprob_threshold", type=float, default=0.0)
51
+ parser.add_argument("--channels", default="0,0")
52
+ parser.add_argument("--output_path", required=True)
53
+ args = parser.parse_args()
54
+
55
+ img = load_image(args.image_path)
56
+ diam = args.diameter if args.diameter > 0 else None
57
+ normalize_param = {"percentile": [1.0, 99.0], "tile_norm_blocksize": 0}
58
+
59
+ try:
60
+ if args.model_type == "CellposeSAM":
61
+ if args.model_name in ("cpsam",):
62
+ model = models.CellposeModel(gpu=False, model_type=args.model_name)
63
+ else:
64
+ model = models.CellposeModel(gpu=False, pretrained_model=args.model_name)
65
+
66
+ chan_indices = parse_channels(args.channels)
67
+ valid = [c for c in chan_indices if c < img.shape[2]]
68
+ img_input = np.zeros_like(img)
69
+ for i, c in enumerate(valid):
70
+ img_input[:, :, i] = img[:, :, c]
71
+
72
+ masks, _, _ = model.eval(
73
+ img_input,
74
+ diameter=diam,
75
+ batch_size=8,
76
+ resample=True,
77
+ normalize=normalize_param,
78
+ flow_threshold=args.flow_threshold,
79
+ cellprob_threshold=args.cellprob_threshold,
80
+ )[:3]
81
+
82
+ else: # Cellpose3.1
83
+ if args.model_name in ("cyto3", "nuclei"):
84
+ model = models.Cellpose(gpu=False, model_type=args.model_name)
85
+ else:
86
+ model = models.CellposeModel(gpu=False, pretrained_model=args.model_name)
87
+
88
+ chan_arg = (parse_channels(args.channels) + [0, 0])[:2]
89
+
90
+ masks, _, _ = model.eval(
91
+ img,
92
+ diameter=diam,
93
+ channels=chan_arg,
94
+ batch_size=8,
95
+ resample=True,
96
+ normalize=normalize_param,
97
+ flow_threshold=args.flow_threshold,
98
+ cellprob_threshold=args.cellprob_threshold,
99
+ )[:3]
100
+
101
+ num_cells = int(masks.max())
102
+
103
+ # Save integer-label mask as 16-bit PNG
104
+ mask16 = masks.astype(np.uint16)
105
+ Image.fromarray(mask16, mode="I;16").save(args.output_path)
106
+
107
+ print(json.dumps({"status": "success", "num_cells": num_cells}))
108
+
109
+ except Exception as e:
110
+ print(json.dumps({"status": "error", "message": str(e)}))
111
+ sys.exit(1)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()