devbernie commited on
Commit
ecf2564
·
verified ·
1 Parent(s): f73e05b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -13,12 +13,13 @@ SUPPORTED_FORMATS = ["JPEG", "PNG", "WEBP"]
13
  MAX_IMAGE_SIZE = (1024, 1024)
14
 
15
  def load_model() -> torch.nn.Module:
16
- """Load pretrained ESRGAN model from torch hub"""
17
  model = torch.hub.load(
18
- "pytorch/vision",
19
  "esrgan",
20
  pretrained=True,
21
- verbose=False
 
22
  )
23
  return model.to(device).eval()
24
 
@@ -26,20 +27,20 @@ def preprocess_image(image: Image.Image) -> torch.Tensor:
26
  """Convert PIL image to preprocessed tensor"""
27
  transform = ToTensor()
28
  tensor = transform(image).unsqueeze(0).to(device)
29
- return tensor
30
 
31
  def postprocess_image(tensor: torch.Tensor) -> Image.Image:
32
  """Convert model output tensor to PIL image"""
33
  transform = ToPILImage()
34
- tensor = tensor.squeeze(0).detach().cpu()
35
- tensor = torch.clamp(tensor, 0, 1)
36
  return transform(tensor)
37
 
38
  def validate_image(image: Image.Image) -> None:
39
  """Validate input image dimensions and format"""
40
  if image.mode not in ["RGB", "RGBA"]:
41
  raise gr.Error("Only RGB/RGBA images supported")
42
- if image.size > MAX_IMAGE_SIZE:
43
  raise gr.Error(f"Max image size {MAX_IMAGE_SIZE} exceeded")
44
 
45
  def enhance_image(
@@ -50,23 +51,32 @@ def enhance_image(
50
  Enhance image using ESRGAN model
51
  Args:
52
  input_image: PIL Image to process
53
- scale_factor: Multiplier for image scaling (1.0-4.0)
54
  Returns:
55
  Enhanced PIL Image
56
  """
57
  try:
58
- # Input validation
59
  validate_image(input_image)
 
 
 
 
 
60
 
61
- # Model processing
62
  with torch.no_grad():
63
  input_tensor = preprocess_image(input_image)
64
  output_tensor = model(input_tensor)
65
 
66
- return postprocess_image(output_tensor)
 
 
 
 
 
 
67
 
68
  except Exception as e:
69
- raise gr.Error(f"Image processing failed: {str(e)}")
70
 
71
  # Load model once at startup
72
  model = load_model()
@@ -83,12 +93,12 @@ interface = gr.Interface(
83
  elem_id="input_image"
84
  ),
85
  gr.Slider(
86
- minimum=1.0,
87
  maximum=4.0,
88
  value=2.0,
89
- step=0.5,
90
- label="Scale Factor",
91
- info="Select upscaling multiplier (1x to 4x)"
92
  )
93
  ],
94
  outputs=gr.Image(
@@ -97,13 +107,16 @@ interface = gr.Interface(
97
  elem_id="output_image"
98
  ),
99
  title="🖼️ AI Image Enhancer",
100
- description="Enhance image quality using ESRGAN super-resolution model (Supports 2x-4x upscaling)",
101
  examples=[
102
  ["examples/example1.jpg", 2.0],
103
  ["examples/example2.png", 4.0]
104
  ],
105
  allow_flagging="never",
106
- css="footer {visibility: hidden}"
 
 
 
107
  )
108
 
109
  # Deployment configuration
 
13
  MAX_IMAGE_SIZE = (1024, 1024)
14
 
15
  def load_model() -> torch.nn.Module:
16
+ """Load pretrained ESRGAN model"""
17
  model = torch.hub.load(
18
+ "facebookresearch/AnimatedDrawings",
19
  "esrgan",
20
  pretrained=True,
21
+ verbose=False,
22
+ trust_repo=True
23
  )
24
  return model.to(device).eval()
25
 
 
27
  """Convert PIL image to preprocessed tensor"""
28
  transform = ToTensor()
29
  tensor = transform(image).unsqueeze(0).to(device)
30
+ return tensor * 2.0 - 1.0 # ESRGAN requires [-1,1] normalization
31
 
32
  def postprocess_image(tensor: torch.Tensor) -> Image.Image:
33
  """Convert model output tensor to PIL image"""
34
  transform = ToPILImage()
35
+ tensor = (tensor + 1.0) / 2.0 # Convert back to [0,1]
36
+ tensor = tensor.squeeze(0).detach().cpu().clamp(0, 1)
37
  return transform(tensor)
38
 
39
  def validate_image(image: Image.Image) -> None:
40
  """Validate input image dimensions and format"""
41
  if image.mode not in ["RGB", "RGBA"]:
42
  raise gr.Error("Only RGB/RGBA images supported")
43
+ if image.size[0] > MAX_IMAGE_SIZE[0] or image.size[1] > MAX_IMAGE_SIZE[1]:
44
  raise gr.Error(f"Max image size {MAX_IMAGE_SIZE} exceeded")
45
 
46
  def enhance_image(
 
51
  Enhance image using ESRGAN model
52
  Args:
53
  input_image: PIL Image to process
54
+ scale_factor: Multiplier for image scaling (2.0 or 4.0)
55
  Returns:
56
  Enhanced PIL Image
57
  """
58
  try:
 
59
  validate_image(input_image)
60
+ original_size = input_image.size
61
+
62
+ # Convert RGBA to RGB if needed
63
+ if input_image.mode == 'RGBA':
64
+ input_image = input_image.convert('RGB')
65
 
 
66
  with torch.no_grad():
67
  input_tensor = preprocess_image(input_image)
68
  output_tensor = model(input_tensor)
69
 
70
+ result = postprocess_image(output_tensor)
71
+ result = result.resize(
72
+ (int(original_size[0]*scale_factor),
73
+ int(original_size[1]*scale_factor)),
74
+ Image.LANCZOS
75
+ )
76
+ return result
77
 
78
  except Exception as e:
79
+ raise gr.Error(f"Image processing error: {str(e)}")
80
 
81
  # Load model once at startup
82
  model = load_model()
 
93
  elem_id="input_image"
94
  ),
95
  gr.Slider(
96
+ minimum=2.0,
97
  maximum=4.0,
98
  value=2.0,
99
+ step=2.0,
100
+ label="Upscale Factor",
101
+ info="Select 2x or 4x upscaling"
102
  )
103
  ],
104
  outputs=gr.Image(
 
107
  elem_id="output_image"
108
  ),
109
  title="🖼️ AI Image Enhancer",
110
+ description="Enhance image quality using ESRGAN super-resolution (2x/4x upscaling)",
111
  examples=[
112
  ["examples/example1.jpg", 2.0],
113
  ["examples/example2.png", 4.0]
114
  ],
115
  allow_flagging="never",
116
+ css="""
117
+ footer {visibility: hidden}
118
+ .gradio-container {max-width: 800px !important}
119
+ """
120
  )
121
 
122
  # Deployment configuration