AkashKumarave commited on
Commit
cdcf1d0
·
verified ·
1 Parent(s): 42af69b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -101
app.py CHANGED
@@ -1,113 +1,50 @@
1
- # app.py
2
- import gradio as gr
3
- import torch
4
- import numpy as np
5
  from PIL import Image
6
  import io
7
- import base64
8
- import requests
9
- import os
10
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
-
12
- # Download pre-trained DIS (IS-Net) weights
13
- def download_weights():
14
- weights_path = "isnet-general-use.pth"
15
- if not os.path.exists(weights_path):
16
- url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
17
- try:
18
- response = requests.get(url, stream=True)
19
- response.raise_for_status()
20
- with open(weights_path, "wb") as f:
21
- for chunk in response.iter_content(chunk_size=8192):
22
- f.write(chunk)
23
- except Exception as e:
24
- raise Exception(f"Failed to download weights: {str(e)}")
25
- return weights_path
26
-
27
- # DIS (IS-Net) model architecture (simplified from https://github.com/xuebinqin/DIS)
28
- class ISNet(torch.nn.Module):
29
- def __init__(self):
30
- super(ISNet, self).__init__()
31
- # Simplified architecture (for demonstration; replace with full IS-Net)
32
- # Full architecture: https://github.com/xuebinqin/DIS/blob/main/ISNet.py
33
- self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
34
- self.pool = torch.nn.MaxPool2d(2, 2)
35
- self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
36
- self.upconv = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
37
- self.conv3 = torch.nn.Conv2d(64, 1, kernel_size=3, padding=1)
38
 
39
- def forward(self, x):
40
- # Simplified forward pass (replace with full IS-Net forward)
41
- x = torch.relu(self.conv1(x))
42
- x = self.pool(x)
43
- x = torch.relu(self.conv2(x))
44
- x = self.upconv(x)
45
- x = torch.sigmoid(self.conv3(x))
46
- return x
47
 
48
- # Initialize model
49
- try:
50
- weights_path = download_weights()
51
- model = ISNet()
52
- state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
53
- model.load_state_dict(state_dict)
54
- model.eval()
55
- except Exception as e:
56
- raise Exception(f"Model initialization failed: {str(e)}")
57
 
58
- def remove_background(image):
59
- """
60
- Remove background using DIS (IS-Net).
61
- Input: PIL Image
62
- Output: Base64-encoded PNG with transparent background
63
- """
64
  try:
65
- # Ensure image is RGB
66
- if image.mode != "RGB":
67
- image = image.convert("RGB")
68
-
69
  # Preprocess image
70
- transform = Compose([
71
- Resize((1024, 1024)),
72
- ToTensor(),
73
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
74
- ])
75
- img_tensor = transform(image).unsqueeze(0)
76
 
77
- # Run inference
78
  with torch.no_grad():
79
- mask = model(img_tensor).squeeze().cpu().numpy()
80
-
81
- # Post-process mask
 
 
82
  mask = (mask > 0.5).astype(np.uint8) * 255
83
- mask = Image.fromarray(mask).resize(image.size, Image.LANCZOS)
84
-
85
- # Apply mask
86
- img_rgba = image.convert("RGBA")
87
- img_array = np.array(img_rgba)
88
- img_array[:, :, 3] = mask
89
- result = Image.fromarray(img_array)
90
-
91
- # Save to bytes buffer
92
- buffered = io.BytesIO()
93
- result.save(buffered, format="PNG")
94
-
95
- # Encode as base64
96
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
97
- return f"data:image/png;base64,{img_str}"
98
- except Exception as e:
99
- return f"Error: {str(e)}"
100
 
101
- # Create Gradio interface
102
- iface = gr.Interface(
103
- fn=remove_background,
104
- inputs=gr.Image(type="pil", label="Upload Image"),
105
- outputs=gr.Image(type="pil", label="Image with Background Removed"),
106
- title="DIS Background Removal",
107
- description="Remove backgrounds from any image using the open-source DIS (IS-Net) model.",
108
- allow_flagging="never"
109
- )
 
110
 
111
- # Launch the interface
112
- if __name__ == "__main__":
113
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import Response
 
 
3
  from PIL import Image
4
  import io
5
+ import numpy as np
6
+ from transformers import AutoModelForImageSegmentation, AutoProcessor
7
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ app = FastAPI()
 
 
 
 
 
 
 
10
 
11
+ # Load the RMBG V1.4 model and processor
12
+ model = AutoModelForImageSegmentation.from_pretrained(
13
+ "briaai/RMBG-1.4", trust_remote_code=True
14
+ )
15
+ processor = AutoProcessor.from_pretrained("briaai/RMBG-1.4")
 
 
 
 
16
 
17
+ @app.post("/remove-background")
18
+ async def remove_background(file: UploadFile = File(...)):
 
 
 
 
19
  try:
20
+ # Read uploaded image
21
+ image_data = await file.read()
22
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
23
+
24
  # Preprocess image
25
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
26
 
27
+ # Run model
28
  with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
+ # Post-process to get mask
32
+ mask = outputs.logits
33
+ mask = torch.sigmoid(mask).cpu().numpy()
34
  mask = (mask > 0.5).astype(np.uint8) * 255
35
+ mask = mask.squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Apply mask to remove background
38
+ image_np = np.array(image)
39
+ alpha_channel = mask
40
+ result = np.dstack((image_np, alpha_channel))
41
+ result_image = Image.fromarray(result, mode="RGBA")
42
+
43
+ # Save result to bytes
44
+ output_buffer = io.BytesIO()
45
+ result_image.save(output_buffer, format="PNG")
46
+ output_bytes = output_buffer.getvalue()
47
 
48
+ return Response(content=output_bytes, media_type="image/png")
49
+ except Exception as e:
50
+ return {"error": str(e)}