SAM / app.py
sajabdoli's picture
Update app.py
fcd4753 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import numpy as np
import torch
import io
import base64
import json
app = FastAPI()
# Add CORS middleware for CVAT
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load SAM Model
sam_checkpoint = "sam_vit_b.pth"
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
@app.get("/")
def read_root():
return {"status": "SAM API is running"}
@app.post("/segment")
async def segment_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_np = np.array(image)
# Get image dimensions
height, width = image_np.shape[:2]
# Use center point instead of fixed point
center_point = np.array([[width // 2, height // 2]])
input_label = np.array([1])
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
point_coords=center_point,
point_labels=input_label,
multimask_output=True # Return multiple masks
)
# Return the best mask
best_mask_idx = np.argmax(scores)
mask = masks[best_mask_idx].astype(bool)
return {
"score": float(scores[best_mask_idx]),
"mask": mask.tolist()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
def list_models():
return {
"models": [
{
"name": "sam-cvat",
"type": "segmentation",
"labels": ["object"]
}
]
}
# CVAT-specific endpoint
@app.post("/predict")
async def predict_for_cvat(body: str = Form(...)):
try:
data = json.loads(body)
image_data = data.get('image', '')
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_np = np.array(image)
# Get points from CVAT request
points = data.get('points', [])
if not points:
# If no points, use center of image
height, width = image_np.shape[:2]
points = [[width // 2, height // 2]]
input_points = np.array(points)
input_labels = np.ones(len(points))
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=True
)
# Get best mask
best_mask_idx = np.argmax(scores)
mask = masks[best_mask_idx].astype(bool)
# Convert mask to CVAT format
height, width = mask.shape
rle = mask_to_rle(mask)
return {
"model": "sam-cvat",
"annotations": [{
"name": "object",
"score": float(scores[best_mask_idx]),
"mask": {
"rle": rle,
"width": width,
"height": height
},
"type": "mask"
}]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Helper function to convert mask to RLE (Run-Length Encoding)
def mask_to_rle(mask):
"""Convert mask to RLE format expected by CVAT"""
flattened_mask = mask.flatten()
rle = []
current_pixel = 0
count = 0
for pixel in flattened_mask:
if pixel == current_pixel:
count += 1
else:
rle.append(count)
current_pixel = pixel
count = 1
rle.append(count)
return rle