primerz commited on
Commit
364071a
·
verified ·
1 Parent(s): 006a6f3

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +22 -125
model.py CHANGED
@@ -1,83 +1,20 @@
1
  import torch
2
  import os
3
- import cv2
4
- import numpy as np
5
  from config import Config
6
 
7
  from diffusers import (
8
- ControlNetModel,
9
- TCDScheduler,
10
  )
11
- from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
 
13
- # Import the custom pipeline from your local file
14
- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
15
-
16
- from huggingface_hub import snapshot_download, hf_hub_download
17
- from insightface.app import FaceAnalysis
18
- from controlnet_aux import LeresDetector, LineartAnimeDetector
19
 
20
  class ModelHandler:
21
  def __init__(self):
22
  self.pipeline = None
23
- self.app = None # InsightFace
24
- self.leres_detector = None
25
- self.lineart_anime_detector = None
26
- self.face_analysis_loaded = False
27
-
28
- def load_face_analysis(self):
29
- """
30
- Load face analysis model.
31
- Downloads from HF Hub to the path insightface expects.
32
- """
33
- print("Loading face analysis model...")
34
-
35
- model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
36
-
37
- if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
38
- print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
39
- try:
40
- snapshot_download(
41
- repo_id=Config.ANTELOPEV2_REPO,
42
- local_dir=model_path, # Download to the correct expected path
43
- )
44
- except Exception as e:
45
- print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
46
- return False
47
-
48
- try:
49
- self.app = FaceAnalysis(
50
- name=Config.ANTELOPEV2_NAME,
51
- root=Config.ANTELOPEV2_ROOT,
52
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
53
- )
54
- self.app.prepare(ctx_id=0, det_size=(640, 640))
55
- print(f" [OK] Face analysis model loaded successfully.")
56
- return True
57
-
58
- except Exception as e:
59
- print(f" [WARNING] Face detection system failed to initialize: {e}")
60
- return False
61
 
62
  def load_models(self):
63
- # 1. Load Face Analysis
64
- self.face_analysis_loaded = self.load_face_analysis()
65
-
66
- # 2. Load ControlNets
67
- print("Loading ControlNets (InstantID, Zoe, LineArt)...")
68
- cn_instantid = ControlNetModel.from_pretrained(
69
- Config.INSTANTID_REPO,
70
- subfolder="ControlNetModel",
71
- torch_dtype=Config.DTYPE
72
- )
73
- cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
74
- cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
75
-
76
- print("Wrapping ControlNets in MultiControlNetModel...")
77
- controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
78
- controlnet = MultiControlNetModel(controlnet_list)
79
-
80
- # 3. Load SDXL Pipeline (Now from 'reality.safetensors')
81
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
82
 
83
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
@@ -91,76 +28,36 @@ class ModelHandler:
91
  )
92
 
93
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
94
- self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
 
95
  checkpoint_local_path,
96
- controlnet=controlnet,
97
  torch_dtype=Config.DTYPE,
98
  use_safetensors=True
99
  )
100
 
101
  self.pipeline.to(Config.DEVICE)
102
 
 
103
  try:
104
  self.pipeline.enable_xformers_memory_efficient_attention()
105
  print(" [OK] xFormers memory efficient attention enabled.")
106
  except Exception as e:
107
  print(f" [WARNING] Failed to enable xFormers: {e}")
108
 
109
- # 4. Set TCD Scheduler (Sanitized Config)
110
- print("Configuring TCDScheduler...")
111
- self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
112
- print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
113
-
114
- # 5. Load Adapters
115
- print("Loading Adapters...")
116
-
117
- # 5b. Load and Fuse Style LoRA (lucasart)
118
- print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
119
- style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
120
- if not os.path.exists(style_lora_path):
121
- hf_hub_download(
122
- repo_id=Config.REPO_ID,
123
- filename=Config.LORA_FILENAME,
124
- local_dir="./models",
125
- local_dir_use_symlinks=False
126
- )
127
- self.pipeline.load_lora_weights("./models", weight_name=Config.LORA_FILENAME)
128
- self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
129
- print(" [OK] Style LoRA fused.")
130
-
131
- # 5c. Load IP-Adapter (for InstantID) - *Must be loaded AFTER fusing*
132
- ip_adapter_filename = "ip-adapter.bin"
133
- ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
134
- if not os.path.exists(ip_adapter_local_path):
135
- hf_hub_download(
136
- repo_id=Config.INSTANTID_REPO,
137
- filename=ip_adapter_filename,
138
- local_dir="./models",
139
- local_dir_use_symlinks=False
140
- )
141
- self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
142
- print(" [OK] IP-Adapter loaded.")
143
 
144
- # --- END FIX ---
145
-
146
- # 7. Load Preprocessors
147
- print("Loading Preprocessors (LeReS, LineArtAnime)...")
148
- self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
149
- self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
150
 
151
- print("--- All models loaded successfully ---")
152
-
153
- def get_face_info(self, image):
154
- """Extracts the largest face, returns insightface result object."""
155
- if not self.face_analysis_loaded:
156
- return None
157
- try:
158
- cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
159
- faces = self.app.get(cv2_img)
160
- if len(faces) == 0:
161
- return None
162
- faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
163
- return faces[0]
164
- except Exception as e:
165
- print(f"Face embedding extraction failed: {e}")
166
- return None
 
1
  import torch
2
  import os
 
 
3
  from config import Config
4
 
5
  from diffusers import (
6
+ StableDiffusionXLPipeline,
7
+ LCMScheduler
8
  )
 
9
 
10
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
11
 
12
  class ModelHandler:
13
  def __init__(self):
14
  self.pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def load_models(self):
17
+ # 1. Load SDXL Text-to-Image Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
19
 
20
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
 
28
  )
29
 
30
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
31
+ # Use standard SDXL Text2Image pipeline
32
+ self.pipeline = StableDiffusionXLPipeline.from_single_file(
33
  checkpoint_local_path,
 
34
  torch_dtype=Config.DTYPE,
35
  use_safetensors=True
36
  )
37
 
38
  self.pipeline.to(Config.DEVICE)
39
 
40
+ # 2. Enable xFormers
41
  try:
42
  self.pipeline.enable_xformers_memory_efficient_attention()
43
  print(" [OK] xFormers memory efficient attention enabled.")
44
  except Exception as e:
45
  print(f" [WARNING] Failed to enable xFormers: {e}")
46
 
47
+ # 3. Set Scheduler (LCM)
48
+ print("Configuring LCMScheduler...")
49
+ scheduler_config = self.pipeline.scheduler.config
50
+ # Disable clipping to prevent NaN artifacts with LCM
51
+ scheduler_config['clip_sample'] = False
52
+ self.pipeline.scheduler = LCMScheduler.from_config(scheduler_config)
53
+ print(" [OK] LCMScheduler loaded (clip_sample=False).")
54
+
55
+ # 4. Load LoRA
56
+ print("Loading LoRA weights...")
57
+ self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
60
+ self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
61
+ print(" [OK] LoRA fused.")
 
 
 
62
 
63
+ print("--- All models loaded successfully ---")