Advancedbgr / app.py
mohamed12ahmed's picture
Update app.py
3d2350b verified
from flask import Flask, request, send_file, render_template_string
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from briarmbg import BriaRMBG
import io
from PIL import Image
# --- Model Loading and Processing Functions ---
# يتم تحميل النموذج مرة واحدة عند بدء تشغيل التطبيق
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image_np):
# prepare input
orig_image = Image.fromarray(image_np)
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_mask = Image.fromarray(np.squeeze(result_array))
# add the mask on the original image as alpha channel
new_im = orig_image.copy()
new_im.putalpha(pil_mask)
return new_im
# --- Flask App Setup ---
app = Flask(__name__)
@app.route('/')
def index():
return render_template_string('''
<!DOCTYPE html>
<html>
<head>
<title>Background Remover</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; }
img { max-width: 90%; margin: 10px; border: 1px solid #ddd; }
.container { display: flex; justify-content: center; gap: 20px; flex-wrap: wrap; }
button { padding: 10px 20px; font-size: 16px; cursor: pointer; }
</style>
</head>
<body>
<h1>Background Remover API</h1>
<input type="file" id="imageInput" accept="image/*">
<br><br>
<button id="processBtn">Remove Background</button>
<br><br>
<div class="container">
<div>
<h3>Original Image</h3>
<img id="originalImage" src="" alt="Original">
</div>
<div>
<h3>Processed Image</h3>
<img id="processedImage" src="" alt="Processed">
</div>
</div>
<script>
const imageInput = document.getElementById('imageInput');
const processBtn = document.getElementById('processBtn');
const originalImage = document.getElementById('originalImage');
const processedImage = document.getElementById('processedImage');
let selectedFile = null;
imageInput.addEventListener('change', (e) => {
selectedFile = e.target.files[0];
if (selectedFile) {
originalImage.src = URL.createObjectURL(selectedFile);
}
});
processBtn.addEventListener('click', () => {
if (!selectedFile) {
alert('Please select an image first.');
return;
}
const formData = new FormData();
formData.append('file', selectedFile);
processedImage.src = '';
processedImage.alt = 'Processing...';
// API Endpoint
fetch('/remove_bg', {
method: 'POST',
body: formData,
})
.then(response => {
if (!response.ok) {
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
}
return response.blob();
})
.then(blob => {
processedImage.src = URL.createObjectURL(blob);
processedImage.alt = 'Processed Image';
})
.catch(error => {
alert('Error: ' + error.message);
processedImage.alt = 'Error';
});
});
</script>
</body>
</html>
''')
# API Route for background removal
@app.route('/remove_bg', methods=['POST'])
def remove_background():
try:
if 'file' not in request.files:
return "No file part in the request", 400
file = request.files['file']
if file.filename == '':
return "No selected file", 400
# Read the file and convert to a NumPy array
image_bytes = file.read()
pil_image = Image.open(io.BytesIO(image_bytes))
image_np = np.array(pil_image)
# استدعاء دالة المعالجة
processed_image = process(image_np)
# تحويل الصورة الناتجة إلى bytes وإرسالها
# ملاحظة: يجب استخدام PNG لدعم الشفافية
img_io = io.BytesIO()
processed_image.save(img_io, format='PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
import traceback
traceback.print_exc()
return f"Error processing image: {str(e)}", 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)