recoilme commited on
Commit
a27c2fe
·
verified ·
1 Parent(s): 0c4b0f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -29
app.py CHANGED
@@ -4,73 +4,120 @@ import torchvision.transforms as T
4
  from PIL import Image
5
  from diffusers import AsymmetricAutoencoderKL
6
  import spaces
7
- import io
8
- import tempfile
9
- import os
10
 
11
  MODEL_ID = "babkasotona/vae8x16x32ch"
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- DTYPE = torch.float16
14
 
15
- def load_vae(model_id=MODEL_ID, device=DEVICE):
 
 
 
 
16
  for attempt in (None, "vae"):
17
  try:
18
  if attempt is None:
19
- vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE)
 
 
 
20
  else:
21
- vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE)
22
- vae.to(device)
23
- vae.eval().half()
 
 
 
 
 
 
 
24
  return vae
 
25
  except Exception as e:
26
  last_err = e
27
- raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}")
 
 
28
 
29
  _vae = None
 
 
30
  def get_vae():
31
  global _vae
32
  if _vae is None:
33
  _vae = load_vae()
34
  return _vae
35
 
 
 
 
 
36
  @spaces.GPU(duration=50)
37
  def encode_decode(img: Image.Image):
 
 
 
 
38
  vae = get_vae()
 
39
  img = img.convert("RGB")
40
 
41
  tfm = T.Compose([
42
  T.ToTensor(),
43
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
44
  ])
45
- t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
 
 
 
46
 
47
  with torch.no_grad():
48
- lat = vae.encode(t).latent_dist.sample()
 
 
 
 
 
49
  dec = vae.decode(lat).sample
50
 
51
  x = (dec.clamp(-1, 1) + 1) * 127.5
52
- x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy()
 
 
 
53
  out = Image.fromarray(x)
54
 
55
- # Временный PNG
56
- #tmp_path = os.path.join(tempfile.gettempdir(), "decoded.png")
57
- #out.save(tmp_path, format="PNG")
58
 
59
- return out#, tmp_path
60
 
 
 
 
61
  with gr.Blocks(title="Asymmetric VAE 2x Upscaler") as demo:
62
- gr.Markdown("""
63
- # 🧠 Asymmetric VAE 2x Upscaler
64
- Загрузите изображение → нажмите **"Upscale"**
65
- """)
66
 
67
- with gr.Column():
68
- inp = gr.Image(type="pil", label="Upload image")
69
- run_btn = gr.Button("Upscale")
70
- out = gr.Image(type="pil", label="Decoded output")
71
- download = gr.File(label="Download result (PNG)")
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- run_btn.click(fn=encode_decode, inputs=[inp], outputs=[out])#, download])
74
 
 
 
 
75
  if __name__ == "__main__":
76
- demo.launch()
 
4
  from PIL import Image
5
  from diffusers import AsymmetricAutoencoderKL
6
  import spaces
 
 
 
7
 
8
  MODEL_ID = "babkasotona/vae8x16x32ch"
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
11
 
12
+
13
+ # -------------------------
14
+ # Load VAE
15
+ # -------------------------
16
+ def load_vae(model_id=MODEL_ID):
17
  for attempt in (None, "vae"):
18
  try:
19
  if attempt is None:
20
+ vae = AsymmetricAutoencoderKL.from_pretrained(
21
+ model_id,
22
+ torch_dtype=DTYPE
23
+ )
24
  else:
25
+ vae = AsymmetricAutoencoderKL.from_pretrained(
26
+ model_id,
27
+ subfolder=attempt,
28
+ torch_dtype=DTYPE
29
+ )
30
+
31
+ vae = vae.to(DEVICE)
32
+ vae.eval()
33
+
34
+ print("VAE loaded on", DEVICE)
35
  return vae
36
+
37
  except Exception as e:
38
  last_err = e
39
+
40
+ raise RuntimeError(f"Failed to load VAE: {last_err}")
41
+
42
 
43
  _vae = None
44
+
45
+
46
  def get_vae():
47
  global _vae
48
  if _vae is None:
49
  _vae = load_vae()
50
  return _vae
51
 
52
+
53
+ # -------------------------
54
+ # Encode / Decode
55
+ # -------------------------
56
  @spaces.GPU(duration=50)
57
  def encode_decode(img: Image.Image):
58
+
59
+ if img is None:
60
+ raise gr.Error("Please upload an image")
61
+
62
  vae = get_vae()
63
+
64
  img = img.convert("RGB")
65
 
66
  tfm = T.Compose([
67
  T.ToTensor(),
68
+ T.Normalize([0.5]*3, [0.5]*3),
69
  ])
70
+
71
+ t = tfm(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
72
+
73
+ print("Input tensor:", t.shape, t.dtype, t.device)
74
 
75
  with torch.no_grad():
76
+
77
+ enc = vae.encode(t)
78
+ lat = enc.latent_dist.sample()
79
+
80
+ print("Latents:", lat.shape)
81
+
82
  dec = vae.decode(lat).sample
83
 
84
  x = (dec.clamp(-1, 1) + 1) * 127.5
85
+ x = x.round().to(torch.uint8)
86
+
87
+ x = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
88
+
89
  out = Image.fromarray(x)
90
 
91
+ print("Output size:", out.size)
92
+
93
+ return out
94
 
 
95
 
96
+ # -------------------------
97
+ # UI
98
+ # -------------------------
99
  with gr.Blocks(title="Asymmetric VAE 2x Upscaler") as demo:
 
 
 
 
100
 
101
+ gr.Markdown(
102
+ "# 🧠 Asymmetric VAE Upscaler\n"
103
+ "Upload image → press **Upscale**"
104
+ )
105
+
106
+ inp = gr.Image(type="pil", label="Upload image")
107
+
108
+ run_btn = gr.Button("Upscale")
109
+
110
+ out = gr.Image(label="Decoded output")
111
+
112
+ run_btn.click(
113
+ fn=encode_decode,
114
+ inputs=inp,
115
+ outputs=out
116
+ )
117
 
 
118
 
119
+ # -------------------------
120
+ # Launch
121
+ # -------------------------
122
  if __name__ == "__main__":
123
+ demo.launch()