supunnadeera commited on
Commit
b340283
·
1 Parent(s): 0b74e31

rollaback to fastapi

Browse files
Files changed (1) hide show
  1. app.py +109 -83
app.py CHANGED
@@ -1,96 +1,122 @@
1
- import gradio as gr
2
- from cellpose import models, utils
3
  import numpy as np
4
- import os
5
  import cv2
6
- from cellpose.plot import mask_overlay
 
 
 
 
 
 
 
 
 
7
 
8
- # --- SETUP ---
9
- # Point to baked-in models (if you are using a Docker container)
10
  os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = "/app/models"
11
 
12
- # 1. LOAD MODEL GLOBALLY
13
- # We load the model once at startup. This fixes the "slow" issue.
14
- # Gradio handles concurrency automatically, so we don't need manual multiprocessing.
15
- print(">>> Loading Cellpose Model...")
16
- model = models.CellposeModel(model_type='cpsam', gpu=False)
17
- print(">>> Model Loaded.")
18
-
19
- def segment(image, diameter, flow_threshold, cellprob_threshold, channels_str):
20
- """
21
- Args:
22
- image: Numpy array from Gradio (H, W, 3)
23
- diameter: Number
24
- flow_threshold: Number
25
- cellprob_threshold: Number
26
- channels_str: String "0,0"
27
- """
28
- if image is None:
29
- return None, "No image provided"
30
-
31
- # 2. PARSE INPUTS
32
- try:
33
- chan_list = [int(c) for c in channels_str.split(',')]
34
- except:
35
- chan_list = [0, 0]
36
 
37
- print(">>> Running Inference...")
38
-
39
- # 3. RUN INFERENCE
40
- # Gradio handles threading, so we just run the eval directly.
41
- masks, flows, styles = model.eval(
42
- image,
43
- diameter=float(diameter) if diameter > 0 else None,
44
- channels=chan_list,
45
- flow_threshold=float(flow_threshold),
46
- cellprob_threshold=float(cellprob_threshold)
47
- )
48
 
49
- # 4. GENERATE TEXT OUTPUT (Your original CSV format)
50
- outlines = utils.outlines_list(masks)
51
- response_lines = []
52
- for outline in outlines:
53
- roi_coords = []
54
- for point in outline:
55
- x, y = point
56
- roi_coords.append(f"{x},{y}")
57
- response_lines.append(",".join(roi_coords))
58
-
59
- text_result = "\n".join(response_lines)
60
 
61
- # 5. GENERATE VISUAL OUTPUT
62
- # Since we are using Gradio, let's return an image with masks drawn on it
63
- # so you can verify the results visually.
64
- overlay_image = mask_overlay(image, masks)
65
 
66
- return overlay_image, text_result
 
 
67
 
68
- # --- UI DEFINITION ---
69
- with gr.Blocks(title="Cellpose Segmentation API") as app:
70
- gr.Markdown("## Cellpose Segmentation (CPSAM)")
71
-
72
- with gr.Row():
73
- with gr.Column():
74
- input_img = gr.Image(label="Input Image", type="numpy")
75
- # Settings
76
- diam = gr.Number(label="Diameter (0 = auto)", value=0)
77
- flow = gr.Slider(label="Flow Threshold", minimum=0.0, maximum=1.0, value=0.4)
78
- prob = gr.Slider(label="Cellprob Threshold", minimum=-6.0, maximum=6.0, value=0.0)
79
- chans = gr.Textbox(label="Channels", value="0,0")
80
- run_btn = gr.Button("Segment", variant="primary")
81
 
82
- with gr.Column():
83
- output_img = gr.Image(label="Visual Result")
84
- output_text = gr.Textbox(label="ROI Coordinates (CSV)", show_copy_button=True, lines=10)
85
-
86
- # Connect logic
87
- run_btn.click(
88
- fn=segment,
89
- inputs=[input_img, diam, flow, prob, chans],
90
- outputs=[output_img, output_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
 
 
 
 
92
 
93
- # Launch with queue enabled (handles the multiprocessing logic automatically)
94
- if __name__ == "__main__":
95
- app.queue(max_size=10) # Manages traffic automatically
96
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Response
2
+ from cellpose import models, utils, io
3
  import numpy as np
 
4
  import cv2
5
+ import os
6
+ import multiprocessing
7
+ import time
8
+ import sys
9
+ import logging
10
+
11
+ # --- LOGGING SETUP ---
12
+ io.logger_setup()
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', stream=sys.stdout)
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # Point to baked-in models
 
17
  os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = "/app/models"
18
 
19
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Global process management
22
+ current_process = None
23
+ result_queue = multiprocessing.Queue()
 
 
 
 
 
 
 
 
24
 
25
+ # --- WORKER FUNCTION ---
26
+ def run_segmentation(img_data, diameter, flow, prob, channels, queue):
27
+ try:
28
+ print(">>> START SEGMENTATION JOB")
29
+ sys.stdout.flush()
 
 
 
 
 
 
30
 
31
+ # Load Model (Fast because cached)
32
+ model = models.CellposeModel(model_type='cpsam', gpu=False)
 
 
33
 
34
+ # Decode Image
35
+ nparr = np.frombuffer(img_data, np.uint8)
36
+ img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
37
 
38
+ # Run Inference
39
+ chan_list = [int(c) for c in channels.split(',')]
40
+ masks, flows, styles = model.eval(
41
+ img,
42
+ diameter=diameter if diameter > 0 else None,
43
+ channels=chan_list,
44
+ flow_threshold=flow,
45
+ cellprob_threshold=prob
46
+ )
47
+
48
+ # Calculate Outlines
49
+ outlines = utils.outlines_list(masks)
 
50
 
51
+ # Format output
52
+ response_lines = []
53
+ for outline in outlines:
54
+ roi_coords = []
55
+ for point in outline:
56
+ x, y = point
57
+ roi_coords.append(f"{x},{y}")
58
+ response_lines.append(",".join(roi_coords))
59
+
60
+ result_text = "\n".join(response_lines)
61
+
62
+ queue.put({"status": "success", "data": result_text})
63
+ print("<<< END SEGMENTATION JOB")
64
+ sys.stdout.flush()
65
+
66
+ except Exception as e:
67
+ print(f"!!! WORKER ERROR: {e}")
68
+ queue.put({"status": "error", "message": str(e)})
69
+
70
+ @app.post("/segment")
71
+ async def segment(
72
+ image: UploadFile = File(...),
73
+ diameter: float = 0.0,
74
+ flow_threshold: float = 0.4,
75
+ cellprob_threshold: float = 0.0,
76
+ channels: str = "0,0"
77
+ ):
78
+ global current_process
79
+
80
+ # 1. KILL EXISTING JOB (If a new request comes in)
81
+ if current_process is not None and current_process.is_alive():
82
+ logger.info(f"⚠️ Stopping previous job (PID: {current_process.pid}) for new request")
83
+ current_process.terminate()
84
+ current_process.join()
85
+ while not result_queue.empty():
86
+ result_queue.get()
87
+
88
+ # 2. READ DATA
89
+ img_data = await image.read()
90
+
91
+ # 3. START WORKER
92
+ current_process = multiprocessing.Process(
93
+ target=run_segmentation,
94
+ args=(img_data, diameter, flow_threshold, cellprob_threshold, channels, result_queue)
95
  )
96
+ current_process.start()
97
+
98
+ start_time = time.time()
99
+ SERVER_TIMEOUT = 300 # 5 Minutes (Seconds)
100
 
101
+ # 4. MONITOR LOOP
102
+ while current_process.is_alive():
103
+ # A. Check Timeout
104
+ elapsed = time.time() - start_time
105
+ if elapsed > SERVER_TIMEOUT:
106
+ logger.error("⏰ JOB TIMED OUT. Killing process.")
107
+ current_process.terminate()
108
+ current_process.join()
109
+ return Response(content="Server Timeout: Processing took longer than 5 minutes.", status_code=504)
110
+
111
+ # B. Check for Result
112
+ try:
113
+ # Poll queue every 0.5s
114
+ result = result_queue.get(timeout=0.5)
115
+ if result["status"] == "success":
116
+ return Response(content=result["data"], media_type="text/plain")
117
+ else:
118
+ return Response(content=result["message"], status_code=500)
119
+ except:
120
+ continue
121
+
122
+ return Response(content="Process failed silently", status_code=500)