Allex21 commited on
Commit
9bbc908
·
verified ·
1 Parent(s): e3f6801

Create utils/editor.py

Browse files
Files changed (1) hide show
  1. utils/editor.py +258 -0
utils/editor.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/editor.py
2
+ import os
3
+ import io
4
+ import math
5
+ from typing import Tuple, Dict, Any
6
+ from PIL import Image, ImageOps
7
+ import numpy as np
8
+
9
+ import torch
10
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
11
+ from transformers import logging as hf_logging
12
+ hf_logging.set_verbosity_error()
13
+
14
+ # detector auxiliar para gerar mapa de pose OpenPose-like
15
+ from controlnet_aux import OpenposeDetector
16
+
17
+ # para remoção de fundo da peça (extrair RGBA)
18
+ from rembg import remove
19
+
20
+ # parâmetros padrão (você pode ajustar)
21
+ MODEL_ID = "runwayml/stable-diffusion-v1-5" # base SD v1.5
22
+ CONTROLNET_ID = "lllyasviel/sd-controlnet-openpose" # controlnet openpose
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # pipeline cache globals
26
+ _PIPELINE = None
27
+ _OP_DETECTOR = None
28
+
29
+ def get_openpose_detector():
30
+ global _OP_DETECTOR
31
+ if _OP_DETECTOR is None:
32
+ _OP_DETECTOR = OpenposeDetector()
33
+ return _OP_DETECTOR
34
+
35
+ def load_pipeline():
36
+ """
37
+ Carrega o pipeline ControlNet + Stable Diffusion (com half precision quando possível).
38
+ """
39
+ global _PIPELINE
40
+ if _PIPELINE is not None:
41
+ return _PIPELINE
42
+
43
+ # Carregar ControlNet
44
+ controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32)
45
+ # Carregar pipeline SD + ControlNet
46
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
47
+ MODEL_ID,
48
+ controlnet=controlnet,
49
+ safety_checker=None,
50
+ torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
51
+ )
52
+ # usar UniPC scheduler — melhora velocidade/qualidade
53
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
54
+ if DEVICE == "cuda":
55
+ pipe.enable_attention_slicing() # economiza VRAM
56
+ pipe.to("cuda")
57
+ else:
58
+ pipe.to("cpu")
59
+
60
+ # reduzir torch_autocast config handled later in inference
61
+ _PIPELINE = pipe
62
+ return _PIPELINE
63
+
64
+ def remove_background(pil_img: Image.Image) -> Image.Image:
65
+ """
66
+ Remove fundo da imagem da peça usando rembg (retorna RGBA com alpha).
67
+ """
68
+ # rembg expects bytes
69
+ img_bytes = io.BytesIO()
70
+ pil_img.convert("RGBA").save(img_bytes, format="PNG")
71
+ img_bytes = img_bytes.getvalue()
72
+ out = remove(img_bytes)
73
+ # out is bytes of PNG with alpha
74
+ out_img = Image.open(io.BytesIO(out)).convert("RGBA")
75
+ return out_img
76
+
77
+ def simple_align_garment_to_model(model_img: Image.Image, garment_rgba: Image.Image, pose_keypoints=None) -> Image.Image:
78
+ """
79
+ Faz um alinhamento simples: escala a peça pela distância entre ombros (estimada)
80
+ e cola-a sobre a modelo aproximadamente no torso. Retorna imagem RGBA (com a modelo).
81
+ Isso é só a iniciação — o SD+ControlNet fará o refinamento.
82
+ """
83
+ model = model_img.convert("RGBA")
84
+ g = garment_rgba
85
+
86
+ Wm, Hm = model.size
87
+ Wg, Hg = g.size
88
+
89
+ # fallback: centragem se não houver keypoints
90
+ if pose_keypoints is None:
91
+ # escala para metade da largura do modelo
92
+ target_w = int(Wm * 0.5)
93
+ scale = target_w / Wg
94
+ new_size = (max(1, int(Wg * scale)), max(1, int(Hg * scale)))
95
+ g_resized = g.resize(new_size, resample=Image.LANCZOS)
96
+ pos = ((Wm - new_size[0]) // 2, int(Hm * 0.28)) # 28% from top as rough torso position
97
+ canvas = model.copy()
98
+ canvas.paste(g_resized, pos, g_resized)
99
+ return canvas
100
+
101
+ # se houver keypoints, tentamos usar ombros para dimensionar
102
+ try:
103
+ # keypoints: dict with names->(x,y) in pixel coords (as returned below)
104
+ ls = pose_keypoints.get("left_shoulder")
105
+ rs = pose_keypoints.get("right_shoulder")
106
+ if ls and rs:
107
+ shoulder_dist = math.hypot(rs[0]-ls[0], rs[1]-ls[1])
108
+ # queremos que a peça cubra ~1.4x a largura dos ombros (ajustar conforme peça)
109
+ target_w = int(shoulder_dist * 1.4)
110
+ scale = max(0.1, target_w / Wg)
111
+ new_size = (max(1, int(Wg * scale)), max(1, int(Hg * scale)))
112
+ g_resized = g.resize(new_size, resample=Image.LANCZOS)
113
+ # center position between shoulders, and slightly below
114
+ center_x = int((ls[0] + rs[0]) / 2)
115
+ top_y = int((ls[1] + rs[1]) / 1.8) # move slightly up/down
116
+ pos = (max(0, center_x - new_size[0]//2), max(0, top_y - new_size[1]//6))
117
+ canvas = model.copy()
118
+ canvas.paste(g_resized, pos, g_resized)
119
+ return canvas
120
+ except Exception:
121
+ pass
122
+
123
+ # fallback
124
+ return simple_align_garment_to_model(model_img, garment_rgba, pose_keypoints=None)
125
+
126
+ def extract_pose_and_keypoints(model_img: Image.Image) -> Tuple[Image.Image, Dict[str, Tuple[int,int]]]:
127
+ """
128
+ Usa controlnet_aux.OpenposeDetector para gerar a pose map (imagem) e tenta retornar
129
+ keypoints úteis (ombros). keypoints dict = {"left_shoulder":(x,y), ...}
130
+ """
131
+ detector = get_openpose_detector()
132
+ # detect returns a PIL image of the pose map; but also returns 'keypoints' structure if requested
133
+ # controlnet_aux OpenposeDetector has method detect which returns images; to get keypoints we call detect_and_return_info
134
+ # We'll attempt to call 'detect' and fallback if not available
135
+ try:
136
+ detected = detector.detect(model_img)
137
+ # detector.detect returns a pose image (PIL)
138
+ pose_image = detected
139
+ # try to get keypoints via internal method if present (may vary by version)
140
+ try:
141
+ info = detector.get_pose(model_img) # some versions provide get_pose
142
+ # info parsing: try to find shoulders - adapt defensively
143
+ keypoints = {}
144
+ for person in info:
145
+ # each person: list of points or dict depending implementation
146
+ # attempt to parse common formats
147
+ if isinstance(person, dict):
148
+ if "left_shoulder" in person and "right_shoulder" in person:
149
+ keypoints["left_shoulder"] = tuple(person["left_shoulder"])
150
+ keypoints["right_shoulder"] = tuple(person["right_shoulder"])
151
+ break
152
+ elif isinstance(person, list) or isinstance(person, tuple):
153
+ # fallback: OpenPose ordering often uses indices:
154
+ # 2 = right shoulder, 5 = left shoulder OR vice-versa depending on lib.
155
+ # We'll try both orders defensively
156
+ try:
157
+ p2 = person[2]
158
+ p5 = person[5]
159
+ # p2/p5 are (x,y,confidence) or similar
160
+ keypoints["right_shoulder"] = (int(p2[0]), int(p2[1]))
161
+ keypoints["left_shoulder"] = (int(p5[0]), int(p5[1]))
162
+ break
163
+ except Exception:
164
+ continue
165
+ return pose_image.convert("RGB"), keypoints
166
+ except Exception:
167
+ # if we can't get structured keypoints, just return pose image and empty dict
168
+ return pose_image.convert("RGB"), {}
169
+ except Exception as e:
170
+ # last fallback: return blank pose (grayscale) and empty keypoints
171
+ blank = Image.new("RGB", model_img.size, (255,255,255))
172
+ return blank, {}
173
+
174
+ def run_pipeline(model_image: Image.Image, garment_image: Image.Image, prompt_extra: str = "") -> Tuple[Image.Image, Dict[str,Any]]:
175
+ """
176
+ Função principal que:
177
+ 1) extrai pose (pose_map)
178
+ 2) remove fundo da peça (garment) e alinha simplisticamente
179
+ 3) monta uma imagem inicial (init_image) com a peça sobre a modelo (RGBA)
180
+ 4) chama Stable Diffusion + ControlNet (image2image) usando pose_map como conditioning image
181
+ Retorna: pil_image_result, info_dict
182
+ """
183
+ # Convert PIL to consistent size (we'll resize to 768 on larger side to balance quality/VRAM)
184
+ max_side = 768
185
+ model_img = model_image.convert("RGB")
186
+ W, H = model_img.size
187
+ scale = max_side / max(W, H) if max(W, H) > max_side else 1.0
188
+ if scale != 1.0:
189
+ model_img = model_img.resize((int(W*scale), int(H*scale)), Image.LANCZOS)
190
+
191
+ # garment: remove background to get alpha
192
+ garment_rgba = remove_background(garment_image)
193
+
194
+ # get pose map and shoulder keypoints
195
+ pose_map, keypoints = extract_pose_and_keypoints(model_img)
196
+
197
+ # align garment roughly
198
+ init_composite = simple_align_garment_to_model(model_img, garment_rgba, pose_keypoints=keypoints)
199
+
200
+ # prepare pipeline and control image
201
+ pipe = load_pipeline()
202
+
203
+ # create prompt: combine prompt_extra with description of garment (basic default)
204
+ prompt = ("photo-realistic fashion try-on, ultra detailed, high resolution, realistic lighting. "
205
+ + (prompt_extra or "garment applied on person, preserve texture and zippers, realistic folds."))
206
+
207
+ # convert images to correct formats
208
+ init_image = init_composite.convert("RGB")
209
+ control_image = pose_map.convert("RGB")
210
+
211
+ # inference parameters (tune if OOM)
212
+ num_inference_steps = 20
213
+ guidance_scale = 7.5
214
+ strength = 0.75 # image2image strength (how much to change)
215
+
216
+ # Run in autocast for fp16 if GPU is available
217
+ generator = torch.Generator(device=DEVICE).manual_seed(torch.randint(0, 2**31 - 1, (1,)).item())
218
+
219
+ # Note: Some versions of diffusers expect 'image' and 'control_image' keyword arguments
220
+ # We'll call the pipeline defensively.
221
+ device = DEVICE
222
+ pipe.to(device)
223
+
224
+ try:
225
+ # The StableDiffusionControlNetPipeline supports image2image by passing 'image' and 'control_image'
226
+ with torch.autocast(device_type="cuda") if device == "cuda" else torch.cpu.amp.autocast(enabled=False):
227
+ out = pipe(
228
+ prompt=prompt,
229
+ image=init_image,
230
+ control_image=control_image,
231
+ num_inference_steps=num_inference_steps,
232
+ guidance_scale=guidance_scale,
233
+ strength=strength,
234
+ generator=generator
235
+ )
236
+ # out.images is a list
237
+ result_img = out.images[0]
238
+ except TypeError:
239
+ # Some diffusers versions use different signature; try alternate call
240
+ out = pipe(
241
+ prompt=prompt,
242
+ init_image=init_image,
243
+ controlnet_conditioning_image=control_image,
244
+ num_inference_steps=num_inference_steps,
245
+ guidance_scale=guidance_scale,
246
+ strength=strength,
247
+ generator=generator
248
+ )
249
+ result_img = out.images[0]
250
+
251
+ info = {
252
+ "model_id": MODEL_ID,
253
+ "controlnet_id": CONTROLNET_ID,
254
+ "steps": num_inference_steps,
255
+ "guidance_scale": guidance_scale,
256
+ "strength": strength
257
+ }
258
+ return result_img, info