Saumya Saksena commited on
Commit
b47bd8c
Β·
1 Parent(s): 639155f

Add files

Browse files
Files changed (4) hide show
  1. app.py +259 -0
  2. examples/heavy_rain.jpg +0 -0
  3. examples/light_rain.jpg +0 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Gradio app for ClearView image deraining model.
3
+
4
+ A user-friendly interface for removing rain from images using deep learning.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ import numpy as np
12
+ import time
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Import your model (adjust paths as needed)
16
+ from clearview.models import UNet
17
+ from clearview.utils.image import numpy_to_tensor, tensor_to_numpy
18
+
19
+
20
+ class DerainModel:
21
+ """Wrapper for the deraining model with preprocessing/postprocessing."""
22
+
23
+ def __init__(self, checkpoint_path: str, device: str = "cuda"):
24
+ """Initialize the model.
25
+
26
+ Args:
27
+ checkpoint_path: Path to model checkpoint (.pth file)
28
+ device: Device to run inference on ('cuda' or 'cpu')
29
+ """
30
+ self.device = device if torch.cuda.is_available() else "cpu"
31
+
32
+ # Load model
33
+ self.model = UNet(in_channels=3, out_channels=3)
34
+
35
+ # Load checkpoint
36
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
37
+ if 'model_state_dict' in checkpoint:
38
+ self.model.load_state_dict(checkpoint['model_state_dict'])
39
+ else:
40
+ self.model.load_state_dict(checkpoint)
41
+
42
+ self.model = self.model.to(self.device)
43
+ self.model.eval()
44
+
45
+ print(f"Model loaded on {self.device}")
46
+
47
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
48
+ """Preprocess PIL image to tensor.
49
+
50
+ Args:
51
+ image: PIL Image
52
+
53
+ Returns:
54
+ Preprocessed tensor (1, 3, H, W)
55
+ """
56
+ # Convert to numpy array
57
+ img_np = np.array(image).astype(np.float32) / 255.0
58
+
59
+ # Convert to tensor (C, H, W)
60
+ img_tensor = numpy_to_tensor(img_np, normalize=False)
61
+
62
+ # Add batch dimension (1, C, H, W)
63
+ img_tensor = img_tensor.unsqueeze(0).to(self.device)
64
+
65
+ return img_tensor
66
+
67
+ def postprocess(self, tensor: torch.Tensor) -> Image.Image:
68
+ """Postprocess tensor to PIL image.
69
+
70
+ Args:
71
+ tensor: Model output (1, 3, H, W)
72
+
73
+ Returns:
74
+ PIL Image
75
+ """
76
+ # Remove batch dimension and move to CPU
77
+ tensor = tensor.squeeze(0).cpu()
78
+
79
+ # Clamp to [0, 1]
80
+ tensor = torch.clamp(tensor, 0, 1)
81
+
82
+ # Convert to numpy (H, W, C)
83
+ img_np = tensor_to_numpy(tensor)
84
+
85
+ # Convert to uint8
86
+ img_np = (img_np * 255).astype(np.uint8)
87
+
88
+ # Convert to PIL
89
+ return Image.fromarray(img_np)
90
+
91
+ @torch.no_grad()
92
+ def derain(self, image: Image.Image) -> tuple[Image.Image, float]:
93
+ """Remove rain from image.
94
+
95
+ Args:
96
+ image: Input PIL Image with rain
97
+
98
+ Returns:
99
+ Tuple of (derained PIL Image, inference time in seconds)
100
+ """
101
+ start_time = time.time()
102
+
103
+ # Store original size
104
+ original_size = image.size # (width, height)
105
+
106
+ # Preprocess
107
+ input_tensor = self.preprocess(image)
108
+
109
+ # Inference
110
+ output_tensor = self.model(input_tensor)
111
+
112
+ # Postprocess
113
+ output_image = self.postprocess(output_tensor)
114
+
115
+ # Ensure output matches input size exactly
116
+ if output_image.size != original_size:
117
+ output_image = output_image.resize(original_size, Image.LANCZOS)
118
+
119
+ inference_time = time.time() - start_time
120
+
121
+ return output_image, inference_time
122
+
123
+
124
+ # Initialize model (will be loaded when app starts)
125
+ MODEL = None
126
+ CHECKPOINT_PATH = hf_hub_download(
127
+ repo_id="dronefreak/clearview-derain-unet", # Your model repo
128
+ filename="clearview-derain-unet.pth"
129
+ )
130
+
131
+
132
+ def load_model():
133
+ """Load model on startup."""
134
+ global MODEL
135
+ if MODEL is None:
136
+ MODEL = DerainModel(CHECKPOINT_PATH)
137
+ return MODEL
138
+
139
+
140
+ def derain_interface(image: Image.Image) -> tuple[tuple[Image.Image, Image.Image], str]:
141
+ """Gradio interface function.
142
+
143
+ Args:
144
+ image: Input PIL Image
145
+
146
+ Returns:
147
+ Tuple of ((input_image, output_image) for slider, info text)
148
+ """
149
+ if image is None:
150
+ return None, "Please upload an image first."
151
+
152
+ # Load model if not already loaded
153
+ model = load_model()
154
+
155
+ image = image.convert("RGB")
156
+
157
+ # Run inference
158
+ output_image, inference_time = model.derain(image)
159
+
160
+ output_image = output_image.convert("RGB")
161
+ output_image = output_image.resize(image.size, Image.NEAREST)
162
+
163
+ # Create info text
164
+ info = f"βœ… Rain removed successfully!\n"
165
+ info += f"⚑ Inference time: {inference_time:.3f}s\n"
166
+ info += f"πŸ“ Image size: {image.size[0]}x{image.size[1]}"
167
+
168
+ # Return (input, output) tuple for ImageSlider
169
+ return (image, output_image), info
170
+
171
+
172
+ # Create Gradio interface
173
+ with gr.Blocks(title="ClearView: Image Deraining") as demo:
174
+ gr.Markdown("""
175
+ # 🌧️ ClearView: Image Deraining
176
+
177
+ Remove rain streaks from images using deep learning. Upload a rainy image and see it transform!
178
+
179
+ **Model:** UNet with L1 loss, trained on Rain1400 dataset (12,600 images)
180
+ **Performance:** 30.9 PSNR / 0.914 SSIM on test set
181
+ """)
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ input_image = gr.Image(
186
+ type="pil",
187
+ label="πŸ“€ Upload Rainy Image",
188
+ # height=400
189
+ )
190
+
191
+ derain_button = gr.Button("✨ Remove Rain", variant="primary", size="lg")
192
+
193
+ info_text = gr.Textbox(
194
+ label="ℹ️ Info",
195
+ lines=3,
196
+ interactive=False
197
+ )
198
+
199
+ with gr.Column():
200
+ output_slider = gr.ImageSlider(
201
+ label="πŸ”„ Before & After (Drag slider to compare)",
202
+ # height=400,
203
+ interactive=False, # Don't allow uploading to the slider
204
+ show_download_button=True
205
+ )
206
+
207
+ # Examples
208
+ gr.Markdown("### πŸ“Έ Try these examples:")
209
+
210
+ gr.Examples(
211
+ examples=[
212
+ # Add paths to your example images here
213
+ ["examples/heavy_rain.jpg"],
214
+ ["examples/light_rain.jpg"],
215
+ ],
216
+ inputs=input_image,
217
+ outputs=[output_slider, info_text],
218
+ fn=derain_interface,
219
+ cache_examples=False,
220
+ )
221
+
222
+ gr.Markdown("""
223
+ ---
224
+ ### πŸ“ Notes:
225
+ - Works best on synthetic rain patterns (trained on Rain1400)
226
+ - May smooth fine textures slightly
227
+ - For best results, use images with clear rain streaks
228
+ - Real-world rain may have mixed results
229
+
230
+ ### πŸ”— Links:
231
+ - [GitHub Repository](https://github.com/dronefreak/clearview)
232
+ - [Model Card](https://huggingface.co/YOUR_USERNAME/clearview-unet)
233
+ - [Paper/Documentation](#)
234
+
235
+ ### ⚠️ Limitations:
236
+ - Trained on synthetic data (may not generalize to all real-world scenarios)
237
+ - No temporal consistency for video (processes frames independently)
238
+ - May lose fine details in heavily textured areas
239
+ """)
240
+
241
+ # Connect button to function
242
+ derain_button.click(
243
+ fn=derain_interface,
244
+ inputs=input_image,
245
+ outputs=[output_slider, info_text]
246
+ )
247
+
248
+
249
+ if __name__ == "__main__":
250
+ print("Starting ClearView Gradio app...")
251
+ print(f"Model checkpoint: {CHECKPOINT_PATH}")
252
+ print(f"CUDA available: {torch.cuda.is_available()}")
253
+
254
+ # Launch app
255
+ demo.launch(
256
+ share=False, # Set to True to create public link
257
+ server_name="0.0.0.0", # Allow external connections
258
+ server_port=7860,
259
+ )
examples/heavy_rain.jpg ADDED
examples/light_rain.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ Pillow>=9.0.0
5
+ numpy>=1.21.0
6
+ huggingface-hub>=0.19.0