File size: 1,852 Bytes
fec397b
 
e429bd1
fec397b
 
 
 
 
 
 
 
 
 
e429bd1
 
 
fec397b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e429bd1
fec397b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI, UploadFile, File, Response, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from rembg import remove, new_session
from PIL import Image, UnidentifiedImageError
import io
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="NoBG API")

model_name = "birefnet-general"
session = new_session(model_name)

origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/remove-bg")
def remove_background(file: UploadFile = File(...)):
    if file.content_type and not file.content_type.startswith("image/"):
        raise HTTPException(
            status_code=400,
            detail=f"Invalid file type: {file.content_type}. Please upload an image.",
        )

    try:
        input_image = Image.open(file.file)
        output_image = remove(input_image, session=session)
        img_byte_arr = io.BytesIO()
        output_image.save(img_byte_arr, format="PNG")
        img_byte_arr.seek(0)
        return Response(content=img_byte_arr.getvalue(), media_type="image/png")

    except UnidentifiedImageError:
        logger.warning(f"Failed to identify image format for file: {file.filename}")
        raise HTTPException(
            status_code=400, detail="Uploaded file is not a valid or supported image."
        )
    except Exception as e:
        logger.error(f"Error processing image {file.filename}: {str(e)}")
        raise HTTPException(
            status_code=500, detail="An error occurred while processing the image."
        )


if __name__ == "__main__":
    import uvicorn

    port = int(os.getenv("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)