| from typing import List, Optional, Union |
| from enum import Enum |
| from contextlib import asynccontextmanager |
|
|
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi import FastAPI, UploadFile, HTTPException |
| from PIL import Image |
| from pydantic.dataclasses import dataclass |
|
|
| from app.cloth_segmentation.model import Layer, segment |
| from app.segment_anything.model import predict |
| from app.simple_segmentation.model import Mode, binary_segment, load_seg_model |
| from app.segment_anything.model import load_segment_model |
|
|
| |
| ml_models = {} |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| ml_models["cloth_segmentation"] = load_seg_model(".checkpoint/model.pth") |
| ml_models["segment_anything"] = load_segment_model('.checkpoint/sam_vit_h_4b8939.pth') |
| yield |
| ml_models.clear() |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=['*'], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| |
|
|
|
|
| @app.get("/") |
| def index(): |
| return {"ok": True, "message": "Invalid route, please check `/docs`"} |
|
|
| class SegmentationMode(str, Enum): |
| BINARY = "binary" |
| SIMPLE = "simple" |
|
|
| @app.post("/mask") |
| def mask(upload: UploadFile, mode: SegmentationMode = SegmentationMode.BINARY) -> List[Layer]: |
| try: |
| image = Image.open(upload.file) |
|
|
| if mode == SegmentationMode.BINARY: |
| result = segment(image) |
| return result |
| elif mode == SegmentationMode.SIMPLE: |
| net = ml_models["cloth_segmentation"] |
| result = binary_segment(image, net, mode=Mode.BINARY) |
| return result |
| else: |
| raise HTTPException(status_code=400, detail="Invalid segmentation mode") |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") |
|
|
| @app.post("/encode") |
| def encode(upload: UploadFile) -> str: |
| image = Image.open(upload.file) |
| return predict(image) |
|
|