Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import torch
|
| 3 |
-
import numpy as np
|
| 4 |
from flask import Flask, request, send_file, render_template_string
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
from briarmbg import BriaRMBG
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
def resize_image(image):
|
| 19 |
image = image.convert('RGB')
|
|
@@ -21,71 +20,154 @@ def resize_image(image):
|
|
| 21 |
image = image.resize(model_input_size, Image.BILINEAR)
|
| 22 |
return image
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
<input type="file" name="file" accept="image/*">
|
| 41 |
-
<button type="submit">Process</button>
|
| 42 |
-
</form>
|
| 43 |
-
{% if input_url %}
|
| 44 |
-
<h3>Original:</h3>
|
| 45 |
-
<img src="{{ input_url }}" width="250"/>
|
| 46 |
-
{% endif %}
|
| 47 |
-
{% if output_url %}
|
| 48 |
-
<h3>Processed:</h3>
|
| 49 |
-
<img src="{{ output_url }}" width="250"/>
|
| 50 |
-
{% endif %}
|
| 51 |
-
</body>
|
| 52 |
-
</html>
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
@app.route("/", methods=["GET"])
|
| 56 |
def index():
|
| 57 |
-
return render_template_string(
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from flask import Flask, request, send_file, render_template_string
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.transforms.functional import normalize
|
| 6 |
from briarmbg import BriaRMBG
|
| 7 |
+
import io
|
| 8 |
+
from PIL import Image
|
| 9 |
|
| 10 |
+
# --- Model Loading and Processing Functions ---
|
| 11 |
+
# يتم تحميل النموذج مرة واحدة عند بدء تشغيل التطبيق
|
| 12 |
+
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
|
| 13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
net.to(device)
|
| 15 |
+
net.eval()
|
| 16 |
|
| 17 |
def resize_image(image):
|
| 18 |
image = image.convert('RGB')
|
|
|
|
| 20 |
image = image.resize(model_input_size, Image.BILINEAR)
|
| 21 |
return image
|
| 22 |
|
| 23 |
+
def process(image_np):
|
| 24 |
+
# prepare input
|
| 25 |
+
orig_image = Image.fromarray(image_np)
|
| 26 |
+
w, h = orig_im_size = orig_image.size
|
| 27 |
+
image = resize_image(orig_image)
|
| 28 |
+
im_np = np.array(image)
|
| 29 |
+
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
|
| 30 |
+
im_tensor = torch.unsqueeze(im_tensor, 0)
|
| 31 |
+
im_tensor = torch.divide(im_tensor, 255.0)
|
| 32 |
+
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
im_tensor = im_tensor.cuda()
|
| 35 |
+
|
| 36 |
+
# inference
|
| 37 |
+
result = net(im_tensor)
|
| 38 |
|
| 39 |
+
# post process
|
| 40 |
+
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
|
| 41 |
+
ma = torch.max(result)
|
| 42 |
+
mi = torch.min(result)
|
| 43 |
+
result = (result - mi) / (ma - mi)
|
| 44 |
|
| 45 |
+
# image to pil
|
| 46 |
+
result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
|
| 47 |
+
pil_mask = Image.fromarray(np.squeeze(result_array))
|
| 48 |
+
|
| 49 |
+
# add the mask on the original image as alpha channel
|
| 50 |
+
new_im = orig_image.copy()
|
| 51 |
+
new_im.putalpha(pil_mask)
|
| 52 |
+
|
| 53 |
+
return new_im
|
| 54 |
+
|
| 55 |
+
# --- Flask App Setup ---
|
| 56 |
+
app = Flask(__name__)
|
| 57 |
+
|
| 58 |
+
@app.route('/')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def index():
|
| 60 |
+
return render_template_string('''
|
| 61 |
+
<!DOCTYPE html>
|
| 62 |
+
<html>
|
| 63 |
+
<head>
|
| 64 |
+
<title>Background Remover</title>
|
| 65 |
+
<style>
|
| 66 |
+
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }
|
| 67 |
+
img { max-width: 90%; margin: 10px; border: 1px solid #ddd; }
|
| 68 |
+
.container { display: flex; justify-content: center; gap: 20px; flex-wrap: wrap; }
|
| 69 |
+
button { padding: 10px 20px; font-size: 16px; cursor: pointer; }
|
| 70 |
+
</style>
|
| 71 |
+
</head>
|
| 72 |
+
<body>
|
| 73 |
+
<h1>Background Remover API</h1>
|
| 74 |
+
<input type="file" id="imageInput" accept="image/*">
|
| 75 |
+
<br><br>
|
| 76 |
+
<button id="processBtn">Remove Background</button>
|
| 77 |
+
<br><br>
|
| 78 |
+
<div class="container">
|
| 79 |
+
<div>
|
| 80 |
+
<h3>Original Image</h3>
|
| 81 |
+
<img id="originalImage" src="" alt="Original">
|
| 82 |
+
</div>
|
| 83 |
+
<div>
|
| 84 |
+
<h3>Processed Image</h3>
|
| 85 |
+
<img id="processedImage" src="" alt="Processed">
|
| 86 |
+
</div>
|
| 87 |
+
</div>
|
| 88 |
+
|
| 89 |
+
<script>
|
| 90 |
+
const imageInput = document.getElementById('imageInput');
|
| 91 |
+
const processBtn = document.getElementById('processBtn');
|
| 92 |
+
const originalImage = document.getElementById('originalImage');
|
| 93 |
+
const processedImage = document.getElementById('processedImage');
|
| 94 |
+
|
| 95 |
+
let selectedFile = null;
|
| 96 |
+
|
| 97 |
+
imageInput.addEventListener('change', (e) => {
|
| 98 |
+
selectedFile = e.target.files[0];
|
| 99 |
+
if (selectedFile) {
|
| 100 |
+
originalImage.src = URL.createObjectURL(selectedFile);
|
| 101 |
+
}
|
| 102 |
+
});
|
| 103 |
+
|
| 104 |
+
processBtn.addEventListener('click', () => {
|
| 105 |
+
if (!selectedFile) {
|
| 106 |
+
alert('Please select an image first.');
|
| 107 |
+
return;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
const formData = new FormData();
|
| 111 |
+
formData.append('file', selectedFile);
|
| 112 |
+
|
| 113 |
+
processedImage.src = '';
|
| 114 |
+
processedImage.alt = 'Processing...';
|
| 115 |
+
|
| 116 |
+
// API Endpoint
|
| 117 |
+
fetch('/remove_bg', {
|
| 118 |
+
method: 'POST',
|
| 119 |
+
body: formData,
|
| 120 |
+
})
|
| 121 |
+
.then(response => {
|
| 122 |
+
if (!response.ok) {
|
| 123 |
+
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
|
| 124 |
+
}
|
| 125 |
+
return response.blob();
|
| 126 |
+
})
|
| 127 |
+
.then(blob => {
|
| 128 |
+
processedImage.src = URL.createObjectURL(blob);
|
| 129 |
+
processedImage.alt = 'Processed Image';
|
| 130 |
+
})
|
| 131 |
+
.catch(error => {
|
| 132 |
+
alert('Error: ' + error.message);
|
| 133 |
+
processedImage.alt = 'Error';
|
| 134 |
+
});
|
| 135 |
+
});
|
| 136 |
+
</script>
|
| 137 |
+
</body>
|
| 138 |
+
</html>
|
| 139 |
+
''')
|
| 140 |
+
|
| 141 |
+
# API Route for background removal
|
| 142 |
+
@app.route('/remove_bg', methods=['POST'])
|
| 143 |
+
def remove_background():
|
| 144 |
+
try:
|
| 145 |
+
if 'file' not in request.files:
|
| 146 |
+
return "No file part in the request", 400
|
| 147 |
+
file = request.files['file']
|
| 148 |
+
if file.filename == '':
|
| 149 |
+
return "No selected file", 400
|
| 150 |
+
|
| 151 |
+
# Read the file and convert to a NumPy array
|
| 152 |
+
image_bytes = file.read()
|
| 153 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
| 154 |
+
image_np = np.array(pil_image)
|
| 155 |
+
|
| 156 |
+
# استدعاء دالة المعالجة
|
| 157 |
+
processed_image = process(image_np)
|
| 158 |
+
|
| 159 |
+
# تحويل الصورة الناتجة إلى bytes وإرسالها
|
| 160 |
+
# ملاحظة: يجب استخدام PNG لدعم الشفافية
|
| 161 |
+
img_io = io.BytesIO()
|
| 162 |
+
processed_image.save(img_io, format='PNG')
|
| 163 |
+
img_io.seek(0)
|
| 164 |
+
|
| 165 |
+
return send_file(img_io, mimetype='image/png')
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
import traceback
|
| 169 |
+
traceback.print_exc()
|
| 170 |
+
return f"Error processing image: {str(e)}", 500
|
| 171 |
+
|
| 172 |
+
if __name__ == '__main__':
|
| 173 |
+
app.run(host='0.0.0.0', port=7860)
|