primerz commited on
Commit
8444fe7
·
verified ·
1 Parent(s): 1f6c7a0

Upload 7 files

Browse files
Files changed (6) hide show
  1. app.py +27 -30
  2. config.py +15 -6
  3. generator.py +63 -16
  4. model.py +124 -19
  5. requirements.txt +4 -1
  6. utils.py +36 -1
app.py CHANGED
@@ -1,24 +1,20 @@
1
  import gradio as gr
 
2
  import torch
3
  from model import ModelHandler
4
  from generator import Generator
 
5
  from config import Config
6
 
7
- # Handle spaces module for HF Spaces ZeroGPU (optional)
8
- try:
9
- import spaces
10
- SPACES_AVAILABLE = True
11
- except ImportError:
12
- SPACES_AVAILABLE = False
13
- print("Running without HF Spaces ZeroGPU support")
14
-
15
  # 1. Initialize Models Globally (in RAM)
 
16
  print("Initializing Application...")
17
  handler = ModelHandler()
18
  handler.load_models()
19
  gen = Generator(handler)
20
 
21
- # 2. Define Inference Function
 
22
  def process_img(
23
  image,
24
  prompt,
@@ -26,9 +22,10 @@ def process_img(
26
  cfg_scale,
27
  steps,
28
  img_strength,
 
29
  depth_strength,
30
  edge_strength,
31
- lora_strength,
32
  seed
33
  ):
34
  if image is None:
@@ -44,9 +41,10 @@ def process_img(
44
  guidance_scale=cfg_scale,
45
  num_inference_steps=steps,
46
  img2img_strength=img_strength,
 
47
  depth_strength=depth_strength,
48
  lineart_strength=edge_strength,
49
- lora_strength=lora_strength,
50
  seed=seed
51
  )
52
  print("--- Generation Complete ---")
@@ -56,16 +54,12 @@ def process_img(
56
  print(f"Error during generation: {e}")
57
  raise gr.Error(f"An error occurred: {str(e)}")
58
 
59
- # Apply spaces.GPU decorator only if available
60
- if SPACES_AVAILABLE:
61
- process_img = spaces.GPU(duration=20)(process_img)
62
-
63
  # 3. Build Gradio Interface
64
- with gr.Blocks(title="Image To Pixel Art", theme=gr.themes.Soft()) as demo:
65
  gr.Markdown(
66
  """
67
- # 🎮 Image to Pixel Art
68
- Upload any image and transform it into pixel art style!
69
  """
70
  )
71
 
@@ -81,7 +75,7 @@ with gr.Blocks(title="Image To Pixel Art", theme=gr.themes.Soft()) as demo:
81
  negative_prompt = gr.Textbox(
82
  label="Negative Prompt (Optional)",
83
  placeholder="e.g., blurry, text, watermark, bad art...",
84
- value=Config.DEFAULT_NEGATIVE_PROMPT
85
  )
86
 
87
  with gr.Accordion("Advanced Settings", open=False):
@@ -116,6 +110,14 @@ with gr.Blocks(title="Image To Pixel Art", theme=gr.themes.Soft()) as demo:
116
  value=Config.IMG_STRENGTH,
117
  label="Image Strength (Img2Img)"
118
  )
 
 
 
 
 
 
 
 
119
  depth_strength = gr.Slider(
120
  elem_id="depth_strength",
121
  minimum=0.0,
@@ -132,14 +134,8 @@ with gr.Blocks(title="Image To Pixel Art", theme=gr.themes.Soft()) as demo:
132
  value=Config.EDGE_STRENGTH,
133
  label="EdgeMap Strength (LineArt)"
134
  )
135
- lora_strength = gr.Slider(
136
- elem_id="lora_strength",
137
- minimum=0.0,
138
- maximum=2.0,
139
- step=0.05,
140
- value=Config.LORA_STRENGTH,
141
- label="LoRA Strength (Pixel Art Style)"
142
- )
143
 
144
  run_btn = gr.Button("Generate Pixel Art", variant="primary")
145
 
@@ -154,9 +150,10 @@ with gr.Blocks(title="Image To Pixel Art", theme=gr.themes.Soft()) as demo:
154
  cfg_scale,
155
  steps,
156
  img_strength,
 
157
  depth_strength,
158
  edge_strength,
159
- lora_strength,
160
  seed
161
  ]
162
 
@@ -173,5 +170,5 @@ if __name__ == "__main__":
173
  demo.launch(
174
  server_name="0.0.0.0",
175
  server_port=7860,
176
- show_api=True
177
- )
 
1
  import gradio as gr
2
+ import spaces
3
  import torch
4
  from model import ModelHandler
5
  from generator import Generator
6
+ # --- IMPORT CONFIG ---
7
  from config import Config
8
 
 
 
 
 
 
 
 
 
9
  # 1. Initialize Models Globally (in RAM)
10
+ # ZeroGPU will move them to VRAM inside the @spaces.GPU function
11
  print("Initializing Application...")
12
  handler = ModelHandler()
13
  handler.load_models()
14
  gen = Generator(handler)
15
 
16
+ # 2. Define GPU-enabled Inference Function
17
+ @spaces.GPU(duration=20) # <-- MODIFIED
18
  def process_img(
19
  image,
20
  prompt,
 
22
  cfg_scale,
23
  steps,
24
  img_strength,
25
+ face_strength,
26
  depth_strength,
27
  edge_strength,
28
+ # tile_strength, # <-- REMOVED
29
  seed
30
  ):
31
  if image is None:
 
41
  guidance_scale=cfg_scale,
42
  num_inference_steps=steps,
43
  img2img_strength=img_strength,
44
+ face_strength=face_strength,
45
  depth_strength=depth_strength,
46
  lineart_strength=edge_strength,
47
+ # tile_strength=tile_strength, # <-- REMOVED
48
  seed=seed
49
  )
50
  print("--- Generation Complete ---")
 
54
  print(f"Error during generation: {e}")
55
  raise gr.Error(f"An error occurred: {str(e)}")
56
 
 
 
 
 
57
  # 3. Build Gradio Interface
58
+ with gr.Blocks(title="Face To Pixel Art", theme=gr.themes.Soft()) as demo:
59
  gr.Markdown(
60
  """
61
+ # 🎮 Face to Pixel Art
62
+ Upload any image. If there is a face, we'll keep the identity. If not, we'll pixelate the scene!
63
  """
64
  )
65
 
 
75
  negative_prompt = gr.Textbox(
76
  label="Negative Prompt (Optional)",
77
  placeholder="e.g., blurry, text, watermark, bad art...",
78
+ value=Config.DEFAULT_NEGATIVE_PROMPT # <-- MODIFIED
79
  )
80
 
81
  with gr.Accordion("Advanced Settings", open=False):
 
110
  value=Config.IMG_STRENGTH,
111
  label="Image Strength (Img2Img)"
112
  )
113
+ face_strength = gr.Slider(
114
+ elem_id="face_strength",
115
+ minimum=0.0,
116
+ maximum=1.0,
117
+ step=0.05,
118
+ value=Config.FACE_STRENGTH,
119
+ label="Face Strength"
120
+ )
121
  depth_strength = gr.Slider(
122
  elem_id="depth_strength",
123
  minimum=0.0,
 
134
  value=Config.EDGE_STRENGTH,
135
  label="EdgeMap Strength (LineArt)"
136
  )
137
+ # --- MODIFIED: Renamed slider ---
138
+ # tile_strength = gr.Slider(...) # <-- REMOVED
 
 
 
 
 
 
139
 
140
  run_btn = gr.Button("Generate Pixel Art", variant="primary")
141
 
 
150
  cfg_scale,
151
  steps,
152
  img_strength,
153
+ face_strength,
154
  depth_strength,
155
  edge_strength,
156
+ # tile_strength, # <-- REMOVED
157
  seed
158
  ]
159
 
 
170
  demo.launch(
171
  server_name="0.0.0.0",
172
  server_port=7860,
173
+ show_api=True # share=True is not needed on Spaces
174
+ )
config.py CHANGED
@@ -7,16 +7,19 @@ class Config:
7
 
8
  # Base Model & LoRA (from primerz/pixagram)
9
  REPO_ID = "primerz/pixagram"
10
- CHECKPOINT_FILENAME = "picto.safetensors"
11
  LORA_FILENAME = "retroart.safetensors"
12
- LORA_STRENGTH = 1.0 # Default strength, now configurable via UI
13
 
14
  # Trigger Words for the LoRA
15
- STYLE_TRIGGER = "p1x3l4rt, pixel art"
16
 
17
  # Default Negative Prompt
18
  DEFAULT_NEGATIVE_PROMPT = "Ugly, artifacts, blurry, disformed, photo-realistic, photo, photography, realistic, low-quality, text."
19
 
 
 
 
20
  # ControlNet Repos
21
  CN_ZOE_REPO = "diffusers/controlnet-zoE-depth-sdxl-1.0"
22
  CN_LINEART_REPO = "ShermanG/ControlNet-Standard-Lineart-for-SDXL"
@@ -27,10 +30,16 @@ class Config:
27
  # Captioning Model
28
  CAPTIONER_REPO = "Salesforce/blip-image-captioning-base"
29
 
 
 
 
 
 
30
  # Gradio Parameters
31
  CGF_SCALE = 1.2
32
  STEPS_NUMBER = 10
33
  IMG_STRENGTH = 0.65
34
- DEPTH_STRENGTH = 0.8 # Updated default
35
- EDGE_STRENGTH = 0.2 # Updated default
36
- CLIP_SKIP = 1
 
 
7
 
8
  # Base Model & LoRA (from primerz/pixagram)
9
  REPO_ID = "primerz/pixagram"
10
+ CHECKPOINT_FILENAME = "horizon.safetensors"
11
  LORA_FILENAME = "retroart.safetensors"
12
+ LORA_STRENGTH = 1.0 # Fixed strength for fusion
13
 
14
  # Trigger Words for the LoRA
15
+ STYLE_TRIGGER = "HD pixel art artwork and high quality illustration in retroart style of "
16
 
17
  # Default Negative Prompt
18
  DEFAULT_NEGATIVE_PROMPT = "Ugly, artifacts, blurry, disformed, photo-realistic, photo, photography, realistic, low-quality, text."
19
 
20
+ # InstantID Assets
21
+ INSTANTID_REPO = "InstantX/InstantID"
22
+
23
  # ControlNet Repos
24
  CN_ZOE_REPO = "diffusers/controlnet-zoE-depth-sdxl-1.0"
25
  CN_LINEART_REPO = "ShermanG/ControlNet-Standard-Lineart-for-SDXL"
 
30
  # Captioning Model
31
  CAPTIONER_REPO = "Salesforce/blip-image-captioning-base"
32
 
33
+ # InsightFace Model (HF Hub mirror)
34
+ ANTELOPEV2_REPO = "DIAMONIK7777/antelopev2"
35
+ ANTELOPEV2_ROOT = "." # Parent folder
36
+ ANTELOPEV2_NAME = "antelopev2"
37
+
38
  # Gradio Parameters
39
  CGF_SCALE = 1.2
40
  STEPS_NUMBER = 10
41
  IMG_STRENGTH = 0.65
42
+ FACE_STRENGTH = 0.75
43
+ DEPTH_STRENGTH = 0.75
44
+ EDGE_STRENGTH = 0.75
45
+ CLIP_SKIP = 2
generator.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from config import Config
3
- from utils import resize_image_to_1mp, get_caption
4
  from PIL import Image
5
 
6
  class Generator:
@@ -34,9 +34,9 @@ class Generator:
34
  guidance_scale=1.5,
35
  num_inference_steps=6,
36
  img2img_strength=0.3,
37
- depth_strength=0.8,
38
- lineart_strength=0.2,
39
- lora_strength=1.0,
40
  seed=-1
41
  ):
42
  # 1. Pre-process Inputs
@@ -44,8 +44,8 @@ class Generator:
44
  processed_image = resize_image_to_1mp(input_image)
45
  target_width, target_height = processed_image.size
46
 
47
- # 2. Set LoRA scale
48
- self.mh.set_lora_scale(lora_strength)
49
 
50
  # 3. Generate Prompt
51
  if not user_prompt.strip():
@@ -61,40 +61,87 @@ class Generator:
61
  print(f"Prompt: {final_prompt}")
62
  print(f"Negative Prompt: {negative_prompt}")
63
 
64
- # 4. Generate Control Maps (Depth, LineArt)
65
  print("Generating Control Maps (Depth, LineArt)...")
66
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
67
 
68
- # 5. Set controlnet conditioning scales
69
- # ControlNet order: [Zoe_Depth, LineArt]
70
- controlnet_conditioning_scale = [depth_strength, lineart_strength]
71
 
72
- # Control guidance end - keep structure through the denoising process
73
- control_guidance_end = [depth_strength, lineart_strength]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # --- Seed/Generator Logic ---
76
  if seed == -1 or seed is None:
77
  seed = torch.Generator().seed()
78
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
79
  print(f"Using seed: {seed}")
 
80
 
81
  # 6. Run Inference
82
  print("Running pipeline...")
83
  result = self.mh.pipeline(
84
  prompt=final_prompt,
85
  negative_prompt=negative_prompt,
86
- image=processed_image,
87
- control_image=[depth_map, lineart_map],
 
88
  generator=generator,
89
 
90
- # Parameters from UI
91
  strength=img2img_strength,
92
  num_inference_steps=num_inference_steps,
93
  guidance_scale=guidance_scale,
 
94
 
95
  controlnet_conditioning_scale=controlnet_conditioning_scale,
96
  control_guidance_end=control_guidance_end,
97
 
 
 
98
  ).images[0]
99
 
100
- return result
 
1
  import torch
2
  from config import Config
3
+ from utils import resize_image_to_1mp, get_caption, draw_kps
4
  from PIL import Image
5
 
6
  class Generator:
 
34
  guidance_scale=1.5,
35
  num_inference_steps=6,
36
  img2img_strength=0.3,
37
+ face_strength=0.3,
38
+ depth_strength=0.3,
39
+ lineart_strength=0.3,
40
  seed=-1
41
  ):
42
  # 1. Pre-process Inputs
 
44
  processed_image = resize_image_to_1mp(input_image)
45
  target_width, target_height = processed_image.size
46
 
47
+ # 2. Get Face Info (replaces get_face_embedding)
48
+ face_info = self.mh.get_face_info(processed_image)
49
 
50
  # 3. Generate Prompt
51
  if not user_prompt.strip():
 
61
  print(f"Prompt: {final_prompt}")
62
  print(f"Negative Prompt: {negative_prompt}")
63
 
64
+ # 4. Generate OTHER Control Maps (Structure)
65
  print("Generating Control Maps (Depth, LineArt)...")
66
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
67
 
68
+ # 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
69
+ # ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
 
70
 
71
+ if face_info is not None:
72
+ print("Face detected: Applying InstantID with keypoints.")
73
+
74
+ # We use face_info['embedding'] (raw) instead of normed_embedding.
75
+ # Raw embedding has higher magnitude (~20-30) required for the adapter.
76
+ face_emb = torch.tensor(
77
+ face_info['embedding'],
78
+ dtype=Config.DTYPE,
79
+ device=Config.DEVICE
80
+ ).unsqueeze(0)
81
+
82
+ # Create keypoint image
83
+ face_kps = draw_kps(processed_image, face_info['kps'])
84
+
85
+ # Set strengths
86
+ controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength]
87
+
88
+ # --- UPDATED: Reduced IP Adapter Scale ---
89
+ # Lowered from 0.8 to 0.7 to allow LoRA style (pixel art) to
90
+ # override realistic skin textures while keeping identity.
91
+ self.mh.pipeline.set_ip_adapter_scale(0.7)
92
+ else:
93
+ print("No face detected: Disabling InstantID.")
94
+ # Create dummy embedding
95
+ face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
96
+ # Create dummy keypoint image (black)
97
+ face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
98
+
99
+ # Set strengths
100
+ controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
101
+ self.mh.pipeline.set_ip_adapter_scale(0.0)
102
+
103
+ # --- UPDATED: Control Guidance End Strategy ---
104
+ # We cap the Face ControlNet duration.
105
+ # Even if strength is 1.0, we stop it at 0.6 (60%) of the steps.
106
+ # This leaves the final 40% of steps pure for the Pixel Art LoRA
107
+ # to "pixelize" the face without the ControlNet trying to fix it back to a photo.
108
+
109
+ face_end_step = min(0.6, face_strength)
110
+
111
+ control_guidance_end = [
112
+ face_end_step, # InstantID: Stop early for style
113
+ depth_strength, # Depth: Keep structure longer
114
+ lineart_strength # Lineart: Keep outlines longer
115
+ ]
116
 
117
  # --- Seed/Generator Logic ---
118
  if seed == -1 or seed is None:
119
  seed = torch.Generator().seed()
120
  generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
121
  print(f"Using seed: {seed}")
122
+ # --- END ---
123
 
124
  # 6. Run Inference
125
  print("Running pipeline...")
126
  result = self.mh.pipeline(
127
  prompt=final_prompt,
128
  negative_prompt=negative_prompt,
129
+ image=processed_image, # Base img2img image
130
+ control_image=[face_kps, depth_map, lineart_map],
131
+ image_embeds=face_emb, # Face identity embedding
132
  generator=generator,
133
 
134
+ # --- Parameters from UI ---
135
  strength=img2img_strength,
136
  num_inference_steps=num_inference_steps,
137
  guidance_scale=guidance_scale,
138
+ # --- End Parameters from UI ---
139
 
140
  controlnet_conditioning_scale=controlnet_conditioning_scale,
141
  control_guidance_end=control_guidance_end,
142
 
143
+ clip_skip=Config.CLIP_SKIP,
144
+
145
  ).images[0]
146
 
147
+ return result
model.py CHANGED
@@ -1,38 +1,93 @@
1
  import torch
2
  import os
 
 
3
  from config import Config
4
 
5
  from diffusers import (
6
  ControlNetModel,
7
  LCMScheduler,
8
- StableDiffusionXLControlNetImg2ImgPipeline,
9
  )
10
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
11
 
12
- from huggingface_hub import hf_hub_download
 
 
 
 
13
  from controlnet_aux import LeresDetector, LineartAnimeDetector
14
 
15
  class ModelHandler:
16
  def __init__(self):
17
  self.pipeline = None
 
18
  self.leres_detector = None
19
  self.lineart_anime_detector = None
20
- self._lora_fused = False
21
- self._current_lora_scale = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def load_models(self):
24
- # 1. Load ControlNets (Zoe Depth and LineArt only)
25
- print("Loading ControlNets (Zoe, LineArt)...")
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
27
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
28
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
29
 
30
- # Wrap the list of models in a MultiControlNetModel
31
  print("Wrapping ControlNets in MultiControlNetModel...")
32
- controlnet_list = [cn_zoe, cn_lineart]
33
  controlnet = MultiControlNetModel(controlnet_list)
 
34
 
35
- # 2. Load SDXL Pipeline
36
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
37
 
38
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
@@ -46,7 +101,7 @@ class ModelHandler:
46
  )
47
 
48
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
49
- self.pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
50
  checkpoint_local_path,
51
  controlnet=controlnet,
52
  torch_dtype=Config.DTYPE,
@@ -65,6 +120,7 @@ class ModelHandler:
65
  scheduler_config = self.pipeline.scheduler.config
66
  scheduler_config['clip_sample'] = False
67
 
 
68
  self.pipeline.scheduler = LCMScheduler.from_config(
69
  scheduler_config,
70
  timestep_spacing="trailing",
@@ -72,25 +128,74 @@ class ModelHandler:
72
  )
73
  print(" [OK] LCMScheduler loaded (clip_sample=False, trailing spacing).")
74
 
75
- # 3. Load LoRA weights (keep unfused for dynamic scaling)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  print("Loading LCM LoRA weights...")
 
77
  self.pipeline.load_lora_weights(
78
  Config.REPO_ID,
79
  weight_name=Config.LORA_FILENAME,
80
  adapter_name="lcm_lora"
81
  )
82
- print(" [OK] LoRA weights loaded (unfused for dynamic scaling).")
83
 
84
- # 4. Load Preprocessors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  print("Loading Preprocessors (LeReS, LineArtAnime)...")
86
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
87
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
88
 
89
  print("--- All models loaded successfully ---")
90
 
91
- def set_lora_scale(self, scale):
92
- """Set the LoRA scale for generation."""
93
- if self.pipeline is not None:
94
- self.pipeline.set_adapters(["lcm_lora"], adapter_weights=[scale])
95
- self._current_lora_scale = scale
96
- print(f" [OK] LoRA scale set to {scale}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import os
3
+ import cv2
4
+ import numpy as np
5
  from config import Config
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
  LCMScheduler,
10
+ # AutoencoderKL # Removed as requested
11
  )
12
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
13
 
14
+ # Import the custom pipeline from your local file
15
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
16
+
17
+ from huggingface_hub import snapshot_download, hf_hub_download
18
+ from insightface.app import FaceAnalysis
19
  from controlnet_aux import LeresDetector, LineartAnimeDetector
20
 
21
  class ModelHandler:
22
  def __init__(self):
23
  self.pipeline = None
24
+ self.app = None # InsightFace
25
  self.leres_detector = None
26
  self.lineart_anime_detector = None
27
+ self.face_analysis_loaded = False
28
+
29
+ def load_face_analysis(self):
30
+ """
31
+ Load face analysis model.
32
+ Downloads from HF Hub to the path insightface expects.
33
+ """
34
+ print("Loading face analysis model...")
35
+
36
+ model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
37
+
38
+ if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
39
+ print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
40
+ try:
41
+ snapshot_download(
42
+ repo_id=Config.ANTELOPEV2_REPO,
43
+ local_dir=model_path, # Download to the correct expected path
44
+ )
45
+ except Exception as e:
46
+ print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
47
+ return False
48
+
49
+ try:
50
+ self.app = FaceAnalysis(
51
+ name=Config.ANTELOPEV2_NAME,
52
+ root=Config.ANTELOPEV2_ROOT,
53
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
54
+ )
55
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
56
+ print(f" [OK] Face analysis model loaded successfully.")
57
+ return True
58
+
59
+ except Exception as e:
60
+ print(f" [WARNING] Face detection system failed to initialize: {e}")
61
+ return False
62
 
63
  def load_models(self):
64
+ # 1. Load Face Analysis
65
+ self.face_analysis_loaded = self.load_face_analysis()
66
+
67
+ # 2. Load ControlNets
68
+ print("Loading ControlNets (InstantID, Zoe, LineArt)...")
69
+
70
+ # Load the InstantID ControlNet from the correct subfolder
71
+ print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
72
+ cn_instantid = ControlNetModel.from_pretrained(
73
+ Config.INSTANTID_REPO,
74
+ subfolder="ControlNetModel",
75
+ torch_dtype=Config.DTYPE
76
+ )
77
+ print(" [OK] Loaded InstantID ControlNet.")
78
 
79
+ # Load other ControlNets normally
80
+ print("Loading Zoe and LineArt ControlNets...")
81
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
82
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
83
 
84
+ # --- Manually wrap the list of models in a MultiControlNetModel ---
85
  print("Wrapping ControlNets in MultiControlNetModel...")
86
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
87
  controlnet = MultiControlNetModel(controlnet_list)
88
+ # --- End wrapping ---
89
 
90
+ # 3. Load SDXL Pipeline
91
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
92
 
93
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
 
101
  )
102
 
103
  print(f"Loading pipeline from local file: {checkpoint_local_path}")
104
+ self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
105
  checkpoint_local_path,
106
  controlnet=controlnet,
107
  torch_dtype=Config.DTYPE,
 
120
  scheduler_config = self.pipeline.scheduler.config
121
  scheduler_config['clip_sample'] = False
122
 
123
+ # --- MODIFIED: optimize for sharp pixel art style ---
124
  self.pipeline.scheduler = LCMScheduler.from_config(
125
  scheduler_config,
126
  timestep_spacing="trailing",
 
128
  )
129
  print(" [OK] LCMScheduler loaded (clip_sample=False, trailing spacing).")
130
 
131
+ # 5. Load Adapters (IP-Adapter & LoRA)
132
+ print("Loading Adapters (IP-Adapter & LoRA)...")
133
+
134
+ ip_adapter_filename = "ip-adapter.bin"
135
+ ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
136
+
137
+ if not os.path.exists(ip_adapter_local_path):
138
+ print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
139
+ hf_hub_download(
140
+ repo_id=Config.INSTANTID_REPO,
141
+ filename=ip_adapter_filename,
142
+ local_dir="./models",
143
+ local_dir_use_symlinks=False
144
+ )
145
+
146
+ print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
147
+ # Load InstantID adapter first
148
+ self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
149
+
150
  print("Loading LCM LoRA weights...")
151
+ # KEY CHANGE 1: Assign an adapter_name so Diffusers distinguishes it from InstantID
152
  self.pipeline.load_lora_weights(
153
  Config.REPO_ID,
154
  weight_name=Config.LORA_FILENAME,
155
  adapter_name="lcm_lora"
156
  )
 
157
 
158
+ # KEY CHANGE 2: Hardcode scale to 1.0 for LCM to remove trigger word dependency
159
+ # (Or ensure Config.LORA_STRENGTH is set to 1.0)
160
+ fuse_scale = 1.0
161
+
162
+ print(f"Fusing LoRA 'lcm_lora' with scale {fuse_scale}...")
163
+
164
+ # KEY CHANGE 3: Fuse ONLY the named adapter
165
+ self.pipeline.fuse_lora(
166
+ adapter_names=["lcm_lora"],
167
+ lora_scale=fuse_scale
168
+ )
169
+
170
+ # KEY CHANGE 4: Unload the side-car weights to free VRAM (since they are now inside the UNet)
171
+ self.pipeline.unload_lora_weights()
172
+
173
+ print(" [OK] LoRA fused and cleaned up.")
174
+
175
+ # 6. Load Preprocessors
176
  print("Loading Preprocessors (LeReS, LineArtAnime)...")
177
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
178
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
179
 
180
  print("--- All models loaded successfully ---")
181
 
182
+ def get_face_info(self, image):
183
+ """Extracts the largest face, returns insightface result object."""
184
+ if not self.face_analysis_loaded:
185
+ return None
186
+
187
+ try:
188
+ cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
189
+ faces = self.app.get(cv2_img)
190
+
191
+ if len(faces) == 0:
192
+ return None
193
+
194
+ # Sort by size (width * height) to find the main character
195
+ faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
196
+
197
+ # Return the largest face info
198
+ return faces[0]
199
+ except Exception as e:
200
+ print(f"Face embedding extraction failed: {e}")
201
+ return None
requirements.txt CHANGED
@@ -5,7 +5,10 @@ peft
5
  torch
6
  opencv-python-headless
7
  Pillow
 
 
8
  gradio>=4.0.0
9
  controlnet_aux
10
  huggingface_hub
11
- timm
 
 
5
  torch
6
  opencv-python-headless
7
  Pillow
8
+ insightface
9
+ onnxruntime
10
  gradio>=4.0.0
11
  controlnet_aux
12
  huggingface_hub
13
+ mediapipe
14
+ timm
utils.py CHANGED
@@ -2,6 +2,9 @@ from PIL import Image
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  import torch
4
  from config import Config
 
 
 
5
 
6
  # Simple global caching for the captioner
7
  captioner_processor = None
@@ -18,7 +21,7 @@ def resize_image_to_1mp(image):
18
  new_h = int((target_pixels / aspect_ratio) ** 0.5)
19
  new_w = int(new_h * aspect_ratio)
20
 
21
- # Ensure divisibility by 64 for efficiency
22
  new_w = (new_w // 64) * 64
23
  new_h = (new_h // 64) * 64
24
 
@@ -40,3 +43,35 @@ def get_caption(image):
40
  out = captioner_model.generate(**inputs)
41
  caption = captioner_processor.decode(out[0], skip_special_tokens=True)
42
  return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  import torch
4
  from config import Config
5
+ import cv2
6
+ import numpy as np
7
+ import math
8
 
9
  # Simple global caching for the captioner
10
  captioner_processor = None
 
21
  new_h = int((target_pixels / aspect_ratio) ** 0.5)
22
  new_w = int(new_h * aspect_ratio)
23
 
24
+ # Ensure divisibility by 48 for efficiency
25
  new_w = (new_w // 64) * 64
26
  new_h = (new_h // 64) * 64
27
 
 
43
  out = captioner_model.generate(**inputs)
44
  caption = captioner_processor.decode(out[0], skip_special_tokens=True)
45
  return caption
46
+
47
+ # --- ADDED: Function from your provided file ---
48
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
49
+ stickwidth = 4
50
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
51
+ kps = np.array(kps)
52
+
53
+ w, h = image_pil.size
54
+ out_img = np.zeros([h, w, 3])
55
+
56
+ for i in range(len(limbSeq)):
57
+ index = limbSeq[i]
58
+ color = color_list[index[0]]
59
+
60
+ x = kps[index][:, 0]
61
+ y = kps[index][:, 1]
62
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
63
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
64
+ polygon = cv2.ellipse2Poly(
65
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
66
+ )
67
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
68
+ out_img = (out_img * 0.6).astype(np.uint8)
69
+
70
+ for idx_kp, kp in enumerate(kps):
71
+ color = color_list[idx_kp]
72
+ x, y = kp
73
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
74
+
75
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
76
+ return out_img_pil
77
+ # --- END ADDED ---