saicharan1234 commited on
Commit
2683177
·
verified ·
1 Parent(s): 1b42456

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -24
main.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
  import uuid
 
3
 
4
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
5
  from diffusers.utils import export_to_video
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
@@ -30,9 +31,9 @@ motions = {
30
  "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
31
  "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
32
  }
33
- step_loaded = None
34
- base_loaded = "Realistic"
35
- motion_loaded = None
36
 
37
  # Ensure model and scheduler are initialized in GPU-enabled function
38
  if not torch.cuda.is_available():
@@ -40,8 +41,22 @@ if not torch.cuda.is_available():
40
 
41
  device = "cuda"
42
  dtype = torch.float16
43
- pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
44
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Safety checkers
47
  from transformers import CLIPFeatureExtractor
@@ -56,10 +71,6 @@ class GenerateImageRequest(BaseModel):
56
 
57
  @app.post("/generate-image")
58
  def generate_image(request: GenerateImageRequest):
59
- global step_loaded
60
- global base_loaded
61
- global motion_loaded
62
-
63
  prompt = request.prompt
64
  base = request.base
65
  motion = request.motion
@@ -67,23 +78,18 @@ def generate_image(request: GenerateImageRequest):
67
 
68
  print(prompt, base, step)
69
 
70
- if step_loaded != step:
71
- repo = "ByteDance/AnimateDiff-Lightning"
72
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
73
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
74
- step_loaded = step
75
 
76
- if base_loaded != base:
77
- pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
78
- base_loaded = base
79
 
80
- if motion_loaded != motion:
 
 
 
 
 
81
  pipe.unload_lora_weights()
82
- if motion in motions:
83
- motion_repo = motions[motion]
84
- pipe.load_lora_weights(motion_repo, adapter_name="motion")
85
- pipe.set_adapters(["motion"], [0.7])
86
- motion_loaded = motion
87
 
88
  output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
89
 
 
1
  import torch
2
  import uuid
3
+ import os
4
 
5
+ from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
6
  from diffusers.utils import export_to_video
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
 
31
  "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
32
  "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
33
  }
34
+ steps = [4,8] # Different steps you want to pre-load
35
+ models = {}
36
+ motions_loaded = {}
37
 
38
  # Ensure model and scheduler are initialized in GPU-enabled function
39
  if not torch.cuda.is_available():
 
41
 
42
  device = "cuda"
43
  dtype = torch.float16
44
+
45
+ # Load all base models and steps
46
+ for base_name, base_repo in bases.items():
47
+ models[base_name] = {}
48
+ for step in steps:
49
+ repo = "ByteDance/AnimateDiff-Lightning"
50
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
51
+ model = AnimateDiffPipeline.from_pretrained(base_repo, torch_dtype=dtype).to(device)
52
+ model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
53
+ model.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
54
+ models[base_name][step] = model
55
+
56
+ # Load all motion models
57
+ for motion_name, motion_repo in motions.items():
58
+ motion_weights = hf_hub_download(motion_repo, "pytorch_model.bin")
59
+ motions_loaded[motion_name] = torch.load(motion_weights, map_location=device)
60
 
61
  # Safety checkers
62
  from transformers import CLIPFeatureExtractor
 
71
 
72
  @app.post("/generate-image")
73
  def generate_image(request: GenerateImageRequest):
 
 
 
 
74
  prompt = request.prompt
75
  base = request.base
76
  motion = request.motion
 
78
 
79
  print(prompt, base, step)
80
 
81
+ if base not in models or step not in models[base]:
82
+ raise HTTPException(status_code=400, detail="Invalid base model or step")
 
 
 
83
 
84
+ pipe = models[base][step]
 
 
85
 
86
+ if motion:
87
+ if motion not in motions_loaded:
88
+ raise HTTPException(status_code=400, detail="Invalid motion")
89
+ pipe.unet.load_state_dict(motions_loaded[motion], strict=False)
90
+ pipe.set_adapters(["motion"], [0.7])
91
+ else:
92
  pipe.unload_lora_weights()
 
 
 
 
 
93
 
94
  output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
95