AkashKumarave commited on
Commit
5e58353
·
verified ·
1 Parent(s): c1db78e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -140
app.py CHANGED
@@ -1,140 +1,6 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import torch
5
- from torchvision import transforms
6
- from PIL import Image
7
- import gradio as gr
8
-
9
- # Set up device
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
- # Clone model repository if needed
13
- if not os.path.exists("DIS"):
14
- os.system("git clone https://github.com/xuebinqin/DIS")
15
- os.system("mv DIS/IS-Net/* .")
16
-
17
- # Import model components
18
- from models import ISNetDIS
19
- from data_loader_cache import normalize
20
-
21
- # Create model directory
22
- os.makedirs("saved_models", exist_ok=True)
23
- if os.path.exists("isnet.pth"):
24
- os.rename("isnet.pth", "saved_models/isnet.pth")
25
-
26
- # Define image preprocessing
27
- class ImageNormalizer:
28
- def __init__(self, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]):
29
- self.mean = mean
30
- self.std = std
31
-
32
- def __call__(self, img):
33
- return normalize(img, self.mean, self.std)
34
-
35
- transform = transforms.Compose([ImageNormalizer()])
36
-
37
- # Load and configure model
38
- model_config = {
39
- "model_path": "saved_models",
40
- "model_file": "isnet.pth",
41
- "input_size": [1024, 1024],
42
- "device": device
43
- }
44
-
45
- model = ISNetDIS().to(device)
46
- if os.path.exists(f"{model_config['model_path']}/{model_config['model_file']}"):
47
- model.load_state_dict(
48
- torch.load(
49
- f"{model_config['model_path']}/{model_config['model_file']}",
50
- map_location=device
51
- )
52
- )
53
- model.eval()
54
-
55
- def process_image(input_image):
56
- """Process an image through the segmentation model"""
57
- try:
58
- # Convert Gradio input to usable image path
59
- if hasattr(input_image, 'name'):
60
- image_path = input_image.name
61
- else:
62
- image_path = input_image
63
-
64
- # Read and preprocess image
65
- img = cv2.imread(image_path)
66
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
67
- img = torch.from_numpy(img).float().permute(2, 0, 1) / 255.0
68
- img = transform(img).unsqueeze(0).to(device)
69
-
70
- # Get prediction
71
- with torch.no_grad():
72
- pred = model(img)[0][0]
73
- pred = torch.sigmoid(pred[0])
74
- pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
75
- mask = (pred.cpu().numpy() * 255).astype(np.uint8)
76
-
77
- # Create output images
78
- original_img = Image.open(image_path).convert("RGB")
79
- mask_img = Image.fromarray(mask).convert("L")
80
- transparent_img = original_img.copy()
81
- transparent_img.putalpha(mask_img)
82
-
83
- return transparent_img, mask_img
84
-
85
- except Exception as e:
86
- raise gr.Error(f"Error processing image: {str(e)}")
87
-
88
- # Gradio interface setup
89
- title = "Image Background Removal"
90
- description = """
91
- Upload an image to automatically remove the background using DIS (Dichotomous Image Segmentation).
92
- <br>Model from: <a href="https://github.com/xuebinqin/DIS">xuebinqin/DIS</a>
93
- """
94
-
95
- # Check for example images
96
- examples = []
97
- for img_file in ["robot.png", "ship.png"]:
98
- if os.path.exists(img_file):
99
- examples.append([img_file])
100
-
101
- # Create interface
102
- with gr.Blocks() as app:
103
- gr.Markdown(f"## {title}")
104
- gr.Markdown(description)
105
-
106
- with gr.Row():
107
- input_col = gr.Column()
108
- output_col = gr.Column()
109
-
110
- with input_col:
111
- image_input = gr.Image(type="filepath", label="Upload Image")
112
- submit_btn = gr.Button("Remove Background", variant="primary")
113
-
114
- with output_col:
115
- transparent_output = gr.Image(label="Transparent Result", type="pil")
116
- mask_output = gr.Image(label="Segmentation Mask", type="pil")
117
-
118
- if examples:
119
- gr.Examples(
120
- examples=examples,
121
- inputs=image_input,
122
- outputs=[transparent_output, mask_output],
123
- fn=process_image,
124
- cache_examples=True,
125
- label="Try Example Images"
126
- )
127
-
128
- submit_btn.click(
129
- fn=process_image,
130
- inputs=image_input,
131
- outputs=[transparent_output, mask_output]
132
- )
133
-
134
- # Launch the app
135
- if __name__ == "__main__":
136
- app.launch(
137
- server_name="0.0.0.0",
138
- server_port=7860,
139
- share=False
140
- )
 
1
+ torch>=1.10.0
2
+ torchvision>=0.11.0
3
+ opencv-python>=4.5.0
4
+ numpy>=1.21.0
5
+ Pillow>=9.0.0
6
+ gradio>=3.0