Debluar / app.py
mohamed12ahmed's picture
Update app.py
999f1f0 verified
import os
import shutil
import torch
import torch.nn.functional as F
import cv2
from skimage import img_as_ubyte
from flask import Flask, request, jsonify, send_file, render_template_string
from werkzeug.utils import secure_filename
import webbrowser
import time
# Flask App setup
app = Flask(__name__)
UPLOAD_FOLDER = 'uploads'
RESULTS_FOLDER = 'results'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['RESULTS_FOLDER'] = RESULTS_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
# Model and Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
def get_model():
global model
if model is None:
try:
# تم تعديل اسم النموذج هنا
model = torch.jit.load("motion_deblurring.pt", map_location=device)
model.to(device)
model.eval()
print("✅ Model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
model = None
return model
# Image Processing function
def process_image_with_model(input_path):
model = get_model()
if model is None:
raise RuntimeError("Model not loaded.")
# تم تعديل اسم المهمة هنا
task = "Motion_Deblurring"
out_dir = os.path.join(app.config["RESULTS_FOLDER"], task)
os.makedirs(out_dir, exist_ok=True)
img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)
input_ = torch.from_numpy(img).float().div(255.).permute(2, 0, 1).unsqueeze(0).to(device)
h, w = input_.shape[2], input_.shape[3]
H = ((h + 8) // 8) * 8
W = ((w + 8) // 8) * 8
padh = H - h if h % 8 != 0 else 0
padw = W - w if w % 8 != 0 else 0
input_ = F.pad(input_, (0, padw, 0, padh), "reflect")
with torch.inference_mode():
restored = torch.clamp(model(input_), 0, 1)
restored = img_as_ubyte(
restored[:, :, :h, :w].permute(0, 2, 3, 1).cpu().numpy()[0]
)
out_path = os.path.join(out_dir, os.path.split(input_path)[-1])
cv2.imwrite(out_path, cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
return out_path
# HTML Interface
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>Restormer Motion Deblurring Demo</title>
<style>
body { text-align:center; font-family: sans-serif; }
.container { max-width: 600px; margin: auto; padding: 20px; border: 1px solid #ccc; border-radius: 8px; }
.image-display { display:flex; justify-content:center; gap:20px; margin-top:20px; }
img { max-width:300px; border:1px solid #ddd; }
h3 { margin-bottom: 10px; }
</style>
</head>
<body>
<div class="container">
<h1>Restormer: Motion Deblurring Demo</h1>
<form id="uploadForm" enctype="multipart/form-data">
<input type="file" id="fileInput" name="file" accept="image/*" required><br><br>
<button type="submit">Process Image</button>
</form>
<p id="loading" style="display:none;">Processing... Please wait.</p>
<div class="image-display">
<div>
<h3>Original</h3>
<img id="original" style="display:none;">
</div>
<div>
<h3>Restored</h3>
<img id="restored" style="display:none;">
</div>
</div>
</div>
<script>
const form = document.getElementById("uploadForm");
const fileInput = document.getElementById("fileInput");
const loading = document.getElementById("loading");
const original = document.getElementById("original");
const restored = document.getElementById("restored");
fileInput.addEventListener("change", (e) => {
if (e.target.files.length > 0) {
original.src = URL.createObjectURL(e.target.files[0]);
original.style.display = "block";
restored.style.display = "none";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (fileInput.files.length === 0) return;
const formData = new FormData();
formData.append("file", fileInput.files[0]);
loading.style.display = "block";
try {
const response = await fetch("/process_image", {
method: "POST",
body: formData
});
if (response.ok) {
const blob = await response.blob();
const url = URL.createObjectURL(blob);
restored.src = url;
restored.style.display = "block";
} else {
const error = await response.json();
alert("Error: " + error.error);
}
} catch (err) {
alert("Request failed: " + err);
} finally {
loading.style.display = "none";
}
});
</script>
</body>
</html>
"""
# Flask Routes
@app.route("/")
def index():
return render_template_string(html_content)
@app.route("/process_image", methods=["POST"])
def process_image():
if "file" not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "No filename"}), 400
filename = secure_filename(file.filename)
input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(input_path)
try:
output_path = process_image_with_model(input_path)
return send_file(output_path, mimetype="image/jpeg")
except Exception as e:
return jsonify({"error": str(e)}), 500
# Main
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)