Commit
·
d5c0caa
1
Parent(s):
dcd6671
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,6 +15,7 @@ from models.vae_flow import *
|
|
| 15 |
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
|
| 16 |
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
ckpt_airplane = torch.load(airplane)
|
| 20 |
ckpt_chair = torch.load(chair)
|
|
@@ -42,15 +43,15 @@ def predict(Seed,ckpt):
|
|
| 42 |
seed_all(Seed)
|
| 43 |
|
| 44 |
if ckpt['args'].model == 'gaussian':
|
| 45 |
-
model = GaussianVAE(ckpt['args']).to(
|
| 46 |
elif ckpt['args'].model == 'flow':
|
| 47 |
-
model = FlowVAE(ckpt['args']).to(
|
| 48 |
|
| 49 |
model.load_state_dict(ckpt['state_dict'])
|
| 50 |
# Generate Point Clouds
|
| 51 |
gen_pcs = []
|
| 52 |
with torch.no_grad():
|
| 53 |
-
z = torch.randn([1, ckpt['args'].latent_dim]).to(
|
| 54 |
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
|
| 55 |
gen_pcs.append(x.detach().cpu())
|
| 56 |
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
|
|
|
|
| 15 |
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
|
| 16 |
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
|
| 17 |
|
| 18 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 19 |
|
| 20 |
ckpt_airplane = torch.load(airplane)
|
| 21 |
ckpt_chair = torch.load(chair)
|
|
|
|
| 43 |
seed_all(Seed)
|
| 44 |
|
| 45 |
if ckpt['args'].model == 'gaussian':
|
| 46 |
+
model = GaussianVAE(ckpt['args']).to(device)
|
| 47 |
elif ckpt['args'].model == 'flow':
|
| 48 |
+
model = FlowVAE(ckpt['args']).to(device)
|
| 49 |
|
| 50 |
model.load_state_dict(ckpt['state_dict'])
|
| 51 |
# Generate Point Clouds
|
| 52 |
gen_pcs = []
|
| 53 |
with torch.no_grad():
|
| 54 |
+
z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
|
| 55 |
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
|
| 56 |
gen_pcs.append(x.detach().cpu())
|
| 57 |
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
|