github-actions[bot] commited on
Commit
18a9cc5
·
1 Parent(s): df4d2da

Update code from GitHub Actions - 2025-12-02 18:06:18

Browse files
Files changed (3) hide show
  1. Stencil.py +114 -32
  2. app.py +50 -10
  3. requirements.txt +3 -0
Stencil.py CHANGED
@@ -6,7 +6,14 @@ using pretrained Stable Diffusion models with prompt engineering.
6
  """
7
 
8
  import torch
9
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
 
 
 
 
 
 
10
  from PIL import Image, ImageOps, ImageEnhance, ImageFilter
11
  from typing import Optional, List, Union
12
  import os
@@ -43,7 +50,9 @@ class StencilGenerator:
43
 
44
  def __init__(
45
  self,
46
- model_id: str = "stabilityai/stable-diffusion-2-1-base",
 
 
47
  device: Optional[str] = None,
48
  use_fp16: bool = True
49
  ):
@@ -51,18 +60,52 @@ class StencilGenerator:
51
  Initialize the Stencil Generator.
52
 
53
  Args:
54
- model_id: HuggingFace model ID for Stable Diffusion model
 
 
55
  device: Device to run on ('cuda', 'cpu', or None for auto-detect)
56
  use_fp16: Whether to use half precision (FP16) for faster inference
57
  """
58
  self.model_id = model_id
 
59
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
60
  self.use_fp16 = use_fp16 and self.device == "cuda"
 
61
 
62
  # Apply monkey-patch to fix transformers version compatibility
63
  _patch_clip_init()
64
 
65
- print(f"Loading model {model_id} on {self.device}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Load the pipeline with version-compatible parameters
68
  dtype = torch.float16 if self.use_fp16 else torch.float32
@@ -86,34 +129,67 @@ class StencilGenerator:
86
  # Uncomment if you have limited VRAM
87
  # self.pipe.enable_vae_slicing()
88
 
89
- print("Model loaded successfully!")
 
 
90
 
91
- # Default stencil prompt suffix - simplified since post-processing does the heavy lifting
92
- self.stencil_suffix = (
93
- "black silhouette, high contrast, simple stencil design, "
94
- "centered in frame, complete object visible, isolated subject"
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Default negative prompt to avoid unwanted features
98
- # self.default_negative_prompt = (
99
- # "color, colorful, photograph, realistic, detailed, complex, "
100
- # "blurry, low quality, watermark, text, cropped, cut off, "
101
- # "partial, multiple subjects, duplicate"
102
- # )
103
-
104
- # Simpler stencil prompt suffix (seems to work better) - simplified since post-processing does the heavy lifting
105
- # self.stencil_suffix = (
106
- # "black silhouette, high contrast, sketch line drawing, simple, simple stencil design, white background, "
107
- # # "centered in frame, complete object visible, isolated subject"
108
- # )
109
-
110
- # Simpler negative prompt (seems to work better) to avoid unwanted features
111
- self.default_negative_prompt = (
112
- "color, colorful, photograph, realistic, detailed, complex, "
113
- # "blurry, low quality, watermark, text, cropped, cut off, "
114
- # "partial, multiple subjects, duplicate"
 
 
 
 
 
 
 
 
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  def _clean_stencil_image(
118
  self,
119
  image: Image.Image,
@@ -234,12 +310,18 @@ class StencilGenerator:
234
  """
235
 
236
 
237
- # Construct full prompt
238
  full_prompt = prompt
239
- if add_stencil_suffix:
240
- full_prompt = f"{prompt}, {self.stencil_suffix}"
 
 
 
 
 
 
241
 
242
- # Use default negative prompt if none provided
243
  full_negative_prompt = negative_prompt or self.default_negative_prompt
244
 
245
  # Set seed if provided
 
6
  """
7
 
8
  import torch
9
+ from diffusers import (
10
+ StableDiffusionPipeline,
11
+ DPMSolverMultistepScheduler,
12
+ UNet2DConditionModel,
13
+ AutoencoderKL,
14
+ PNDMScheduler
15
+ )
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
  from PIL import Image, ImageOps, ImageEnhance, ImageFilter
18
  from typing import Optional, List, Union
19
  import os
 
50
 
51
  def __init__(
52
  self,
53
+ model_id: str = "Manojb/stable-diffusion-2-1-base",
54
+ # model_id: str = "runwayml/stable-diffusion-v1-5",
55
+ checkpoint_path: Optional[str] = None,
56
  device: Optional[str] = None,
57
  use_fp16: bool = True
58
  ):
 
60
  Initialize the Stencil Generator.
61
 
62
  Args:
63
+ model_id: HuggingFace model ID for Stable Diffusion model (used if checkpoint_path is None)
64
+ checkpoint_path: Path to fine-tuned checkpoint directory (e.g., "./checkpoint-1000")
65
+ If provided, loads fine-tuned model instead of pretrained model
66
  device: Device to run on ('cuda', 'cpu', or None for auto-detect)
67
  use_fp16: Whether to use half precision (FP16) for faster inference
68
  """
69
  self.model_id = model_id
70
+ self.checkpoint_path = checkpoint_path
71
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
72
  self.use_fp16 = use_fp16 and self.device == "cuda"
73
+ self.is_checkpoint_model = checkpoint_path is not None
74
 
75
  # Apply monkey-patch to fix transformers version compatibility
76
  _patch_clip_init()
77
 
78
+ # Load model based on whether checkpoint is provided
79
+ if self.is_checkpoint_model:
80
+ self._load_from_checkpoint(checkpoint_path)
81
+ else:
82
+ self._load_from_pretrained(model_id)
83
+
84
+ print("Model loaded successfully!")
85
+
86
+ # Set prompt decoration based on model type
87
+ if self.is_checkpoint_model:
88
+ # Fine-tuned models use simple "sketch of" prefix
89
+ self.stencil_suffix = "Sketch of"
90
+ self.default_negative_prompt = None
91
+ else:
92
+ # Standard SD 2.1 models use detailed stencil suffix
93
+ self.stencil_suffix = (
94
+ "black silhouette, high contrast, simple stencil design, "
95
+ "centered in frame, complete object visible, isolated subject"
96
+ )
97
+ self.default_negative_prompt = (
98
+ "color, colorful, photograph, realistic, detailed, complex, "
99
+ )
100
+
101
+ def _load_from_pretrained(self, model_id: str):
102
+ """
103
+ Load a pretrained model from HuggingFace.
104
+
105
+ Args:
106
+ model_id: HuggingFace model ID
107
+ """
108
+ print(f"Loading pretrained model {model_id} on {self.device}...")
109
 
110
  # Load the pipeline with version-compatible parameters
111
  dtype = torch.float16 if self.use_fp16 else torch.float32
 
129
  # Uncomment if you have limited VRAM
130
  # self.pipe.enable_vae_slicing()
131
 
132
+ def _load_from_checkpoint(self, checkpoint_path: str):
133
+ """
134
+ Load a fine-tuned model from checkpoint directory or HuggingFace Hub.
135
 
136
+ Args:
137
+ checkpoint_path: Path to checkpoint directory containing UNet,
138
+ or HuggingFace Hub model ID (e.g., "username/model-name")
139
+ """
140
+ print(f"Loading fine-tuned checkpoint from {checkpoint_path} on {self.device}...")
141
+
142
+ # Base model for standard components
143
+ base_model = "runwayml/stable-diffusion-v1-5"
144
+
145
+ print("Loading tokenizer...")
146
+ tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")
147
+
148
+ print("Loading text encoder...")
149
+ text_encoder = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder")
150
+
151
+ print("Loading VAE...")
152
+ vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae")
153
 
154
+ print("Loading scheduler...")
155
+ scheduler = PNDMScheduler.from_pretrained(base_model, subfolder="scheduler")
156
+
157
+ # Load fine-tuned UNet from checkpoint
158
+ # Handles both local paths and HuggingFace Hub model IDs
159
+ if os.path.exists(checkpoint_path):
160
+ # Local path - append /unet subdirectory
161
+ unet_path = f"{checkpoint_path}/unet"
162
+ else:
163
+ # Assume it's a HuggingFace Hub model ID
164
+ unet_path = checkpoint_path
165
+
166
+ print(f"Loading fine-tuned UNet from {unet_path}...")
167
+ unet = UNet2DConditionModel.from_pretrained(unet_path, subfolder="unet" if not os.path.exists(checkpoint_path) else None)
168
+
169
+ # Assemble pipeline
170
+ print("Assembling pipeline...")
171
+ self.pipe = StableDiffusionPipeline(
172
+ vae=vae,
173
+ text_encoder=text_encoder,
174
+ tokenizer=tokenizer,
175
+ unet=unet,
176
+ scheduler=scheduler,
177
+ safety_checker=None,
178
+ feature_extractor=None,
179
+ requires_safety_checker=False
180
  )
181
 
182
+ # Move to device with FP16 if enabled
183
+ if self.device == "cuda":
184
+ if self.use_fp16:
185
+ self.pipe.vae = self.pipe.vae.to(self.device, dtype=torch.float16)
186
+ self.pipe.text_encoder = self.pipe.text_encoder.to(self.device, dtype=torch.float16)
187
+ self.pipe.unet = self.pipe.unet.to(self.device, dtype=torch.float16)
188
+ else:
189
+ self.pipe = self.pipe.to(self.device)
190
+ else:
191
+ self.pipe = self.pipe.to(self.device)
192
+
193
  def _clean_stencil_image(
194
  self,
195
  image: Image.Image,
 
310
  """
311
 
312
 
313
+ # Construct full prompt based on model type
314
  full_prompt = prompt
315
+ if self.is_checkpoint_model:
316
+ # For fine-tuned checkpoints, add "sketch of" prefix
317
+ if add_stencil_suffix and not prompt.lower().startswith("sketch of"):
318
+ full_prompt = f"sketch of {prompt}"
319
+ else:
320
+ # For standard models, use stencil suffix
321
+ if add_stencil_suffix:
322
+ full_prompt = f"{prompt}, {self.stencil_suffix}"
323
 
324
+ # Use default negative prompt if none provided (None for checkpoint models)
325
  full_negative_prompt = negative_prompt or self.default_negative_prompt
326
 
327
  # Set seed if provided
app.py CHANGED
@@ -11,6 +11,7 @@ from StencilCV import StencilCV
11
  import torch
12
  from typing import Optional
13
  import numpy as np
 
14
 
15
  MAX_IMAGES = 4
16
 
@@ -20,22 +21,48 @@ class StencilApp:
20
  def __init__(self):
21
  """Initialize the Stencil Generator."""
22
  self.generator = None
 
23
  self.original_images = [] # Store original images for toggling
24
  self.outlined_status = [] # Track which images have outline applied
25
 
26
- def load_model(self):
27
- """Lazy load the model when first needed."""
28
- if self.generator is None:
29
- print("Initializing Stencil Generator...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.generator = StencilGenerator(
31
  model_id="stabilityai/stable-diffusion-2-1-base",
 
32
  use_fp16=torch.cuda.is_available()
33
  )
 
 
34
  return self.generator
35
 
36
  def generate_stencil(
37
  self,
38
  prompt: str,
 
39
  negative_prompt: Optional[str],
40
  num_images: int,
41
  num_inference_steps: int,
@@ -56,8 +83,8 @@ class StencilApp:
56
  return [], "Please enter a prompt!"
57
 
58
  try:
59
- # Load model if not already loaded
60
- generator = self.load_model()
61
 
62
  # Generate the image(s)
63
  images = generator.generate(
@@ -106,8 +133,12 @@ class StencilApp:
106
  if not gallery_data:
107
  return gallery_data, "No images to process!"
108
 
 
109
  if selected_index is None:
110
- return gallery_data, "Please select an image first by clicking on it!"
 
 
 
111
 
112
  if selected_index >= len(self.original_images):
113
  return gallery_data, "Error: Image index out of range!"
@@ -177,10 +208,17 @@ def create_interface():
177
  lines=3
178
  )
179
 
 
 
 
 
 
 
 
180
  num_images = gr.Slider(
181
  minimum=1,
182
  maximum=MAX_IMAGES,
183
- value=1,
184
  step=1,
185
  label="Number of Images",
186
  info="Generate multiple variations to choose from"
@@ -305,9 +343,10 @@ def create_interface():
305
  """
306
  ### Tips for Best Results:
307
  - Keep prompts simple and descriptive
 
 
308
  - Generate multiple images to see variations
309
- - The AI automatically adds stencil styling
310
- - Use negative prompts to avoid unwanted features
311
  - Try the outline option after generation for different styles
312
  - Higher inference steps = better quality (but slower)
313
  """
@@ -318,6 +357,7 @@ def create_interface():
318
  fn=app.generate_stencil,
319
  inputs=[
320
  prompt,
 
321
  negative_prompt,
322
  num_images,
323
  num_inference_steps,
 
11
  import torch
12
  from typing import Optional
13
  import numpy as np
14
+ import os
15
 
16
  MAX_IMAGES = 4
17
 
 
21
  def __init__(self):
22
  """Initialize the Stencil Generator."""
23
  self.generator = None
24
+ self.current_model_type = None
25
  self.original_images = [] # Store original images for toggling
26
  self.outlined_status = [] # Track which images have outline applied
27
 
28
+ def load_model(self, model_type: str = "Standard SD 2.1"):
29
+ """
30
+ Lazy load the model when first needed or reload if model type changed.
31
+
32
+ Args:
33
+ model_type: Type of model to load ("Standard SD 2.1", "Checkpoint-500", "Checkpoint-1000")
34
+ """
35
+ # Reload if model type changed or first load
36
+ if self.generator is None or self.current_model_type != model_type:
37
+ print(f"Initializing Stencil Generator with {model_type}...")
38
+
39
+ # Determine checkpoint path based on model type
40
+ # Can be local path or HuggingFace Hub model ID
41
+ checkpoint_path = None
42
+ if model_type == "Checkpoint-500":
43
+ # Try local path first, fallback to HuggingFace Hub
44
+ checkpoint_path = "./Fine-tuning/checkpoint-500"
45
+ if not os.path.exists(checkpoint_path):
46
+ checkpoint_path = "mrpink925/stencilai-checkpoint-500"
47
+ elif model_type == "Checkpoint-1000":
48
+ # Try local path first, fallback to HuggingFace Hub
49
+ checkpoint_path = "./Fine-tuning/checkpoint-1000"
50
+ if not os.path.exists(checkpoint_path):
51
+ checkpoint_path = "mrpink925/stencilai-checkpoint-1000"
52
+
53
  self.generator = StencilGenerator(
54
  model_id="stabilityai/stable-diffusion-2-1-base",
55
+ checkpoint_path=checkpoint_path,
56
  use_fp16=torch.cuda.is_available()
57
  )
58
+ self.current_model_type = model_type
59
+
60
  return self.generator
61
 
62
  def generate_stencil(
63
  self,
64
  prompt: str,
65
+ model_type: str,
66
  negative_prompt: Optional[str],
67
  num_images: int,
68
  num_inference_steps: int,
 
83
  return [], "Please enter a prompt!"
84
 
85
  try:
86
+ # Load model (will reload if model type changed)
87
+ generator = self.load_model(model_type)
88
 
89
  # Generate the image(s)
90
  images = generator.generate(
 
133
  if not gallery_data:
134
  return gallery_data, "No images to process!"
135
 
136
+ # If there's only 1 image and no selection, default to index 0
137
  if selected_index is None:
138
+ if len(self.original_images) == 1:
139
+ selected_index = 0
140
+ else:
141
+ return gallery_data, "Please select an image first by clicking on it!"
142
 
143
  if selected_index >= len(self.original_images):
144
  return gallery_data, "Error: Image index out of range!"
 
208
  lines=3
209
  )
210
 
211
+ model_selector = gr.Radio(
212
+ choices=["Standard SD 2.1", "Checkpoint-500", "Checkpoint-1000"],
213
+ value="Checkpoint-1000",
214
+ label="Model Type",
215
+ info="Choose between standard model or fine-tuned checkpoints (trained on sketch-style images)"
216
+ )
217
+
218
  num_images = gr.Slider(
219
  minimum=1,
220
  maximum=MAX_IMAGES,
221
+ value=2,
222
  step=1,
223
  label="Number of Images",
224
  info="Generate multiple variations to choose from"
 
343
  """
344
  ### Tips for Best Results:
345
  - Keep prompts simple and descriptive
346
+ - **Standard SD 2.1**: Best for general stencils with detailed prompt engineering
347
+ - **Checkpoint models**: Fine-tuned for sketch-style stencils (automatically adds "sketch of" prefix)
348
  - Generate multiple images to see variations
349
+ - Use negative prompts to avoid unwanted features (works best with Standard SD 2.1)
 
350
  - Try the outline option after generation for different styles
351
  - Higher inference steps = better quality (but slower)
352
  """
 
357
  fn=app.generate_stencil,
358
  inputs=[
359
  prompt,
360
+ model_selector,
361
  negative_prompt,
362
  num_images,
363
  num_inference_steps,
requirements.txt CHANGED
@@ -3,12 +3,15 @@ diffusers>=0.21.0
3
  transformers>=4.30.0
4
  accelerate>=0.20.0
5
  safetensors>=0.3.0
 
6
  gradio>=4.0.0
7
  numpy>=1.24.0
8
  Pillow>=9.0.0
9
  scipy>=1.10.0
10
  scikit-image>=0.20.0
11
  opencv-python>=4.8.0
 
 
12
 
13
  # Note: Pillow, numpy, scipy, scikit-image required for AI-based post-processing
14
  # opencv-python required for StencilCV (traditional computer vision approach)
 
3
  transformers>=4.30.0
4
  accelerate>=0.20.0
5
  safetensors>=0.3.0
6
+ huggingface-hub>=0.16.0
7
  gradio>=4.0.0
8
  numpy>=1.24.0
9
  Pillow>=9.0.0
10
  scipy>=1.10.0
11
  scikit-image>=0.20.0
12
  opencv-python>=4.8.0
13
+ spacy[cuda11x]
14
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
15
 
16
  # Note: Pillow, numpy, scipy, scikit-image required for AI-based post-processing
17
  # opencv-python required for StencilCV (traditional computer vision approach)