import torch import os from config import Config from diffusers import ( StableDiffusionXLPipeline, LCMScheduler ) from huggingface_hub import hf_hub_download class ModelHandler: def __init__(self): self.pipeline = None def load_models(self): # 1. Load SDXL Text-to-Image Pipeline print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...") checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME) if not os.path.exists(checkpoint_local_path): print(f"Downloading checkpoint to {checkpoint_local_path}...") hf_hub_download( repo_id=Config.REPO_ID, filename=Config.CHECKPOINT_FILENAME, local_dir="./models", local_dir_use_symlinks=False ) print(f"Loading pipeline from local file: {checkpoint_local_path}") # Use standard SDXL Text2Image pipeline self.pipeline = StableDiffusionXLPipeline.from_single_file( checkpoint_local_path, torch_dtype=Config.DTYPE, use_safetensors=True ) self.pipeline.to(Config.DEVICE) # 2. Enable xFormers try: self.pipeline.enable_xformers_memory_efficient_attention() print(" [OK] xFormers memory efficient attention enabled.") except Exception as e: print(f" [WARNING] Failed to enable xFormers: {e}") # 3. Set Scheduler (LCM) print("Configuring LCMScheduler...") scheduler_config = self.pipeline.scheduler.config # Disable clipping to prevent NaN artifacts with LCM scheduler_config['clip_sample'] = False self.pipeline.scheduler = LCMScheduler.from_config( scheduler_config, timestep_spacing="trailing", beta_schedule="scaled_linear" ) print(" [OK] LCMScheduler loaded (clip_sample=False).") # 4. Load LoRA print("Loading LoRA weights...") self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME) print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...") self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH) print(" [OK] LoRA fused.") print("--- All models loaded successfully ---")