pratyyush commited on
Commit
8549916
·
verified ·
1 Parent(s): 3c3c722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -1,16 +1,11 @@
1
  import os
2
  import gradio as gr
3
  from zeroscratches import EraseScratches
4
- import torch
5
  import cv2
6
  import numpy as np
7
  from concurrent.futures import ThreadPoolExecutor
8
  from PIL import Image
9
 
10
- # ✅ Force GPU usage if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- restorer = EraseScratches().to(device)
13
-
14
  # ✅ Custom CSS for clean UI
15
  custom_css = """
16
  /* Dark theme styling */
@@ -64,16 +59,15 @@ img {
64
  }
65
  """
66
 
67
- # ✅ Optimized image processing function
68
  def predict(img_file):
69
- """Run image processing with multi-threading"""
70
  with ThreadPoolExecutor(max_workers=4) as executor:
71
  future = executor.submit(process_image, img_file)
72
  restored_img = future.result()
73
-
74
  return restored_img
75
 
76
- # ✅ Image pre-processing function
77
  def process_image(img_file):
78
  """Efficient image loading and restoration"""
79
 
@@ -81,18 +75,16 @@ def process_image(img_file):
81
  img = cv2.imdecode(np.frombuffer(img_file.read(), np.uint8), cv2.IMREAD_COLOR)
82
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
83
 
84
- # Resize large images to speed up processing
85
  max_size = (1024, 1024)
86
  img = cv2.resize(img, max_size, interpolation=cv2.INTER_AREA)
87
 
88
- # Convert image to tensor and send to GPU
89
- img_tensor = torch.from_numpy(img).to(device)
90
-
91
- # Perform restoration
92
- restored_img = restorer.erase(img_tensor).cpu().numpy()
93
 
94
- # Convert back to PIL format
95
- restored_img = Image.fromarray(restored_img)
 
96
 
97
  return restored_img
98
 
 
1
  import os
2
  import gradio as gr
3
  from zeroscratches import EraseScratches
 
4
  import cv2
5
  import numpy as np
6
  from concurrent.futures import ThreadPoolExecutor
7
  from PIL import Image
8
 
 
 
 
 
9
  # ✅ Custom CSS for clean UI
10
  custom_css = """
11
  /* Dark theme styling */
 
59
  }
60
  """
61
 
62
+ # ✅ Image processing function
63
  def predict(img_file):
64
+ """Apply scratch removal with multi-threading"""
65
  with ThreadPoolExecutor(max_workers=4) as executor:
66
  future = executor.submit(process_image, img_file)
67
  restored_img = future.result()
 
68
  return restored_img
69
 
70
+ # ✅ Faster image processing with OpenCV
71
  def process_image(img_file):
72
  """Efficient image loading and restoration"""
73
 
 
75
  img = cv2.imdecode(np.frombuffer(img_file.read(), np.uint8), cv2.IMREAD_COLOR)
76
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
77
 
78
+ # Resize large images for faster processing
79
  max_size = (1024, 1024)
80
  img = cv2.resize(img, max_size, interpolation=cv2.INTER_AREA)
81
 
82
+ # Convert to PIL format
83
+ img_pil = Image.fromarray(img)
 
 
 
84
 
85
+ # Perform scratch removal
86
+ restorer = EraseScratches()
87
+ restored_img = restorer.erase(img_pil)
88
 
89
  return restored_img
90