saicharan1234 commited on
Commit
ee2cec4
·
verified ·
1 Parent(s): 2683177

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -32
main.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import uuid
3
- import os
4
 
5
- from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
6
  from diffusers.utils import export_to_video
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
@@ -31,9 +30,9 @@ motions = {
31
  "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
32
  "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
33
  }
34
- steps = [4,8] # Different steps you want to pre-load
35
- models = {}
36
- motions_loaded = {}
37
 
38
  # Ensure model and scheduler are initialized in GPU-enabled function
39
  if not torch.cuda.is_available():
@@ -41,36 +40,29 @@ if not torch.cuda.is_available():
41
 
42
  device = "cuda"
43
  dtype = torch.float16
44
-
45
- # Load all base models and steps
46
- for base_name, base_repo in bases.items():
47
- models[base_name] = {}
48
- for step in steps:
49
- repo = "ByteDance/AnimateDiff-Lightning"
50
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
51
- model = AnimateDiffPipeline.from_pretrained(base_repo, torch_dtype=dtype).to(device)
52
- model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
53
- model.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
54
- models[base_name][step] = model
55
-
56
- # Load all motion models
57
- for motion_name, motion_repo in motions.items():
58
- motion_weights = hf_hub_download(motion_repo, "pytorch_model.bin")
59
- motions_loaded[motion_name] = torch.load(motion_weights, map_location=device)
60
 
61
  # Safety checkers
62
  from transformers import CLIPFeatureExtractor
63
 
64
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
65
 
 
66
  class GenerateImageRequest(BaseModel):
67
  prompt: str
68
  base: str = "Realistic"
69
  motion: str = ""
70
  step: int = 8
71
 
 
72
  @app.post("/generate-image")
73
  def generate_image(request: GenerateImageRequest):
 
 
 
 
74
  prompt = request.prompt
75
  base = request.base
76
  motion = request.motion
@@ -78,26 +70,34 @@ def generate_image(request: GenerateImageRequest):
78
 
79
  print(prompt, base, step)
80
 
81
- if base not in models or step not in models[base]:
82
- raise HTTPException(status_code=400, detail="Invalid base model or step")
 
 
 
83
 
84
- pipe = models[base][step]
 
 
 
 
85
 
86
- if motion:
87
- if motion not in motions_loaded:
88
- raise HTTPException(status_code=400, detail="Invalid motion")
89
- pipe.unet.load_state_dict(motions_loaded[motion], strict=False)
90
- pipe.set_adapters(["motion"], [0.7])
91
- else:
92
  pipe.unload_lora_weights()
 
 
 
 
 
93
 
94
  output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
95
 
96
  name = str(uuid.uuid4()).replace("-", "")
97
  path = f"/tmp/{name}.mp4"
98
  export_to_video(output.frames[0], path, fps=10)
99
-
100
  return FileResponse(path, media_type="video/mp4", filename=f"{name}.mp4")
101
 
 
102
  if __name__ == "__main__":
103
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import torch
2
  import uuid
 
3
 
4
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
5
  from diffusers.utils import export_to_video
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
 
30
  "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
31
  "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
32
  }
33
+ step_loaded = None
34
+ base_loaded = "Realistic"
35
+ motion_loaded = None
36
 
37
  # Ensure model and scheduler are initialized in GPU-enabled function
38
  if not torch.cuda.is_available():
 
40
 
41
  device = "cuda"
42
  dtype = torch.float16
43
+ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
44
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing",
45
+ beta_schedule="linear")
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Safety checkers
48
  from transformers import CLIPFeatureExtractor
49
 
50
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
51
 
52
+
53
  class GenerateImageRequest(BaseModel):
54
  prompt: str
55
  base: str = "Realistic"
56
  motion: str = ""
57
  step: int = 8
58
 
59
+
60
  @app.post("/generate-image")
61
  def generate_image(request: GenerateImageRequest):
62
+ global step_loaded
63
+ global base_loaded
64
+ global motion_loaded
65
+
66
  prompt = request.prompt
67
  base = request.base
68
  motion = request.motion
 
70
 
71
  print(prompt, base, step)
72
 
73
+ if step_loaded != step:
74
+ repo = "ByteDance/AnimateDiff-Lightning"
75
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
76
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
77
+ step_loaded = step
78
 
79
+ if base_loaded != base:
80
+ pipe.unet.load_state_dict(
81
+ torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
82
+ strict=False)
83
+ base_loaded = base
84
 
85
+ if motion_loaded != motion:
 
 
 
 
 
86
  pipe.unload_lora_weights()
87
+ if motion in motions:
88
+ motion_repo = motions[motion]
89
+ pipe.load_lora_weights(motion_repo, adapter_name="motion")
90
+ pipe.set_adapters(["motion"], [0.7])
91
+ motion_loaded = motion
92
 
93
  output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
94
 
95
  name = str(uuid.uuid4()).replace("-", "")
96
  path = f"/tmp/{name}.mp4"
97
  export_to_video(output.frames[0], path, fps=10)
98
+
99
  return FileResponse(path, media_type="video/mp4", filename=f"{name}.mp4")
100
 
101
+
102
  if __name__ == "__main__":
103
+ uvicorn.run(app, host="0.0.0.0", port=7860)