AI_VID / models.py
arif670's picture
Update models.py
edb58da verified
import torch
import logging
import os
from diffusers import StableDiffusionPipeline, DiffusionPipeline
from huggingface_hub import login
from typing import Tuple
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]:
try:
# Device and precision configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if device.type == "cuda" else torch.float32
# Authentication
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Text-to-image model
logger.info(f"Loading text-to-image model on {device} with {dtype}")
text_to_image = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
use_safetensors=True,
safety_checker=None
)
text_to_image = text_to_image.to(device)
if device.type == "cuda":
text_to_image.enable_xformers_memory_efficient_attention()
else:
text_to_image.enable_attention_slicing()
# Image-to-video model
logger.info(f"Loading video model on {device} with {dtype}")
image_to_video = DiffusionPipeline.from_pretrained(
"cerspense/zeroscope_v2_576w",
torch_dtype=dtype
)
image_to_video = image_to_video.to(device)
if device.type == "cuda":
image_to_video.enable_xformers_memory_efficient_attention()
else:
image_to_video.enable_attention_slicing()
return text_to_image, image_to_video, None
except Exception as e:
logger.error(f"Model load failed: {str(e)}")
raise