pratyyush commited on
Commit
0b84bc4
·
verified ·
1 Parent(s): 08fa660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -6,11 +6,11 @@ from PIL import Image
6
  import numpy as np
7
  import os
8
 
9
- # Check if CUDA is available
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print(f"Using device: {device}")
12
 
13
- # Model definition
14
  class DeblurNet(nn.Module):
15
  def __init__(self):
16
  super(DeblurNet, self).__init__()
@@ -53,13 +53,19 @@ class DeblurNet(nn.Module):
53
  x = self.final_conv(x)
54
  return torch.tanh(x)
55
 
56
- # Load model
57
  model = DeblurNet().to(device)
58
  model_path = os.path.join('model', 'best_deblur_model.pth')
59
- model.load_state_dict(torch.load(model_path, map_location=device))
60
- model.eval()
61
 
62
- # Image processing functions
 
 
 
 
 
 
 
 
63
  transform = transforms.Compose([
64
  transforms.Resize((256, 256)),
65
  transforms.ToTensor(),
@@ -67,6 +73,7 @@ transform = transforms.Compose([
67
  ])
68
 
69
  def postprocess_image(tensor):
 
70
  tensor = tensor * 0.5 + 0.5
71
  tensor = torch.clamp(tensor, 0, 1)
72
  image = tensor.cpu().detach().numpy()
@@ -74,11 +81,12 @@ def postprocess_image(tensor):
74
  return (image * 255).astype(np.uint8)
75
 
76
  def deblur_image(filepath):
77
- if filepath is None:
 
78
  return None
79
-
80
  try:
81
- # Load the image from the filepath
82
  input_image = Image.open(filepath).convert("RGB")
83
 
84
  # Save original size
@@ -91,24 +99,30 @@ def deblur_image(filepath):
91
  with torch.no_grad():
92
  output_tensor = model(input_tensor)
93
 
94
- # Postprocess
95
  output_image = postprocess_image(output_tensor[0])
96
 
97
  # Resize back to original size
98
  output_image = Image.fromarray(output_image).resize(original_size)
 
99
  return np.array(output_image)
100
 
101
  except Exception as e:
102
  print(f"Error processing image: {e}")
103
  return None
104
 
105
- # ✅ Gradio interface using gr.File with correct type
106
  custom_css = """
107
  /* Hide Gradio's footer and header */
108
  footer, header, .gradio-footer, .gradio-header {
109
  display: none !important;
110
  }
111
 
 
 
 
 
 
112
  /* Non-draggable images */
113
  img {
114
  pointer-events: none !important;
@@ -140,11 +154,6 @@ body, .gradio-container {
140
  color: white !important;
141
  border: 1px solid #333333 !important;
142
  }
143
-
144
- /* Hide the share button icon in the output box */
145
- .gradio-container .gr-image-output .wrap.svelte-1ipelgc {
146
- display: none !important;
147
- }
148
  """
149
 
150
  # ✅ Gradio interface
@@ -157,5 +166,6 @@ demo = gr.Interface(
157
  css=custom_css
158
  )
159
 
 
160
  if __name__ == "__main__":
161
  demo.launch()
 
6
  import numpy as np
7
  import os
8
 
9
+ # Check if CUDA is available
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  print(f"Using device: {device}")
12
 
13
+ # Model definition
14
  class DeblurNet(nn.Module):
15
  def __init__(self):
16
  super(DeblurNet, self).__init__()
 
53
  x = self.final_conv(x)
54
  return torch.tanh(x)
55
 
56
+ # Load model
57
  model = DeblurNet().to(device)
58
  model_path = os.path.join('model', 'best_deblur_model.pth')
 
 
59
 
60
+ # Ensure model path exists before loading
61
+ if os.path.exists(model_path):
62
+ model.load_state_dict(torch.load(model_path, map_location=device))
63
+ model.eval()
64
+ print("Model loaded successfully.")
65
+ else:
66
+ print(f"Model file not found at {model_path}. Please check the path.")
67
+
68
+ # ✅ Image processing functions
69
  transform = transforms.Compose([
70
  transforms.Resize((256, 256)),
71
  transforms.ToTensor(),
 
73
  ])
74
 
75
  def postprocess_image(tensor):
76
+ """Post-process the output tensor into a displayable image."""
77
  tensor = tensor * 0.5 + 0.5
78
  tensor = torch.clamp(tensor, 0, 1)
79
  image = tensor.cpu().detach().numpy()
 
81
  return (image * 255).astype(np.uint8)
82
 
83
  def deblur_image(filepath):
84
+ """Deblurs the uploaded image."""
85
+ if not filepath:
86
  return None
87
+
88
  try:
89
+ # Load image from filepath
90
  input_image = Image.open(filepath).convert("RGB")
91
 
92
  # Save original size
 
99
  with torch.no_grad():
100
  output_tensor = model(input_tensor)
101
 
102
+ # Post-process
103
  output_image = postprocess_image(output_tensor[0])
104
 
105
  # Resize back to original size
106
  output_image = Image.fromarray(output_image).resize(original_size)
107
+
108
  return np.array(output_image)
109
 
110
  except Exception as e:
111
  print(f"Error processing image: {e}")
112
  return None
113
 
114
+ # ✅ Custom CSS for styling and hiding share button
115
  custom_css = """
116
  /* Hide Gradio's footer and header */
117
  footer, header, .gradio-footer, .gradio-header {
118
  display: none !important;
119
  }
120
 
121
+ /* Hide share button */
122
+ .share-wrap {
123
+ display: none !important;
124
+ }
125
+
126
  /* Non-draggable images */
127
  img {
128
  pointer-events: none !important;
 
154
  color: white !important;
155
  border: 1px solid #333333 !important;
156
  }
 
 
 
 
 
157
  """
158
 
159
  # ✅ Gradio interface
 
166
  css=custom_css
167
  )
168
 
169
+ # ✅ Launch Gradio app
170
  if __name__ == "__main__":
171
  demo.launch()