File size: 1,835 Bytes
f324573
a80662e
bedec2e
f324573
2855157
bedec2e
a80662e
 
 
f324573
edb58da
a80662e
edb58da
 
 
 
 
bedec2e
a80662e
 
 
edb58da
 
a80662e
bedec2e
edb58da
a80662e
edb58da
a80662e
edb58da
bedec2e
edb58da
bedec2e
 
 
a80662e
edb58da
 
a80662e
bedec2e
edb58da
2855157
edb58da
 
 
bedec2e
 
 
2855157
edb58da
2855157
a80662e
edb58da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import logging
import os
from diffusers import StableDiffusionPipeline, DiffusionPipeline
from huggingface_hub import login
from typing import Tuple

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]:
    try:
        # Device and precision configuration
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dtype = torch.float16 if device.type == "cuda" else torch.float32
        
        # Authentication
        hf_token = os.getenv("HF_TOKEN")
        if hf_token:
            login(token=hf_token)

        # Text-to-image model
        logger.info(f"Loading text-to-image model on {device} with {dtype}")
        text_to_image = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=dtype,
            use_safetensors=True,
            safety_checker=None
        )
        text_to_image = text_to_image.to(device)
        
        if device.type == "cuda":
            text_to_image.enable_xformers_memory_efficient_attention()
        else:
            text_to_image.enable_attention_slicing()

        # Image-to-video model
        logger.info(f"Loading video model on {device} with {dtype}")
        image_to_video = DiffusionPipeline.from_pretrained(
            "cerspense/zeroscope_v2_576w",
            torch_dtype=dtype
        )
        image_to_video = image_to_video.to(device)
        
        if device.type == "cuda":
            image_to_video.enable_xformers_memory_efficient_attention()
        else:
            image_to_video.enable_attention_slicing()

        return text_to_image, image_to_video, None

    except Exception as e:
        logger.error(f"Model load failed: {str(e)}")
        raise