AkashKumarave commited on
Commit
7941a82
·
verified ·
1 Parent(s): 0ffc276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -106
app.py CHANGED
@@ -1,115 +1,29 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- import numpy as np
5
  from PIL import Image
6
- import cv2
7
- import os
8
 
9
- # Ensure models directory is accessible
10
- try:
11
- from models.isnet import ISNetGT
12
- except ImportError:
13
- raise ImportError("Could not import ISNetGT from models.isnet. Ensure models/isnet.py is in the Space.")
14
-
15
- # Define model loading function
16
- def load_model(model_path="isnet-general-use.pth"):
17
- if not os.path.exists(model_path):
18
- raise FileNotFoundError(f"Model file {model_path} not found. Upload it to the Space root directory.")
19
-
20
- model = ISNetGT()
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- model.load_state_dict(torch.load(model_path, map_location=device))
23
- model.to(device).eval()
24
- return model, device
25
-
26
- # Image preprocessing function
27
- def preprocess_image(image, target_size=(1024, 1024)):
28
- # Convert PIL Image to numpy array
29
- image = np.array(image)
30
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
31
 
32
- # Resize image while preserving aspect ratio
33
- h, w = image.shape[:2]
34
- scale = min(target_size[0] / h, target_size[1] / w)
35
- new_h, new_w = int(h * scale), int(w * scale)
36
- image_resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
37
 
38
- # Pad to target size
39
- padded_image = np.zeros((target_size[0], target_size[1], 3), dtype=np.uint8)
40
- padded_image[:new_h, :new_w] = image_resized
41
 
42
- # Normalize and convert to tensor
43
- image_tensor = torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0
44
- image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
45
-
46
- return image_tensor, (new_h, new_w), (h, w)
47
-
48
- # Inference function
49
- def inference(model, image_tensor, device):
50
- image_tensor = image_tensor.to(device)
51
- with torch.no_grad():
52
- output = model(image_tensor)[0] # Get segmentation output
53
- output = F.interpolate(output, size=image_tensor.shape[2:], mode='bilinear', align_corners=True)
54
- output = torch.sigmoid(output).cpu().numpy()[0, 0] # Convert to probability map
55
- return output
56
-
57
- # Post-processing function
58
- def postprocess_output(output, original_size, resized_size):
59
- # Resize mask to resized image size, then to original size
60
- mask = cv2.resize(output, resized_size[::-1], interpolation=cv2.INTER_LANCZOS4)
61
- mask = cv2.resize(mask, original_size[::-1], interpolation=cv2.INTER_LANCZOS4)
62
- mask = (mask > 0.5).astype(np.uint8) * 255 # Binarize mask
63
- return mask
64
 
65
- # Background removal function
66
- def remove_background(input_image):
67
- if input_image is None:
68
- return None
69
-
70
- try:
71
- # Load model
72
- model, device = load_model()
73
-
74
- # Preprocess image
75
- image_tensor, resized_size, original_size = preprocess_image(input_image)
76
-
77
- # Run inference
78
- mask = inference(model, image_tensor, device)
79
-
80
- # Post-process mask
81
- mask = postprocess_output(mask, original_size, resized_size)
82
-
83
- # Apply mask to create transparent image
84
- input_array = np.array(input_image)
85
- alpha = mask
86
- rgba = np.zeros((input_array.shape[0], input_array.shape[1], 4), dtype=np.uint8)
87
- rgba[..., :3] = input_array
88
- rgba[..., 3] = alpha
89
-
90
- # Convert to PIL Image
91
- output_image = Image.fromarray(rgba, mode='RGBA')
92
- return output_image
93
-
94
- except Exception as e:
95
- return f"Error: {str(e)}"
96
-
97
- # Set up Gradio Blocks interface
98
- with gr.Blocks(title="DIS Background Remover") as demo:
99
- gr.Markdown("## DIS Background Remover")
100
- gr.Markdown("Upload an image to remove its background using the IS-Net model from xuebinqin/DIS.")
101
-
102
- with gr.Row():
103
- input_image = gr.Image(type="pil", label="Upload Image")
104
- output_image = gr.Image(type="pil", label="Image with Background Removed")
105
-
106
- submit_btn = gr.Button("Remove Background")
107
- submit_btn.click(
108
- fn=remove_background,
109
- inputs=input_image,
110
- outputs=output_image
111
- )
112
 
113
- # Launch the app
114
  if __name__ == "__main__":
115
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
3
  from PIL import Image
 
 
4
 
5
+ # Função para remover o background da imagem
6
+ def remove_background(image):
7
+ # Inicializar o pipeline de segmentação de imagem
8
+ pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Obter a máscara da imagem
11
+ pillow_mask = pipe(image, return_mask=True)
 
 
 
12
 
13
+ # Aplicar máscara na imagem original
14
+ pillow_image = pipe(image)
 
15
 
16
+ return pillow_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Criar uma interface Gradio
19
+ app = gr.Interface(
20
+ fn=remove_background,
21
+ inputs=gr.Image(type="pil"),
22
+ outputs=gr.Image(type="pil", format="png"), # Especificar saída como PNG
23
+ title="Remoção de Background de Imagens",
24
+ description="Envie uma imagem e veja o background sendo removido automaticamente. A imagem resultante será no formato PNG."
25
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Iniciar a interface
28
  if __name__ == "__main__":
29
+ app.launch(share=True)