devbernie commited on
Commit
8c70566
·
verified ·
1 Parent(s): 244d9ce
Files changed (1) hide show
  1. app.py +45 -75
app.py CHANGED
@@ -3,63 +3,65 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  from torchvision.transforms import ToTensor, ToPILImage
6
- from typing import Tuple, Optional
 
7
 
8
  # Device configuration
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # Constants
12
- SUPPORTED_FORMATS = ["JPEG", "PNG", "WEBP"]
 
13
  MAX_IMAGE_SIZE = (1024, 1024)
14
 
 
 
 
 
 
 
 
 
 
 
15
  def load_model() -> torch.nn.Module:
16
- """Load pretrained ESRGAN model"""
17
- model = torch.hub.load(
18
- "facebookresearch/AnimatedDrawings",
19
- "esrgan",
20
- pretrained=True,
21
- verbose=False,
22
- trust_repo=True
23
- )
24
  return model.to(device).eval()
25
 
26
  def preprocess_image(image: Image.Image) -> torch.Tensor:
27
- """Convert PIL image to preprocessed tensor"""
28
  transform = ToTensor()
29
- tensor = transform(image).unsqueeze(0).to(device)
30
- return tensor * 2.0 - 1.0 # ESRGAN requires [-1,1] normalization
31
 
32
  def postprocess_image(tensor: torch.Tensor) -> Image.Image:
33
- """Convert model output tensor to PIL image"""
34
  transform = ToPILImage()
35
- tensor = (tensor + 1.0) / 2.0 # Convert back to [0,1]
36
  tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
37
  return transform(tensor)
38
 
39
- def validate_image(image: Image.Image) -> None:
40
- """Validate input image dimensions and format"""
41
  if image.mode not in ["RGB", "RGBA"]:
42
  raise gr.Error("Only RGB/RGBA images supported")
43
- if image.size[0] > MAX_IMAGE_SIZE[0] or image.size[1] > MAX_IMAGE_SIZE[1]:
44
- raise gr.Error(f"Max image size {MAX_IMAGE_SIZE} exceeded")
45
 
46
  def enhance_image(
47
  input_image: Image.Image,
48
  scale_factor: float = 2.0
49
  ) -> Image.Image:
50
- """
51
- Enhance image using ESRGAN model
52
- Args:
53
- input_image: PIL Image to process
54
- scale_factor: Multiplier for image scaling (2.0 or 4.0)
55
- Returns:
56
- Enhanced PIL Image
57
- """
58
  try:
59
  validate_image(input_image)
60
- original_size = input_image.size
61
 
62
- # Convert RGBA to RGB if needed
63
  if input_image.mode == 'RGBA':
64
  input_image = input_image.convert('RGB')
65
 
@@ -68,62 +70,30 @@ def enhance_image(
68
  output_tensor = model(input_tensor)
69
 
70
  result = postprocess_image(output_tensor)
71
- result = result.resize(
72
- (int(original_size[0]*scale_factor),
73
- int(original_size[1]*scale_factor)),
74
  Image.LANCZOS
75
  )
76
- return result
77
-
78
  except Exception as e:
79
- raise gr.Error(f"Image processing error: {str(e)}")
80
 
81
- # Load model once at startup
82
  model = load_model()
83
 
84
- # Gradio interface configuration
85
  interface = gr.Interface(
86
  fn=enhance_image,
87
  inputs=[
88
- gr.Image(
89
- label="Input Image",
90
- type="pil",
91
- image_mode="RGB",
92
- sources=["upload"],
93
- elem_id="input_image"
94
- ),
95
- gr.Slider(
96
- minimum=2.0,
97
- maximum=4.0,
98
- value=2.0,
99
- step=2.0,
100
- label="Upscale Factor",
101
- info="Select 2x or 4x upscaling"
102
- )
103
- ],
104
- outputs=gr.Image(
105
- label="Enhanced Image",
106
- type="pil",
107
- elem_id="output_image"
108
- ),
109
- title="🖼️ AI Image Enhancer",
110
- description="Enhance image quality using ESRGAN super-resolution (2x/4x upscaling)",
111
- examples=[
112
- ["examples/example1.jpg", 2.0],
113
- ["examples/example2.png", 4.0]
114
  ],
115
- allow_flagging="never",
116
- css="""
117
- footer {visibility: hidden}
118
- .gradio-container {max-width: 800px !important}
119
- """
120
  )
121
 
122
- # Deployment configuration
123
  if __name__ == "__main__":
124
- interface.launch(
125
- server_name="0.0.0.0",
126
- server_port=7860,
127
- show_error=True,
128
- debug=False
129
- )
 
3
  import numpy as np
4
  from PIL import Image
5
  from torchvision.transforms import ToTensor, ToPILImage
6
+ from urllib.request import urlretrieve
7
+ import os
8
 
9
  # Device configuration
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Constants
13
+ MODEL_URL = "https://github.com/xinntao/ESRGAN/releases/download/v0.1.1/RRDB_ESRGAN_x4.pth"
14
+ MODEL_PATH = "RRDB_ESRGAN_x4.pth"
15
  MAX_IMAGE_SIZE = (1024, 1024)
16
 
17
+ # ESRGAN model architecture
18
+ class RRDBNet(torch.nn.Module):
19
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
20
+ super(RRDBNet, self).__init__()
21
+ self.model = self._make_network(in_nc, out_nc, nf, nb, gc)
22
+
23
+ def _make_network(self, in_nc, out_nc, nf, nb, gc):
24
+ # [Original architecture implementation here...]
25
+ # Full implementation: https://github.com/xinntao/ESRGAN/blob/master/RRDBNet_arch.py
26
+
27
  def load_model() -> torch.nn.Module:
28
+ """Download and load ESRGAN model"""
29
+ if not os.path.exists(MODEL_PATH):
30
+ print("Downloading ESRGAN model...")
31
+ urlretrieve(MODEL_URL, MODEL_PATH)
32
+
33
+ model = RRDBNet()
34
+ state_dict = torch.load(MODEL_PATH, map_location=device)
35
+ model.load_state_dict(state_dict)
36
  return model.to(device).eval()
37
 
38
  def preprocess_image(image: Image.Image) -> torch.Tensor:
39
+ """Convert PIL image to normalized tensor"""
40
  transform = ToTensor()
41
+ return transform(image).unsqueeze(0).to(device)
 
42
 
43
  def postprocess_image(tensor: torch.Tensor) -> Image.Image:
44
+ """Convert tensor to PIL image"""
45
  transform = ToPILImage()
 
46
  tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
47
  return transform(tensor)
48
 
49
+ def validate_image(image: Image.Image):
50
+ """Validate input image constraints"""
51
  if image.mode not in ["RGB", "RGBA"]:
52
  raise gr.Error("Only RGB/RGBA images supported")
53
+ if max(image.size) > max(MAX_IMAGE_SIZE):
54
+ raise gr.Error(f"Max image dimension exceeded ({MAX_IMAGE_SIZE[0]}x{MAX_IMAGE_SIZE[1]})")
55
 
56
  def enhance_image(
57
  input_image: Image.Image,
58
  scale_factor: float = 2.0
59
  ) -> Image.Image:
60
+ """Main processing function"""
 
 
 
 
 
 
 
61
  try:
62
  validate_image(input_image)
 
63
 
64
+ # Convert RGBA to RGB
65
  if input_image.mode == 'RGBA':
66
  input_image = input_image.convert('RGB')
67
 
 
70
  output_tensor = model(input_tensor)
71
 
72
  result = postprocess_image(output_tensor)
73
+ return result.resize(
74
+ (int(input_image.width*scale_factor),
75
+ int(input_image.height*scale_factor)),
76
  Image.LANCZOS
77
  )
78
+
 
79
  except Exception as e:
80
+ raise gr.Error(f"Processing error: {str(e)}")
81
 
82
+ # Initialize model
83
  model = load_model()
84
 
85
+ # Gradio interface
86
  interface = gr.Interface(
87
  fn=enhance_image,
88
  inputs=[
89
+ gr.Image(type="pil", label="Input Image"),
90
+ gr.Slider(2.0, 4.0, 2.0, step=2.0, label="Scale Factor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ],
92
+ outputs=gr.Image(type="pil", label="Enhanced Image"),
93
+ title="🎨 AI Image Enhancer",
94
+ examples=[["examples/example1.jpg", 2.0]],
95
+ css=".gradio-container {max-width: 800px !important}"
 
96
  )
97
 
 
98
  if __name__ == "__main__":
99
+ interface.launch(server_name="0.0.0.0")