Spaces:
Sleeping
Sleeping
supunnadeera commited on
Commit ·
b340283
1
Parent(s): 0b74e31
rollaback to fastapi
Browse files
app.py
CHANGED
|
@@ -1,96 +1,122 @@
|
|
| 1 |
-
import
|
| 2 |
-
from cellpose import models, utils
|
| 3 |
import numpy as np
|
| 4 |
-
import os
|
| 5 |
import cv2
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
# Point to baked-in models (if you are using a Docker container)
|
| 10 |
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = "/app/models"
|
| 11 |
|
| 12 |
-
|
| 13 |
-
# We load the model once at startup. This fixes the "slow" issue.
|
| 14 |
-
# Gradio handles concurrency automatically, so we don't need manual multiprocessing.
|
| 15 |
-
print(">>> Loading Cellpose Model...")
|
| 16 |
-
model = models.CellposeModel(model_type='cpsam', gpu=False)
|
| 17 |
-
print(">>> Model Loaded.")
|
| 18 |
-
|
| 19 |
-
def segment(image, diameter, flow_threshold, cellprob_threshold, channels_str):
|
| 20 |
-
"""
|
| 21 |
-
Args:
|
| 22 |
-
image: Numpy array from Gradio (H, W, 3)
|
| 23 |
-
diameter: Number
|
| 24 |
-
flow_threshold: Number
|
| 25 |
-
cellprob_threshold: Number
|
| 26 |
-
channels_str: String "0,0"
|
| 27 |
-
"""
|
| 28 |
-
if image is None:
|
| 29 |
-
return None, "No image provided"
|
| 30 |
-
|
| 31 |
-
# 2. PARSE INPUTS
|
| 32 |
-
try:
|
| 33 |
-
chan_list = [int(c) for c in channels_str.split(',')]
|
| 34 |
-
except:
|
| 35 |
-
chan_list = [0, 0]
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# Gradio handles threading, so we just run the eval directly.
|
| 41 |
-
masks, flows, styles = model.eval(
|
| 42 |
-
image,
|
| 43 |
-
diameter=float(diameter) if diameter > 0 else None,
|
| 44 |
-
channels=chan_list,
|
| 45 |
-
flow_threshold=float(flow_threshold),
|
| 46 |
-
cellprob_threshold=float(cellprob_threshold)
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
for point in outline:
|
| 55 |
-
x, y = point
|
| 56 |
-
roi_coords.append(f"{x},{y}")
|
| 57 |
-
response_lines.append(",".join(roi_coords))
|
| 58 |
-
|
| 59 |
-
text_result = "\n".join(response_lines)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# so you can verify the results visually.
|
| 64 |
-
overlay_image = mask_overlay(image, masks)
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
run_btn = gr.Button("Segment", variant="primary")
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, Response
|
| 2 |
+
from cellpose import models, utils, io
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
import cv2
|
| 5 |
+
import os
|
| 6 |
+
import multiprocessing
|
| 7 |
+
import time
|
| 8 |
+
import sys
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
# --- LOGGING SETUP ---
|
| 12 |
+
io.logger_setup()
|
| 13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', stream=sys.stdout)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
# Point to baked-in models
|
|
|
|
| 17 |
os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = "/app/models"
|
| 18 |
|
| 19 |
+
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Global process management
|
| 22 |
+
current_process = None
|
| 23 |
+
result_queue = multiprocessing.Queue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# --- WORKER FUNCTION ---
|
| 26 |
+
def run_segmentation(img_data, diameter, flow, prob, channels, queue):
|
| 27 |
+
try:
|
| 28 |
+
print(">>> START SEGMENTATION JOB")
|
| 29 |
+
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# Load Model (Fast because cached)
|
| 32 |
+
model = models.CellposeModel(model_type='cpsam', gpu=False)
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# Decode Image
|
| 35 |
+
nparr = np.frombuffer(img_data, np.uint8)
|
| 36 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
| 37 |
|
| 38 |
+
# Run Inference
|
| 39 |
+
chan_list = [int(c) for c in channels.split(',')]
|
| 40 |
+
masks, flows, styles = model.eval(
|
| 41 |
+
img,
|
| 42 |
+
diameter=diameter if diameter > 0 else None,
|
| 43 |
+
channels=chan_list,
|
| 44 |
+
flow_threshold=flow,
|
| 45 |
+
cellprob_threshold=prob
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Calculate Outlines
|
| 49 |
+
outlines = utils.outlines_list(masks)
|
|
|
|
| 50 |
|
| 51 |
+
# Format output
|
| 52 |
+
response_lines = []
|
| 53 |
+
for outline in outlines:
|
| 54 |
+
roi_coords = []
|
| 55 |
+
for point in outline:
|
| 56 |
+
x, y = point
|
| 57 |
+
roi_coords.append(f"{x},{y}")
|
| 58 |
+
response_lines.append(",".join(roi_coords))
|
| 59 |
+
|
| 60 |
+
result_text = "\n".join(response_lines)
|
| 61 |
+
|
| 62 |
+
queue.put({"status": "success", "data": result_text})
|
| 63 |
+
print("<<< END SEGMENTATION JOB")
|
| 64 |
+
sys.stdout.flush()
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"!!! WORKER ERROR: {e}")
|
| 68 |
+
queue.put({"status": "error", "message": str(e)})
|
| 69 |
+
|
| 70 |
+
@app.post("/segment")
|
| 71 |
+
async def segment(
|
| 72 |
+
image: UploadFile = File(...),
|
| 73 |
+
diameter: float = 0.0,
|
| 74 |
+
flow_threshold: float = 0.4,
|
| 75 |
+
cellprob_threshold: float = 0.0,
|
| 76 |
+
channels: str = "0,0"
|
| 77 |
+
):
|
| 78 |
+
global current_process
|
| 79 |
+
|
| 80 |
+
# 1. KILL EXISTING JOB (If a new request comes in)
|
| 81 |
+
if current_process is not None and current_process.is_alive():
|
| 82 |
+
logger.info(f"⚠️ Stopping previous job (PID: {current_process.pid}) for new request")
|
| 83 |
+
current_process.terminate()
|
| 84 |
+
current_process.join()
|
| 85 |
+
while not result_queue.empty():
|
| 86 |
+
result_queue.get()
|
| 87 |
+
|
| 88 |
+
# 2. READ DATA
|
| 89 |
+
img_data = await image.read()
|
| 90 |
+
|
| 91 |
+
# 3. START WORKER
|
| 92 |
+
current_process = multiprocessing.Process(
|
| 93 |
+
target=run_segmentation,
|
| 94 |
+
args=(img_data, diameter, flow_threshold, cellprob_threshold, channels, result_queue)
|
| 95 |
)
|
| 96 |
+
current_process.start()
|
| 97 |
+
|
| 98 |
+
start_time = time.time()
|
| 99 |
+
SERVER_TIMEOUT = 300 # 5 Minutes (Seconds)
|
| 100 |
|
| 101 |
+
# 4. MONITOR LOOP
|
| 102 |
+
while current_process.is_alive():
|
| 103 |
+
# A. Check Timeout
|
| 104 |
+
elapsed = time.time() - start_time
|
| 105 |
+
if elapsed > SERVER_TIMEOUT:
|
| 106 |
+
logger.error("⏰ JOB TIMED OUT. Killing process.")
|
| 107 |
+
current_process.terminate()
|
| 108 |
+
current_process.join()
|
| 109 |
+
return Response(content="Server Timeout: Processing took longer than 5 minutes.", status_code=504)
|
| 110 |
+
|
| 111 |
+
# B. Check for Result
|
| 112 |
+
try:
|
| 113 |
+
# Poll queue every 0.5s
|
| 114 |
+
result = result_queue.get(timeout=0.5)
|
| 115 |
+
if result["status"] == "success":
|
| 116 |
+
return Response(content=result["data"], media_type="text/plain")
|
| 117 |
+
else:
|
| 118 |
+
return Response(content=result["message"], status_code=500)
|
| 119 |
+
except:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
return Response(content="Process failed silently", status_code=500)
|