Janeka commited on
Commit
01831da
·
verified ·
1 Parent(s): 4f91b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -35
app.py CHANGED
@@ -1,45 +1,48 @@
1
  import gradio as gr
2
  import numpy as np
3
- import cv2
4
- import torch
5
  from PIL import Image
 
 
6
 
7
- # Simple image enhancement function (no complex dependencies)
8
- def enhance_image(input_img, contrast=1.2, brightness=10, sharpness=2.0):
9
- # Convert to OpenCV format
10
- img = np.array(input_img)
11
-
12
- # Contrast and brightness adjustment
13
- img = cv2.convertScaleAbs(img, alpha=contrast, beta=brightness)
14
-
15
- # Sharpening
16
- kernel = np.array([[-1,-1,-1],
17
- [-1,9,-1],
18
- [-1,-1,-1]])
19
- img = cv2.filter2D(img, -1, kernel)
20
-
21
- # Color correction
22
- img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
23
- l, a, b = cv2.split(img)
24
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
25
- l = clahe.apply(l)
26
- img = cv2.merge((l,a,b))
27
- img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
 
 
 
 
28
 
29
- return Image.fromarray(img)
 
 
 
30
 
31
- # Gradio interface with adjustable parameters
32
  demo = gr.Interface(
33
- fn=enhance_image,
34
- inputs=[
35
- gr.Image(type="pil", label="Input Image"),
36
- gr.Slider(0.5, 2.0, value=1.2, label="Contrast"),
37
- gr.Slider(0, 30, value=10, label="Brightness"),
38
- gr.Slider(0.5, 3.0, value=2.0, label="Sharpness")
39
- ],
40
- outputs=gr.Image(type="pil", label="Enhanced Image"),
41
- title="Professional Image Enhancement",
42
- examples=["example.jpg"] if os.path.exists("example.jpg") else None
43
  )
44
 
45
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import onnxruntime as ort
 
4
  from PIL import Image
5
+ import cv2
6
+ import os
7
 
8
+ # Load ONNX model
9
+ model_path = "esrgan.onnx" # Replace with your ONNX file name
10
+ ort_session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
11
+
12
+ def preprocess(img):
13
+ """Convert PIL image to ONNX-compatible input"""
14
+ img = np.array(img)
15
+ img = img.astype(np.float32) / 255.0 # Normalize to [0,1]
16
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # ESRGAN expects BGR
17
+ img = np.transpose(img, (2, 0, 1)) # HWC to CHW
18
+ img = np.expand_dims(img, axis=0) # Add batch dimension
19
+ return img
20
+
21
+ def postprocess(output):
22
+ """Convert model output to PIL image"""
23
+ output = output.squeeze() # Remove batch dim
24
+ output = np.transpose(output, (1, 2, 0)) # CHW to HWC
25
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) # BGR to RGB
26
+ output = (output * 255.0).clip(0, 255).astype(np.uint8) # Denormalize
27
+ return Image.fromarray(output)
28
+
29
+ def enhance(image):
30
+ # Resize if too large (free-tier GPU memory is limited)
31
+ if max(image.size) > 1024:
32
+ image = image.resize((512, 512))
33
 
34
+ # Preprocess → Inference → Postprocess
35
+ input_tensor = preprocess(image)
36
+ output = ort_session.run(None, {'input': input_tensor})[0]
37
+ return postprocess(output)
38
 
39
+ # Gradio Interface
40
  demo = gr.Interface(
41
+ fn=enhance,
42
+ inputs=gr.Image(type="pil", label="Input Image"),
43
+ outputs=gr.Image(type="pil", label="Enhanced"),
44
+ title="ESRGAN Image Enhancement (ONNX)",
45
+ examples=["example.jpg"] if os.path.exists("example.jpg") else None,
 
 
 
 
 
46
  )
47
 
48
  demo.launch()