vit-matte / app /server.py
pillipop
remove mask response
eefe61d unverified
from typing import List, Optional, Union
from enum import Enum
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, UploadFile, HTTPException
from PIL import Image
from pydantic.dataclasses import dataclass
from app.cloth_segmentation.model import Layer, segment
from app.segment_anything.model import predict
from app.simple_segmentation.model import Mode, binary_segment, load_seg_model
from app.segment_anything.model import load_segment_model
# === Context ===
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth")
ml_models["segment_anything"] = load_segment_model('.checkpoint/sam_vit_h_4b8939.pth')
yield
ml_models.clear()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Routes
@app.get("/")
def index():
return {"ok": True, "message": "Invalid route, please check `/docs`"}
class SegmentationMode(str, Enum):
BINARY = "binary"
SIMPLE = "simple"
@app.post("/mask")
def mask(upload: UploadFile, mode: SegmentationMode = SegmentationMode.BINARY) -> List[Layer]:
try:
image = Image.open(upload.file)
if mode == SegmentationMode.BINARY:
result = segment(image)
return result # List[Layer]
elif mode == SegmentationMode.SIMPLE:
net = ml_models["cloth_segmentation"]
result = binary_segment(image, net, mode=Mode.BINARY)
return result # MaskResponse
else:
raise HTTPException(status_code=400, detail="Invalid segmentation mode")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
@app.post("/encode")
def encode(upload: UploadFile) -> str:
image = Image.open(upload.file)
return predict(image)