File size: 1,578 Bytes
61fe82c
6a88e99
 
 
bef5dde
3881e25
f11be46
d8934a7
3514da2
 
ab5fe15
 
 
 
 
6a88e99
bdbe97a
 
 
 
6a88e99
 
3881e25
6a88e99
f11be46
e7d9c1d
f11be46
6a88e99
 
3881e25
 
6a88e99
3881e25
6a88e99
 
 
bdbe97a
351880b
6a88e99
 
 
720818d
 
 
 
6a88e99
d8934a7
 
6a88e99
720818d
 
 
 
 
f11be46
6a88e99
 
720818d
 
 
6a88e99
 
720818d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from fastapi import FastAPI, UploadFile, File, Query
import tifffile
import numpy as np
from cellpose import models
from fastapi.responses import FileResponse
from huggingface_hub import snapshot_download
import os
from io import BytesIO
import tempfile

app = FastAPI(
    title="Cellpose API",
    docs_url="/docs",
    redoc_url=None
)

@app.get("/")
def root():
    return {"status": "ok"}
    
MODEL_REPO = "Shalani08/cellpose_3.1.1_finetuned"

repo_dir = snapshot_download(
    repo_id=MODEL_REPO,
    repo_type="model",
    revision="8a020c0",
    token=os.environ.get("HF_TOKEN")
)

model_dir = f"{repo_dir}/ddq_model"

model = models.CellposeModel(
    pretrained_model=model_dir,
    gpu=False
)



@app.post("/segment")
async def segment(
    image: UploadFile = File(...),
    diameter: float = Query(0, description="Cell diameter (0 = auto)"),
    channels: str = Query("0,0", description="Cellpose channels"),
    flow_threshold: float = Query(0.4),
    cellprob_threshold: float = Query(0.0)
):
    img_bytes = await image.read()
    img = tifffile.imread(BytesIO(img_bytes))

    if img.ndim == 3 and img.shape[-1] == 3:
        img = img[..., 0]

    ch = [int(c) for c in channels.split(",")]

    masks, _, _ = model.eval(
        img,
        diameter=diameter,
        channels=ch,
        flow_threshold=flow_threshold,
        cellprob_threshold=cellprob_threshold
    )

    tmp = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
    tifffile.imwrite(tmp.name, masks.astype(np.uint16))

    return FileResponse(tmp.name, media_type="image/tiff")