Spaces:
Sleeping
Sleeping
| from flask import Flask, request, send_file, render_template_string | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import normalize | |
| from briarmbg import BriaRMBG | |
| import io | |
| from PIL import Image | |
| # --- Model Loading and Processing Functions --- | |
| # يتم تحميل النموذج مرة واحدة عند بدء تشغيل التطبيق | |
| net = BriaRMBG.from_pretrained("briaai/RMBG-1.4") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| net.to(device) | |
| net.eval() | |
| def resize_image(image): | |
| image = image.convert('RGB') | |
| model_input_size = (1024, 1024) | |
| image = image.resize(model_input_size, Image.BILINEAR) | |
| return image | |
| def process(image_np): | |
| # prepare input | |
| orig_image = Image.fromarray(image_np) | |
| w, h = orig_im_size = orig_image.size | |
| image = resize_image(orig_image) | |
| im_np = np.array(image) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) | |
| im_tensor = torch.unsqueeze(im_tensor, 0) | |
| im_tensor = torch.divide(im_tensor, 255.0) | |
| im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
| if torch.cuda.is_available(): | |
| im_tensor = im_tensor.cuda() | |
| # inference | |
| result = net(im_tensor) | |
| # post process | |
| result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result - mi) / (ma - mi) | |
| # image to pil | |
| result_array = (result * 255).cpu().data.numpy().astype(np.uint8) | |
| pil_mask = Image.fromarray(np.squeeze(result_array)) | |
| # add the mask on the original image as alpha channel | |
| new_im = orig_image.copy() | |
| new_im.putalpha(pil_mask) | |
| return new_im | |
| # --- Flask App Setup --- | |
| app = Flask(__name__) | |
| def index(): | |
| return render_template_string(''' | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Background Remover</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; } | |
| img { max-width: 90%; margin: 10px; border: 1px solid #ddd; } | |
| .container { display: flex; justify-content: center; gap: 20px; flex-wrap: wrap; } | |
| button { padding: 10px 20px; font-size: 16px; cursor: pointer; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Background Remover API</h1> | |
| <input type="file" id="imageInput" accept="image/*"> | |
| <br><br> | |
| <button id="processBtn">Remove Background</button> | |
| <br><br> | |
| <div class="container"> | |
| <div> | |
| <h3>Original Image</h3> | |
| <img id="originalImage" src="" alt="Original"> | |
| </div> | |
| <div> | |
| <h3>Processed Image</h3> | |
| <img id="processedImage" src="" alt="Processed"> | |
| </div> | |
| </div> | |
| <script> | |
| const imageInput = document.getElementById('imageInput'); | |
| const processBtn = document.getElementById('processBtn'); | |
| const originalImage = document.getElementById('originalImage'); | |
| const processedImage = document.getElementById('processedImage'); | |
| let selectedFile = null; | |
| imageInput.addEventListener('change', (e) => { | |
| selectedFile = e.target.files[0]; | |
| if (selectedFile) { | |
| originalImage.src = URL.createObjectURL(selectedFile); | |
| } | |
| }); | |
| processBtn.addEventListener('click', () => { | |
| if (!selectedFile) { | |
| alert('Please select an image first.'); | |
| return; | |
| } | |
| const formData = new FormData(); | |
| formData.append('file', selectedFile); | |
| processedImage.src = ''; | |
| processedImage.alt = 'Processing...'; | |
| // API Endpoint | |
| fetch('/remove_bg', { | |
| method: 'POST', | |
| body: formData, | |
| }) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.json().then(err => { throw new Error(err.error || 'Unknown error'); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| processedImage.src = URL.createObjectURL(blob); | |
| processedImage.alt = 'Processed Image'; | |
| }) | |
| .catch(error => { | |
| alert('Error: ' + error.message); | |
| processedImage.alt = 'Error'; | |
| }); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| ''') | |
| # API Route for background removal | |
| def remove_background(): | |
| try: | |
| if 'file' not in request.files: | |
| return "No file part in the request", 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return "No selected file", 400 | |
| # Read the file and convert to a NumPy array | |
| image_bytes = file.read() | |
| pil_image = Image.open(io.BytesIO(image_bytes)) | |
| image_np = np.array(pil_image) | |
| # استدعاء دالة المعالجة | |
| processed_image = process(image_np) | |
| # تحويل الصورة الناتجة إلى bytes وإرسالها | |
| # ملاحظة: يجب استخدام PNG لدعم الشفافية | |
| img_io = io.BytesIO() | |
| processed_image.save(img_io, format='PNG') | |
| img_io.seek(0) | |
| return send_file(img_io, mimetype='image/png') | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error processing image: {str(e)}", 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |