recoilme commited on
Commit
a0a72cb
·
verified ·
1 Parent(s): 1eb4d93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -33,7 +33,7 @@ def get_vae():
33
  _vae = load_vae()
34
  return _vae
35
 
36
- @spaces.GPU(duration=5)
37
  def encode_decode(img: Image.Image):
38
  vae = get_vae()
39
  img = img.convert("RGB")
@@ -45,8 +45,7 @@ def encode_decode(img: Image.Image):
45
  t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
46
 
47
  with torch.no_grad():
48
- enc = vae.encode(t)
49
- lat = enc.latent_dist.mean
50
  dec = vae.decode(lat).sample
51
 
52
  x = (dec.clamp(-1, 1) + 1) * 127.5
@@ -54,10 +53,10 @@ def encode_decode(img: Image.Image):
54
  out = Image.fromarray(x)
55
 
56
  # Временный PNG
57
- tmp_path = os.path.join(tempfile.gettempdir(), "decoded.png")
58
- out.save(tmp_path, format="PNG")
59
 
60
- return out, tmp_path
61
 
62
  with gr.Blocks(title="Asymmetric VAE 2x Upscaler") as demo:
63
  gr.Markdown("""
@@ -71,7 +70,7 @@ with gr.Blocks(title="Asymmetric VAE 2x Upscaler") as demo:
71
  out = gr.Image(type="pil", label="Decoded output")
72
  download = gr.File(label="Download result (PNG)")
73
 
74
- run_btn.click(fn=encode_decode, inputs=[inp], outputs=[out, download])
75
 
76
  if __name__ == "__main__":
77
  demo.launch()
 
33
  _vae = load_vae()
34
  return _vae
35
 
36
+ @spaces.GPU(duration=50)
37
  def encode_decode(img: Image.Image):
38
  vae = get_vae()
39
  img = img.convert("RGB")
 
45
  t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
46
 
47
  with torch.no_grad():
48
+ lat = vae.encode(t).latent_dist.sample()
 
49
  dec = vae.decode(lat).sample
50
 
51
  x = (dec.clamp(-1, 1) + 1) * 127.5
 
53
  out = Image.fromarray(x)
54
 
55
  # Временный PNG
56
+ #tmp_path = os.path.join(tempfile.gettempdir(), "decoded.png")
57
+ #out.save(tmp_path, format="PNG")
58
 
59
+ return out#, tmp_path
60
 
61
  with gr.Blocks(title="Asymmetric VAE 2x Upscaler") as demo:
62
  gr.Markdown("""
 
70
  out = gr.Image(type="pil", label="Decoded output")
71
  download = gr.File(label="Download result (PNG)")
72
 
73
+ run_btn.click(fn=encode_decode, inputs=[inp], outputs=[out]#, download])
74
 
75
  if __name__ == "__main__":
76
  demo.launch()