Spaces:
Runtime error
Runtime error
Minor update.
Browse files
app.py
CHANGED
|
@@ -22,7 +22,7 @@ import spaces
|
|
| 22 |
TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
|
| 23 |
DESCRIPTION = """
|
| 24 |
<div>
|
| 25 |
-
Try <a href='https://arxiv.org/abs/24xx.xxxxx'><b>DragAPart</b></a> yourself to manipulate your favorite articulated objects in
|
| 26 |
</div>
|
| 27 |
"""
|
| 28 |
INSTRUCTION = '''
|
|
@@ -185,11 +185,8 @@ def single_image_sample(
|
|
| 185 |
drags,
|
| 186 |
hidden_cls,
|
| 187 |
num_steps=50,
|
| 188 |
-
vae=None,
|
| 189 |
):
|
| 190 |
z = torch.randn(2, 4, 32, 32).to("cuda")
|
| 191 |
-
if vae is not None:
|
| 192 |
-
vae = vae.to("cuda")
|
| 193 |
|
| 194 |
# Prepare input for classifer-free guidance
|
| 195 |
rel = torch.cat([rel, rel], dim=0).to("cuda")
|
|
@@ -226,9 +223,7 @@ def single_image_sample(
|
|
| 226 |
|
| 227 |
samples, _ = samples.chunk(2, dim=0)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
images = vae.decode(samples / 0.18215).sample
|
| 231 |
-
return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 232 |
|
| 233 |
@spaces.GPU
|
| 234 |
def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
|
|
@@ -278,7 +273,7 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
|
|
| 278 |
if idx == 9:
|
| 279 |
break
|
| 280 |
|
| 281 |
-
|
| 282 |
model.to("cuda"),
|
| 283 |
diffusion,
|
| 284 |
x_cond,
|
|
@@ -289,9 +284,10 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
|
|
| 289 |
drags,
|
| 290 |
cls_embedding,
|
| 291 |
num_steps=50,
|
| 292 |
-
vae=vae,
|
| 293 |
)
|
| 294 |
-
|
|
|
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
sam_predictor = sam_init()
|
|
|
|
| 22 |
TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
|
| 23 |
DESCRIPTION = """
|
| 24 |
<div>
|
| 25 |
+
Try <a href='https://arxiv.org/abs/24xx.xxxxx'><b>DragAPart</b></a> yourself to manipulate your favorite articulated objects in seconds!
|
| 26 |
</div>
|
| 27 |
"""
|
| 28 |
INSTRUCTION = '''
|
|
|
|
| 185 |
drags,
|
| 186 |
hidden_cls,
|
| 187 |
num_steps=50,
|
|
|
|
| 188 |
):
|
| 189 |
z = torch.randn(2, 4, 32, 32).to("cuda")
|
|
|
|
|
|
|
| 190 |
|
| 191 |
# Prepare input for classifer-free guidance
|
| 192 |
rel = torch.cat([rel, rel], dim=0).to("cuda")
|
|
|
|
| 223 |
|
| 224 |
samples, _ = samples.chunk(2, dim=0)
|
| 225 |
|
| 226 |
+
return samples
|
|
|
|
|
|
|
| 227 |
|
| 228 |
@spaces.GPU
|
| 229 |
def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
|
|
|
|
| 273 |
if idx == 9:
|
| 274 |
break
|
| 275 |
|
| 276 |
+
samples = single_image_sample(
|
| 277 |
model.to("cuda"),
|
| 278 |
diffusion,
|
| 279 |
x_cond,
|
|
|
|
| 284 |
drags,
|
| 285 |
cls_embedding,
|
| 286 |
num_steps=50,
|
|
|
|
| 287 |
)
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
images = vae.decode(samples / 0.18215).sample
|
| 290 |
+
return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 291 |
|
| 292 |
|
| 293 |
sam_predictor = sam_init()
|