JustForWorld commited on
Commit
98f1dc6
·
1 Parent(s): fd2ee05

feat: implement lazy loading for YOLO and Stable Diffusion models to speed up startup time

Browse files
Files changed (1) hide show
  1. logic.py +113 -65
logic.py CHANGED
@@ -1,6 +1,5 @@
1
  from ultralytics import YOLO
2
  import cv2
3
- import os
4
  import numpy as np
5
  from PIL import Image, ImageDraw
6
  import torch
@@ -8,41 +7,52 @@ from loguru import logger
8
  import time
9
  from diffusers import AutoPipelineForInpainting
10
 
11
- # ===================================================================
12
- # Класс WatermarkRemover
13
- # ===================================================================
14
  class WatermarkRemover:
15
  def __init__(self, device=None):
16
- # 👇 Автоматический выбор GPU, если доступен
17
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
18
- logger.info(f"Используемое устройство: {self.device}")
19
-
20
- # ---------------------------------------------------------------
21
- # Загрузка кастомной модели YOLOv8
22
- # ---------------------------------------------------------------
23
- logger.info("Загрузка кастомной модели YOLOv8 ('best.pt')...")
24
- self.detector = YOLO("best.pt")
25
- self.detector.to(self.device)
26
- self.detector.fuse() # 👈 ускоряет инференс YOLO
27
- logger.info("Кастомная модель YOLOv8 успешно загружена.")
28
-
29
- # ---------------------------------------------------------------
30
- # Загрузка модели Stable Diffusion 2 Inpainting
31
- # ---------------------------------------------------------------
32
- logger.info("Загрузка модели Stable Diffusion 2 Inpainting...")
33
- self.inpainting_pipe = AutoPipelineForInpainting.from_pretrained(
34
- "stabilityai/stable-diffusion-2-inpainting",
35
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, # 👈 экономия VRAM
36
- safety_checker=None, # 👈 не нужен для локального inpainting
37
- )
38
- self.inpainting_pipe = self.inpainting_pipe.to(self.device)
39
- self.inpainting_pipe.enable_attention_slicing() # 👈 снижает пиковое использование VRAM
40
- logger.info("Модель Stable Diffusion 2 Inpainting успешно загружена.")
41
-
42
- # ===================================================================
43
- # Генерация маски с помощью YOLO
44
- # ===================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def _get_mask_yolo(self, image: Image.Image) -> Image.Image:
 
 
46
  img_np = np.array(image.convert("RGB"))
47
  results = self.detector.predict(img_np, conf=0.25, imgsz=864, device=self.device)
48
  mask = Image.new("L", image.size, 0)
@@ -50,17 +60,19 @@ class WatermarkRemover:
50
  if results and len(results[0].boxes) > 0:
51
  draw = ImageDraw.Draw(mask)
52
  boxes = results[0].boxes.xyxy.cpu().numpy()
53
- logger.info(f"Кастомная модель YOLO нашла {len(boxes)} bbox.")
54
  for bbox in boxes:
55
  draw.rectangle(list(bbox), fill=255)
56
  else:
57
- logger.warning("Кастомная модель YOLO не нашла watermark.")
58
  return mask
59
 
60
- # ===================================================================
61
- # Инпейнтинг изображения с помощью diffusers
62
- # ===================================================================
63
- def _inpaint_image(self, image: Image.Image, mask: Image.Image) -> np.ndarray:
 
 
64
  prompt = (
65
  "ultra realistic photo of interior or exterior architecture, "
66
  "natural lighting, clean surface, consistent material texture, realistic color balance"
@@ -70,58 +82,94 @@ class WatermarkRemover:
70
  "painting, mirror artifact, blurry, distorted, deformed, low quality, noise, grain"
71
  )
72
 
73
- logger.info("Запуск Stable Diffusion Inpainting с 30 шагами...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # --- 🔹 Сохраняем оригинальный размер
76
- orig_size = image.size # (width, height)
 
77
 
78
- # --- 🔹 Resize до кратного 8 (иначе модель может ругаться)
79
- new_w = (orig_size[0] // 8) * 8
80
- new_h = (orig_size[1] // 8) * 8
81
- resized_image = image.resize((new_w, new_h), Image.LANCZOS)
82
- resized_mask = mask.resize((new_w, new_h), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # --- 🔹 Инференс
85
  with torch.inference_mode():
86
  result = self.inpainting_pipe(
87
  prompt=prompt,
88
  negative_prompt=negative_prompt,
89
- image=resized_image,
90
  mask_image=resized_mask,
91
- num_inference_steps=30,
92
- guidance_scale=7.5,
93
  ).images[0]
94
 
95
- # --- 🔹 Возвращаем к оригинальному размеру
96
- result = result.resize(orig_size, Image.LANCZOS)
 
 
 
 
 
 
 
97
 
98
- return np.array(result)
 
99
 
100
- # ===================================================================
101
- # Основной процесс
102
- # ===================================================================
 
 
103
  def run(self, image: Image.Image) -> Image.Image:
104
  start_time = time.time()
105
- logger.info("Начало процесса удаления вотермарок (YOLOv8 + Stable Diffusion)...")
106
 
107
  mask_image = self._get_mask_yolo(image)
108
  mask_np = np.array(mask_image)
109
 
110
  if not np.any(mask_np):
111
- logger.info("Вотермарки не найдены. Возвращаем оригинальное изображение.")
112
  return image
113
 
114
- logger.info("Постобработка маски...")
115
  kernel = np.ones((15, 15), np.uint8)
116
  closed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel)
117
  final_kernel = np.ones((7, 7), np.uint8)
118
  processed_mask_np = cv2.dilate(closed_mask, final_kernel, iterations=1)
119
  processed_mask_pil = Image.fromarray(processed_mask_np)
120
- logger.success("Маска обработана.")
121
-
122
- logger.info("Закрашивание области с помощью Stable Diffusion...")
123
- result_np_rgb = self._inpaint_image(image, processed_mask_pil)
124
 
 
125
  end_time = time.time()
126
- logger.success(f"Удаление watermark завершено за {end_time - start_time:.2f} сек.")
127
- return Image.fromarray(result_np_rgb)
 
1
  from ultralytics import YOLO
2
  import cv2
 
3
  import numpy as np
4
  from PIL import Image, ImageDraw
5
  import torch
 
7
  import time
8
  from diffusers import AutoPipelineForInpainting
9
 
 
 
 
10
  class WatermarkRemover:
11
  def __init__(self, device=None):
 
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ logger.info(f"Using device: {self.device}")
14
+
15
+ # Lazy-loaded models
16
+ self.detector = None
17
+ self.inpainting_pipe = None
18
+
19
+ # ======================================================
20
+ # Lazy-load YOLO
21
+ # ======================================================
22
+ def _load_detector(self):
23
+ if self.detector is None:
24
+ logger.info("Loading YOLOv8 custom model ('best.pt')...")
25
+ self.detector = YOLO("best.pt")
26
+ self.detector.to(self.device)
27
+ try:
28
+ self.detector.fuse()
29
+ except Exception:
30
+ pass
31
+ logger.success("YOLOv8 model loaded successfully.")
32
+
33
+ # ======================================================
34
+ # Lazy-load Stable Diffusion
35
+ # ======================================================
36
+ def _load_inpainting_model(self):
37
+ if self.inpainting_pipe is None:
38
+ logger.info("Loading Stable Diffusion 2 Inpainting...")
39
+ self.inpainting_pipe = AutoPipelineForInpainting.from_pretrained(
40
+ "stabilityai/stable-diffusion-2-inpainting",
41
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
42
+ safety_checker=None,
43
+ ).to(self.device)
44
+ try:
45
+ self.inpainting_pipe.enable_attention_slicing()
46
+ except Exception:
47
+ pass
48
+ logger.success("Stable Diffusion 2 Inpainting model loaded successfully.")
49
+
50
+ # ======================================================
51
+ # Mask generation via YOLO
52
+ # ======================================================
53
  def _get_mask_yolo(self, image: Image.Image) -> Image.Image:
54
+ self._load_detector() # ensure YOLO loaded
55
+
56
  img_np = np.array(image.convert("RGB"))
57
  results = self.detector.predict(img_np, conf=0.25, imgsz=864, device=self.device)
58
  mask = Image.new("L", image.size, 0)
 
60
  if results and len(results[0].boxes) > 0:
61
  draw = ImageDraw.Draw(mask)
62
  boxes = results[0].boxes.xyxy.cpu().numpy()
63
+ logger.info(f"YOLO found {len(boxes)} watermark box(es).")
64
  for bbox in boxes:
65
  draw.rectangle(list(bbox), fill=255)
66
  else:
67
+ logger.warning("No watermark detected.")
68
  return mask
69
 
70
+ # ======================================================
71
+ # Partial inpainting
72
+ # ======================================================
73
+ def _inpaint_image(self, image: Image.Image, mask: Image.Image) -> Image.Image:
74
+ self._load_inpainting_model() # ensure pipeline loaded
75
+
76
  prompt = (
77
  "ultra realistic photo of interior or exterior architecture, "
78
  "natural lighting, clean surface, consistent material texture, realistic color balance"
 
82
  "painting, mirror artifact, blurry, distorted, deformed, low quality, noise, grain"
83
  )
84
 
85
+ logger.info("Running partial Stable Diffusion inpainting...")
86
+
87
+ orig_w, orig_h = image.size
88
+ mask_np = np.array(mask)
89
+ ys, xs = np.where(mask_np > 0)
90
+
91
+ if xs.size == 0 or ys.size == 0:
92
+ logger.info("Mask empty — skipping inpainting.")
93
+ return image
94
+
95
+ pad = max(48, int(min(orig_w, orig_h) * 0.03))
96
+ x_min = max(int(xs.min()) - pad, 0)
97
+ x_max = min(int(xs.max()) + pad, orig_w)
98
+ y_min = max(int(ys.min()) - pad, 0)
99
+ y_max = min(int(ys.max()) + pad, orig_h)
100
 
101
+ crop_box = (x_min, y_min, x_max, y_max)
102
+ crop_img = image.crop(crop_box)
103
+ crop_mask = mask.crop(crop_box)
104
 
105
+ crop_w, crop_h = crop_img.size
106
+ max_side = 1024
107
+ scale = 1.0
108
+ if max(crop_w, crop_h) > max_side:
109
+ scale = max_side / max(crop_w, crop_h)
110
+
111
+ new_w = int(np.ceil((crop_w * scale) / 8) * 8)
112
+ new_h = int(np.ceil((crop_h * scale) / 8) * 8)
113
+
114
+ if (new_w, new_h) != (crop_w, crop_h):
115
+ resized_img = crop_img.resize((new_w, new_h), resample=Image.LANCZOS)
116
+ resized_mask = crop_mask.resize((new_w, new_h), resample=Image.LANCZOS)
117
+ else:
118
+ resized_img, resized_mask = crop_img, crop_mask
119
+
120
+ resized_mask = resized_mask.convert("L")
121
+ mask_thr = np.array(resized_mask)
122
+ mask_thr = (mask_thr > 127).astype(np.uint8) * 255
123
+ resized_mask = Image.fromarray(mask_thr, mode="L")
124
 
 
125
  with torch.inference_mode():
126
  result = self.inpainting_pipe(
127
  prompt=prompt,
128
  negative_prompt=negative_prompt,
129
+ image=resized_img,
130
  mask_image=resized_mask,
131
+ num_inference_steps=35,
132
+ guidance_scale=8.0,
133
  ).images[0]
134
 
135
+ if result.size != crop_img.size:
136
+ result_resized = result.resize(crop_img.size, resample=Image.LANCZOS)
137
+ else:
138
+ result_resized = result
139
+
140
+ base = image.copy()
141
+ paste_mask = crop_mask.convert("L")
142
+ paste_mask = Image.fromarray((np.array(paste_mask) > 127).astype(np.uint8) * 255, mode="L")
143
+ base.paste(result_resized, (x_min, y_min), mask=paste_mask)
144
 
145
+ if self.device == "cuda":
146
+ torch.cuda.empty_cache()
147
 
148
+ return base
149
+
150
+ # ======================================================
151
+ # Main process
152
+ # ======================================================
153
  def run(self, image: Image.Image) -> Image.Image:
154
  start_time = time.time()
155
+ logger.info("Starting watermark removal...")
156
 
157
  mask_image = self._get_mask_yolo(image)
158
  mask_np = np.array(mask_image)
159
 
160
  if not np.any(mask_np):
161
+ logger.info("No watermark found. Returning original image.")
162
  return image
163
 
164
+ logger.info("Post-processing mask (morphology)...")
165
  kernel = np.ones((15, 15), np.uint8)
166
  closed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel)
167
  final_kernel = np.ones((7, 7), np.uint8)
168
  processed_mask_np = cv2.dilate(closed_mask, final_kernel, iterations=1)
169
  processed_mask_pil = Image.fromarray(processed_mask_np)
170
+ logger.success("Mask processed.")
 
 
 
171
 
172
+ result_img = self._inpaint_image(image, processed_mask_pil)
173
  end_time = time.time()
174
+ logger.success(f"Watermark removal completed in {end_time - start_time:.2f}s.")
175
+ return result_img