pixel-art / model.py
primerz's picture
Upload 7 files
0f0e9c7 verified
raw
history blame
2.37 kB
import torch
import os
from config import Config
from diffusers import (
StableDiffusionXLPipeline,
LCMScheduler
)
from huggingface_hub import hf_hub_download
class ModelHandler:
def __init__(self):
self.pipeline = None
def load_models(self):
# 1. Load SDXL Text-to-Image Pipeline
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
if not os.path.exists(checkpoint_local_path):
print(f"Downloading checkpoint to {checkpoint_local_path}...")
hf_hub_download(
repo_id=Config.REPO_ID,
filename=Config.CHECKPOINT_FILENAME,
local_dir="./models",
local_dir_use_symlinks=False
)
print(f"Loading pipeline from local file: {checkpoint_local_path}")
# Use standard SDXL Text2Image pipeline
self.pipeline = StableDiffusionXLPipeline.from_single_file(
checkpoint_local_path,
torch_dtype=Config.DTYPE,
use_safetensors=True
)
self.pipeline.to(Config.DEVICE)
# 2. Enable xFormers
try:
self.pipeline.enable_xformers_memory_efficient_attention()
print(" [OK] xFormers memory efficient attention enabled.")
except Exception as e:
print(f" [WARNING] Failed to enable xFormers: {e}")
# 3. Set Scheduler (LCM)
print("Configuring LCMScheduler...")
scheduler_config = self.pipeline.scheduler.config
# Disable clipping to prevent NaN artifacts with LCM
scheduler_config['clip_sample'] = False
self.pipeline.scheduler = LCMScheduler.from_config(
scheduler_config,
timestep_spacing="trailing",
beta_schedule="scaled_linear"
)
print(" [OK] LCMScheduler loaded (clip_sample=False).")
# 4. Load LoRA
print("Loading LoRA weights...")
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
print(" [OK] LoRA fused.")
print("--- All models loaded successfully ---")