File size: 1,678 Bytes
84adbaa
730aa68
 
84adbaa
 
730aa68
 
 
 
 
 
84adbaa
730aa68
 
 
 
 
 
84adbaa
730aa68
 
 
 
 
84adbaa
730aa68
 
84adbaa
730aa68
 
84adbaa
730aa68
 
 
84adbaa
730aa68
84adbaa
 
 
730aa68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_gif

def load_model():
    try:
        # Load Motion Adapter
        adapter = MotionAdapter.from_pretrained(
            "guoyww/animatediff-motion-adapter-v1-5",
            torch_dtype=torch.float16
        )

        # Load AnimateDiff pipeline with Stable Diffusion 1.5
        pipeline = AnimateDiffPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            motion_adapter=adapter,
            torch_dtype=torch.float16
        )

        # Use Euler scheduler (smoother animations)
        pipeline.scheduler = EulerDiscreteScheduler.from_config(
            pipeline.scheduler.config,
            timestep_spacing="trailing"
        )

        device = "cuda" if torch.cuda.is_available() else "cpu"
        pipeline = pipeline.to(device)

        print("✅ Models loaded successfully!")
        return pipeline

    except Exception as e:
        print(f"❌ Error during model loading: {e}")
        raise

# Load once globally
pipe = load_model()


def generate(prompt: str, num_frames: int = 16, steps: int = 25, guidance: float = 7.5, seed: int = 42, out_path: str = "output.gif"):
    """
    Generate an animated GIF from a text prompt.
    """
    generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
    result = pipe(
        prompt=prompt,
        num_frames=num_frames,
        num_inference_steps=steps,
        guidance_scale=guidance,
        generator=generator
    )
    frames = result.frames[0]
    export_to_gif(frames, out_path)
    return out_path