File size: 4,160 Bytes
cb317f0
 
586d0cb
 
 
 
 
cb317f0
 
586d0cb
 
 
cb317f0
 
 
 
 
 
 
 
 
586d0cb
cb317f0
586d0cb
 
 
 
 
 
fd54065
 
 
 
586d0cb
 
cb317f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586d0cb
fcd4753
 
 
 
 
 
 
 
 
 
 
 
 
cb317f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcd4753
cb317f0
fcd4753
 
 
 
 
 
 
 
cb317f0
 
 
 
586d0cb
cb317f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import numpy as np
import torch
import io
import base64
import json

app = FastAPI()

# Add CORS middleware for CVAT
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load SAM Model
sam_checkpoint = "sam_vit_b.pth"
model_type = "vit_b"

device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)

@app.get("/")
def read_root():
    return {"status": "SAM API is running"}

@app.post("/segment")
async def segment_image(file: UploadFile = File(...)):
    try:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_np = np.array(image)
        
        # Get image dimensions
        height, width = image_np.shape[:2]
        
        # Use center point instead of fixed point
        center_point = np.array([[width // 2, height // 2]])
        input_label = np.array([1])
        
        predictor.set_image(image_np)
        masks, scores, _ = predictor.predict(
            point_coords=center_point,
            point_labels=input_label,
            multimask_output=True  # Return multiple masks
        )
        
        # Return the best mask
        best_mask_idx = np.argmax(scores)
        mask = masks[best_mask_idx].astype(bool)
        
        return {
            "score": float(scores[best_mask_idx]),
            "mask": mask.tolist()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/models")
def list_models():
    return {
        "models": [
            {
                "name": "sam-cvat",
                "type": "segmentation",
                "labels": ["object"]
            }
        ]
    }        

# CVAT-specific endpoint
@app.post("/predict")
async def predict_for_cvat(body: str = Form(...)):
    try:
        data = json.loads(body)
        image_data = data.get('image', '')
        
        # Decode base64 image
        image_bytes = base64.b64decode(image_data)
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_np = np.array(image)
        
        # Get points from CVAT request
        points = data.get('points', [])
        if not points:
            # If no points, use center of image
            height, width = image_np.shape[:2]
            points = [[width // 2, height // 2]]
        
        input_points = np.array(points)
        input_labels = np.ones(len(points))
        
        predictor.set_image(image_np)
        masks, scores, _ = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=True
        )
        
        # Get best mask
        best_mask_idx = np.argmax(scores)
        mask = masks[best_mask_idx].astype(bool)
        
        # Convert mask to CVAT format
        height, width = mask.shape
        rle = mask_to_rle(mask)
        return {
            "model": "sam-cvat",
            "annotations": [{
                "name": "object",
                "score": float(scores[best_mask_idx]),
                "mask": {
                    "rle": rle,
                    "width": width,
                    "height": height
                },
                "type": "mask"
            }]
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Helper function to convert mask to RLE (Run-Length Encoding)
def mask_to_rle(mask):
    """Convert mask to RLE format expected by CVAT"""
    flattened_mask = mask.flatten()
    rle = []
    current_pixel = 0
    count = 0
    
    for pixel in flattened_mask:
        if pixel == current_pixel:
            count += 1
        else:
            rle.append(count)
            current_pixel = pixel
            count = 1
    
    rle.append(count)
    return rle