|
|
import torch |
|
|
import os |
|
|
from config import Config |
|
|
|
|
|
from diffusers import ( |
|
|
StableDiffusionXLPipeline, |
|
|
LCMScheduler |
|
|
) |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class ModelHandler: |
|
|
def __init__(self): |
|
|
self.pipeline = None |
|
|
|
|
|
def load_models(self): |
|
|
|
|
|
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...") |
|
|
|
|
|
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME) |
|
|
if not os.path.exists(checkpoint_local_path): |
|
|
print(f"Downloading checkpoint to {checkpoint_local_path}...") |
|
|
hf_hub_download( |
|
|
repo_id=Config.REPO_ID, |
|
|
filename=Config.CHECKPOINT_FILENAME, |
|
|
local_dir="./models", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
print(f"Loading pipeline from local file: {checkpoint_local_path}") |
|
|
|
|
|
self.pipeline = StableDiffusionXLPipeline.from_single_file( |
|
|
checkpoint_local_path, |
|
|
torch_dtype=Config.DTYPE, |
|
|
use_safetensors=True |
|
|
) |
|
|
|
|
|
self.pipeline.to(Config.DEVICE) |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipeline.enable_xformers_memory_efficient_attention() |
|
|
print(" [OK] xFormers memory efficient attention enabled.") |
|
|
except Exception as e: |
|
|
print(f" [WARNING] Failed to enable xFormers: {e}") |
|
|
|
|
|
|
|
|
print("Configuring LCMScheduler...") |
|
|
scheduler_config = self.pipeline.scheduler.config |
|
|
|
|
|
scheduler_config['clip_sample'] = False |
|
|
self.pipeline.scheduler = LCMScheduler.from_config( |
|
|
scheduler_config, |
|
|
timestep_spacing="trailing", |
|
|
beta_schedule="scaled_linear" |
|
|
) |
|
|
print(" [OK] LCMScheduler loaded (clip_sample=False).") |
|
|
|
|
|
|
|
|
print("Loading LoRA weights...") |
|
|
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME) |
|
|
|
|
|
print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...") |
|
|
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH) |
|
|
print(" [OK] LoRA fused.") |
|
|
|
|
|
print("--- All models loaded successfully ---") |
|
|
|