AkashKumarave commited on
Commit
93700b4
·
verified ·
1 Parent(s): 5e58353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -6
app.py CHANGED
@@ -1,6 +1,124 @@
1
- torch>=1.10.0
2
- torchvision>=0.11.0
3
- opencv-python>=4.5.0
4
- numpy>=1.21.0
5
- Pillow>=9.0.0
6
- gradio>=3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+ # Set up device
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Clone model repository if needed
13
+ if not os.path.exists("DIS"):
14
+ os.system("git clone https://github.com/xuebinqin/DIS")
15
+ os.system("mv DIS/IS-Net/* .")
16
+
17
+ # Create model directory and move weights
18
+ os.makedirs("saved_models", exist_ok=True)
19
+ if os.path.exists("isnet.pth"):
20
+ os.rename("isnet.pth", "saved_models/isnet.pth")
21
+
22
+ # Custom normalize function to replace skimage dependency
23
+ def normalize(image, mean, std):
24
+ """Normalize image with mean and std"""
25
+ if isinstance(mean, (int, float)):
26
+ mean = [mean] * image.shape[0]
27
+ if isinstance(std, (int, float)):
28
+ std = [std] * image.shape[0]
29
+
30
+ image = image.clone()
31
+ for t, m, s in zip(image, mean, std):
32
+ t.sub_(m).div_(s)
33
+ return image
34
+
35
+ # Define image preprocessing
36
+ class ImageNormalizer:
37
+ def __init__(self, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]):
38
+ self.mean = mean
39
+ self.std = std
40
+
41
+ def __call__(self, img):
42
+ return normalize(img, self.mean, self.std)
43
+
44
+ transform = transforms.Compose([ImageNormalizer()])
45
+
46
+ # Load model
47
+ from models import ISNetDIS
48
+
49
+ model = ISNetDIS().to(device)
50
+ model_path = "saved_models/isnet.pth"
51
+ if os.path.exists(model_path):
52
+ model.load_state_dict(torch.load(model_path, map_location=device))
53
+ model.eval()
54
+
55
+ def process_image(input_image):
56
+ """Process an image through the segmentation model"""
57
+ try:
58
+ # Convert Gradio input to usable image
59
+ if isinstance(input_image, str):
60
+ image_path = input_image
61
+ else:
62
+ image_path = input_image.name
63
+
64
+ # Read image with OpenCV (replaces skimage)
65
+ img = cv2.imread(image_path)
66
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67
+
68
+ # Convert to tensor and normalize
69
+ img_tensor = torch.from_numpy(img).float().permute(2, 0, 1) / 255.0
70
+ img_tensor = transform(img_tensor).unsqueeze(0).to(device)
71
+
72
+ # Get prediction
73
+ with torch.no_grad():
74
+ pred = model(img_tensor)[0][0]
75
+ pred = torch.sigmoid(pred[0])
76
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
77
+ mask = (pred.cpu().numpy() * 255).astype(np.uint8)
78
+
79
+ # Create output images
80
+ original_img = Image.open(image_path).convert("RGB")
81
+ mask_img = Image.fromarray(mask).convert("L")
82
+ transparent_img = original_img.copy()
83
+ transparent_img.putalpha(mask_img)
84
+
85
+ return transparent_img, mask_img
86
+
87
+ except Exception as e:
88
+ raise gr.Error(f"Error processing image: {str(e)}")
89
+
90
+ # Gradio interface
91
+ title = "Image Background Removal"
92
+ description = "Upload an image to automatically remove the background"
93
+
94
+ with gr.Blocks() as app:
95
+ gr.Markdown(f"## {title}")
96
+ gr.Markdown(description)
97
+
98
+ with gr.Row():
99
+ with gr.Column():
100
+ image_input = gr.Image(type="filepath", label="Input Image")
101
+ submit_btn = gr.Button("Process", variant="primary")
102
+
103
+ with gr.Column():
104
+ transparent_output = gr.Image(label="Result with Transparency", type="pil")
105
+ mask_output = gr.Image(label="Segmentation Mask", type="pil")
106
+
107
+ # Add examples if files exist
108
+ example_files = [f for f in ["robot.png", "ship.png"] if os.path.exists(f)]
109
+ if example_files:
110
+ gr.Examples(
111
+ examples=[[f] for f in example_files],
112
+ inputs=image_input,
113
+ outputs=[transparent_output, mask_output],
114
+ fn=process_image,
115
+ cache_examples=True
116
+ )
117
+
118
+ submit_btn.click(
119
+ fn=process_image,
120
+ inputs=image_input,
121
+ outputs=[transparent_output, mask_output]
122
+ )
123
+
124
+ app.launch(server_name="0.0.0.0", server_port=7860)