devbernie commited on
Commit
f73e05b
·
verified ·
1 Parent(s): 1d8c753

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor, ToPILImage
6
+ from typing import Tuple, Optional
7
+
8
+ # Device configuration
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Constants
12
+ SUPPORTED_FORMATS = ["JPEG", "PNG", "WEBP"]
13
+ MAX_IMAGE_SIZE = (1024, 1024)
14
+
15
+ def load_model() -> torch.nn.Module:
16
+ """Load pretrained ESRGAN model from torch hub"""
17
+ model = torch.hub.load(
18
+ "pytorch/vision",
19
+ "esrgan",
20
+ pretrained=True,
21
+ verbose=False
22
+ )
23
+ return model.to(device).eval()
24
+
25
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
26
+ """Convert PIL image to preprocessed tensor"""
27
+ transform = ToTensor()
28
+ tensor = transform(image).unsqueeze(0).to(device)
29
+ return tensor
30
+
31
+ def postprocess_image(tensor: torch.Tensor) -> Image.Image:
32
+ """Convert model output tensor to PIL image"""
33
+ transform = ToPILImage()
34
+ tensor = tensor.squeeze(0).detach().cpu()
35
+ tensor = torch.clamp(tensor, 0, 1)
36
+ return transform(tensor)
37
+
38
+ def validate_image(image: Image.Image) -> None:
39
+ """Validate input image dimensions and format"""
40
+ if image.mode not in ["RGB", "RGBA"]:
41
+ raise gr.Error("Only RGB/RGBA images supported")
42
+ if image.size > MAX_IMAGE_SIZE:
43
+ raise gr.Error(f"Max image size {MAX_IMAGE_SIZE} exceeded")
44
+
45
+ def enhance_image(
46
+ input_image: Image.Image,
47
+ scale_factor: float = 2.0
48
+ ) -> Image.Image:
49
+ """
50
+ Enhance image using ESRGAN model
51
+ Args:
52
+ input_image: PIL Image to process
53
+ scale_factor: Multiplier for image scaling (1.0-4.0)
54
+ Returns:
55
+ Enhanced PIL Image
56
+ """
57
+ try:
58
+ # Input validation
59
+ validate_image(input_image)
60
+
61
+ # Model processing
62
+ with torch.no_grad():
63
+ input_tensor = preprocess_image(input_image)
64
+ output_tensor = model(input_tensor)
65
+
66
+ return postprocess_image(output_tensor)
67
+
68
+ except Exception as e:
69
+ raise gr.Error(f"Image processing failed: {str(e)}")
70
+
71
+ # Load model once at startup
72
+ model = load_model()
73
+
74
+ # Gradio interface configuration
75
+ interface = gr.Interface(
76
+ fn=enhance_image,
77
+ inputs=[
78
+ gr.Image(
79
+ label="Input Image",
80
+ type="pil",
81
+ image_mode="RGB",
82
+ sources=["upload"],
83
+ elem_id="input_image"
84
+ ),
85
+ gr.Slider(
86
+ minimum=1.0,
87
+ maximum=4.0,
88
+ value=2.0,
89
+ step=0.5,
90
+ label="Scale Factor",
91
+ info="Select upscaling multiplier (1x to 4x)"
92
+ )
93
+ ],
94
+ outputs=gr.Image(
95
+ label="Enhanced Image",
96
+ type="pil",
97
+ elem_id="output_image"
98
+ ),
99
+ title="🖼️ AI Image Enhancer",
100
+ description="Enhance image quality using ESRGAN super-resolution model (Supports 2x-4x upscaling)",
101
+ examples=[
102
+ ["examples/example1.jpg", 2.0],
103
+ ["examples/example2.png", 4.0]
104
+ ],
105
+ allow_flagging="never",
106
+ css="footer {visibility: hidden}"
107
+ )
108
+
109
+ # Deployment configuration
110
+ if __name__ == "__main__":
111
+ interface.launch(
112
+ server_name="0.0.0.0",
113
+ server_port=7860,
114
+ show_error=True,
115
+ debug=False
116
+ )