|
|
import torch |
|
|
import logging |
|
|
import os |
|
|
from diffusers import StableDiffusionPipeline, DiffusionPipeline |
|
|
from huggingface_hub import login |
|
|
from typing import Tuple |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]: |
|
|
try: |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
dtype = torch.float16 if device.type == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
|
|
|
|
|
|
logger.info(f"Loading text-to-image model on {device} with {dtype}") |
|
|
text_to_image = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
safety_checker=None |
|
|
) |
|
|
text_to_image = text_to_image.to(device) |
|
|
|
|
|
if device.type == "cuda": |
|
|
text_to_image.enable_xformers_memory_efficient_attention() |
|
|
else: |
|
|
text_to_image.enable_attention_slicing() |
|
|
|
|
|
|
|
|
logger.info(f"Loading video model on {device} with {dtype}") |
|
|
image_to_video = DiffusionPipeline.from_pretrained( |
|
|
"cerspense/zeroscope_v2_576w", |
|
|
torch_dtype=dtype |
|
|
) |
|
|
image_to_video = image_to_video.to(device) |
|
|
|
|
|
if device.type == "cuda": |
|
|
image_to_video.enable_xformers_memory_efficient_attention() |
|
|
else: |
|
|
image_to_video.enable_attention_slicing() |
|
|
|
|
|
return text_to_image, image_to_video, None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model load failed: {str(e)}") |
|
|
raise |