Ammmob commited on
Commit
89ec264
·
1 Parent(s): b8d3068

Refactor Gradio demo structure and add inference step control

Browse files
.gitattributes CHANGED
File without changes
.gitignore CHANGED
File without changes
README.md CHANGED
File without changes
app.py CHANGED
@@ -1,285 +1,12 @@
1
- from pathlib import Path
2
 
3
- def patch_asyncio_cleanup_error() -> None:
4
- try:
5
- import asyncio.base_events as base_events
6
- except Exception:
7
- return
8
 
9
- original_del = getattr(base_events.BaseEventLoop, "__del__", None)
10
- if original_del is None:
11
- return
12
 
13
- def patched_del(self):
14
- try:
15
- original_del(self)
16
- except ValueError as exc:
17
- if "Invalid file descriptor" not in str(exc):
18
- raise
19
 
20
- base_events.BaseEventLoop.__del__ = patched_del
21
 
22
-
23
- patch_asyncio_cleanup_error()
24
-
25
- import gradio as gr
26
- import torch
27
- from huggingface_hub import hf_hub_download
28
- from PIL import Image
29
-
30
-
31
- def patch_qwen_diffusers_bug() -> None:
32
- import importlib.util
33
-
34
- spec = importlib.util.find_spec("diffusers")
35
- if spec is None or spec.origin is None:
36
- return
37
-
38
- target_file = (
39
- Path(spec.origin).resolve().parent
40
- / "pipelines"
41
- / "qwenimage"
42
- / "pipeline_qwenimage_edit_plus.py"
43
- )
44
- if not target_file.exists():
45
- return
46
-
47
- text = target_file.read_text(encoding="utf-8")
48
- match = "if prompt_embeds_mask is not None and prompt_embeds_mask.all()"
49
- if f"# {match}" in text:
50
- return
51
-
52
- lines = text.splitlines()
53
- for idx, line in enumerate(lines):
54
- if line.strip() == match:
55
- if not lines[idx].lstrip().startswith("#"):
56
- lines[idx] = f"# {lines[idx]}"
57
- if idx + 1 < len(lines) and not lines[idx + 1].lstrip().startswith("#"):
58
- lines[idx + 1] = f"# {lines[idx + 1]}"
59
- break
60
-
61
- target_file.write_text("\n".join(lines) + "\n", encoding="utf-8")
62
-
63
-
64
- patch_qwen_diffusers_bug()
65
-
66
- from diffusers import QwenImageEditPlusPipeline
67
-
68
- from pixelsmile.linear_conditioning import compute_text_embeddings
69
- from pixelsmile.utils.image import resize
70
-
71
-
72
- SUPPORTED_EXPRESSIONS = [
73
- "angry",
74
- "confused",
75
- "contempt",
76
- "confident",
77
- "disgust",
78
- "fear",
79
- "happy",
80
- "sad",
81
- "shy",
82
- "sleepy",
83
- "surprised",
84
- "anxious",
85
- ]
86
-
87
- DEFAULT_METHOD = "score_one_all"
88
- DEFAULT_INF_STEPS = 50
89
- DEFAULT_RESIZE_MODE = "crop"
90
- DEFAULT_WIDTH = 512
91
- DEFAULT_HEIGHT = 512
92
- DEFAULT_DATA_TYPE = "human"
93
- DEFAULT_SEED = 42
94
- DEFAULT_WEIGHT_VERSION = "preview"
95
-
96
- ROOT_DIR = Path(__file__).resolve().parent
97
- WEIGHTS_DIR = ROOT_DIR / "weights"
98
- BASE_MODEL_REPO = "Qwen/Qwen-Image-Edit-2511"
99
- PIXELSMILE_DIR = WEIGHTS_DIR / "PixelSmile"
100
- PIXELSMILE_REPO = "PixelSmile/PixelSmile"
101
- WEIGHT_FILES = {
102
- "preview": "PixelSmile-preview.safetensors",
103
- "stable": "PixelSmile-stable.safetensors",
104
- }
105
-
106
- PIPE = None
107
- PIPE_STATE = {"version": None, "device": None}
108
-
109
-
110
- def get_subject_name(data_type: str) -> str:
111
- if data_type == "human":
112
- return "person"
113
- if data_type == "anime":
114
- return "character"
115
- raise ValueError(f"Unsupported data_type: {data_type}")
116
-
117
-
118
- def build_edit_condition(subject: str, expression: str, scale: float) -> dict:
119
- return {
120
- "prompt": f"Edit the {subject} to show a {expression} expression",
121
- "prompt_neu": f"Edit the {subject} to show a neutral expression",
122
- "category": expression,
123
- "scores": {expression: scale},
124
- }
125
-
126
-
127
- def resolve_lora_path(weight_version: str) -> Path:
128
- if weight_version not in WEIGHT_FILES:
129
- raise ValueError(f"Unsupported weight version: {weight_version}")
130
- return PIXELSMILE_DIR / WEIGHT_FILES[weight_version]
131
-
132
-
133
- def ensure_lora_path(weight_version: str) -> Path:
134
- PIXELSMILE_DIR.mkdir(parents=True, exist_ok=True)
135
- lora_path = resolve_lora_path(weight_version)
136
- if lora_path.exists():
137
- return lora_path
138
-
139
- filename = WEIGHT_FILES[weight_version]
140
- downloaded_path = hf_hub_download(
141
- repo_id=PIXELSMILE_REPO,
142
- filename=filename,
143
- local_dir=str(PIXELSMILE_DIR),
144
- )
145
- return Path(downloaded_path)
146
-
147
-
148
- def get_device() -> torch.device:
149
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
150
-
151
-
152
- def load_pipe(weight_version: str) -> QwenImageEditPlusPipeline:
153
- global PIPE
154
-
155
- device = get_device()
156
- device_key = str(device)
157
- if PIPE is not None and PIPE_STATE["version"] == weight_version and PIPE_STATE["device"] == device_key:
158
- return PIPE
159
-
160
- lora_path = ensure_lora_path(weight_version)
161
- pipe = QwenImageEditPlusPipeline.from_pretrained(
162
- BASE_MODEL_REPO,
163
- torch_dtype=torch.bfloat16,
164
- cache_dir=str(WEIGHTS_DIR),
165
- )
166
- pipe.load_lora_weights(str(lora_path))
167
- pipe.to(device)
168
-
169
- PIPE = pipe
170
- PIPE_STATE["version"] = weight_version
171
- PIPE_STATE["device"] = device_key
172
- return PIPE
173
-
174
-
175
- def prepare_input_image(image: Image.Image) -> Image.Image:
176
- if image is None:
177
- raise gr.Error("Please upload an input image.")
178
- if not isinstance(image, Image.Image):
179
- image = Image.fromarray(image)
180
- image = image.convert("RGB")
181
- return resize(image, (DEFAULT_WIDTH, DEFAULT_HEIGHT), DEFAULT_RESIZE_MODE)
182
-
183
-
184
- def run_edit(
185
- image: Image.Image,
186
- expression: str,
187
- scale: float,
188
- data_type: str,
189
- seed: int,
190
- weight_version: str,
191
- ) -> Image.Image:
192
- subject = get_subject_name(data_type)
193
- pipe = load_pipe(weight_version)
194
- input_image = prepare_input_image(image)
195
- edit_condition = build_edit_condition(subject, expression, float(scale))
196
-
197
- prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
198
- method=DEFAULT_METHOD,
199
- pipeline=pipe,
200
- data=edit_condition,
201
- image=input_image,
202
- max_sequence_length=1024,
203
- )
204
-
205
- generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
206
- with torch.no_grad():
207
- output = pipe(
208
- image=input_image,
209
- prompt_embeds=prompt_embeds,
210
- prompt_embeds_mask=prompt_embeds_mask,
211
- num_inference_steps=DEFAULT_INF_STEPS,
212
- true_cfg_scale=0,
213
- output_type="pil",
214
- generator=generator,
215
- )
216
- return output.images[0]
217
-
218
-
219
- def run_demo(
220
- image: Image.Image,
221
- expression: str,
222
- scale: float,
223
- data_type: str,
224
- seed: int,
225
- weight_version: str,
226
- ):
227
- try:
228
- result = run_edit(
229
- image=image,
230
- expression=expression,
231
- scale=scale,
232
- data_type=data_type,
233
- seed=seed,
234
- weight_version=weight_version,
235
- )
236
- return result
237
- except Exception as exc:
238
- raise gr.Error(str(exc)) from exc
239
-
240
-
241
- with gr.Blocks(title="PixelSmile Demo") as demo:
242
- gr.Markdown("# PixelSmile Demo")
243
- gr.Markdown(
244
- "Fine-grained facial expression editing with Qwen-Image-Edit-2511 and PixelSmile weights."
245
- )
246
-
247
- with gr.Row():
248
- with gr.Column(scale=1):
249
- input_image = gr.Image(type="pil", label="Input Image")
250
- expression = gr.Dropdown(
251
- choices=SUPPORTED_EXPRESSIONS,
252
- value="happy",
253
- label="Target Expression",
254
- )
255
- scale = gr.Slider(
256
- minimum=0.0,
257
- maximum=1.5,
258
- step=0.1,
259
- value=0.8,
260
- label="Expression Strength",
261
- )
262
- data_type = gr.Radio(
263
- choices=["human", "anime"],
264
- value=DEFAULT_DATA_TYPE,
265
- label="Data Type",
266
- )
267
- weight_version = gr.Radio(
268
- choices=["preview", "stable"],
269
- value=DEFAULT_WEIGHT_VERSION,
270
- label="PixelSmile Weight Version",
271
- )
272
- seed = gr.Number(value=DEFAULT_SEED, precision=0, label="Seed")
273
- run_button = gr.Button("Run Inference", variant="primary")
274
-
275
- with gr.Column(scale=1):
276
- output_image = gr.Image(type="pil", label="Edited Image")
277
-
278
- run_button.click(
279
- fn=run_demo,
280
- inputs=[input_image, expression, scale, data_type, seed, weight_version],
281
- outputs=output_image,
282
- )
283
 
284
 
285
  if __name__ == "__main__":
 
1
+ from gradio_app.patches import apply_runtime_patches
2
 
 
 
 
 
 
3
 
4
+ apply_runtime_patches()
 
 
5
 
6
+ from gradio_app.demo import create_demo
 
 
 
 
 
7
 
 
8
 
9
+ demo = create_demo()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  if __name__ == "__main__":
gradio_app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Gradio application package for PixelSmile Space.
gradio_app/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+
4
+ SUPPORTED_EXPRESSIONS = [
5
+ "angry",
6
+ "confused",
7
+ "contempt",
8
+ "confident",
9
+ "disgust",
10
+ "fear",
11
+ "happy",
12
+ "sad",
13
+ "shy",
14
+ "sleepy",
15
+ "surprised",
16
+ "anxious",
17
+ ]
18
+
19
+ DEFAULT_METHOD = "score_one_all"
20
+ DEFAULT_INF_STEPS = 50
21
+ DEFAULT_RESIZE_MODE = "crop"
22
+ DEFAULT_WIDTH = 512
23
+ DEFAULT_HEIGHT = 512
24
+ DEFAULT_DATA_TYPE = "human"
25
+ DEFAULT_SEED = 42
26
+ DEFAULT_WEIGHT_VERSION = "preview"
27
+
28
+ ROOT_DIR = Path(__file__).resolve().parent.parent
29
+ WEIGHTS_DIR = ROOT_DIR / "weights"
30
+ BASE_MODEL_REPO = "Qwen/Qwen-Image-Edit-2511"
31
+ PIXELSMILE_DIR = WEIGHTS_DIR / "PixelSmile"
32
+ PIXELSMILE_REPO = "PixelSmile/PixelSmile"
33
+ WEIGHT_FILES = {
34
+ "preview": "PixelSmile-preview.safetensors",
35
+ "stable": "PixelSmile-stable.safetensors",
36
+ }
gradio_app/demo.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from gradio_app.config import (
5
+ DEFAULT_DATA_TYPE,
6
+ DEFAULT_INF_STEPS,
7
+ DEFAULT_SEED,
8
+ DEFAULT_WEIGHT_VERSION,
9
+ SUPPORTED_EXPRESSIONS,
10
+ )
11
+ from gradio_app.edit import run_edit
12
+ from gradio_app.pipeline import PRELOAD_STATE, start_preload
13
+
14
+
15
+ def run_demo(
16
+ image: Image.Image,
17
+ expression: str,
18
+ scale: float,
19
+ data_type: str,
20
+ seed: int,
21
+ weight_version: str,
22
+ num_inference_steps: int,
23
+ ):
24
+ try:
25
+ if PRELOAD_STATE["loading"]:
26
+ raise gr.Error(
27
+ "The model is still loading. Please wait for the startup preload to finish and try again."
28
+ )
29
+ if PRELOAD_STATE["error"] is not None:
30
+ raise gr.Error(f"Model preload failed: {PRELOAD_STATE['error']}")
31
+
32
+ return run_edit(
33
+ image=image,
34
+ expression=expression,
35
+ scale=scale,
36
+ data_type=data_type,
37
+ seed=seed,
38
+ weight_version=weight_version,
39
+ num_inference_steps=num_inference_steps,
40
+ )
41
+ except Exception as exc:
42
+ raise gr.Error(str(exc)) from exc
43
+
44
+
45
+ def create_demo() -> gr.Blocks:
46
+ with gr.Blocks(title="PixelSmile Demo") as demo:
47
+ gr.Markdown("# PixelSmile Demo")
48
+ gr.Markdown(
49
+ "Fine-grained facial expression editing with Qwen-Image-Edit-2511 and PixelSmile weights."
50
+ )
51
+
52
+ with gr.Row():
53
+ with gr.Column(scale=1):
54
+ input_image = gr.Image(type="pil", label="Input Image", height=420)
55
+ expression = gr.Dropdown(
56
+ choices=SUPPORTED_EXPRESSIONS,
57
+ value="happy",
58
+ label="Target Expression",
59
+ )
60
+ scale = gr.Slider(
61
+ minimum=0.0,
62
+ maximum=1.5,
63
+ step=0.1,
64
+ value=0.8,
65
+ label="Expression Strength",
66
+ )
67
+ data_type = gr.Dropdown(
68
+ choices=["human"],
69
+ value=DEFAULT_DATA_TYPE,
70
+ label="Data Type",
71
+ )
72
+ gr.Markdown("<span style='font-size: 12px;'>Anime editing support is coming soon.</span>")
73
+ weight_version = gr.Dropdown(
74
+ choices=["preview"],
75
+ value=DEFAULT_WEIGHT_VERSION,
76
+ label="PixelSmile Weight Version",
77
+ )
78
+ gr.Markdown("<span style='font-size: 12px;'>Stable weights are coming soon.</span>")
79
+ seed = gr.Number(value=DEFAULT_SEED, precision=0, label="Seed")
80
+ num_inference_steps = gr.Number(
81
+ value=DEFAULT_INF_STEPS,
82
+ precision=0,
83
+ label="Inference Steps",
84
+ )
85
+ run_button = gr.Button("Run Inference", variant="primary")
86
+
87
+ with gr.Column(scale=1):
88
+ output_image = gr.Image(type="pil", label="Edited Image", height=420)
89
+
90
+ run_button.click(
91
+ fn=run_demo,
92
+ inputs=[input_image, expression, scale, data_type, seed, weight_version, num_inference_steps],
93
+ outputs=output_image,
94
+ )
95
+
96
+ start_preload()
97
+ return demo
gradio_app/edit.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import gradio as gr
4
+
5
+ from gradio_app.config import (
6
+ DEFAULT_HEIGHT,
7
+ DEFAULT_INF_STEPS,
8
+ DEFAULT_METHOD,
9
+ DEFAULT_RESIZE_MODE,
10
+ DEFAULT_WIDTH,
11
+ )
12
+ from gradio_app.pipeline import load_lora
13
+ from pixelsmile.linear_conditioning import compute_text_embeddings
14
+ from pixelsmile.utils.image import resize
15
+
16
+
17
+ def get_subject_name(data_type: str) -> str:
18
+ if data_type == "human":
19
+ return "person"
20
+ if data_type == "anime":
21
+ return "character"
22
+ raise ValueError(f"Unsupported data_type: {data_type}")
23
+
24
+
25
+ def build_edit_condition(subject: str, expression: str, scale: float) -> dict:
26
+ return {
27
+ "prompt": f"Edit the {subject} to show a {expression} expression",
28
+ "prompt_neu": f"Edit the {subject} to show a neutral expression",
29
+ "category": expression,
30
+ "scores": {expression: scale},
31
+ }
32
+
33
+
34
+ def prepare_input_image(image: Image.Image) -> Image.Image:
35
+ if image is None:
36
+ raise gr.Error("Please upload an input image.")
37
+ if not isinstance(image, Image.Image):
38
+ image = Image.fromarray(image)
39
+ image = image.convert("RGB")
40
+ return resize(image, (DEFAULT_WIDTH, DEFAULT_HEIGHT), DEFAULT_RESIZE_MODE)
41
+
42
+
43
+ def run_edit(
44
+ image: Image.Image,
45
+ expression: str,
46
+ scale: float,
47
+ data_type: str,
48
+ seed: int,
49
+ weight_version: str,
50
+ num_inference_steps: int,
51
+ ) -> Image.Image:
52
+ subject = get_subject_name(data_type)
53
+ pipe = load_lora(weight_version)
54
+ input_image = prepare_input_image(image)
55
+ edit_condition = build_edit_condition(subject, expression, float(scale))
56
+
57
+ prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
58
+ method=DEFAULT_METHOD,
59
+ pipeline=pipe,
60
+ data=edit_condition,
61
+ image=input_image,
62
+ max_sequence_length=1024,
63
+ )
64
+
65
+ generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
66
+ with torch.no_grad():
67
+ output = pipe(
68
+ image=input_image,
69
+ prompt_embeds=prompt_embeds,
70
+ prompt_embeds_mask=prompt_embeds_mask,
71
+ num_inference_steps=int(num_inference_steps),
72
+ true_cfg_scale=0,
73
+ output_type="pil",
74
+ generator=generator,
75
+ )
76
+ return output.images[0]
gradio_app/patches.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+
5
+ def patch_asyncio_cleanup_error() -> None:
6
+ try:
7
+ import asyncio.base_events as base_events
8
+ except Exception:
9
+ return
10
+
11
+ original_del = getattr(base_events.BaseEventLoop, "__del__", None)
12
+ if original_del is None:
13
+ return
14
+
15
+ def patched_del(self):
16
+ try:
17
+ original_del(self)
18
+ except ValueError as exc:
19
+ if "Invalid file descriptor" not in str(exc):
20
+ raise
21
+
22
+ base_events.BaseEventLoop.__del__ = patched_del
23
+
24
+
25
+ def configure_hf_download_env() -> None:
26
+ os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "1800")
27
+ os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "1800")
28
+
29
+
30
+ def patch_qwen_diffusers_bug() -> None:
31
+ import importlib.util
32
+
33
+ spec = importlib.util.find_spec("diffusers")
34
+ if spec is None or spec.origin is None:
35
+ return
36
+
37
+ target_file = (
38
+ Path(spec.origin).resolve().parent
39
+ / "pipelines"
40
+ / "qwenimage"
41
+ / "pipeline_qwenimage_edit_plus.py"
42
+ )
43
+ if not target_file.exists():
44
+ return
45
+
46
+ text = target_file.read_text(encoding="utf-8")
47
+ match = "if prompt_embeds_mask is not None and prompt_embeds_mask.all()"
48
+ if f"# {match}" in text:
49
+ return
50
+
51
+ lines = text.splitlines()
52
+ for idx, line in enumerate(lines):
53
+ if line.strip() == match:
54
+ if not lines[idx].lstrip().startswith("#"):
55
+ lines[idx] = f"# {lines[idx]}"
56
+ if idx + 1 < len(lines) and not lines[idx + 1].lstrip().startswith("#"):
57
+ lines[idx + 1] = f"# {lines[idx + 1]}"
58
+ break
59
+
60
+ target_file.write_text("\n".join(lines) + "\n", encoding="utf-8")
61
+
62
+
63
+ def apply_runtime_patches() -> None:
64
+ patch_asyncio_cleanup_error()
65
+ configure_hf_download_env()
66
+ patch_qwen_diffusers_bug()
gradio_app/pipeline.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from diffusers import QwenImageEditPlusPipeline
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
+
8
+ from gradio_app.config import (
9
+ BASE_MODEL_REPO,
10
+ DEFAULT_WEIGHT_VERSION,
11
+ PIXELSMILE_DIR,
12
+ PIXELSMILE_REPO,
13
+ WEIGHTS_DIR,
14
+ WEIGHT_FILES,
15
+ )
16
+
17
+
18
+ PIPE = None
19
+ PIPE_STATE = {"version": None, "device": None}
20
+ PRELOAD_STATE = {"loading": False, "ready": False, "error": None}
21
+
22
+
23
+ def resolve_lora_path(weight_version: str) -> Path:
24
+ if weight_version not in WEIGHT_FILES:
25
+ raise ValueError(f"Unsupported weight version: {weight_version}")
26
+ return PIXELSMILE_DIR / WEIGHT_FILES[weight_version]
27
+
28
+
29
+ def ensure_lora_path(weight_version: str) -> Path:
30
+ PIXELSMILE_DIR.mkdir(parents=True, exist_ok=True)
31
+ lora_path = resolve_lora_path(weight_version)
32
+ if lora_path.exists():
33
+ return lora_path
34
+
35
+ filename = WEIGHT_FILES[weight_version]
36
+ downloaded_path = hf_hub_download(
37
+ repo_id=PIXELSMILE_REPO,
38
+ filename=filename,
39
+ local_dir=str(PIXELSMILE_DIR),
40
+ )
41
+ return Path(downloaded_path)
42
+
43
+
44
+ def get_device() -> torch.device:
45
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
+
47
+
48
+ def load_pipe() -> QwenImageEditPlusPipeline:
49
+ global PIPE
50
+
51
+ device = get_device()
52
+ device_key = str(device)
53
+ if PIPE is not None and PIPE_STATE["device"] == device_key:
54
+ return PIPE
55
+
56
+ try:
57
+ model_path = snapshot_download(
58
+ repo_id=BASE_MODEL_REPO,
59
+ cache_dir=str(WEIGHTS_DIR),
60
+ resume_download=True,
61
+ )
62
+ except Exception:
63
+ model_path = snapshot_download(
64
+ repo_id=BASE_MODEL_REPO,
65
+ cache_dir=str(WEIGHTS_DIR),
66
+ local_files_only=True,
67
+ )
68
+
69
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
70
+ model_path,
71
+ torch_dtype=torch.bfloat16,
72
+ cache_dir=str(WEIGHTS_DIR),
73
+ )
74
+ pipe.to(device)
75
+
76
+ PIPE = pipe
77
+ PIPE_STATE["version"] = None
78
+ PIPE_STATE["device"] = device_key
79
+ return PIPE
80
+
81
+
82
+ def load_lora(weight_version: str) -> QwenImageEditPlusPipeline:
83
+ pipe = load_pipe()
84
+ device_key = str(get_device())
85
+ if PIPE_STATE["version"] == weight_version and PIPE_STATE["device"] == device_key:
86
+ return pipe
87
+
88
+ lora_path = ensure_lora_path(weight_version)
89
+ try:
90
+ pipe.unload_lora_weights()
91
+ except AttributeError:
92
+ pass
93
+ pipe.load_lora_weights(str(lora_path))
94
+ PIPE_STATE["version"] = weight_version
95
+ return pipe
96
+
97
+
98
+ def preload_default_pipe() -> None:
99
+ try:
100
+ PRELOAD_STATE["loading"] = True
101
+ PRELOAD_STATE["ready"] = False
102
+ PRELOAD_STATE["error"] = None
103
+ load_lora(DEFAULT_WEIGHT_VERSION)
104
+ PRELOAD_STATE["ready"] = True
105
+ except Exception as exc:
106
+ PRELOAD_STATE["error"] = str(exc)
107
+ print(f"[WARN] Failed to preload PixelSmile pipeline: {exc}")
108
+ finally:
109
+ PRELOAD_STATE["loading"] = False
110
+
111
+
112
+ def start_preload() -> None:
113
+ threading.Thread(target=preload_default_pipe, daemon=True).start()
pixelsmile/__init__.py CHANGED
@@ -1 +1 @@
1
- # PixelSmile demo package.
 
1
+ # Shared PixelSmile demo core package.
pixelsmile/utils/__init__.py CHANGED
File without changes
requirements.txt CHANGED
File without changes
weights/.gitkeep CHANGED
File without changes