Spaces:
Paused
Paused
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import uuid
|
|
|
|
| 3 |
|
| 4 |
-
from diffusers import AnimateDiffPipeline,
|
| 5 |
from diffusers.utils import export_to_video
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from safetensors.torch import load_file
|
|
@@ -30,9 +31,9 @@ motions = {
|
|
| 30 |
"Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
|
| 31 |
"Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
|
| 32 |
}
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
| 38 |
if not torch.cuda.is_available():
|
|
@@ -40,8 +41,22 @@ if not torch.cuda.is_available():
|
|
| 40 |
|
| 41 |
device = "cuda"
|
| 42 |
dtype = torch.float16
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Safety checkers
|
| 47 |
from transformers import CLIPFeatureExtractor
|
|
@@ -56,10 +71,6 @@ class GenerateImageRequest(BaseModel):
|
|
| 56 |
|
| 57 |
@app.post("/generate-image")
|
| 58 |
def generate_image(request: GenerateImageRequest):
|
| 59 |
-
global step_loaded
|
| 60 |
-
global base_loaded
|
| 61 |
-
global motion_loaded
|
| 62 |
-
|
| 63 |
prompt = request.prompt
|
| 64 |
base = request.base
|
| 65 |
motion = request.motion
|
|
@@ -67,23 +78,18 @@ def generate_image(request: GenerateImageRequest):
|
|
| 67 |
|
| 68 |
print(prompt, base, step)
|
| 69 |
|
| 70 |
-
if
|
| 71 |
-
|
| 72 |
-
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
| 73 |
-
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
| 74 |
-
step_loaded = step
|
| 75 |
|
| 76 |
-
|
| 77 |
-
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
|
| 78 |
-
base_loaded = base
|
| 79 |
|
| 80 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
pipe.unload_lora_weights()
|
| 82 |
-
if motion in motions:
|
| 83 |
-
motion_repo = motions[motion]
|
| 84 |
-
pipe.load_lora_weights(motion_repo, adapter_name="motion")
|
| 85 |
-
pipe.set_adapters(["motion"], [0.7])
|
| 86 |
-
motion_loaded = motion
|
| 87 |
|
| 88 |
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
|
| 89 |
|
|
|
|
| 1 |
import torch
|
| 2 |
import uuid
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
+
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
|
| 6 |
from diffusers.utils import export_to_video
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from safetensors.torch import load_file
|
|
|
|
| 31 |
"Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
|
| 32 |
"Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
|
| 33 |
}
|
| 34 |
+
steps = [4,8] # Different steps you want to pre-load
|
| 35 |
+
models = {}
|
| 36 |
+
motions_loaded = {}
|
| 37 |
|
| 38 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
| 39 |
if not torch.cuda.is_available():
|
|
|
|
| 41 |
|
| 42 |
device = "cuda"
|
| 43 |
dtype = torch.float16
|
| 44 |
+
|
| 45 |
+
# Load all base models and steps
|
| 46 |
+
for base_name, base_repo in bases.items():
|
| 47 |
+
models[base_name] = {}
|
| 48 |
+
for step in steps:
|
| 49 |
+
repo = "ByteDance/AnimateDiff-Lightning"
|
| 50 |
+
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
| 51 |
+
model = AnimateDiffPipeline.from_pretrained(base_repo, torch_dtype=dtype).to(device)
|
| 52 |
+
model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
| 53 |
+
model.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
| 54 |
+
models[base_name][step] = model
|
| 55 |
+
|
| 56 |
+
# Load all motion models
|
| 57 |
+
for motion_name, motion_repo in motions.items():
|
| 58 |
+
motion_weights = hf_hub_download(motion_repo, "pytorch_model.bin")
|
| 59 |
+
motions_loaded[motion_name] = torch.load(motion_weights, map_location=device)
|
| 60 |
|
| 61 |
# Safety checkers
|
| 62 |
from transformers import CLIPFeatureExtractor
|
|
|
|
| 71 |
|
| 72 |
@app.post("/generate-image")
|
| 73 |
def generate_image(request: GenerateImageRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
prompt = request.prompt
|
| 75 |
base = request.base
|
| 76 |
motion = request.motion
|
|
|
|
| 78 |
|
| 79 |
print(prompt, base, step)
|
| 80 |
|
| 81 |
+
if base not in models or step not in models[base]:
|
| 82 |
+
raise HTTPException(status_code=400, detail="Invalid base model or step")
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
pipe = models[base][step]
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
if motion:
|
| 87 |
+
if motion not in motions_loaded:
|
| 88 |
+
raise HTTPException(status_code=400, detail="Invalid motion")
|
| 89 |
+
pipe.unet.load_state_dict(motions_loaded[motion], strict=False)
|
| 90 |
+
pipe.set_adapters(["motion"], [0.7])
|
| 91 |
+
else:
|
| 92 |
pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
|
| 95 |
|