Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
|
@@ -110,58 +110,17 @@ class ModelHandler:
|
|
| 110 |
|
| 111 |
self.pipeline.to(Config.DEVICE)
|
| 112 |
|
| 113 |
-
# --- NEW: Enable xFormers ---
|
| 114 |
try:
|
| 115 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 116 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 117 |
except Exception as e:
|
| 118 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
| 119 |
-
# --- END NEW ---
|
| 120 |
-
|
| 121 |
-
# ==============================================================================
|
| 122 |
-
# 3.5 LOAD PIVOTAL TUNING EMBEDDINGS (TEXTUAL INVERSION)
|
| 123 |
-
# ==============================================================================
|
| 124 |
-
print("Loading Textual Inversion Embeddings (Pivotal Tuning)...")
|
| 125 |
-
|
| 126 |
-
# Define the embedding name (assumed to be in the same repo as Config.REPO_ID)
|
| 127 |
-
embedding_filename = "retroart.safetensors"
|
| 128 |
-
embedding_path = os.path.join("./models", embedding_filename)
|
| 129 |
-
|
| 130 |
-
# 1. Download if missing
|
| 131 |
-
if not os.path.exists(embedding_path):
|
| 132 |
-
print(f"Downloading embedding '{embedding_filename}' to {embedding_path}...")
|
| 133 |
-
try:
|
| 134 |
-
hf_hub_download(
|
| 135 |
-
repo_id=Config.REPO_ID, # Or 'primerz/pixagram' if separate
|
| 136 |
-
filename=embedding_filename,
|
| 137 |
-
local_dir="./models",
|
| 138 |
-
local_dir_use_symlinks=False
|
| 139 |
-
)
|
| 140 |
-
except Exception as e:
|
| 141 |
-
print(f" [WARNING] Could not download embedding: {e}")
|
| 142 |
-
|
| 143 |
-
# 2. Load into the pipeline
|
| 144 |
-
# SDXL pipelines automatically handle loading into both tokenizers/text_encoders
|
| 145 |
-
if os.path.exists(embedding_path):
|
| 146 |
-
try:
|
| 147 |
-
self.pipeline.load_textual_inversion(
|
| 148 |
-
embedding_path,
|
| 149 |
-
token="retroart", # Trigger word: <retroart> or retroart
|
| 150 |
-
local_files_only=True
|
| 151 |
-
)
|
| 152 |
-
print(f" [OK] Loaded embedding '{embedding_filename}' associated with token 'retroart'.")
|
| 153 |
-
except Exception as e:
|
| 154 |
-
print(f" [ERROR] Failed to load textual inversion: {e}")
|
| 155 |
-
else:
|
| 156 |
-
print(f" [SKIP] Embedding file not found locally.")
|
| 157 |
-
# ==============================================================================
|
| 158 |
-
|
| 159 |
|
| 160 |
# 4. Set Scheduler
|
| 161 |
# --- MODIFIED: Disable clipping to prevent NaN artifacts ---
|
| 162 |
print("Configuring LCMScheduler...")
|
| 163 |
scheduler_config = self.pipeline.scheduler.config
|
| 164 |
-
|
| 165 |
self.pipeline.scheduler = LCMScheduler.from_config(scheduler_config)
|
| 166 |
print(" [OK] LCMScheduler loaded (clip_sample=False).")
|
| 167 |
# --- END MODIFIED ---
|
|
|
|
| 110 |
|
| 111 |
self.pipeline.to(Config.DEVICE)
|
| 112 |
|
|
|
|
| 113 |
try:
|
| 114 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 115 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 116 |
except Exception as e:
|
| 117 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# 4. Set Scheduler
|
| 120 |
# --- MODIFIED: Disable clipping to prevent NaN artifacts ---
|
| 121 |
print("Configuring LCMScheduler...")
|
| 122 |
scheduler_config = self.pipeline.scheduler.config
|
| 123 |
+
scheduler_config['clip_sample'] = False
|
| 124 |
self.pipeline.scheduler = LCMScheduler.from_config(scheduler_config)
|
| 125 |
print(" [OK] LCMScheduler loaded (clip_sample=False).")
|
| 126 |
# --- END MODIFIED ---
|