File size: 2,008 Bytes
57c1ee2
 
 
97921c2
 
57c1ee2
78c4d3f
97921c2
57c1ee2
97921c2
57c1ee2
 
97921c2
57c1ee2
 
a50b994
97921c2
57c1ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97921c2
a50b994
57c1ee2
a50b994
57c1ee2
 
 
78c4d3f
57c1ee2
a50b994
 
 
57c1ee2
 
 
 
 
a50b994
57c1ee2
 
 
 
97921c2
 
57c1ee2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import Response
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import io

app = FastAPI()

# ========= تحميل الموديل =========
device = "cuda" if torch.cuda.is_available() else "cpu"
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
).to(device)
birefnet.eval()

# ========= معالجة الصورة =========
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

@app.post("/process")
async def process_image(data: UploadFile = File(...), mode: str = Form("r")):
    # 1. قراءة الصورة
    content = await data.read()
    im = Image.open(io.BytesIO(content)).convert("RGB")
    original_size = im.size

    # 2. إزالة الخلفية
    input_tensor = transform_image(im).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = birefnet(input_tensor)[-1].sigmoid().cpu()
    mask = transforms.ToPILImage()(preds[0].squeeze()).resize(original_size)

    # 3. التطبيق (شفاف)
    transparent = im.convert("RGBA")
    transparent.putalpha(mask)

    # 4. لو المطلوب أبيض (Mode w)
    if mode.lower() == "w":
        bg = Image.new("RGBA", transparent.size, (255, 255, 255, 255))
        bg.paste(transparent, (0, 0), transparent)
        final_img = bg.convert("RGB")
        format_type = "JPEG"
    else:
        final_img = transparent
        format_type = "PNG"

    # 5. تحويل الصورة لـ Bytes عشان ترجع لـ n8n
    img_byte_arr = io.BytesIO()
    final_img.save(img_byte_arr, format=format_type)
    return Response(content=img_byte_arr.getvalue(), media_type=f"image/{format_type.lower()}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)