import torch import logging import os from diffusers import StableDiffusionPipeline, DiffusionPipeline from huggingface_hub import login from typing import Tuple logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]: try: # Device and precision configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if device.type == "cuda" else torch.float32 # Authentication hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) # Text-to-image model logger.info(f"Loading text-to-image model on {device} with {dtype}") text_to_image = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype, use_safetensors=True, safety_checker=None ) text_to_image = text_to_image.to(device) if device.type == "cuda": text_to_image.enable_xformers_memory_efficient_attention() else: text_to_image.enable_attention_slicing() # Image-to-video model logger.info(f"Loading video model on {device} with {dtype}") image_to_video = DiffusionPipeline.from_pretrained( "cerspense/zeroscope_v2_576w", torch_dtype=dtype ) image_to_video = image_to_video.to(device) if device.type == "cuda": image_to_video.enable_xformers_memory_efficient_attention() else: image_to_video.enable_attention_slicing() return text_to_image, image_to_video, None except Exception as e: logger.error(f"Model load failed: {str(e)}") raise