mohamed12ahmed commited on
Commit
3d2350b
·
verified ·
1 Parent(s): de97dd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -77
app.py CHANGED
@@ -1,19 +1,18 @@
1
- import io
2
- import torch
3
- import numpy as np
4
  from flask import Flask, request, send_file, render_template_string
5
- from PIL import Image
6
- from torchvision import transforms
7
-
8
- # استدعاء الموديل
9
  from briarmbg import BriaRMBG
 
 
10
 
11
- app = Flask(__name__)
12
-
13
- model = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model.to(device)
16
- model.eval()
17
 
18
  def resize_image(image):
19
  image = image.convert('RGB')
@@ -21,71 +20,154 @@ def resize_image(image):
21
  image = image.resize(model_input_size, Image.BILINEAR)
22
  return image
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
25
 
26
- # transform
27
- to_tensor = transforms.ToTensor()
28
-
29
- # HTML واجهة بسيطة
30
- HTML_TEMPLATE = """
31
- <!DOCTYPE html>
32
- <html lang="en">
33
- <head>
34
- <meta charset="UTF-8">
35
- <title>Remove Background</title>
36
- </head>
37
- <body>
38
- <h2>Upload Image</h2>
39
- <form action="/remove_bg" method="post" enctype="multipart/form-data">
40
- <input type="file" name="file" accept="image/*">
41
- <button type="submit">Process</button>
42
- </form>
43
- {% if input_url %}
44
- <h3>Original:</h3>
45
- <img src="{{ input_url }}" width="250"/>
46
- {% endif %}
47
- {% if output_url %}
48
- <h3>Processed:</h3>
49
- <img src="{{ output_url }}" width="250"/>
50
- {% endif %}
51
- </body>
52
- </html>
53
- """
54
-
55
- @app.route("/", methods=["GET"])
56
  def index():
57
- return render_template_string(HTML_TEMPLATE)
58
-
59
- @app.route("/remove_bg", methods=["POST"])
60
- def remove_bg():
61
- if "file" not in request.files:
62
- return "No file uploaded", 400
63
-
64
- file = request.files["file"]
65
- img = Image.open(file.stream).convert("RGB")
66
- w, h = img.size
67
-
68
- # preprocess
69
- input_tensor = to_tensor(img).unsqueeze(0).to(device)
70
-
71
- # predict
72
- with torch.no_grad():
73
- preds, _ = model(input_tensor)
74
- mask = preds[0] # ناخد أول output
75
- mask = torch.nn.functional.interpolate(mask, size=(h, w), mode="bilinear")
76
- mask = mask[0][0].cpu().numpy()
77
-
78
- # apply mask
79
- mask = (mask * 255).astype(np.uint8)
80
- mask_img = Image.fromarray(mask).convert("L")
81
- output = Image.new("RGBA", img.size, (0, 0, 0, 0))
82
- output.paste(img, mask=mask_img)
83
-
84
- # رجع كـ response (مباشر)
85
- img_io = io.BytesIO()
86
- output.save(img_io, format="PNG")
87
- img_io.seek(0)
88
- return send_file(img_io, mimetype="image/png")
89
-
90
- if __name__ == "__main__":
91
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from flask import Flask, request, send_file, render_template_string
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms.functional import normalize
6
  from briarmbg import BriaRMBG
7
+ import io
8
+ from PIL import Image
9
 
10
+ # --- Model Loading and Processing Functions ---
11
+ # يتم تحميل النموذج مرة واحدة عند بدء تشغيل التطبيق
12
+ net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ net.to(device)
15
+ net.eval()
16
 
17
  def resize_image(image):
18
  image = image.convert('RGB')
 
20
  image = image.resize(model_input_size, Image.BILINEAR)
21
  return image
22
 
23
+ def process(image_np):
24
+ # prepare input
25
+ orig_image = Image.fromarray(image_np)
26
+ w, h = orig_im_size = orig_image.size
27
+ image = resize_image(orig_image)
28
+ im_np = np.array(image)
29
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
30
+ im_tensor = torch.unsqueeze(im_tensor, 0)
31
+ im_tensor = torch.divide(im_tensor, 255.0)
32
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
33
+ if torch.cuda.is_available():
34
+ im_tensor = im_tensor.cuda()
35
+
36
+ # inference
37
+ result = net(im_tensor)
38
 
39
+ # post process
40
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
41
+ ma = torch.max(result)
42
+ mi = torch.min(result)
43
+ result = (result - mi) / (ma - mi)
44
 
45
+ # image to pil
46
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
47
+ pil_mask = Image.fromarray(np.squeeze(result_array))
48
+
49
+ # add the mask on the original image as alpha channel
50
+ new_im = orig_image.copy()
51
+ new_im.putalpha(pil_mask)
52
+
53
+ return new_im
54
+
55
+ # --- Flask App Setup ---
56
+ app = Flask(__name__)
57
+
58
+ @app.route('/')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def index():
60
+ return render_template_string('''
61
+ <!DOCTYPE html>
62
+ <html>
63
+ <head>
64
+ <title>Background Remover</title>
65
+ <style>
66
+ body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }
67
+ img { max-width: 90%; margin: 10px; border: 1px solid #ddd; }
68
+ .container { display: flex; justify-content: center; gap: 20px; flex-wrap: wrap; }
69
+ button { padding: 10px 20px; font-size: 16px; cursor: pointer; }
70
+ </style>
71
+ </head>
72
+ <body>
73
+ <h1>Background Remover API</h1>
74
+ <input type="file" id="imageInput" accept="image/*">
75
+ <br><br>
76
+ <button id="processBtn">Remove Background</button>
77
+ <br><br>
78
+ <div class="container">
79
+ <div>
80
+ <h3>Original Image</h3>
81
+ <img id="originalImage" src="" alt="Original">
82
+ </div>
83
+ <div>
84
+ <h3>Processed Image</h3>
85
+ <img id="processedImage" src="" alt="Processed">
86
+ </div>
87
+ </div>
88
+
89
+ <script>
90
+ const imageInput = document.getElementById('imageInput');
91
+ const processBtn = document.getElementById('processBtn');
92
+ const originalImage = document.getElementById('originalImage');
93
+ const processedImage = document.getElementById('processedImage');
94
+
95
+ let selectedFile = null;
96
+
97
+ imageInput.addEventListener('change', (e) => {
98
+ selectedFile = e.target.files[0];
99
+ if (selectedFile) {
100
+ originalImage.src = URL.createObjectURL(selectedFile);
101
+ }
102
+ });
103
+
104
+ processBtn.addEventListener('click', () => {
105
+ if (!selectedFile) {
106
+ alert('Please select an image first.');
107
+ return;
108
+ }
109
+
110
+ const formData = new FormData();
111
+ formData.append('file', selectedFile);
112
+
113
+ processedImage.src = '';
114
+ processedImage.alt = 'Processing...';
115
+
116
+ // API Endpoint
117
+ fetch('/remove_bg', {
118
+ method: 'POST',
119
+ body: formData,
120
+ })
121
+ .then(response => {
122
+ if (!response.ok) {
123
+ return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
124
+ }
125
+ return response.blob();
126
+ })
127
+ .then(blob => {
128
+ processedImage.src = URL.createObjectURL(blob);
129
+ processedImage.alt = 'Processed Image';
130
+ })
131
+ .catch(error => {
132
+ alert('Error: ' + error.message);
133
+ processedImage.alt = 'Error';
134
+ });
135
+ });
136
+ </script>
137
+ </body>
138
+ </html>
139
+ ''')
140
+
141
+ # API Route for background removal
142
+ @app.route('/remove_bg', methods=['POST'])
143
+ def remove_background():
144
+ try:
145
+ if 'file' not in request.files:
146
+ return "No file part in the request", 400
147
+ file = request.files['file']
148
+ if file.filename == '':
149
+ return "No selected file", 400
150
+
151
+ # Read the file and convert to a NumPy array
152
+ image_bytes = file.read()
153
+ pil_image = Image.open(io.BytesIO(image_bytes))
154
+ image_np = np.array(pil_image)
155
+
156
+ # استدعاء دالة المعالجة
157
+ processed_image = process(image_np)
158
+
159
+ # تحويل الصورة الناتجة إلى bytes وإرسالها
160
+ # ملاحظة: يجب استخدام PNG لدعم الشفافية
161
+ img_io = io.BytesIO()
162
+ processed_image.save(img_io, format='PNG')
163
+ img_io.seek(0)
164
+
165
+ return send_file(img_io, mimetype='image/png')
166
+
167
+ except Exception as e:
168
+ import traceback
169
+ traceback.print_exc()
170
+ return f"Error processing image: {str(e)}", 500
171
+
172
+ if __name__ == '__main__':
173
+ app.run(host='0.0.0.0', port=7860)