hann1010 commited on
Commit
291dbb1
Β·
verified Β·
1 Parent(s): 3cb650d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -175
app.py CHANGED
@@ -7,11 +7,9 @@ from PIL import Image
7
  import os
8
  from huggingface_hub import hf_hub_download
9
  import cv2
10
- from pathlib import Path
11
  import sys
12
  import warnings
13
-
14
- from models.rrdbnet import RRDBNet, process_with_tiling
15
 
16
  warnings.filterwarnings('ignore', category=FutureWarning)
17
  warnings.filterwarnings('ignore', category=UserWarning)
@@ -19,6 +17,156 @@ os.environ['PYTHONWARNINGS'] = 'ignore'
19
 
20
  sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  MODELS = {
23
  "Classical SR x8 (DIV2K)": {
24
  "repo": "deepinv/swinir",
@@ -76,7 +224,6 @@ model_cache = {}
76
  def setup_directories():
77
  os.makedirs("models", exist_ok=True)
78
  os.makedirs("temp", exist_ok=True)
79
- print("βœ… Directories created")
80
 
81
  def download_all_models():
82
  print("πŸš€ Starting model download...")
@@ -107,13 +254,7 @@ def download_all_models():
107
  print(f"❌ Failed to download {model_name}: {str(e)}")
108
  failed.append(model_name)
109
 
110
- print(f"\nπŸ“Š Download Summary:")
111
- print(f" βœ… Success: {downloaded}/{len(MODELS)}")
112
- if failed:
113
- print(f" ❌ Failed: {len(failed)}")
114
- for name in failed:
115
- print(f" - {name}")
116
-
117
  return downloaded, failed
118
 
119
  def load_realesrgan_model(model_path, device, scale=4):
@@ -126,126 +267,71 @@ def load_realesrgan_model(model_path, device, scale=4):
126
  else:
127
  state_dict = checkpoint
128
 
129
- # βœ… AUTO-DETECT architecture dari state_dict
130
- # Detect input channels
131
  in_nc = 3
132
  if 'conv_first.weight' in state_dict:
133
  in_nc = state_dict['conv_first.weight'].shape[1]
134
 
135
- # Detect output channels
136
  out_nc = 3
137
  if 'conv_last.weight' in state_dict:
138
  out_nc = state_dict['conv_last.weight'].shape[0]
139
 
140
- # Detect nf (number of features)
141
- nf = 64
142
- if 'conv_first.weight' in state_dict:
143
- nf = state_dict['conv_first.weight'].shape[0]
144
-
145
- # Detect nb (number of blocks) by counting body blocks
146
- nb = 0
147
- for key in state_dict.keys():
148
- if key.startswith('body.') and 'rdb1.conv1.weight' in key:
149
- block_idx = int(key.split('.')[1])
150
- nb = max(nb, block_idx + 1)
151
-
152
- # Detect gc (growth channels)
153
- gc = 32
154
- if 'body.0.rdb1.conv1.weight' in state_dict:
155
- gc = state_dict['body.0.rdb1.conv1.weight'].shape[0]
156
-
157
- print(f"πŸ” Auto-detected architecture:")
158
- print(f" in_nc={in_nc}, out_nc={out_nc}, nf={nf}, nb={nb}, gc={gc}, scale={scale}")
159
-
160
- # βœ… Create model with DETECTED architecture (bukan hardcoded!)
161
- model = RRDBNet(
162
- in_nc=in_nc,
163
- out_nc=out_nc,
164
- nf=nf,
165
- nb=nb,
166
- gc=gc,
167
- scale=scale
168
- )
169
-
170
- # Load state dict
171
  model.load_state_dict(state_dict, strict=True)
172
  model.eval()
173
-
174
- # βœ… CPU OPTIMIZATION
175
- if device.type == 'cpu':
176
- torch.set_num_threads(2)
177
- if hasattr(torch.backends, 'mkldnn'):
178
- torch.backends.mkldnn.enabled = True
179
-
180
  model = model.to(device)
181
 
182
- print(f"βœ… Real-ESRGAN model loaded successfully")
183
- print(f" πŸ“Š Model size: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M parameters")
184
  return model
185
  except Exception as e:
186
- print(f"❌ Error loading Real-ESRGAN: {e}")
187
- import traceback
188
- traceback.print_exc()
189
  return None
190
 
191
- def process_with_realesrgan(image, model_path, device, scale=4, use_tiling=True):
192
  try:
193
  model = load_realesrgan_model(model_path, device, scale)
194
-
195
  if model is None:
196
  return None
197
 
198
- # Get input channels from model
199
  in_nc = model.conv_first.weight.shape[1]
200
 
201
  img = np.array(image).astype(np.float32) / 255.0
202
  if len(img.shape) == 2:
203
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
204
 
205
- # Convert to tensor
206
  img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
207
  img = img.unsqueeze(0).to(device)
208
-
209
- # βœ… Clamp input
210
  img = torch.clamp(img, 0, 1)
211
 
212
- print(f"πŸ“₯ Input image shape: {img.shape}")
213
 
214
- # Handle different input channel requirements
215
  if in_nc == 12:
216
  b, c, h, w = img.shape
217
-
218
- # Ensure dimensions are divisible by 2
219
  pad_h = (2 - h % 2) % 2
220
  pad_w = (2 - w % 2) % 2
221
 
222
  if pad_h > 0 or pad_w > 0:
223
  img = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate')
224
- print(f"πŸ”§ Padded image to: {img.shape}")
225
 
226
- # Pixel unshuffle
227
  img = F.pixel_unshuffle(img, 2)
228
- print(f"πŸ”„ Applied pixel unshuffle: {img.shape}")
229
 
230
- # βœ… ALWAYS use tiling untuk HF Spaces (hemat memory)
231
- h, w = img.shape[2], img.shape[3]
232
- print(f"πŸ”² Using optimized tiling mode ({h}x{w})")
233
  output = process_with_tiling(model, img, tile_size=160, tile_overlap=32)
234
 
235
- print(f"πŸ“€ Output shape: {output.shape}")
236
 
237
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
238
  output = np.transpose(output, (1, 2, 0))
239
  output = (output * 255.0).round().astype(np.uint8)
240
 
241
- # Clean up GPU memory
242
  del model, img
243
  if device.type == 'cuda':
244
  torch.cuda.empty_cache()
 
245
 
246
  return Image.fromarray(output)
247
  except Exception as e:
248
- print(f"❌ Error in Real-ESRGAN processing: {e}")
249
  import traceback
250
  traceback.print_exc()
251
  return None
@@ -290,35 +376,18 @@ def process_image_simple(image, scale, task_type):
290
  return Image.fromarray(output)
291
 
292
  def upscale_image(image, model_name, output_format="png"):
293
- """Process and upscale image with better error handling"""
294
-
295
- # Validation
296
  if image is None:
297
  return None, "❌ Please upload an image first!"
298
 
 
 
 
 
 
 
299
  try:
300
- # Convert image to PIL if needed
301
- if not isinstance(image, Image.Image):
302
- try:
303
- image = Image.fromarray(image)
304
- except Exception as e:
305
- return None, f"❌ Could not convert image: {str(e)}"
306
-
307
- model_info = MODELS[model_name]
308
- model_path = os.path.join("models", model_info["filename"])
309
-
310
- if not os.path.exists(model_path):
311
- return None, f"❌ Model not found: {model_info['filename']}\nPlease restart the app to download models."
312
-
313
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
314
 
315
- print(f"πŸ“₯ Processing image: {image.size} | Format: {image.format or 'Unknown'}")
316
-
317
- # Ensure image is in RGB mode
318
- if image.mode != 'RGB':
319
- print(f"πŸ”„ Converting from {image.mode} to RGB")
320
- image = image.convert('RGB')
321
-
322
  if model_info["type"] == "realesrgan":
323
  print(f"πŸ”₯ Processing with Real-ESRGAN {model_info['scale']}x...")
324
  result_image = process_with_realesrgan(
@@ -329,50 +398,29 @@ def upscale_image(image, model_name, output_format="png"):
329
  )
330
 
331
  if result_image is None:
332
- return None, "❌ Error processing with Real-ESRGAN model. Check console for details."
333
  else:
334
- print(f"πŸ“Š Processing with {model_info['type'].upper()}...")
335
- # For SwinIR models, use simple upscaling for now
336
  result_image = process_image_simple(
337
  image,
338
  model_info["scale"],
339
  model_info["task"]
340
  )
341
 
342
- # Build info message
343
- info = f"βœ… Successfully Enhanced!\n"
344
- info += f"{'='*50}\n"
345
- info += f"🎯 Model: {model_name}\n"
346
- info += f"πŸ“Š Type: {model_info['type'].upper()}\n"
347
- info += f"πŸ” Task: {model_info['task']}\n"
348
  info += f"πŸ”’ Scale: {model_info['scale']}x\n"
349
- info += f"{'='*50}\n"
350
- info += f"πŸ“ Input: {image.size[0]} x {image.size[1]} px\n"
351
- info += f"πŸ“ Output: {result_image.size[0]} x {result_image.size[1]} px\n"
352
- info += f"{'='*50}\n"
353
  info += f"πŸ’» Device: {device}\n"
354
- info += f"πŸ“ Format: {output_format.upper()}\n"
355
- info += f"πŸ“¦ Model File: {model_info['filename']}\n"
356
 
357
  return result_image, info
358
 
359
- except FileNotFoundError as e:
360
- error_msg = f"❌ File Error: Could not read uploaded image.\n"
361
- error_msg += f"This might be a temporary Gradio issue.\n"
362
- error_msg += f"Please try:\n"
363
- error_msg += f" 1. Re-uploading the image\n"
364
- error_msg += f" 2. Using a different image\n"
365
- error_msg += f" 3. Refreshing the page\n\n"
366
- error_msg += f"Technical: {str(e)}"
367
- return None, error_msg
368
-
369
  except Exception as e:
370
  import traceback
371
- error_msg = f"❌ Error processing image\n"
372
- error_msg += f"{'='*50}\n"
373
- error_msg += f"Error: {str(e)}\n"
374
- error_msg += f"{'='*50}\n"
375
- error_msg += f"Traceback:\n{traceback.format_exc()}"
376
  return None, error_msg
377
 
378
  def get_model_status():
@@ -393,16 +441,12 @@ def get_model_status():
393
  return status
394
 
395
  print("="*60)
396
- print("🎨 SwinIR & Real-ESRGAN Image Upscaler")
397
  print("="*60)
398
  downloaded_count, failed_models = download_all_models()
399
  print("="*60)
400
 
401
- # GRADIO INTERFACE
402
- with gr.Blocks(
403
- title="AI Image Upscaler",
404
- theme=gr.themes.Soft(),
405
- ) as demo:
406
 
407
  gr.HTML("""
408
  <div style="text-align: center; padding: 2rem 0;">
@@ -410,7 +454,7 @@ with gr.Blocks(
410
  πŸš€ AI Image Upscaler
411
  </h1>
412
  <p style="font-size: 1.1rem; color: #666;">
413
- Professional Image Enhancement powered by Real-ESRGAN & SwinIR
414
  </p>
415
  </div>
416
  """)
@@ -422,15 +466,14 @@ with gr.Blocks(
422
  input_image = gr.Image(
423
  label="πŸ“€ Upload Your Image",
424
  type="pil",
425
- sources=["upload", "clipboard"], # Allow clipboard paste
426
  height=400
427
  )
428
 
429
  model_dropdown = gr.Dropdown(
430
  choices=list(MODELS.keys()),
431
  value="πŸ”₯ Real-ESRGAN x4 (Best for 4x)",
432
- label="🎯 Choose AI Model",
433
- info="Select the enhancement model"
434
  )
435
 
436
  output_format = gr.Radio(
@@ -446,11 +489,11 @@ with gr.Blocks(
446
  )
447
 
448
  gr.Markdown("""
449
- ### πŸ’‘ Pro Tips
450
- - πŸ”₯ **Real-ESRGAN**: Best for real photos
451
- - 🎨 **SwinIR**: Good for general upscaling
452
- - πŸ“Š Higher scale = Larger output file
453
- - πŸ–ΌοΈ Supports JPG, PNG, WEBP formats
454
  """)
455
 
456
  with gr.Column(scale=1):
@@ -462,8 +505,7 @@ with gr.Blocks(
462
 
463
  output_info = gr.Textbox(
464
  label="πŸ“Š Processing Details",
465
- lines=15,
466
- max_lines=20
467
  )
468
 
469
  with gr.Tab("πŸ“Š Model Status"):
@@ -473,7 +515,6 @@ with gr.Blocks(
473
  label="Model Status",
474
  value=get_model_status(),
475
  lines=25,
476
- max_lines=30,
477
  interactive=False
478
  )
479
 
@@ -484,37 +525,27 @@ with gr.Blocks(
484
  gr.Markdown(f"""
485
  ## About This App
486
 
487
- This application provides state-of-the-art image upscaling using AI models.
488
-
489
  ### πŸ“ˆ Statistics
490
  - **Models Available:** {downloaded_count}/{len(MODELS)}
491
  - **Device:** {'πŸš€ GPU (CUDA)' if torch.cuda.is_available() else 'πŸ’» CPU'}
492
- - **PyTorch Version:** {torch.__version__}
493
- - **Gradio Version:** {gr.__version__}
494
 
495
- ### 🎯 Supported Tasks
496
- 1. **Real-ESRGAN πŸ”₯** - Best quality upscaling for real photos
497
- 2. **Super Resolution** - Increase image resolution up to 8x
498
- 3. **Denoising** - Remove noise from images (SwinIR)
499
- 4. **JPEG Artifact Removal** - Fix compression artifacts (SwinIR)
 
 
 
 
 
500
 
501
  ### πŸ“š Model Sources
502
  - **SwinIR:** [deepinv/swinir](https://huggingface.co/deepinv/swinir)
503
  - **Real-ESRGAN:** [ai-forever/Real-ESRGAN](https://huggingface.co/ai-forever/Real-ESRGAN)
504
 
505
- ### πŸ”§ Features
506
- - βœ… Multiple AI models
507
- - βœ… GPU acceleration
508
- - βœ… Batch processing ready
509
- - βœ… Multiple output formats
510
- - βœ… High-quality upscaling
511
-
512
- ### πŸ› Troubleshooting
513
- If you encounter issues:
514
- 1. **File upload errors**: Try re-uploading or use a different browser
515
- 2. **Processing errors**: Check the console logs
516
- 3. **Slow processing**: GPU acceleration requires CUDA
517
-
518
  ---
519
  Made with ❀️ using Gradio and PyTorch
520
  """)
@@ -527,9 +558,4 @@ with gr.Blocks(
527
  )
528
 
529
  if __name__ == "__main__":
530
- demo.queue(max_size=10)
531
- demo.launch(
532
- server_name="0.0.0.0",
533
- server_port=7860,
534
- share=False
535
- )
 
7
  import os
8
  from huggingface_hub import hf_hub_download
9
  import cv2
 
10
  import sys
11
  import warnings
12
+ import gc
 
13
 
14
  warnings.filterwarnings('ignore', category=FutureWarning)
15
  warnings.filterwarnings('ignore', category=UserWarning)
 
17
 
18
  sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
19
 
20
+ class ResidualDenseBlock(nn.Module):
21
+ def __init__(self, nf=64, gc=32):
22
+ super(ResidualDenseBlock, self).__init__()
23
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=True)
24
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=True)
25
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=True)
26
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=True)
27
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=True)
28
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
29
+
30
+ def forward(self, x):
31
+ x1 = self.lrelu(self.conv1(x))
32
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
33
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
34
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
35
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
36
+ return x5 * 0.2 + x
37
+
38
+ class RRDB(nn.Module):
39
+ def __init__(self, nf, gc=32):
40
+ super(RRDB, self).__init__()
41
+ self.rdb1 = ResidualDenseBlock(nf, gc)
42
+ self.rdb2 = ResidualDenseBlock(nf, gc)
43
+ self.rdb3 = ResidualDenseBlock(nf, gc)
44
+
45
+ def forward(self, x):
46
+ out = self.rdb1(x)
47
+ out = self.rdb2(out)
48
+ out = self.rdb3(out)
49
+ return out * 0.2 + x
50
+
51
+ class RRDBNet(nn.Module):
52
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4):
53
+ super(RRDBNet, self).__init__()
54
+ self.scale = scale
55
+
56
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
57
+ self.body = nn.ModuleList([RRDB(nf, gc) for _ in range(nb)])
58
+ self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59
+
60
+ self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
61
+ self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
62
+ if scale >= 8:
63
+ self.conv_up3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64
+
65
+ self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
66
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
67
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
68
+
69
+ def forward(self, x):
70
+ fea = self.conv_first(x)
71
+ trunk = fea
72
+ for block in self.body:
73
+ trunk = block(trunk)
74
+ trunk = self.conv_body(trunk)
75
+ fea = fea + trunk
76
+ del trunk
77
+
78
+ fea = self.lrelu(self.conv_up1(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
79
+ fea = self.lrelu(self.conv_up2(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
80
+ if self.scale >= 8:
81
+ fea = self.lrelu(self.conv_up3(F.interpolate(fea, scale_factor=2, mode='bilinear', align_corners=False)))
82
+
83
+ out = self.conv_last(self.lrelu(self.conv_hr(fea)))
84
+ del fea
85
+ return out
86
+
87
+ def hdr_like(img):
88
+ mean_val = img.mean(dim=(2, 3), keepdim=True)
89
+ img = img - mean_val
90
+ img = img * 1.1
91
+ img = img + 0.5
92
+ img = torch.clamp(img, 0, 1)
93
+ img = img ** 0.85
94
+ return torch.clamp(img, 0, 1)
95
+
96
+ def sharpen(img, amount=0.15):
97
+ blur = F.avg_pool2d(img, kernel_size=3, stride=1, padding=1)
98
+ sharpened = img + amount * (img - blur)
99
+ return torch.clamp(sharpened, 0, 1)
100
+
101
+ def process_with_tiling(model, img_tensor, tile_size=160, tile_overlap=32):
102
+ device = img_tensor.device
103
+ b, c, h, w = img_tensor.shape
104
+ scale = model.scale
105
+
106
+ if device.type == 'cpu':
107
+ tile_size = min(tile_size, 128)
108
+ tile_overlap = 16
109
+
110
+ if h <= tile_size and w <= tile_size:
111
+ with torch.no_grad():
112
+ output = model(img_tensor)
113
+ output = hdr_like(output)
114
+ output = sharpen(output, amount=0.15)
115
+ return torch.clamp(output, 0, 1)
116
+
117
+ sample_tile = img_tensor[:, :, :min(tile_size, h), :min(tile_size, w)]
118
+ with torch.no_grad():
119
+ sample_output = model(sample_tile)
120
+
121
+ output_channels = sample_output.shape[1]
122
+ sample_scale_h = sample_output.shape[2] / sample_tile.shape[2]
123
+ sample_scale_w = sample_output.shape[3] / sample_tile.shape[3]
124
+ del sample_tile, sample_output
125
+
126
+ output_h = int(h * sample_scale_h)
127
+ output_w = int(w * sample_scale_w)
128
+
129
+ output = torch.zeros((b, output_channels, output_h, output_w), device=device)
130
+
131
+ stride = tile_size - tile_overlap
132
+ tiles_h = (h - 1) // stride + 1
133
+ tiles_w = (w - 1) // stride + 1
134
+
135
+ print(f"πŸ”² Processing {tiles_h}x{tiles_w} = {tiles_h*tiles_w} tiles")
136
+ print(f" Input: {c}ch {h}x{w} β†’ Output: {output_channels}ch {output_h}x{output_w}")
137
+
138
+ for i in range(0, h, stride):
139
+ for j in range(0, w, stride):
140
+ h_start = i
141
+ h_end = min(i + tile_size, h)
142
+ w_start = j
143
+ w_end = min(j + tile_size, w)
144
+
145
+ tile = img_tensor[:, :, h_start:h_end, w_start:w_end]
146
+
147
+ with torch.no_grad():
148
+ tile_output = model(tile)
149
+
150
+ actual_h = tile_output.shape[2]
151
+ actual_w = tile_output.shape[3]
152
+
153
+ output_h_start = int(h_start * sample_scale_h)
154
+ output_w_start = int(w_start * sample_scale_w)
155
+
156
+ output[:, :, output_h_start:output_h_start+actual_h, output_w_start:output_w_start+actual_w] = tile_output
157
+
158
+ del tile, tile_output
159
+
160
+ if ((i // stride) * tiles_w + (j // stride)) % 4 == 0:
161
+ gc.collect()
162
+ if device.type == 'cuda':
163
+ torch.cuda.empty_cache()
164
+
165
+ output = hdr_like(output)
166
+ output = sharpen(output, amount=0.15)
167
+
168
+ return torch.clamp(output, 0, 1)
169
+
170
  MODELS = {
171
  "Classical SR x8 (DIV2K)": {
172
  "repo": "deepinv/swinir",
 
224
  def setup_directories():
225
  os.makedirs("models", exist_ok=True)
226
  os.makedirs("temp", exist_ok=True)
 
227
 
228
  def download_all_models():
229
  print("πŸš€ Starting model download...")
 
254
  print(f"❌ Failed to download {model_name}: {str(e)}")
255
  failed.append(model_name)
256
 
257
+ print(f"\nπŸ“Š Download Summary: βœ… {downloaded}/{len(MODELS)}")
 
 
 
 
 
 
258
  return downloaded, failed
259
 
260
  def load_realesrgan_model(model_path, device, scale=4):
 
267
  else:
268
  state_dict = checkpoint
269
 
 
 
270
  in_nc = 3
271
  if 'conv_first.weight' in state_dict:
272
  in_nc = state_dict['conv_first.weight'].shape[1]
273
 
 
274
  out_nc = 3
275
  if 'conv_last.weight' in state_dict:
276
  out_nc = state_dict['conv_last.weight'].shape[0]
277
 
278
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=64, nb=23, gc=32, scale=scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  model.load_state_dict(state_dict, strict=True)
280
  model.eval()
 
 
 
 
 
 
 
281
  model = model.to(device)
282
 
283
+ print(f"βœ… Model loaded: {scale}x | In:{in_nc}ch | Out:{out_nc}ch")
 
284
  return model
285
  except Exception as e:
286
+ print(f"❌ Error loading model: {e}")
 
 
287
  return None
288
 
289
+ def process_with_realesrgan(image, model_path, device, scale=4):
290
  try:
291
  model = load_realesrgan_model(model_path, device, scale)
 
292
  if model is None:
293
  return None
294
 
 
295
  in_nc = model.conv_first.weight.shape[1]
296
 
297
  img = np.array(image).astype(np.float32) / 255.0
298
  if len(img.shape) == 2:
299
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
300
 
 
301
  img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
302
  img = img.unsqueeze(0).to(device)
 
 
303
  img = torch.clamp(img, 0, 1)
304
 
305
+ print(f"πŸ“₯ Input: {img.shape}")
306
 
 
307
  if in_nc == 12:
308
  b, c, h, w = img.shape
 
 
309
  pad_h = (2 - h % 2) % 2
310
  pad_w = (2 - w % 2) % 2
311
 
312
  if pad_h > 0 or pad_w > 0:
313
  img = F.pad(img, (0, pad_w, 0, pad_h), mode='replicate')
314
+ print(f"πŸ”§ Padded: {img.shape}")
315
 
 
316
  img = F.pixel_unshuffle(img, 2)
317
+ print(f"πŸ”„ Pixel unshuffle: {img.shape}")
318
 
 
 
 
319
  output = process_with_tiling(model, img, tile_size=160, tile_overlap=32)
320
 
321
+ print(f"πŸ“€ Output: {output.shape}")
322
 
323
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
324
  output = np.transpose(output, (1, 2, 0))
325
  output = (output * 255.0).round().astype(np.uint8)
326
 
 
327
  del model, img
328
  if device.type == 'cuda':
329
  torch.cuda.empty_cache()
330
+ gc.collect()
331
 
332
  return Image.fromarray(output)
333
  except Exception as e:
334
+ print(f"❌ Processing error: {e}")
335
  import traceback
336
  traceback.print_exc()
337
  return None
 
376
  return Image.fromarray(output)
377
 
378
  def upscale_image(image, model_name, output_format="png"):
 
 
 
379
  if image is None:
380
  return None, "❌ Please upload an image first!"
381
 
382
+ model_info = MODELS[model_name]
383
+ model_path = os.path.join("models", model_info["filename"])
384
+
385
+ if not os.path.exists(model_path):
386
+ return None, f"❌ Model not found: {model_info['filename']}\nPlease restart the app."
387
+
388
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
390
 
 
 
 
 
 
 
 
391
  if model_info["type"] == "realesrgan":
392
  print(f"πŸ”₯ Processing with Real-ESRGAN {model_info['scale']}x...")
393
  result_image = process_with_realesrgan(
 
398
  )
399
 
400
  if result_image is None:
401
+ return None, "❌ Error processing with Real-ESRGAN"
402
  else:
403
+ state_dict = load_model(model_path, device)
 
404
  result_image = process_image_simple(
405
  image,
406
  model_info["scale"],
407
  model_info["task"]
408
  )
409
 
410
+ info = f"βœ… Model: {model_name}\n"
411
+ info += f"🎯 Type: {model_info['type'].upper()}\n"
412
+ info += f"πŸ“Š Task: {model_info['task']}\n"
 
 
 
413
  info += f"πŸ”’ Scale: {model_info['scale']}x\n"
414
+ info += f"πŸ“ Input: {image.size[0]}x{image.size[1]}\n"
415
+ info += f"πŸ“ Output: {result_image.size[0]}x{result_image.size[1]}\n"
 
 
416
  info += f"πŸ’» Device: {device}\n"
417
+ info += f"πŸ“ Format: {output_format.upper()}"
 
418
 
419
  return result_image, info
420
 
 
 
 
 
 
 
 
 
 
 
421
  except Exception as e:
422
  import traceback
423
+ error_msg = f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
 
424
  return None, error_msg
425
 
426
  def get_model_status():
 
441
  return status
442
 
443
  print("="*60)
444
+ print("🎨 AI Image Upscaler - Optimized Edition")
445
  print("="*60)
446
  downloaded_count, failed_models = download_all_models()
447
  print("="*60)
448
 
449
+ with gr.Blocks(title="AI Image Upscaler", theme=gr.themes.Soft()) as demo:
 
 
 
 
450
 
451
  gr.HTML("""
452
  <div style="text-align: center; padding: 2rem 0;">
 
454
  πŸš€ AI Image Upscaler
455
  </h1>
456
  <p style="font-size: 1.1rem; color: #666;">
457
+ Enhanced with HDR-like processing & Smart Tiling
458
  </p>
459
  </div>
460
  """)
 
466
  input_image = gr.Image(
467
  label="πŸ“€ Upload Your Image",
468
  type="pil",
469
+ sources=["upload", "clipboard"],
470
  height=400
471
  )
472
 
473
  model_dropdown = gr.Dropdown(
474
  choices=list(MODELS.keys()),
475
  value="πŸ”₯ Real-ESRGAN x4 (Best for 4x)",
476
+ label="🎯 Choose AI Model"
 
477
  )
478
 
479
  output_format = gr.Radio(
 
489
  )
490
 
491
  gr.Markdown("""
492
+ ### πŸ’‘ New Features
493
+ - 🎨 **HDR-like tone mapping**
494
+ - πŸ”ͺ **Smart sharpening**
495
+ - πŸ”² **Optimized tiling**
496
+ - πŸš€ **Better memory management**
497
  """)
498
 
499
  with gr.Column(scale=1):
 
505
 
506
  output_info = gr.Textbox(
507
  label="πŸ“Š Processing Details",
508
+ lines=15
 
509
  )
510
 
511
  with gr.Tab("πŸ“Š Model Status"):
 
515
  label="Model Status",
516
  value=get_model_status(),
517
  lines=25,
 
518
  interactive=False
519
  )
520
 
 
525
  gr.Markdown(f"""
526
  ## About This App
527
 
 
 
528
  ### πŸ“ˆ Statistics
529
  - **Models Available:** {downloaded_count}/{len(MODELS)}
530
  - **Device:** {'πŸš€ GPU (CUDA)' if torch.cuda.is_available() else 'πŸ’» CPU'}
531
+ - **PyTorch:** {torch.__version__}
532
+ - **Gradio:** {gr.__version__}
533
 
534
+ ### ✨ Optimizations
535
+ - Bilinear upsampling for smooth results
536
+ - HDR-like tone mapping for better contrast
537
+ - Smart sharpening (DSLR look)
538
+ - Memory-efficient tiling for large images
539
+ - Automatic garbage collection
540
+
541
+ ### 🎯 Supported Models
542
+ 1. **Real-ESRGAN πŸ”₯** - Best for real photos (2x, 4x, 8x)
543
+ 2. **SwinIR** - Lightweight super-resolution (2x, 3x, 4x, 8x)
544
 
545
  ### πŸ“š Model Sources
546
  - **SwinIR:** [deepinv/swinir](https://huggingface.co/deepinv/swinir)
547
  - **Real-ESRGAN:** [ai-forever/Real-ESRGAN](https://huggingface.co/ai-forever/Real-ESRGAN)
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  ---
550
  Made with ❀️ using Gradio and PyTorch
551
  """)
 
558
  )
559
 
560
  if __name__ == "__main__":
561
+ demo.launch()