mohamed12ahmed commited on
Commit
999f1f0
·
verified ·
1 Parent(s): 61e0008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -69
app.py CHANGED
@@ -1,84 +1,179 @@
1
  import os
2
  import shutil
3
  import torch
 
4
  import cv2
5
- import gradio as gr
6
- from PIL import Image
 
 
 
7
 
8
- #os.chdir('Restormer')
 
 
 
 
 
9
 
10
- # Download sample images
11
- os.system("wget https://github.com/swz30/Restormer/releases/download/v1.0/sample_images.zip")
12
- shutil.unpack_archive('sample_images.zip')
13
- os.remove('sample_images.zip')
14
 
 
 
 
15
 
16
- examples = [['sample_images/Real_Denoising/degraded/117355.png', 'Denoising'],
17
- ['sample_images/Single_Image_Defocus_Deblurring/degraded/engagement.jpg', 'Defocus Deblurring'],
18
- ['sample_images/Motion_Deblurring/degraded/GoPro-GOPR0854_11_00-000090-input.jpg','Motion Deblurring'],
19
- ['sample_images/Deraining/degraded/Rain100H-77-input.jpg','Deraining']]
 
 
 
 
 
 
 
 
 
20
 
21
- inference_on = ['Full Resolution Image', 'Downsampled Image']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- title = "Restormer"
24
- description = """
25
- Gradio demo for <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n
26
- <b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). But if you want to perform inference on the original input, then choose the option "Full Resolution Image" from the dropdown menu.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- ##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3) Motion Deblurring, and (4) Image Deraining.
29
- ##To use it, simply upload your own image, or click one of the examples provided below.
30
 
31
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"
 
 
 
32
 
33
-
34
- def inference(img, task, run_on):
35
- if not os.path.exists('temp'):
36
- os.system('mkdir temp')
37
-
38
- if run_on == 'Full Resolution Image':
39
- img = img
40
- else: # 'Downsampled Image'
41
- #### Resize the longer edge of the input image
42
- max_res = 512
43
- width, height = img.size
44
- if max(width,height) > max_res:
45
- scale = max_res /max(width,height)
46
- width = int(scale*width)
47
- height = int(scale*height)
48
- img = img.resize((width,height), Image.ANTIALIAS)
49
-
50
- img.save("temp/image.jpg", "JPEG")
51
-
52
- if task == 'Motion Deblurring':
53
- task = 'Motion_Deblurring'
54
- os.system("python demo_gradio.py --task 'Motion_Deblurring' --input_path 'temp/image.jpg' --result_dir './temp/'")
55
-
56
- if task == 'Defocus Deblurring':
57
- task = 'Single_Image_Defocus_Deblurring'
58
- os.system("python demo_gradio.py --task 'Single_Image_Defocus_Deblurring' --input_path 'temp/image.jpg' --result_dir './temp/'")
59
-
60
- if task == 'Denoising':
61
- task = 'Real_Denoising'
62
- os.system("python demo_gradio.py --task 'Real_Denoising' --input_path 'temp/image.jpg' --result_dir './temp/'")
63
-
64
- if task == 'Deraining':
65
- os.system("python demo_gradio.py --task 'Deraining' --input_path 'temp/image.jpg' --result_dir './temp/'")
66
-
67
- return f'temp/{task}/image.jpg'
68
 
69
- gr.Interface(
70
- inference,
71
- [
72
- gr.inputs.Image(type="pil", label="Input"),
73
- gr.inputs.Radio(["Denoising", "Defocus Deblurring", "Motion Deblurring", "Deraining"], default="Denoising", label='task'),
74
- gr.inputs.Dropdown(choices=inference_on, type="value", default='Downsampled Image', label='Inference on')
75
 
76
- ],
77
- gr.outputs.Image(type="file", label="Output"),
78
- title=title,
79
- description=description,
80
- article=article,
81
- theme ="huggingface",
82
- examples=examples,
83
- allow_flagging=False,
84
- ).launch(debug=False,enable_queue=True)
 
1
  import os
2
  import shutil
3
  import torch
4
+ import torch.nn.functional as F
5
  import cv2
6
+ from skimage import img_as_ubyte
7
+ from flask import Flask, request, jsonify, send_file, render_template_string
8
+ from werkzeug.utils import secure_filename
9
+ import webbrowser
10
+ import time
11
 
12
+ # Flask App setup
13
+ app = Flask(__name__)
14
+ UPLOAD_FOLDER = 'uploads'
15
+ RESULTS_FOLDER = 'results'
16
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
17
+ app.config['RESULTS_FOLDER'] = RESULTS_FOLDER
18
 
19
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
20
+ os.makedirs(RESULTS_FOLDER, exist_ok=True)
 
 
21
 
22
+ # Model and Device setup
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model = None
25
 
26
+ def get_model():
27
+ global model
28
+ if model is None:
29
+ try:
30
+ # تم تعديل اسم النموذج هنا
31
+ model = torch.jit.load("motion_deblurring.pt", map_location=device)
32
+ model.to(device)
33
+ model.eval()
34
+ print("✅ Model loaded successfully")
35
+ except Exception as e:
36
+ print(f"❌ Error loading model: {e}")
37
+ model = None
38
+ return model
39
 
40
+ # Image Processing function
41
+ def process_image_with_model(input_path):
42
+ model = get_model()
43
+ if model is None:
44
+ raise RuntimeError("Model not loaded.")
45
+
46
+ # تم تعديل اسم المهمة هنا
47
+ task = "Motion_Deblurring"
48
+ out_dir = os.path.join(app.config["RESULTS_FOLDER"], task)
49
+ os.makedirs(out_dir, exist_ok=True)
50
+
51
+ img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)
52
+ input_ = torch.from_numpy(img).float().div(255.).permute(2, 0, 1).unsqueeze(0).to(device)
53
+
54
+ h, w = input_.shape[2], input_.shape[3]
55
+ H = ((h + 8) // 8) * 8
56
+ W = ((w + 8) // 8) * 8
57
+ padh = H - h if h % 8 != 0 else 0
58
+ padw = W - w if w % 8 != 0 else 0
59
+ input_ = F.pad(input_, (0, padw, 0, padh), "reflect")
60
+
61
+ with torch.inference_mode():
62
+ restored = torch.clamp(model(input_), 0, 1)
63
+
64
+ restored = img_as_ubyte(
65
+ restored[:, :, :h, :w].permute(0, 2, 3, 1).cpu().numpy()[0]
66
+ )
67
+
68
+ out_path = os.path.join(out_dir, os.path.split(input_path)[-1])
69
+ cv2.imwrite(out_path, cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
70
+ return out_path
71
 
72
+ # HTML Interface
73
+ html_content = """
74
+ <!DOCTYPE html>
75
+ <html>
76
+ <head>
77
+ <title>Restormer Motion Deblurring Demo</title>
78
+ <style>
79
+ body { text-align:center; font-family: sans-serif; }
80
+ .container { max-width: 600px; margin: auto; padding: 20px; border: 1px solid #ccc; border-radius: 8px; }
81
+ .image-display { display:flex; justify-content:center; gap:20px; margin-top:20px; }
82
+ img { max-width:300px; border:1px solid #ddd; }
83
+ h3 { margin-bottom: 10px; }
84
+ </style>
85
+ </head>
86
+ <body>
87
+ <div class="container">
88
+ <h1>Restormer: Motion Deblurring Demo</h1>
89
+ <form id="uploadForm" enctype="multipart/form-data">
90
+ <input type="file" id="fileInput" name="file" accept="image/*" required><br><br>
91
+ <button type="submit">Process Image</button>
92
+ </form>
93
+ <p id="loading" style="display:none;">Processing... Please wait.</p>
94
+ <div class="image-display">
95
+ <div>
96
+ <h3>Original</h3>
97
+ <img id="original" style="display:none;">
98
+ </div>
99
+ <div>
100
+ <h3>Restored</h3>
101
+ <img id="restored" style="display:none;">
102
+ </div>
103
+ </div>
104
+ </div>
105
+ <script>
106
+ const form = document.getElementById("uploadForm");
107
+ const fileInput = document.getElementById("fileInput");
108
+ const loading = document.getElementById("loading");
109
+ const original = document.getElementById("original");
110
+ const restored = document.getElementById("restored");
111
+
112
+ fileInput.addEventListener("change", (e) => {
113
+ if (e.target.files.length > 0) {
114
+ original.src = URL.createObjectURL(e.target.files[0]);
115
+ original.style.display = "block";
116
+ restored.style.display = "none";
117
+ }
118
+ });
119
+
120
+ form.addEventListener("submit", async (e) => {
121
+ e.preventDefault();
122
+ if (fileInput.files.length === 0) return;
123
+ const formData = new FormData();
124
+ formData.append("file", fileInput.files[0]);
125
+
126
+ loading.style.display = "block";
127
+
128
+ try {
129
+ const response = await fetch("/process_image", {
130
+ method: "POST",
131
+ body: formData
132
+ });
133
+
134
+ if (response.ok) {
135
+ const blob = await response.blob();
136
+ const url = URL.createObjectURL(blob);
137
+ restored.src = url;
138
+ restored.style.display = "block";
139
+ } else {
140
+ const error = await response.json();
141
+ alert("Error: " + error.error);
142
+ }
143
+ } catch (err) {
144
+ alert("Request failed: " + err);
145
+ } finally {
146
+ loading.style.display = "none";
147
+ }
148
+ });
149
+ </script>
150
+ </body>
151
+ </html>
152
  """
 
 
153
 
154
+ # Flask Routes
155
+ @app.route("/")
156
+ def index():
157
+ return render_template_string(html_content)
158
 
159
+ @app.route("/process_image", methods=["POST"])
160
+ def process_image():
161
+ if "file" not in request.files:
162
+ return jsonify({"error": "No file part"}), 400
163
+ file = request.files["file"]
164
+ if file.filename == "":
165
+ return jsonify({"error": "No filename"}), 400
166
+
167
+ filename = secure_filename(file.filename)
168
+ input_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
169
+ file.save(input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ try:
172
+ output_path = process_image_with_model(input_path)
173
+ return send_file(output_path, mimetype="image/jpeg")
174
+ except Exception as e:
175
+ return jsonify({"error": str(e)}), 500
 
176
 
177
+ # Main
178
+ if __name__ == "__main__":
179
+ app.run(host="0.0.0.0", port=7860, debug=True)