Spaces:
Paused
Paused
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import uuid
|
| 3 |
-
import os
|
| 4 |
|
| 5 |
-
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
|
| 6 |
from diffusers.utils import export_to_video
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from safetensors.torch import load_file
|
|
@@ -31,9 +30,9 @@ motions = {
|
|
| 31 |
"Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
|
| 32 |
"Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
|
| 33 |
}
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
| 39 |
if not torch.cuda.is_available():
|
|
@@ -41,36 +40,29 @@ if not torch.cuda.is_available():
|
|
| 41 |
|
| 42 |
device = "cuda"
|
| 43 |
dtype = torch.float16
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
models[base_name] = {}
|
| 48 |
-
for step in steps:
|
| 49 |
-
repo = "ByteDance/AnimateDiff-Lightning"
|
| 50 |
-
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
| 51 |
-
model = AnimateDiffPipeline.from_pretrained(base_repo, torch_dtype=dtype).to(device)
|
| 52 |
-
model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
| 53 |
-
model.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
| 54 |
-
models[base_name][step] = model
|
| 55 |
-
|
| 56 |
-
# Load all motion models
|
| 57 |
-
for motion_name, motion_repo in motions.items():
|
| 58 |
-
motion_weights = hf_hub_download(motion_repo, "pytorch_model.bin")
|
| 59 |
-
motions_loaded[motion_name] = torch.load(motion_weights, map_location=device)
|
| 60 |
|
| 61 |
# Safety checkers
|
| 62 |
from transformers import CLIPFeatureExtractor
|
| 63 |
|
| 64 |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
| 65 |
|
|
|
|
| 66 |
class GenerateImageRequest(BaseModel):
|
| 67 |
prompt: str
|
| 68 |
base: str = "Realistic"
|
| 69 |
motion: str = ""
|
| 70 |
step: int = 8
|
| 71 |
|
|
|
|
| 72 |
@app.post("/generate-image")
|
| 73 |
def generate_image(request: GenerateImageRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
prompt = request.prompt
|
| 75 |
base = request.base
|
| 76 |
motion = request.motion
|
|
@@ -78,26 +70,34 @@ def generate_image(request: GenerateImageRequest):
|
|
| 78 |
|
| 79 |
print(prompt, base, step)
|
| 80 |
|
| 81 |
-
if
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
if motion:
|
| 87 |
-
if motion not in motions_loaded:
|
| 88 |
-
raise HTTPException(status_code=400, detail="Invalid motion")
|
| 89 |
-
pipe.unet.load_state_dict(motions_loaded[motion], strict=False)
|
| 90 |
-
pipe.set_adapters(["motion"], [0.7])
|
| 91 |
-
else:
|
| 92 |
pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
|
| 95 |
|
| 96 |
name = str(uuid.uuid4()).replace("-", "")
|
| 97 |
path = f"/tmp/{name}.mp4"
|
| 98 |
export_to_video(output.frames[0], path, fps=10)
|
| 99 |
-
|
| 100 |
return FileResponse(path, media_type="video/mp4", filename=f"{name}.mp4")
|
| 101 |
|
|
|
|
| 102 |
if __name__ == "__main__":
|
| 103 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
import torch
|
| 2 |
import uuid
|
|
|
|
| 3 |
|
| 4 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
|
| 5 |
from diffusers.utils import export_to_video
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from safetensors.torch import load_file
|
|
|
|
| 30 |
"Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
|
| 31 |
"Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
|
| 32 |
}
|
| 33 |
+
step_loaded = None
|
| 34 |
+
base_loaded = "Realistic"
|
| 35 |
+
motion_loaded = None
|
| 36 |
|
| 37 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
| 38 |
if not torch.cuda.is_available():
|
|
|
|
| 40 |
|
| 41 |
device = "cuda"
|
| 42 |
dtype = torch.float16
|
| 43 |
+
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
|
| 44 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing",
|
| 45 |
+
beta_schedule="linear")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# Safety checkers
|
| 48 |
from transformers import CLIPFeatureExtractor
|
| 49 |
|
| 50 |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
| 51 |
|
| 52 |
+
|
| 53 |
class GenerateImageRequest(BaseModel):
|
| 54 |
prompt: str
|
| 55 |
base: str = "Realistic"
|
| 56 |
motion: str = ""
|
| 57 |
step: int = 8
|
| 58 |
|
| 59 |
+
|
| 60 |
@app.post("/generate-image")
|
| 61 |
def generate_image(request: GenerateImageRequest):
|
| 62 |
+
global step_loaded
|
| 63 |
+
global base_loaded
|
| 64 |
+
global motion_loaded
|
| 65 |
+
|
| 66 |
prompt = request.prompt
|
| 67 |
base = request.base
|
| 68 |
motion = request.motion
|
|
|
|
| 70 |
|
| 71 |
print(prompt, base, step)
|
| 72 |
|
| 73 |
+
if step_loaded != step:
|
| 74 |
+
repo = "ByteDance/AnimateDiff-Lightning"
|
| 75 |
+
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
|
| 76 |
+
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
|
| 77 |
+
step_loaded = step
|
| 78 |
|
| 79 |
+
if base_loaded != base:
|
| 80 |
+
pipe.unet.load_state_dict(
|
| 81 |
+
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
|
| 82 |
+
strict=False)
|
| 83 |
+
base_loaded = base
|
| 84 |
|
| 85 |
+
if motion_loaded != motion:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
pipe.unload_lora_weights()
|
| 87 |
+
if motion in motions:
|
| 88 |
+
motion_repo = motions[motion]
|
| 89 |
+
pipe.load_lora_weights(motion_repo, adapter_name="motion")
|
| 90 |
+
pipe.set_adapters(["motion"], [0.7])
|
| 91 |
+
motion_loaded = motion
|
| 92 |
|
| 93 |
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
|
| 94 |
|
| 95 |
name = str(uuid.uuid4()).replace("-", "")
|
| 96 |
path = f"/tmp/{name}.mp4"
|
| 97 |
export_to_video(output.frames[0], path, fps=10)
|
| 98 |
+
|
| 99 |
return FileResponse(path, media_type="video/mp4", filename=f"{name}.mp4")
|
| 100 |
|
| 101 |
+
|
| 102 |
if __name__ == "__main__":
|
| 103 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|