saicharan1234 commited on
Commit
54bb5af
·
verified ·
1 Parent(s): 538c6d5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -38
main.py CHANGED
@@ -1,13 +1,14 @@
1
  import torch
2
- import os
3
  import uuid
4
- from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
 
5
  from diffusers.utils import export_to_video
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
- from transformers import CLIPFeatureExtractor
9
- from fastapi import FastAPI, Form, HTTPException
10
- from fastapi.responses import FileResponse
 
11
 
12
  app = FastAPI()
13
 
@@ -18,16 +19,15 @@ bases = {
18
  "3d": "Lykon/DreamShaper",
19
  "Anime": "Yntec/mistoonAnime2"
20
  }
21
- motion_choices = {
22
- "": "Default",
23
- "guoyww/animatediff-motion-lora-zoom-in": "Zoom in",
24
- "guoyww/animatediff-motion-lora-zoom-out": "Zoom out",
25
- "guoyww/animatediff-motion-lora-tilt-up": "Tilt up",
26
- "guoyww/animatediff-motion-lora-tilt-down": "Tilt down",
27
- "guoyww/animatediff-motion-lora-pan-left": "Pan left",
28
- "guoyww/animatediff-motion-lora-pan-right": "Pan right",
29
- "guoyww/animatediff-motion-lora-rolling-anticlockwise": "Roll left",
30
- "guoyww/animatediff-motion-lora-rolling-clockwise": "Roll right"
31
  }
32
  step_loaded = None
33
  base_loaded = "Realistic"
@@ -43,14 +43,28 @@ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype
43
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
44
 
45
  # Safety checkers
 
 
46
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
47
 
48
- # Function to generate image
49
- def generate_image(prompt, base="Realistic", motion="", step=8):
 
 
 
 
 
 
50
  global step_loaded
51
  global base_loaded
52
  global motion_loaded
53
- print(prompt, base, step, motion)
 
 
 
 
 
 
54
 
55
  if step_loaded != step:
56
  repo = "ByteDance/AnimateDiff-Lightning"
@@ -64,8 +78,9 @@ def generate_image(prompt, base="Realistic", motion="", step=8):
64
 
65
  if motion_loaded != motion:
66
  pipe.unload_lora_weights()
67
- if motion != "":
68
- pipe.load_lora_weights(motion, adapter_name="motion")
 
69
  pipe.set_adapters(["motion"], [0.7])
70
  motion_loaded = motion
71
 
@@ -74,23 +89,8 @@ def generate_image(prompt, base="Realistic", motion="", step=8):
74
  name = str(uuid.uuid4()).replace("-", "")
75
  path = f"/tmp/{name}.mp4"
76
  export_to_video(output.frames[0], path, fps=10)
77
- return path
78
-
79
- # API Endpoint to generate video
80
- @app.post("/generate-video/")
81
- async def generate_video(
82
- prompt: str = Form(...),
83
- base: str = Form("Realistic"),
84
- motion: str = Form(""),
85
- step: int = Form(8)
86
- ):
87
- try:
88
- video_path = generate_image(prompt, base, motion, step)
89
- return FileResponse(video_path, media_type="video/mp4")
90
- except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
92
-
93
- # Run the app
94
  if __name__ == "__main__":
95
- import uvicorn
96
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import torch
 
2
  import uuid
3
+
4
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
5
  from diffusers.utils import export_to_video
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
+ from PIL import Image
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
11
+ import uvicorn
12
 
13
  app = FastAPI()
14
 
 
19
  "3d": "Lykon/DreamShaper",
20
  "Anime": "Yntec/mistoonAnime2"
21
  }
22
+ motions = {
23
+ "Zoom in": "guoyww/animatediff-motion-lora-zoom-in",
24
+ "Zoom out": "guoyww/animatediff-motion-lora-zoom-out",
25
+ "Tilt up": "guoyww/animatediff-motion-lora-tilt-up",
26
+ "Tilt down": "guoyww/animatediff-motion-lora-tilt-down",
27
+ "Pan left": "guoyww/animatediff-motion-lora-pan-left",
28
+ "Pan right": "guoyww/animatediff-motion-lora-pan-right",
29
+ "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
30
+ "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
 
31
  }
32
  step_loaded = None
33
  base_loaded = "Realistic"
 
43
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
44
 
45
  # Safety checkers
46
+ from transformers import CLIPFeatureExtractor
47
+
48
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
49
 
50
+ class GenerateImageRequest(BaseModel):
51
+ prompt: str
52
+ base: str = "Realistic"
53
+ motion: str = ""
54
+ step: int = 8
55
+
56
+ @app.post("/generate-image")
57
+ def generate_image(request: GenerateImageRequest):
58
  global step_loaded
59
  global base_loaded
60
  global motion_loaded
61
+
62
+ prompt = request.prompt
63
+ base = request.base
64
+ motion = request.motion
65
+ step = request.step
66
+
67
+ print(prompt, base, step)
68
 
69
  if step_loaded != step:
70
  repo = "ByteDance/AnimateDiff-Lightning"
 
78
 
79
  if motion_loaded != motion:
80
  pipe.unload_lora_weights()
81
+ if motion in motions:
82
+ motion_repo = motions[motion]
83
+ pipe.load_lora_weights(motion_repo, adapter_name="motion")
84
  pipe.set_adapters(["motion"], [0.7])
85
  motion_loaded = motion
86
 
 
89
  name = str(uuid.uuid4()).replace("-", "")
90
  path = f"/tmp/{name}.mp4"
91
  export_to_video(output.frames[0], path, fps=10)
92
+
93
+ return {"video_path": path}
94
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if __name__ == "__main__":
 
96
  uvicorn.run(app, host="0.0.0.0", port=7860)