MarioDiffusion-MLM-regular0 / models /pipeline_loader.py
schrum2's picture
better error message
d57cf30 verified
raw
history blame
1.79 kB
from models.text_diffusion_pipeline import TextConditionalDDPMPipeline
from models.latent_diffusion_pipeline import UnconditionalDDPMPipeline
import os
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
def get_pipeline(model_path):
# If model_path is a local directory, use the original logic
if os.path.isdir(model_path):
#Diffusion models
if os.path.exists(os.path.join(model_path, "unet")):
if os.path.exists(os.path.join(model_path, "text_encoder")):
#If it has a text encoder and a unet, it's text conditional diffusion
pipe = TextConditionalDDPMPipeline.from_pretrained(model_path)
else:
#If it has no text encoder, use the unconditional diffusion model
pipe = UnconditionalDDPMPipeline.from_pretrained(model_path)
else:
# Assume it's a Hugging Face Hub model ID
# Try to load config to determine if it's text-conditional
print(model_path)
config, _ = DiffusionPipeline.load_config(model_path)
has_text_encoder = "text_encoder" in config
if has_text_encoder:
# Use the local pipeline file for custom_pipeline
pipe = DiffusionPipeline.from_pretrained(
model_path,
custom_pipeline="models.text_diffusion_pipeline.TextConditionalDDPMPipeline",
trust_remote_code=True,
)
else:
# Fallback: try unconditional
pipe = DiffusionPipeline.from_pretrained(
model_path,
custom_pipeline="models.latent_diffusion_pipeline.UnconditionalDDPMPipeline",
trust_remote_code=True,
)
return pipe