from typing import List, Optional, Union from enum import Enum from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, UploadFile, HTTPException from PIL import Image from pydantic.dataclasses import dataclass from app.cloth_segmentation.model import Layer, segment from app.segment_anything.model import predict from app.simple_segmentation.model import Mode, binary_segment, load_seg_model from app.segment_anything.model import load_segment_model # === Context === ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth") ml_models["segment_anything"] = load_segment_model('.checkpoint/sam_vit_h_4b8939.pth') yield ml_models.clear() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # === Routes @app.get("/") def index(): return {"ok": True, "message": "Invalid route, please check `/docs`"} class SegmentationMode(str, Enum): BINARY = "binary" SIMPLE = "simple" @app.post("/mask") def mask(upload: UploadFile, mode: SegmentationMode = SegmentationMode.BINARY) -> List[Layer]: try: image = Image.open(upload.file) if mode == SegmentationMode.BINARY: result = segment(image) return result # List[Layer] elif mode == SegmentationMode.SIMPLE: net = ml_models["cloth_segmentation"] result = binary_segment(image, net, mode=Mode.BINARY) return result # MaskResponse else: raise HTTPException(status_code=400, detail="Invalid segmentation mode") except Exception as e: raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") @app.post("/encode") def encode(upload: UploadFile) -> str: image = Image.open(upload.file) return predict(image)