saicharan1234 commited on
Commit
30ed707
·
verified ·
1 Parent(s): bb6d3ad

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -0
main.py CHANGED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import uuid
4
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
5
+ from diffusers.utils import export_to_video
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file
8
+ from transformers import CLIPFeatureExtractor
9
+ from fastapi import FastAPI, Form, HTTPException
10
+ from fastapi.responses import FileResponse
11
+
12
+ app = FastAPI()
13
+
14
+ # Constants
15
+ bases = {
16
+ "Cartoon": "frankjoshua/toonyou_beta6",
17
+ "Realistic": "emilianJR/epiCRealism",
18
+ "3d": "Lykon/DreamShaper",
19
+ "Anime": "Yntec/mistoonAnime2"
20
+ }
21
+ step_loaded = None
22
+ base_loaded = "Realistic"
23
+ motion_loaded = None
24
+
25
+ # Ensure model and scheduler are initialized in GPU-enabled function
26
+ if not torch.cuda.is_available():
27
+ raise NotImplementedError("No GPU detected!")
28
+
29
+ device = "cuda"
30
+ dtype = torch.float16
31
+ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
32
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
33
+
34
+ # Safety checkers
35
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
36
+
37
+ # Function to generate image
38
+ def generate_image(prompt, base="Realistic", motion="", step=8):
39
+ global step_loaded
40
+ global base_loaded
41
+ global motion_loaded
42
+ print(prompt, base, step)
43
+
44
+ if step_loaded != step:
45
+ repo = "ByteDance/AnimateDiff-Lightning"
46
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
47
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
48
+ step_loaded = step
49
+
50
+ if base_loaded != base:
51
+ pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
52
+ base_loaded = base
53
+
54
+ if motion_loaded != motion:
55
+ pipe.unload_lora_weights()
56
+ if motion != "":
57
+ pipe.load_lora_weights(motion, adapter_name="motion")
58
+ pipe.set_adapters(["motion"], [0.7])
59
+ motion_loaded = motion
60
+
61
+ output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)
62
+
63
+ name = str(uuid.uuid4()).replace("-", "")
64
+ path = f"/tmp/{name}.mp4"
65
+ export_to_video(output.frames[0], path, fps=10)
66
+ return path
67
+
68
+ # API Endpoint to generate video
69
+ @app.post("/generate-video/")
70
+ async def generate_video(
71
+ prompt: str = Form(...),
72
+ base: str = Form("Realistic"),
73
+ motion: str = Form(""),
74
+ step: int = Form(8)
75
+ ):
76
+ try:
77
+ video_path = generate_image(prompt, base, motion, step)
78
+ return FileResponse(video_path, media_type="video/mp4")
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=str(e))
81
+
82
+ # Run the app
83
+ if __name__ == "__main__":
84
+ import uvicorn
85
+ uvicorn.run(app, host="0.0.0.0", port=7860)