primerz commited on
Commit
6d5987b
·
verified ·
1 Parent(s): eb833d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +87 -53
model.py CHANGED
@@ -1,14 +1,17 @@
1
  import torch
 
 
 
 
 
2
  from diffusers import (
3
  StableDiffusionXLControlNetPipeline,
4
  ControlNetModel,
5
- LCMScheduler,
6
- AutoencoderKL
7
  )
8
- from huggingface_hub import hf_hub_download
9
  from insightface.app import FaceAnalysis
10
- from config import Config
11
- import os
12
 
13
  class ModelHandler:
14
  def __init__(self):
@@ -16,74 +19,105 @@ class ModelHandler:
16
  self.app = None # InsightFace
17
  self.zoe_detector = None
18
  self.lineart_detector = None
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def load_models(self):
21
- print("Loading Face Analysis (InsightFace)...")
22
- # Ensure you have downloaded antelopev2 to the local folder or handle download here
23
- self.app = FaceAnalysis(
24
- name='antelopev2',
25
- root='./',
26
- providers=['CPUExecutionProvider']
27
- )
28
- self.app.prepare(ctx_id=0, det_size=(640, 640))
29
 
30
- print("Loading ControlNets...")
31
- # 1. InstantID ControlNet
32
  cn_instantid = ControlNetModel.from_pretrained(
33
  Config.CN_INSTANTID_REPO,
34
  subfolder="controlnet",
35
  torch_dtype=Config.DTYPE
36
  )
37
- # 2. Zoe Depth ControlNet
38
- cn_zoe = ControlNetModel.from_pretrained(
39
- Config.CN_ZOE_REPO,
40
- torch_dtype=Config.DTYPE
41
- )
42
- # 3. LineArt ControlNet
43
- cn_lineart = ControlNetModel.from_pretrained(
44
- Config.CN_LINEART_REPO,
45
- torch_dtype=Config.DTYPE
46
- )
47
 
48
- print("Loading SDXL Pipeline...")
49
- # Load custom checkpoint (Single File)
50
- ckpt_path = hf_hub_download(repo_id=Config.REPO_ID, filename=Config.CHECKPOINT_FILENAME)
51
-
52
  self.pipeline = StableDiffusionXLControlNetPipeline.from_single_file(
53
- ckpt_path,
 
54
  controlnet=[cn_instantid, cn_zoe, cn_lineart], # ORDER MATTERS
55
  torch_dtype=Config.DTYPE,
 
56
  ).to(Config.DEVICE)
57
 
58
- # Load LCM Scheduler
59
  self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
60
 
 
61
  print("Loading Adapters (IP-Adapter & LoRA)...")
62
- # Load InstantID IP-Adapter
63
  self.pipeline.load_ip_adapter_instantid(Config.INSTANTID_REPO)
64
-
65
- # Load Custom Style LoRA
66
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
67
- self.pipeline.fuse_lora() # Optional: Fuse for speed, might make switching harder if you have multiple styles
68
 
69
- print("Loading Preprocessors...")
70
- from controlnet_aux import ZoeDetector, LineartDetector
71
- self.zoe_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators")
72
- self.lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
73
-
74
- print("Models Loaded Successfully.")
75
 
 
 
76
  def get_face_embedding(self, image):
77
- # Convert PIL to CV2
78
- cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
79
-
80
- faces = self.app.get(cv2_img)
81
-
82
- if len(faces) == 0:
83
- return None # Return None instead of crashing
84
 
85
- # Sort by size (width * height) to find the main character
86
- faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
87
-
88
- # Return the largest face
89
- return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from config import Config
6
+
7
  from diffusers import (
8
  StableDiffusionXLControlNetPipeline,
9
  ControlNetModel,
10
+ LCMScheduler
 
11
  )
12
+ from huggingface_hub import snapshot_download
13
  from insightface.app import FaceAnalysis
14
+ from controlnet_aux import ZoeDetector, LineartDetector
 
15
 
16
  class ModelHandler:
17
  def __init__(self):
 
19
  self.app = None # InsightFace
20
  self.zoe_detector = None
21
  self.lineart_detector = None
22
+ self.face_analysis_loaded = False
23
+
24
+ def load_face_analysis(self):
25
+ """
26
+ Load face analysis model.
27
+ Downloads from HF Hub (fast) instead of GitHub (slow) if not present.
28
+ Forces CPU to avoid ZeroGPU initialization errors.
29
+ """
30
+ print("Loading face analysis model...")
31
+ model_root_path = os.path.join(Config.ANTELOPEV2_ROOT, Config.ANTELOPEV2_NAME)
32
 
33
+ # 1. Download from HF Hub (Much faster than default InsightFace download)
34
+ if not os.path.exists(os.path.join(model_root_path, "scrfd_10g_bnkps.onnx")):
35
+ print("Downloading AntelopeV2 models from HuggingFace...")
36
+ try:
37
+ snapshot_download(
38
+ repo_id=Config.ANTELOPEV2_REPO,
39
+ local_dir=model_root_path,
40
+ local_dir_use_symlinks=False
41
+ )
42
+ except Exception as e:
43
+ print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
44
+ return False
45
+
46
+ # 2. Initialize InsightFace on CPU
47
+ # CRITICAL: Use ONLY 'CPUExecutionProvider'.
48
+ # ZeroGPU will crash if you try to look for CUDA during init.
49
+ try:
50
+ self.app = FaceAnalysis(
51
+ name=Config.ANTELOPEV2_NAME,
52
+ root=Config.ANTELOPEV2_ROOT,
53
+ providers=['CPUExecutionProvider']
54
+ )
55
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
56
+ print(f" [OK] Face analysis model loaded successfully.")
57
+ return True
58
+
59
+ except Exception as e:
60
+ print(f" [WARNING] Face detection system failed to initialize: {e}")
61
+ return False
62
+
63
  def load_models(self):
64
+ # 1. Load Face Analysis
65
+ self.face_analysis_loaded = self.load_face_analysis()
 
 
 
 
 
 
66
 
67
+ # 2. Load ControlNets
68
+ print("Loading ControlNets (InstantID, Zoe, LineArt)...")
69
  cn_instantid = ControlNetModel.from_pretrained(
70
  Config.CN_INSTANTID_REPO,
71
  subfolder="controlnet",
72
  torch_dtype=Config.DTYPE
73
  )
74
+ cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
75
+ cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
 
 
 
 
 
 
 
 
76
 
77
+ # 3. Load SDXL Pipeline
78
+ print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
 
 
79
  self.pipeline = StableDiffusionXLControlNetPipeline.from_single_file(
80
+ Config.REPO_ID,
81
+ filename=Config.CHECKPOINT_FILENAME,
82
  controlnet=[cn_instantid, cn_zoe, cn_lineart], # ORDER MATTERS
83
  torch_dtype=Config.DTYPE,
84
+ use_safetensors=True
85
  ).to(Config.DEVICE)
86
 
87
+ # 4. Set Scheduler
88
  self.pipeline.scheduler = LCMScheduler.from_config(self.pipeline.scheduler.config)
89
 
90
+ # 5. Load Adapters (IP-Adapter & LoRA)
91
  print("Loading Adapters (IP-Adapter & LoRA)...")
 
92
  self.pipeline.load_ip_adapter_instantid(Config.INSTANTID_REPO)
 
 
93
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
94
+ self.pipeline.fuse_lora(lora_scale=1.0) # Fuse with scale 1.0
95
 
96
+ # 6. Load Preprocessors
97
+ print("Loading Preprocessors (Zoe, LineArt)...")
98
+ self.zoe_detector = ZoeDetector.from_pretrained(Config.ANNOTATOR_REPO)
99
+ self.lineart_detector = LineartDetector.from_pretrained(Config.ANNOTATOR_REPO)
 
 
100
 
101
+ print("--- All models loaded successfully ---")
102
+
103
  def get_face_embedding(self, image):
104
+ """Extracts face embedding, returns None if no face is found."""
105
+ if not self.face_analysis_loaded:
106
+ return None
 
 
 
 
107
 
108
+ try:
109
+ # Convert PIL to CV2
110
+ cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
111
+ faces = self.app.get(cv2_img)
112
+
113
+ if len(faces) == 0:
114
+ return None
115
+
116
+ # Sort by size (width * height) to find the main character
117
+ faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
118
+
119
+ # Return the largest face
120
+ return torch.tensor(faces[0].normed_embedding).unsqueeze(0)
121
+ except Exception as e:
122
+ print(f"Face embedding extraction failed: {e}")
123
+ return None