dev1461 commited on
Commit
3c9c540
·
verified ·
1 Parent(s): a65f1cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -68,6 +68,11 @@ transform = transforms.Compose([
68
  # INFERENCE FUNCTION
69
  # ---------------------------
70
 
 
 
 
 
 
71
  def enhance_image(input_image):
72
  img = input_image.convert("RGB")
73
 
@@ -79,20 +84,27 @@ def enhance_image(input_image):
79
  output_img = output.squeeze().permute(1,2,0).cpu().numpy()
80
  output_img = (output_img * 255).astype(np.uint8)
81
 
82
- return output_img
 
 
 
 
83
 
84
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown("# 🔍 Image Super Resolution")
87
-
88
  input_img = gr.Image(type="pil", label="Upload Image")
89
- output_img = gr.Image(type="numpy", label="Enhanced Image")
 
90
 
91
  btn = gr.Button("Enhance Image")
92
 
93
- btn.click(fn=enhance_image, inputs=input_img, outputs=output_img)
94
-
95
- gr.DownloadButton(label="Download Enhanced Image", data=output_img)
 
 
96
 
97
  demo.launch()
98
 
 
68
  # INFERENCE FUNCTION
69
  # ---------------------------
70
 
71
+ import gradio as gr
72
+ import numpy as np
73
+ from PIL import Image
74
+ import tempfile
75
+
76
  def enhance_image(input_image):
77
  img = input_image.convert("RGB")
78
 
 
84
  output_img = output.squeeze().permute(1,2,0).cpu().numpy()
85
  output_img = (output_img * 255).astype(np.uint8)
86
 
87
+ # 🔥 SAVE TEMP FILE FOR DOWNLOAD
88
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
89
+ Image.fromarray(output_img).save(temp_file.name)
90
+
91
+ return output_img, temp_file.name
92
 
93
 
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# 🔍 Image Super Resolution")
96
+
97
  input_img = gr.Image(type="pil", label="Upload Image")
98
+ output_img = gr.Image(label="Enhanced Image")
99
+ download_file = gr.File(label="Download Image")
100
 
101
  btn = gr.Button("Enhance Image")
102
 
103
+ btn.click(
104
+ fn=enhance_image,
105
+ inputs=input_img,
106
+ outputs=[output_img, download_file]
107
+ )
108
 
109
  demo.launch()
110