AkashKumarave commited on
Commit
6c0952e
·
verified ·
1 Parent(s): e890d3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -43
app.py CHANGED
@@ -1,52 +1,31 @@
1
  from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import Response
3
- from PIL import Image
4
  import io
5
- import numpy as np
6
- from transformers import AutoModelForImageSegmentation, AutoProcessor
7
- import torch
8
 
9
  app = FastAPI()
10
 
11
- # Load the RMBG V1.4 model and processor with trust_remote_code=True
12
- model = AutoModelForImageSegmentation.from_pretrained(
13
- "briaai/RMBG-1.4", trust_remote_code=True
14
- )
15
- processor = AutoProcessor.from_pretrained(
16
- "briaai/RMBG-1.4", trust_remote_code=True
 
17
  )
18
 
19
  @app.post("/remove-background")
20
  async def remove_background(file: UploadFile = File(...)):
21
- try:
22
- # Read uploaded image
23
- image_data = await file.read()
24
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
25
-
26
- # Preprocess image
27
- inputs = processor(images=image, return_tensors="pt")
28
-
29
- # Run model
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
-
33
- # Post-process to get mask
34
- mask = outputs.logits
35
- mask = torch.sigmoid(mask).cpu().numpy()
36
- mask = (mask > 0.5).astype(np.uint8) * 255
37
- mask = mask.squeeze()
38
-
39
- # Apply mask to remove background
40
- image_np = np.array(image)
41
- alpha_channel = mask
42
- result = np.dstack((image_np, alpha_channel))
43
- result_image = Image.fromarray(result, mode="RGBA")
44
-
45
- # Save result to bytes
46
- output_buffer = io.BytesIO()
47
- result_image.save(output_buffer, format="PNG")
48
- output_bytes = output_buffer.getvalue()
49
-
50
- return Response(content=output_bytes, media_type="image/png")
51
- except Exception as e:
52
- return {"error": str(e)}
 
1
  from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import pipeline
4
  import io
5
+ from PIL import Image
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # Allow CORS for Figma plugin
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"], # Update to specific origins in production
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
  )
17
 
18
  @app.post("/remove-background")
19
  async def remove_background(file: UploadFile = File(...)):
20
+ # Load the RMBG-1.4 model
21
+ pipe = pipeline('image-segmentation', 'briaai/RMBG-1.4')
22
+ # Read image
23
+ image_data = await file.read()
24
+ image = Image.open(io.BytesIO(image_data))
25
+ # Process image
26
+ output = pipe(image)
27
+ # Convert output to bytes
28
+ output_image = Image.fromarray(output.data)
29
+ output_bytes = io.BytesIO()
30
+ output_image.save(output_bytes, format='PNG')
31
+ return {"image": output_bytes.getvalue()}