MarioDiffusion-MLM-regular0 / models /pipeline_loader.py
schrum2's picture
just use snapshot download?
634e6bc verified
raw
history blame
1.49 kB
from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
import os
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from huggingface_hub import snapshot_download
def get_pipeline(model_path):
# If model_path is a local directory, use the original logic
if os.path.isdir(model_path):
#Diffusion models
if os.path.exists(os.path.join(model_path, "unet")):
if os.path.exists(os.path.join(model_path, "text_encoder")):
#If it has a text encoder and a unet, it's text conditional diffusion
pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
else:
#If it has no text encoder, use the unconditional diffusion model
pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
else:
# For HF Hub models, download first then load locally
print(f"Downloading model {model_path}...")
local_path = snapshot_download(repo_id=model_path, cache_dir="./temp_model_cache")
# Check what components exist
has_text_encoder = os.path.exists(os.path.join(local_path, "text_encoder"))
if has_text_encoder:
pipe = TextConditionalDDPMPipeline.from_pretrained(local_path)
else:
pipe = UnconditionalDDPMPipeline.from_pretrained(local_path)
return pipe