Janeka commited on
Commit
aad3cf7
·
verified ·
1 Parent(s): 7ac7dad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -1,40 +1,53 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
4
  from PIL import Image
5
  from torchvision.transforms import ToTensor, ToPILImage
6
 
7
- # Define the EDSR model architecture (simplified version)
8
  class EDSR(nn.Module):
9
  def __init__(self):
10
  super(EDSR, self).__init__()
11
- # Simplified architecture - in practice you'd want the full EDSR
12
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
13
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
14
- self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
 
 
 
 
15
 
16
  def forward(self, x):
17
  x = torch.relu(self.conv1(x))
18
- x = torch.relu(self.conv2(x))
19
- x = self.conv3(x)
 
 
20
  return x
21
 
22
- # Load the model (we'll use a pretrained version from Hugging Face)
23
  model = EDSR()
24
 
25
- # Load pretrained weights (alternative approach)
26
  try:
27
  state_dict = torch.hub.load_state_dict_from_url(
28
  "https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model.bin",
29
- map_location="cpu"
 
30
  )
31
  model.load_state_dict(state_dict)
32
- except:
33
- print("Couldn't load pretrained weights, using random initialization")
 
 
34
 
35
  model.eval()
36
 
37
  def enhance_image(input_img):
 
 
 
38
  # Convert to tensor and add batch dimension
39
  input_tensor = ToTensor()(input_img).unsqueeze(0)
40
 
@@ -46,13 +59,19 @@ def enhance_image(input_img):
46
  output_img = ToPILImage()(output_tensor.squeeze(0).clamp(0, 1))
47
  return output_img
48
 
 
 
 
 
 
49
  # Gradio UI
50
  demo = gr.Interface(
51
  fn=enhance_image,
52
- inputs=gr.Image(type="pil", label="Upload Image"),
53
  outputs=gr.Image(type="pil", label="Enhanced Image"),
54
- title="Image Super-Resolution (EDSR)",
55
- examples=["example_image.jpg"] if os.path.exists("example_image.jpg") else None,
 
56
  )
57
 
58
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import os
5
  from PIL import Image
6
  from torchvision.transforms import ToTensor, ToPILImage
7
 
8
+ # Define the EDSR model architecture
9
  class EDSR(nn.Module):
10
  def __init__(self):
11
  super(EDSR, self).__init__()
12
+ # Basic EDSR architecture
13
+ self.conv1 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
14
+ self.resblocks = nn.Sequential(*[nn.Sequential(
15
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
16
+ nn.ReLU(),
17
+ nn.Conv2d(256, 256, kernel_size=3, padding=1)
18
+ ) for _ in range(8)])
19
+ self.conv2 = nn.Conv2d(256, 3, kernel_size=3, padding=1)
20
 
21
  def forward(self, x):
22
  x = torch.relu(self.conv1(x))
23
+ residual = x
24
+ x = self.resblocks(x)
25
+ x += residual
26
+ x = self.conv2(x)
27
  return x
28
 
29
+ # Initialize model
30
  model = EDSR()
31
 
32
+ # Try loading pretrained weights from Hugging Face
33
  try:
34
  state_dict = torch.hub.load_state_dict_from_url(
35
  "https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model.bin",
36
+ map_location="cpu",
37
+ file_name="edsr_weights.pth"
38
  )
39
  model.load_state_dict(state_dict)
40
+ print("Successfully loaded pretrained weights")
41
+ except Exception as e:
42
+ print(f"Couldn't load pretrained weights: {str(e)}")
43
+ print("Using randomly initialized model")
44
 
45
  model.eval()
46
 
47
  def enhance_image(input_img):
48
+ # Resize input to prevent memory issues
49
+ input_img = input_img.resize((256, 256))
50
+
51
  # Convert to tensor and add batch dimension
52
  input_tensor = ToTensor()(input_img).unsqueeze(0)
53
 
 
59
  output_img = ToPILImage()(output_tensor.squeeze(0).clamp(0, 1))
60
  return output_img
61
 
62
+ # Prepare examples
63
+ example_images = []
64
+ if os.path.exists("example_image.jpg"):
65
+ example_images = ["example_image.jpg"]
66
+
67
  # Gradio UI
68
  demo = gr.Interface(
69
  fn=enhance_image,
70
+ inputs=gr.Image(type="pil", label="Input Image"),
71
  outputs=gr.Image(type="pil", label="Enhanced Image"),
72
+ title="EDSR Image Super-Resolution",
73
+ examples=example_images,
74
+ allow_flagging="never"
75
  )
76
 
77
  demo.launch()