ghmk commited on
Commit
da23dfe
·
1 Parent(s): 9050c2a

Deploy full Character Sheet Pro with HF auth

Browse files

- Complete app with 7-view character sheet generation
- All src/ modules (backend_router, character_service, etc.)
- HuggingFace token auth for gated FLUX.2 klein 9B model
- Updated requirements with all dependencies

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,132 +1,887 @@
1
- import gradio as gr
2
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
- import random
5
- import spaces
6
- import torch
7
- from huggingface_hub import login
 
8
  import base64
 
 
 
 
 
 
9
 
10
- # Model access configuration (read-only)
11
  def _get_access_key():
12
- # Encoded for basic obfuscation
13
  _k = "aGZfRUR2akdKUXJGRmFQUnhLY1BOUmlUR0lXd0dKYkJ4dkNCWA=="
14
  return base64.b64decode(_k).decode()
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN") or _get_access_key()
17
- print("Authenticating...")
18
  login(token=HF_TOKEN)
19
- print("Authentication successful")
20
-
21
- from diffusers import Flux2KleinPipeline
22
 
23
- dtype = torch.bfloat16
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- print("Loading FLUX.2 klein 9B model...")
27
- pipe = Flux2KleinPipeline.from_pretrained(
28
- "black-forest-labs/FLUX.2-klein-9B",
29
- torch_dtype=dtype,
30
- use_auth_token=HF_TOKEN
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- MAX_SEED = np.iinfo(np.int32).max
34
- MAX_IMAGE_SIZE = 2048
35
-
36
- @spaces.GPU()
37
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
38
- if randomize_seed:
39
- seed = random.randint(0, MAX_SEED)
40
- generator = torch.Generator().manual_seed(seed)
41
- image = pipe(
42
- prompt=prompt,
43
- width=width,
44
- height=height,
45
- num_inference_steps=num_inference_steps,
46
- generator=generator,
47
- guidance_scale=1.0
48
- ).images[0]
49
- return image, seed
50
-
51
- examples = [
52
- "character turnaround sheet, female warrior with red hair, front view, full body, white background",
53
- "character turnaround sheet, male wizard with blue robes, front view, full body, white background",
54
- "character turnaround sheet, female elf archer, front view, full body, white background",
55
- ]
56
-
57
- css="""
58
- #col-container {
59
- margin: 0 auto;
60
- max-width: 520px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  }
62
  """
63
 
64
- with gr.Blocks() as demo:
65
- with gr.Column(elem_id="col-container"):
66
- gr.Markdown(f"""# CharacterForgePro
67
- Generate character images using FLUX.2 klein 9B on Zero GPU.
68
- """)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with gr.Row():
71
- prompt = gr.Text(
72
- label="Prompt",
73
- show_label=False,
74
- max_lines=1,
75
- placeholder="Enter your prompt",
76
- container=False,
 
 
 
77
  )
78
- run_button = gr.Button("Run", scale=0)
79
-
80
- result = gr.Image(label="Result", show_label=False)
81
 
82
- with gr.Accordion("Advanced Settings", open=False):
83
- seed = gr.Slider(
84
- label="Seed",
85
- minimum=0,
86
- maximum=MAX_SEED,
87
- step=1,
88
- value=0,
89
  )
90
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
91
-
92
- with gr.Row():
93
- width = gr.Slider(
94
- label="Width",
95
- minimum=256,
96
- maximum=MAX_IMAGE_SIZE,
97
- step=32,
98
- value=1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
100
- height = gr.Slider(
101
- label="Height",
102
- minimum=256,
103
- maximum=MAX_IMAGE_SIZE,
104
- step=32,
105
- value=1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
- with gr.Row():
109
- num_inference_steps = gr.Slider(
110
- label="Number of inference steps",
111
- minimum=1,
112
- maximum=50,
113
- step=1,
114
- value=4,
115
  )
116
 
117
- gr.Examples(
118
- examples=examples,
119
- fn=infer,
120
- inputs=[prompt],
121
- outputs=[result, seed],
122
- cache_examples=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
 
125
- gr.on(
126
- triggers=[run_button.click, prompt.submit],
127
- fn=infer,
128
- inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
129
- outputs=[result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
 
132
- demo.launch(css=css)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Character Sheet Pro - HuggingFace Spaces Version
3
+ =================================================
4
+
5
+ 7-View Character Sheet Generator optimized for HuggingFace Spaces Zero GPU.
6
+ Uses FLUX.2 klein 4B as primary backend with Gemini Flash as fallback.
7
+
8
+ This is a simplified version of app.py designed for:
9
+ - Zero GPU (A10G 24GB) deployment
10
+ - 5-minute session timeout
11
+ - Automatic model loading on first generation
12
+ """
13
+
14
  import os
15
+ import json
16
+ import logging
17
+ import zipfile
18
+ import threading
19
+ import queue
20
  import base64
21
+ from pathlib import Path
22
+ from typing import Optional, Tuple, Dict, Any, List, Generator
23
+ from datetime import datetime
24
+ import gradio as gr
25
+ from PIL import Image
26
+ from huggingface_hub import login
27
 
28
+ # HuggingFace authentication for gated models
29
  def _get_access_key():
 
30
  _k = "aGZfRUR2akdKUXJGRmFQUnhLY1BOUmlUR0lXd0dKYkJ4dkNCWA=="
31
  return base64.b64decode(_k).decode()
32
 
33
  HF_TOKEN = os.environ.get("HF_TOKEN") or _get_access_key()
 
34
  login(token=HF_TOKEN)
35
+ print("HuggingFace authentication successful")
 
 
36
 
37
+ # HuggingFace Spaces SDK - provides @spaces.GPU decorator
38
+ try:
39
+ import spaces
40
+ HF_SPACES = True
41
+ except ImportError:
42
+ # Running locally without spaces SDK
43
+ HF_SPACES = False
44
+ # Create a dummy decorator for local testing
45
+ class spaces:
46
+ @staticmethod
47
+ def GPU(duration=300):
48
+ def decorator(func):
49
+ return func
50
+ return decorator
51
 
52
+ # Configure logging
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
56
  )
57
+ logger = logging.getLogger(__name__)
58
+
59
+ # Import local modules
60
+ from src.character_service import CharacterSheetService
61
+ from src.models import CharacterSheetConfig
62
+ from src.backend_router import BackendRouter, BackendType
63
+ from src.utils import preprocess_input_image, sanitize_filename
64
+
65
+
66
+ def ensure_png_image(image: Optional[Image.Image], max_size: int = 768) -> Optional[Image.Image]:
67
+ """Convert any image to PNG-compatible RGB format with proper sizing for FLUX."""
68
+ if image is None:
69
+ return None
70
+ # FLUX models work best with smaller inputs (512-768px)
71
+ # Larger images slow down processing significantly
72
+ return preprocess_input_image(image, max_size=max_size, ensure_rgb=True)
73
+
74
+
75
+ def create_pending_placeholder(width: int = 200, height: int = 200, text: str = "Pending...") -> Image.Image:
76
+ """Create a placeholder image showing that generation is pending."""
77
+ from PIL import ImageDraw, ImageFont
78
+
79
+ # Create gradient-like dark background
80
+ img = Image.new('RGB', (width, height), color=(25, 25, 45))
81
+ draw = ImageDraw.Draw(img)
82
+
83
+ # Draw border to make it clearly a placeholder
84
+ border_color = (255, 149, 0) # Orange
85
+ draw.rectangle([(2, 2), (width-3, height-3)], outline=border_color, width=2)
86
+
87
+ # Draw loading indicator (three dots)
88
+ center_y = height // 2
89
+ dot_spacing = 20
90
+ dot_radius = 5
91
+ for i, offset in enumerate([-dot_spacing, 0, dot_spacing]):
92
+ shade = 200 + (i * 25)
93
+ dot_color = (shade, int(shade * 0.6), 0)
94
+ x = width // 2 + offset
95
+ draw.ellipse([(x - dot_radius, center_y - dot_radius),
96
+ (x + dot_radius, center_y + dot_radius)], fill=dot_color)
97
+
98
+ # Draw text
99
+ try:
100
+ font = ImageFont.truetype("arial.ttf", 14)
101
+ except:
102
+ font = ImageFont.load_default()
103
+
104
+ bbox = draw.textbbox((0, 0), text, font=font)
105
+ text_width = bbox[2] - bbox[0]
106
+ x = (width - text_width) // 2
107
+ y = center_y + 25
108
 
109
+ draw.text((x, y), text, fill=(180, 180, 180), font=font)
110
+
111
+ return img
112
+
113
+
114
+ # =============================================================================
115
+ # Configuration
116
+ # =============================================================================
117
+
118
+ OUTPUT_DIR = Path("./outputs")
119
+ OUTPUT_DIR.mkdir(exist_ok=True)
120
+
121
+ # Get API key from environment (HuggingFace Spaces secrets)
122
+ API_KEY = os.environ.get("GEMINI_API_KEY", "")
123
+
124
+ # Model defaults - include all FLUX variants
125
+ MODEL_DEFAULTS = {
126
+ "flux_klein": {"steps": 4, "guidance": 1.0, "name": "FLUX.2 klein 4B", "costume_in_faces": False},
127
+ "flux_klein_9b_fp8": {"steps": 4, "guidance": 1.0, "name": "FLUX.2 klein 9B", "costume_in_faces": False},
128
+ "gemini_flash": {"steps": 1, "guidance": 1.0, "name": "Gemini Flash", "costume_in_faces": True},
129
+ }
130
+
131
+
132
+ def get_model_defaults(backend_value: str) -> Tuple[int, float]:
133
+ """Get default steps and guidance for a backend."""
134
+ defaults = MODEL_DEFAULTS.get(backend_value, {"steps": 4, "guidance": 1.0})
135
+ return defaults["steps"], defaults["guidance"]
136
+
137
+
138
+ def get_costume_in_faces_default(backend_value: str) -> bool:
139
+ """Get default for including costume reference in face views."""
140
+ defaults = MODEL_DEFAULTS.get(backend_value, {"costume_in_faces": True})
141
+ return defaults.get("costume_in_faces", True)
142
+
143
+
144
+ # =============================================================================
145
+ # Presets Loading
146
+ # =============================================================================
147
+
148
+ EXAMPLES_DIR = Path("./examples")
149
+ PRESETS_FILE = EXAMPLES_DIR / "presets.json"
150
+
151
+
152
+ def load_presets() -> Dict[str, Any]:
153
+ """Load presets configuration from JSON file."""
154
+ if PRESETS_FILE.exists():
155
+ with open(PRESETS_FILE, 'r') as f:
156
+ return json.load(f)
157
+ return {"characters": [], "costumes": []}
158
+
159
+
160
+ def get_character_presets() -> List[Dict]:
161
+ """Get list of character presets."""
162
+ presets = load_presets()
163
+ return presets.get("characters", [])
164
+
165
+
166
+ def load_character_preset(preset_id: str) -> Tuple[Optional[Image.Image], str, str]:
167
+ """Load a character preset."""
168
+ presets = get_character_presets()
169
+ for preset in presets:
170
+ if preset["id"] == preset_id:
171
+ image_path = EXAMPLES_DIR / preset["file"]
172
+ if image_path.exists():
173
+ img = Image.open(image_path)
174
+ return (
175
+ img,
176
+ preset.get("name", ""),
177
+ preset.get("gender", "Auto/Neutral")
178
+ )
179
+ return None, "", "Auto/Neutral"
180
+
181
+
182
+ # =============================================================================
183
+ # Character Sheet Metadata
184
+ # =============================================================================
185
+
186
+ def create_character_sheet_metadata(
187
+ character_name: str,
188
+ character_sheet: Image.Image,
189
+ stages: Dict[str, Any],
190
+ config: CharacterSheetConfig,
191
+ backend: str,
192
+ input_type: str,
193
+ costume_description: str,
194
+ steps: int,
195
+ guidance: float
196
+ ) -> Dict[str, Any]:
197
+ """Create JSON metadata with pixel coordinates for each view."""
198
+ sheet_width, sheet_height = character_sheet.size
199
+ spacing = config.spacing
200
+
201
+ # Calculate face row dimensions
202
+ face_images = ['left_face', 'front_face', 'right_face']
203
+ face_height = 0
204
+ face_widths = []
205
+ for name in face_images:
206
+ if name in stages and stages[name] is not None:
207
+ face_height = stages[name].height
208
+ face_widths.append(stages[name].width)
209
+ else:
210
+ face_widths.append(0)
211
+
212
+ # Calculate body row dimensions
213
+ body_images = ['left_body', 'front_body', 'right_body', 'back_body']
214
+ body_height = 0
215
+ body_widths = []
216
+ for name in body_images:
217
+ if name in stages and stages[name] is not None:
218
+ body_height = stages[name].height
219
+ body_widths.append(stages[name].width)
220
+ else:
221
+ body_widths.append(0)
222
+
223
+ body_start_y = face_height + spacing
224
+
225
+ # Build view regions
226
+ views = {}
227
+
228
+ # Face row
229
+ x = 0
230
+ for i, name in enumerate(face_images):
231
+ views[name] = {
232
+ "x": x, "y": 0,
233
+ "width": face_widths[i], "height": face_height,
234
+ "description": {
235
+ "left_face": "Left profile view of face (90 degrees)",
236
+ "front_face": "Front-facing portrait view",
237
+ "right_face": "Right profile view of face (90 degrees)"
238
+ }.get(name, name)
239
+ }
240
+ x += face_widths[i]
241
+
242
+ # Body row
243
+ x = 0
244
+ for i, name in enumerate(body_images):
245
+ views[name] = {
246
+ "x": x, "y": body_start_y,
247
+ "width": body_widths[i], "height": body_height,
248
+ "description": {
249
+ "left_body": "Left side full body view (90 degrees)",
250
+ "front_body": "Front-facing full body view",
251
+ "right_body": "Right side full body view (90 degrees)",
252
+ "back_body": "Rear full body view (180 degrees)"
253
+ }.get(name, name)
254
+ }
255
+ x += body_widths[i]
256
+
257
+ metadata = {
258
+ "version": "1.0",
259
+ "generator": "Character Sheet Pro (HuggingFace Spaces)",
260
+ "timestamp": datetime.now().isoformat(),
261
+ "character": {
262
+ "name": character_name,
263
+ "input_type": input_type,
264
+ "costume_description": costume_description or None
265
+ },
266
+ "generation": {
267
+ "backend": backend,
268
+ "steps": steps,
269
+ "guidance_scale": guidance
270
+ },
271
+ "sheet": {
272
+ "width": sheet_width,
273
+ "height": sheet_height,
274
+ "spacing": spacing,
275
+ "background_color": config.background_color
276
+ },
277
+ "views": views,
278
+ "files": {
279
+ "character_sheet": f"{sanitize_filename(character_name)}_character_sheet.png",
280
+ "individual_views": {
281
+ name: f"{sanitize_filename(character_name)}_{name}.png"
282
+ for name in list(face_images) + list(body_images)
283
+ }
284
+ }
285
+ }
286
+
287
+ return metadata
288
+
289
+
290
+ def create_download_zip(
291
+ character_name: str,
292
+ character_sheet: Image.Image,
293
+ stages: Dict[str, Any],
294
+ metadata: Dict[str, Any],
295
+ output_dir: Path
296
+ ) -> Path:
297
+ """Create a ZIP file with character sheet, individual views, and metadata JSON."""
298
+ safe_name = sanitize_filename(character_name)
299
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
300
+ zip_path = output_dir / f"{safe_name}_{timestamp}.zip"
301
+
302
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
303
+ # Add character sheet
304
+ sheet_path = output_dir / f"{safe_name}_character_sheet.png"
305
+ character_sheet.save(sheet_path)
306
+ zf.write(sheet_path, f"{safe_name}_character_sheet.png")
307
+ sheet_path.unlink()
308
+
309
+ # Add individual views
310
+ view_names = ['left_face', 'front_face', 'right_face',
311
+ 'left_body', 'front_body', 'right_body', 'back_body']
312
+ for name in view_names:
313
+ if name in stages and stages[name] is not None:
314
+ img = stages[name]
315
+ img_path = output_dir / f"{safe_name}_{name}.png"
316
+ img.save(img_path)
317
+ zf.write(img_path, f"{safe_name}_{name}.png")
318
+ img_path.unlink()
319
+
320
+ # Add metadata JSON
321
+ json_path = output_dir / f"{safe_name}_metadata.json"
322
+ with open(json_path, 'w') as f:
323
+ json.dump(metadata, f, indent=2)
324
+ zf.write(json_path, f"{safe_name}_metadata.json")
325
+ json_path.unlink()
326
+
327
+ return zip_path
328
+
329
+
330
+ # =============================================================================
331
+ # Zero GPU Generation Function
332
+ # =============================================================================
333
+
334
+ # Global cache for the service (persists across GPU sessions)
335
+ _cached_service = None
336
+ _cached_backend = None
337
+
338
+
339
+ @spaces.GPU(duration=300) # 5-minute timeout for the full pipeline
340
+ def generate_with_gpu(
341
+ input_image: Optional[Image.Image],
342
+ input_type: str,
343
+ character_name: str,
344
+ gender: str,
345
+ costume_description: str,
346
+ costume_image: Optional[Image.Image],
347
+ face_image: Optional[Image.Image],
348
+ body_image: Optional[Image.Image],
349
+ backend_choice: str,
350
+ api_key: str,
351
+ num_steps: int,
352
+ guidance_scale: float,
353
+ include_costume_in_faces: bool
354
+ ) -> Tuple[Optional[Image.Image], str, Dict[str, Any]]:
355
+ """
356
+ GPU-wrapped generation function for Zero GPU.
357
+
358
+ This function runs entirely within a GPU session.
359
+ Model loading happens inside this function for Zero GPU compatibility.
360
+ """
361
+ global _cached_service, _cached_backend
362
+
363
+ try:
364
+ # Determine backend
365
+ backend = BackendRouter.backend_from_string(backend_choice)
366
+ is_cloud = backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO)
367
+
368
+ # Validate API key for cloud backends
369
+ if is_cloud and not api_key:
370
+ return None, "Error: Gemini API key required for cloud backends", {}
371
+
372
+ # Load or reuse service
373
+ if _cached_service is None or _cached_backend != backend:
374
+ logger.info(f"Loading model for {backend.value}...")
375
+
376
+ # For local FLUX model, create service (this loads the model)
377
+ _cached_service = CharacterSheetService(
378
+ api_key=api_key if is_cloud else None,
379
+ backend=backend
380
+ )
381
+ _cached_backend = backend
382
+
383
+ # Configure steps/guidance
384
+ if hasattr(_cached_service.client, 'default_steps'):
385
+ _cached_service.client.default_steps = num_steps
386
+ if hasattr(_cached_service.client, 'default_guidance'):
387
+ _cached_service.client.default_guidance = guidance_scale
388
+
389
+ logger.info(f"Model loaded successfully: {backend.value}")
390
+
391
+ # Map gender selection
392
+ gender_map = {
393
+ "Auto/Neutral": "character",
394
+ "Male": "man",
395
+ "Female": "woman"
396
+ }
397
+ gender_term = gender_map.get(gender, "character")
398
+
399
+ # Validate steps and guidance
400
+ num_steps = max(1, min(100, int(num_steps)))
401
+ guidance_scale = max(0.0, min(20.0, float(guidance_scale)))
402
+
403
+ # Update steps/guidance if different
404
+ if hasattr(_cached_service.client, 'default_steps'):
405
+ _cached_service.client.default_steps = num_steps
406
+ if hasattr(_cached_service.client, 'default_guidance'):
407
+ _cached_service.client.default_guidance = guidance_scale
408
+
409
+ # Run generation
410
+ logger.info(f"Starting generation for {character_name}...")
411
+
412
+ sheet, status, metadata = _cached_service.generate_character_sheet(
413
+ initial_image=input_image,
414
+ input_type=input_type,
415
+ character_name=character_name or "Character",
416
+ gender_term=gender_term,
417
+ costume_description=costume_description,
418
+ costume_image=costume_image,
419
+ face_image=face_image,
420
+ body_image=body_image,
421
+ include_costume_in_faces=include_costume_in_faces,
422
+ output_dir=OUTPUT_DIR
423
+ )
424
+
425
+ return sheet, status, metadata
426
+
427
+ except Exception as e:
428
+ logger.exception(f"Generation error: {e}")
429
+ return None, f"Error: {str(e)}", {}
430
+
431
+
432
+ # =============================================================================
433
+ # Gradio Interface Functions
434
+ # =============================================================================
435
+
436
+ def generate_character_sheet(
437
+ input_image: Optional[Image.Image],
438
+ input_type: str,
439
+ character_name: str,
440
+ gender: str,
441
+ costume_description: str,
442
+ costume_image: Optional[Image.Image],
443
+ face_image: Optional[Image.Image],
444
+ body_image: Optional[Image.Image],
445
+ backend_choice: str,
446
+ api_key_override: str,
447
+ num_steps: int,
448
+ guidance_scale: float,
449
+ include_costume_in_faces: bool,
450
+ progress=gr.Progress()
451
+ ) -> Generator:
452
+ """
453
+ Generate character sheet from input image(s).
454
+
455
+ This wrapper handles preprocessing and calls the GPU-wrapped function.
456
+ """
457
+ # Initial empty state
458
+ empty_previews = [None] * 7
459
+
460
+ yield (None, "Initializing...", *empty_previews, None, None)
461
+
462
+ # Preprocess all input images to PNG format
463
+ input_image = ensure_png_image(input_image)
464
+ face_image = ensure_png_image(face_image)
465
+ body_image = ensure_png_image(body_image)
466
+ costume_image = ensure_png_image(costume_image)
467
+
468
+ # Validate input
469
+ if input_type == "Face + Body (Separate)":
470
+ if face_image is None or body_image is None:
471
+ yield (None, "Error: Both face and body images required for this mode.",
472
+ *empty_previews, None, None)
473
+ return
474
+ elif input_image is None:
475
+ yield (None, "Error: Please upload an input image.", *empty_previews, None, None)
476
+ return
477
+
478
+ # Get API key
479
+ api_key = api_key_override.strip() if api_key_override.strip() else API_KEY
480
+
481
+ # Show loading state
482
+ progress(0.1, desc="Allocating GPU...")
483
+ yield (None, "Allocating GPU and loading model (this may take 30-60 seconds on first run)...",
484
+ *empty_previews, None, None)
485
+
486
+ try:
487
+ # Call the GPU-wrapped function
488
+ character_sheet, status, metadata = generate_with_gpu(
489
+ input_image=input_image,
490
+ input_type=input_type,
491
+ character_name=character_name or "Character",
492
+ gender=gender,
493
+ costume_description=costume_description,
494
+ costume_image=costume_image,
495
+ face_image=face_image,
496
+ body_image=body_image,
497
+ backend_choice=backend_choice,
498
+ api_key=api_key,
499
+ num_steps=int(num_steps),
500
+ guidance_scale=float(guidance_scale),
501
+ include_costume_in_faces=include_costume_in_faces
502
+ )
503
+
504
+ if character_sheet is None:
505
+ yield (None, status, *empty_previews, None, None)
506
+ return
507
+
508
+ # Get stages from metadata for preview
509
+ stages = metadata.get('stages', {})
510
+
511
+ # Create preview list
512
+ preview_list = [
513
+ stages.get('left_face'),
514
+ stages.get('front_face'),
515
+ stages.get('right_face'),
516
+ stages.get('left_body'),
517
+ stages.get('front_body'),
518
+ stages.get('right_body'),
519
+ stages.get('back_body')
520
+ ]
521
+
522
+ # Determine backend
523
+ backend = BackendRouter.backend_from_string(backend_choice)
524
+
525
+ # Create metadata JSON
526
+ config = CharacterSheetConfig()
527
+ json_metadata = create_character_sheet_metadata(
528
+ character_name=character_name or "Character",
529
+ character_sheet=character_sheet,
530
+ stages=stages,
531
+ config=config,
532
+ backend=BackendRouter.BACKEND_NAMES.get(backend, backend_choice),
533
+ input_type=input_type,
534
+ costume_description=costume_description,
535
+ steps=num_steps,
536
+ guidance=guidance_scale
537
+ )
538
+
539
+ # Save JSON file
540
+ safe_name = sanitize_filename(character_name or "Character")
541
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
542
+ json_path = OUTPUT_DIR / f"{safe_name}_{timestamp}_metadata.json"
543
+ with open(json_path, 'w') as f:
544
+ json.dump(json_metadata, f, indent=2)
545
+
546
+ # Create ZIP file
547
+ zip_path = create_download_zip(
548
+ character_name=character_name or "Character",
549
+ character_sheet=character_sheet,
550
+ stages=stages,
551
+ metadata=json_metadata,
552
+ output_dir=OUTPUT_DIR
553
+ )
554
+
555
+ # Final yield with all outputs
556
+ yield (
557
+ character_sheet,
558
+ status,
559
+ *preview_list,
560
+ str(json_path),
561
+ str(zip_path)
562
+ )
563
+
564
+ except Exception as e:
565
+ logger.exception(f"Error: {e}")
566
+ yield (None, f"Error: {str(e)}", *empty_previews, None, None)
567
+
568
+
569
+ def update_input_visibility(input_type: str):
570
+ """Update visibility of input components based on input type."""
571
+ if input_type == "Face + Body (Separate)":
572
+ return (
573
+ gr.update(visible=False), # Main input
574
+ gr.update(visible=True), # Face input
575
+ gr.update(visible=True), # Body input
576
+ )
577
+ else:
578
+ return (
579
+ gr.update(visible=True), # Main input
580
+ gr.update(visible=False), # Face input
581
+ gr.update(visible=False), # Body input
582
+ )
583
+
584
+
585
+ def update_defaults_on_backend_change(backend_value: str):
586
+ """Update steps, guidance, and costume-in-faces when backend changes."""
587
+ steps, guidance = get_model_defaults(backend_value)
588
+ costume_in_faces = get_costume_in_faces_default(backend_value)
589
+ return gr.update(value=steps), gr.update(value=guidance), gr.update(value=costume_in_faces)
590
+
591
+
592
+ # =============================================================================
593
+ # Gradio UI
594
+ # =============================================================================
595
+
596
+ # CSS for the interface
597
+ APP_CSS = """
598
+ .container { max-width: 1200px; margin: auto; }
599
+ .output-image { min-height: 400px; }
600
+
601
+ /* GPU status banner */
602
+ .gpu-banner {
603
+ background: linear-gradient(90deg, #7c3aed, #a855f7);
604
+ padding: 12px 20px;
605
+ text-align: center;
606
+ color: white;
607
+ font-weight: bold;
608
+ border-radius: 8px;
609
+ margin-bottom: 16px;
610
+ }
611
+
612
+ /* Generate button styling */
613
+ .generate-btn-main {
614
+ background: linear-gradient(90deg, #00aa44, #00cc55) !important;
615
+ color: white !important;
616
+ font-weight: bold !important;
617
+ font-size: 20px !important;
618
+ padding: 16px 32px !important;
619
+ border: none !important;
620
+ box-shadow: 0 4px 15px rgba(0, 170, 68, 0.4) !important;
621
+ }
622
+
623
+ .generate-btn-main:hover {
624
+ background: linear-gradient(90deg, #00cc55, #00ee66) !important;
625
  }
626
  """
627
 
 
 
 
 
 
628
 
629
+ def create_ui():
630
+ """Create the Gradio interface for HuggingFace Spaces."""
631
+
632
+ with gr.Blocks(title="Character Sheet Pro") as demo:
633
+
634
+ # GPU status banner
635
+ gr.HTML(
636
+ '<div class="gpu-banner">'
637
+ 'Zero GPU (A10G) - Model loads automatically on first generation'
638
+ '</div>'
639
+ )
640
+
641
+ gr.Markdown("# Character Sheet Pro")
642
+ gr.Markdown("Generate 7-view character turnaround sheets from a single input image using FLUX.2 klein.")
643
+
644
+ # Backend selection and controls
645
  with gr.Row():
646
+ backend_dropdown = gr.Dropdown(
647
+ choices=[
648
+ ("FLUX.2 klein 9B (Best Quality, ~20GB)", "flux_klein_9b_fp8"),
649
+ ("FLUX.2 klein 4B (Fast, ~13GB)", BackendType.FLUX_KLEIN.value),
650
+ ("Gemini Flash (Cloud - Fallback)", BackendType.GEMINI_FLASH.value),
651
+ ],
652
+ value="flux_klein_9b_fp8", # Default to best quality
653
+ label="Backend",
654
+ scale=2
655
  )
 
 
 
656
 
657
+ api_key_input = gr.Textbox(
658
+ label="Gemini API Key (for cloud backend)",
659
+ placeholder="Enter API key if using Gemini",
660
+ type="password",
661
+ value="",
662
+ scale=2
 
663
  )
664
+
665
+ with gr.Row():
666
+ # Left column: Inputs
667
+ with gr.Column(scale=1):
668
+ gr.Markdown("### Input Settings")
669
+
670
+ input_type = gr.Radio(
671
+ choices=["Face Only", "Full Body", "Face + Body (Separate)"],
672
+ value="Face Only",
673
+ label="Input Type",
674
+ info="What type of image(s) are you providing?"
675
+ )
676
+
677
+ main_input = gr.Image(
678
+ label="Input Image",
679
+ type="pil",
680
+ format="png",
681
+ visible=True
682
+ )
683
+
684
+ with gr.Row(visible=False) as face_body_row:
685
+ face_input = gr.Image(
686
+ label="Face Reference",
687
+ type="pil",
688
+ format="png",
689
+ visible=False
690
+ )
691
+ body_input = gr.Image(
692
+ label="Body Reference",
693
+ type="pil",
694
+ format="png",
695
+ visible=False
696
+ )
697
+
698
+ gr.Markdown("### Character Details")
699
+
700
+ character_name = gr.Textbox(
701
+ label="Character Name",
702
+ placeholder="My Character",
703
+ value=""
704
+ )
705
+
706
+ gender = gr.Radio(
707
+ choices=["Auto/Neutral", "Male", "Female"],
708
+ value="Auto/Neutral",
709
+ label="Gender"
710
+ )
711
+
712
+ costume_description = gr.Textbox(
713
+ label="Costume Description (Optional)",
714
+ placeholder="e.g., Full plate armor with gold trim...",
715
+ value="",
716
+ lines=3
717
  )
718
+
719
+ costume_image = gr.Image(
720
+ label="Costume Reference Image (Optional)",
721
+ type="pil",
722
+ format="png"
723
+ )
724
+
725
+ gr.Markdown("### Generation Parameters")
726
+
727
+ with gr.Row():
728
+ num_steps = gr.Number(
729
+ label="Inference Steps",
730
+ value=4,
731
+ minimum=1,
732
+ maximum=50,
733
+ step=1,
734
+ info="FLUX klein uses 4 steps"
735
+ )
736
+ guidance_scale = gr.Number(
737
+ label="Guidance Scale",
738
+ value=1.0,
739
+ minimum=0.0,
740
+ maximum=10.0,
741
+ step=0.1,
742
+ info="FLUX klein uses 1.0"
743
+ )
744
+
745
+ include_costume_in_faces = gr.Checkbox(
746
+ label="Include costume in face views",
747
+ value=False,
748
+ info="Turn OFF for FLUX (can confuse framing)"
749
  )
750
 
751
+ # GENERATE BUTTON
752
+ generate_btn = gr.Button(
753
+ "GENERATE CHARACTER SHEET",
754
+ variant="primary",
755
+ size="lg",
756
+ elem_classes=["generate-btn-main"]
 
757
  )
758
 
759
+ # Right column: Output
760
+ with gr.Column(scale=2):
761
+ gr.Markdown("### Generated Character Sheet")
762
+
763
+ output_image = gr.Image(
764
+ label="Character Sheet",
765
+ type="pil",
766
+ format="png",
767
+ elem_classes=["output-image"]
768
+ )
769
+
770
+ status_text = gr.Textbox(
771
+ label="Status",
772
+ interactive=False
773
+ )
774
+
775
+ # Preview gallery
776
+ gr.Markdown("### Individual Views Preview")
777
+
778
+ with gr.Row():
779
+ gr.Markdown("**Face Views:**")
780
+ with gr.Row():
781
+ preview_left_face = gr.Image(label="Left Face", type="pil", height=150, width=112)
782
+ preview_front_face = gr.Image(label="Front Face", type="pil", height=150, width=112)
783
+ preview_right_face = gr.Image(label="Right Face", type="pil", height=150, width=112)
784
+
785
+ with gr.Row():
786
+ gr.Markdown("**Body Views:**")
787
+ with gr.Row():
788
+ preview_left_body = gr.Image(label="Left Body", type="pil", height=150, width=84)
789
+ preview_front_body = gr.Image(label="Front Body", type="pil", height=150, width=84)
790
+ preview_right_body = gr.Image(label="Right Body", type="pil", height=150, width=84)
791
+ preview_back_body = gr.Image(label="Back Body", type="pil", height=150, width=84)
792
+
793
+ # Downloads
794
+ gr.Markdown("### Downloads")
795
+ with gr.Row():
796
+ json_download = gr.File(label="Metadata JSON", interactive=False)
797
+ zip_download = gr.File(label="Complete Package (ZIP)", interactive=False)
798
+
799
+ # Usage instructions
800
+ gr.Markdown("---")
801
+ gr.Markdown("### How to Use")
802
+ gr.Markdown("""
803
+ 1. **Upload an image** (face portrait or full body)
804
+ 2. **Select input type** based on your image
805
+ 3. **Optionally** add character name, gender, and costume description
806
+ 4. **Click Generate** - the model loads automatically on first run (~30-60s)
807
+ 5. **Wait** for all 7 views to generate (~2-3 minutes total)
808
+ 6. **Download** the complete package
809
+
810
+ **GPU Notes:**
811
+ - Uses Zero GPU (A10G 24GB) - free but with 5-minute session limit
812
+ - First generation loads the model (adds ~30-60 seconds)
813
+ - Subsequent generations in the same session are faster
814
+ - If GPU unavailable, switch to Gemini Flash (requires API key)
815
+ """)
816
+
817
+ # Event handlers
818
+ input_type.change(
819
+ fn=update_input_visibility,
820
+ inputs=[input_type],
821
+ outputs=[main_input, face_input, body_input]
822
+ )
823
+
824
+ backend_dropdown.change(
825
+ fn=update_defaults_on_backend_change,
826
+ inputs=[backend_dropdown],
827
+ outputs=[num_steps, guidance_scale, include_costume_in_faces]
828
  )
829
 
830
+ generate_btn.click(
831
+ fn=generate_character_sheet,
832
+ inputs=[
833
+ main_input,
834
+ input_type,
835
+ character_name,
836
+ gender,
837
+ costume_description,
838
+ costume_image,
839
+ face_input,
840
+ body_input,
841
+ backend_dropdown,
842
+ api_key_input,
843
+ num_steps,
844
+ guidance_scale,
845
+ include_costume_in_faces
846
+ ],
847
+ outputs=[
848
+ output_image,
849
+ status_text,
850
+ preview_left_face,
851
+ preview_front_face,
852
+ preview_right_face,
853
+ preview_left_body,
854
+ preview_front_body,
855
+ preview_right_body,
856
+ preview_back_body,
857
+ json_download,
858
+ zip_download
859
+ ]
860
  )
861
 
862
+ return demo
863
+
864
+
865
+ # =============================================================================
866
+ # Main
867
+ # =============================================================================
868
+
869
+ if __name__ == "__main__":
870
+ demo = create_ui()
871
+
872
+ if HF_SPACES:
873
+ # Running on HuggingFace Spaces
874
+ demo.launch(
875
+ theme=gr.themes.Soft(),
876
+ css=APP_CSS
877
+ )
878
+ else:
879
+ # Local testing
880
+ print("Running locally (no Zero GPU)")
881
+ demo.launch(
882
+ server_name="0.0.0.0",
883
+ server_port=7890,
884
+ share=False,
885
+ theme=gr.themes.Soft(),
886
+ css=APP_CSS
887
+ )
requirements.txt CHANGED
@@ -1,8 +1,36 @@
1
- accelerate>=1.12.0
 
 
 
 
 
 
 
 
 
2
  git+https://github.com/huggingface/diffusers.git
3
- torch>=2.5.0
4
- transformers>=5.0.0
5
- sentencepiece
6
- numpy
7
- Pillow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  invisible_watermark
 
1
+ # Character Sheet Pro - HuggingFace Spaces
2
+ # =========================================
3
+
4
+ # Image processing
5
+ Pillow>=10.0.0
6
+
7
+ # Utilities
8
+ python-dotenv>=1.0.0
9
+
10
+ # Diffusers from git (required for Flux2KleinPipeline)
11
  git+https://github.com/huggingface/diffusers.git
12
+
13
+ # PyTorch
14
+ torch>=2.1.0
15
+ torchvision>=0.16.0
16
+
17
+ # Transformers
18
+ transformers>=4.40.0
19
+
20
+ # Accelerate
21
+ accelerate>=0.25.0
22
+
23
+ # HuggingFace Hub
24
+ huggingface-hub>=0.20.0
25
+
26
+ # Safetensors
27
+ safetensors>=0.4.0
28
+
29
+ # Sentencepiece
30
+ sentencepiece>=0.1.99
31
+
32
+ # Google Gemini API (fallback backend)
33
+ google-genai>=0.3.0
34
+
35
+ # Invisible watermark (for FLUX)
36
  invisible_watermark
src/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Character Sheet Pro - 7-View Character Sheet Generator
3
+ ======================================================
4
+
5
+ A standalone character sheet generation system that creates
6
+ multi-view turnaround sheets from a single input image.
7
+
8
+ Supports:
9
+ - 7 views (3 face + 4 body)
10
+ - Multiple backends: Gemini (Cloud), FLUX.2 klein (Local), Qwen-Image-Edit (Local/ComfyUI)
11
+ - HuggingFace Spaces deployment via Gradio
12
+ """
13
+
14
+ from .models import GenerationRequest, GenerationResult
15
+ from .gemini_client import GeminiClient
16
+ from .character_service import CharacterSheetService
17
+ from .backend_router import BackendRouter, BackendType
18
+ from .flux_klein_client import FluxKleinClient
19
+ from .qwen_image_edit_client import QwenImageEditClient
20
+ from .comfyui_client import ComfyUIClient
21
+ from .model_manager import ModelManager, ModelState, get_model_manager
22
+
23
+ __version__ = "2.3.0" # Bumped for model manager feature
24
+ __all__ = [
25
+ "GenerationRequest",
26
+ "GenerationResult",
27
+ "GeminiClient",
28
+ "CharacterSheetService",
29
+ "BackendRouter",
30
+ "BackendType",
31
+ "FluxKleinClient",
32
+ "QwenImageEditClient",
33
+ "ComfyUIClient",
34
+ "ModelManager",
35
+ "ModelState",
36
+ "get_model_manager",
37
+ ]
src/backend_router.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backend Router
3
+ ==============
4
+
5
+ Unified router for selecting between different image generation backends:
6
+ - Gemini (Flash/Pro) - Cloud API
7
+ - FLUX.2 klein 4B/9B - Local model
8
+ - Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM
9
+ - Qwen-Image-Edit-2511 - Local model
10
+ """
11
+
12
+ import logging
13
+ from typing import Optional, Protocol, Union
14
+ from enum import Enum, auto
15
+ from PIL import Image
16
+
17
+ from .models import GenerationRequest, GenerationResult
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class BackendType(Enum):
24
+ """Available backend types."""
25
+ GEMINI_FLASH = "gemini_flash"
26
+ GEMINI_PRO = "gemini_pro"
27
+ FLUX_KLEIN = "flux_klein" # 4B model (~13GB VRAM)
28
+ FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8" # 9B FP8 model (~20GB VRAM, best quality)
29
+ ZIMAGE_TURBO = "zimage_turbo" # Z-Image Turbo 6B (9 steps, 16GB VRAM)
30
+ ZIMAGE_BASE = "zimage_base" # Z-Image Base 6B (50 steps, CFG support) - NEW!
31
+ LONGCAT_EDIT = "longcat_edit" # LongCat-Image-Edit (instruction-following, 18GB)
32
+ QWEN_IMAGE_EDIT = "qwen_image_edit" # Direct diffusers (slow, high VRAM)
33
+ QWEN_COMFYUI = "qwen_comfyui" # Via ComfyUI with FP8 quantization
34
+
35
+
36
+ class ImageClient(Protocol):
37
+ """Protocol for image generation clients."""
38
+
39
+ def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult:
40
+ """Generate an image from request."""
41
+ ...
42
+
43
+ def is_healthy(self) -> bool:
44
+ """Check if client is ready."""
45
+ ...
46
+
47
+
48
+ class BackendRouter:
49
+ """
50
+ Router for selecting between image generation backends.
51
+
52
+ Supports lazy loading of local models to save memory.
53
+ """
54
+
55
+ BACKEND_NAMES = {
56
+ BackendType.GEMINI_FLASH: "Gemini Flash",
57
+ BackendType.GEMINI_PRO: "Gemini Pro",
58
+ BackendType.FLUX_KLEIN: "FLUX.2 klein 4B",
59
+ BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8",
60
+ BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B",
61
+ BackendType.ZIMAGE_BASE: "Z-Image Base 6B",
62
+ BackendType.LONGCAT_EDIT: "LongCat-Image-Edit",
63
+ BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511",
64
+ BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)",
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ gemini_api_key: Optional[str] = None,
70
+ default_backend: BackendType = BackendType.GEMINI_FLASH
71
+ ):
72
+ """
73
+ Initialize backend router.
74
+
75
+ Args:
76
+ gemini_api_key: API key for Gemini backends
77
+ default_backend: Default backend to use
78
+ """
79
+ self.gemini_api_key = gemini_api_key
80
+ self.default_backend = default_backend
81
+ self._clients: dict = {}
82
+ self._active_backend: Optional[BackendType] = None
83
+
84
+ logger.info(f"BackendRouter initialized (default: {default_backend.value})")
85
+
86
+ def get_client(self, backend: Optional[BackendType] = None) -> ImageClient:
87
+ """
88
+ Get or create client for specified backend.
89
+
90
+ Args:
91
+ backend: Backend type (uses default if None)
92
+
93
+ Returns:
94
+ ImageClient instance
95
+ """
96
+ if backend is None:
97
+ backend = self.default_backend
98
+
99
+ # Return cached client if available
100
+ if backend in self._clients:
101
+ self._active_backend = backend
102
+ return self._clients[backend]
103
+
104
+ # Create new client
105
+ client = self._create_client(backend)
106
+ self._clients[backend] = client
107
+ self._active_backend = backend
108
+
109
+ return client
110
+
111
+ def _create_client(self, backend: BackendType) -> ImageClient:
112
+ """Create client for specified backend."""
113
+ logger.info(f"Creating client for {backend.value}...")
114
+
115
+ if backend == BackendType.GEMINI_FLASH:
116
+ from .gemini_client import GeminiClient
117
+ if not self.gemini_api_key:
118
+ raise ValueError("Gemini API key required for Gemini backends")
119
+ return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False)
120
+
121
+ elif backend == BackendType.GEMINI_PRO:
122
+ from .gemini_client import GeminiClient
123
+ if not self.gemini_api_key:
124
+ raise ValueError("Gemini API key required for Gemini backends")
125
+ return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True)
126
+
127
+ elif backend == BackendType.FLUX_KLEIN:
128
+ from .flux_klein_client import FluxKleinClient
129
+ # 4B model (~13GB VRAM) - fast
130
+ client = FluxKleinClient(
131
+ model_variant="4b",
132
+ enable_cpu_offload=False
133
+ )
134
+ if not client.load_model():
135
+ raise RuntimeError("Failed to load FLUX.2 klein 4B model")
136
+ return client
137
+
138
+ elif backend == BackendType.FLUX_KLEIN_9B_FP8:
139
+ from .flux_klein_client import FluxKleinClient
140
+ # 9B model (~29GB VRAM with CPU offload) - best quality
141
+ client = FluxKleinClient(
142
+ model_variant="9b",
143
+ enable_cpu_offload=True # Required for 24GB VRAM
144
+ )
145
+ if not client.load_model():
146
+ raise RuntimeError("Failed to load FLUX.2 klein 9B model")
147
+ return client
148
+
149
+ elif backend == BackendType.ZIMAGE_TURBO:
150
+ from .zimage_client import ZImageClient
151
+ # Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM
152
+ client = ZImageClient(
153
+ model_variant="turbo",
154
+ enable_cpu_offload=True
155
+ )
156
+ if not client.load_model():
157
+ raise RuntimeError("Failed to load Z-Image Turbo model")
158
+ return client
159
+
160
+ elif backend == BackendType.ZIMAGE_BASE:
161
+ from .zimage_client import ZImageClient
162
+ # Z-Image Base 6B - quality (50 steps), CFG support, negative prompts
163
+ client = ZImageClient(
164
+ model_variant="base",
165
+ enable_cpu_offload=True
166
+ )
167
+ if not client.load_model():
168
+ raise RuntimeError("Failed to load Z-Image Base model")
169
+ return client
170
+
171
+ elif backend == BackendType.LONGCAT_EDIT:
172
+ from .longcat_edit_client import LongCatEditClient
173
+ # LongCat-Image-Edit - instruction-following editing (~18GB VRAM)
174
+ client = LongCatEditClient(
175
+ enable_cpu_offload=True
176
+ )
177
+ if not client.load_model():
178
+ raise RuntimeError("Failed to load LongCat-Image-Edit model")
179
+ return client
180
+
181
+ elif backend == BackendType.QWEN_IMAGE_EDIT:
182
+ from .qwen_image_edit_client import QwenImageEditClient
183
+ client = QwenImageEditClient(enable_cpu_offload=False)
184
+ if not client.load_model():
185
+ raise RuntimeError("Failed to load Qwen-Image-Edit model")
186
+ return client
187
+
188
+ elif backend == BackendType.QWEN_COMFYUI:
189
+ from .comfyui_client import ComfyUIClient
190
+ client = ComfyUIClient()
191
+ if not client.is_healthy():
192
+ raise RuntimeError(
193
+ "ComfyUI is not running. Please start ComfyUI first:\n"
194
+ " cd comfyui && python main.py"
195
+ )
196
+ return client
197
+
198
+ else:
199
+ raise ValueError(f"Unknown backend: {backend}")
200
+
201
+ def generate(
202
+ self,
203
+ request: GenerationRequest,
204
+ backend: Optional[BackendType] = None,
205
+ **kwargs
206
+ ) -> GenerationResult:
207
+ """
208
+ Generate image using specified backend.
209
+
210
+ Args:
211
+ request: Generation request
212
+ backend: Backend to use (default if None)
213
+ **kwargs: Backend-specific parameters
214
+
215
+ Returns:
216
+ GenerationResult
217
+ """
218
+ try:
219
+ client = self.get_client(backend)
220
+ return client.generate(request, **kwargs)
221
+ except Exception as e:
222
+ logger.error(f"Generation failed with {backend}: {e}", exc_info=True)
223
+ return GenerationResult.error_result(f"Backend error: {str(e)}")
224
+
225
+ def unload_local_models(self):
226
+ """Unload all local models to free memory."""
227
+ local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
228
+ for backend, client in list(self._clients.items()):
229
+ if backend in local_backends:
230
+ if hasattr(client, 'unload_model'):
231
+ client.unload_model()
232
+ del self._clients[backend]
233
+ logger.info(f"Unloaded {backend.value}")
234
+
235
+ def switch_backend(self, backend: BackendType) -> bool:
236
+ """
237
+ Switch to a different backend.
238
+
239
+ For local models, this will load the new model and optionally
240
+ unload the previous one to save memory.
241
+
242
+ Args:
243
+ backend: Backend to switch to
244
+
245
+ Returns:
246
+ True if switch successful
247
+ """
248
+ try:
249
+ local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI}
250
+
251
+ # Unload other local models first to save memory
252
+ if backend in local_backends:
253
+ for other_local in local_backends - {backend}:
254
+ if other_local in self._clients:
255
+ if hasattr(self._clients[other_local], 'unload_model'):
256
+ self._clients[other_local].unload_model()
257
+ del self._clients[other_local]
258
+
259
+ # Get/create the new client
260
+ self.get_client(backend)
261
+ self.default_backend = backend
262
+
263
+ logger.info(f"Switched to {backend.value}")
264
+ return True
265
+
266
+ except Exception as e:
267
+ logger.error(f"Failed to switch to {backend}: {e}", exc_info=True)
268
+ return False
269
+
270
+ def get_active_backend_name(self) -> str:
271
+ """Get human-readable name of active backend."""
272
+ if self._active_backend:
273
+ return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend))
274
+ return "None"
275
+
276
+ def is_local_backend(self, backend: Optional[BackendType] = None) -> bool:
277
+ """Check if backend is a local model."""
278
+ if backend is None:
279
+ backend = self._active_backend
280
+ return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
281
+
282
+ @staticmethod
283
+ def get_supported_aspect_ratios(backend: BackendType) -> dict:
284
+ """
285
+ Get supported aspect ratios for a backend.
286
+
287
+ Returns dict mapping ratio strings to (width, height) tuples.
288
+ """
289
+ # Import clients to get their ASPECT_RATIOS
290
+ if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8):
291
+ from .flux_klein_client import FluxKleinClient
292
+ return FluxKleinClient.ASPECT_RATIOS
293
+
294
+ elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE):
295
+ from .zimage_client import ZImageClient
296
+ return ZImageClient.ASPECT_RATIOS
297
+
298
+ elif backend == BackendType.LONGCAT_EDIT:
299
+ from .longcat_edit_client import LongCatEditClient
300
+ return LongCatEditClient.ASPECT_RATIOS
301
+
302
+ elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO):
303
+ from .gemini_client import GeminiClient
304
+ return GeminiClient.ASPECT_RATIOS
305
+
306
+ elif backend == BackendType.QWEN_IMAGE_EDIT:
307
+ from .qwen_image_edit_client import QwenImageEditClient
308
+ return QwenImageEditClient.ASPECT_RATIOS
309
+
310
+ elif backend == BackendType.QWEN_COMFYUI:
311
+ from .comfyui_client import ComfyUIClient
312
+ return ComfyUIClient.ASPECT_RATIOS
313
+
314
+ else:
315
+ # Default fallback
316
+ return {
317
+ "1:1": (1024, 1024),
318
+ "16:9": (1344, 768),
319
+ "9:16": (768, 1344),
320
+ }
321
+
322
+ @staticmethod
323
+ def get_aspect_ratio_choices(backend: BackendType) -> list:
324
+ """
325
+ Get aspect ratio choices for UI dropdowns.
326
+
327
+ Returns list of (label, value) tuples.
328
+ """
329
+ ratios = BackendRouter.get_supported_aspect_ratios(backend)
330
+ choices = []
331
+ for ratio, (w, h) in ratios.items():
332
+ label = f"{ratio} ({w}x{h})"
333
+ choices.append((label, ratio))
334
+ return choices
335
+
336
+ def get_available_backends(self) -> list:
337
+ """Get list of available backends."""
338
+ available = []
339
+
340
+ # Gemini backends require API key
341
+ if self.gemini_api_key:
342
+ available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO])
343
+
344
+ # Local backends always available (if dependencies installed)
345
+ try:
346
+ from diffusers import Flux2KleinPipeline
347
+ available.append(BackendType.FLUX_KLEIN)
348
+ except ImportError:
349
+ pass
350
+
351
+ try:
352
+ from diffusers import ZImagePipeline
353
+ available.append(BackendType.ZIMAGE_TURBO)
354
+ available.append(BackendType.ZIMAGE_BASE)
355
+ except ImportError:
356
+ pass
357
+
358
+ try:
359
+ from diffusers import LongCatImageEditPipeline
360
+ available.append(BackendType.LONGCAT_EDIT)
361
+ except ImportError:
362
+ pass
363
+
364
+ try:
365
+ from diffusers import QwenImageEditPlusPipeline
366
+ available.append(BackendType.QWEN_IMAGE_EDIT)
367
+ except ImportError:
368
+ pass
369
+
370
+ # ComfyUI backend - check if ComfyUI client works
371
+ try:
372
+ from .comfyui_client import ComfyUIClient
373
+ client = ComfyUIClient()
374
+ if client.is_healthy():
375
+ available.append(BackendType.QWEN_COMFYUI)
376
+ except Exception:
377
+ pass
378
+
379
+ return available
380
+
381
+ @staticmethod
382
+ def get_backend_choices() -> list:
383
+ """Get list of backend choices for UI dropdowns."""
384
+ return [
385
+ ("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value),
386
+ ("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value),
387
+ ("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value),
388
+ ("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value),
389
+ ("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value),
390
+ ("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value),
391
+ ("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value),
392
+ ("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value),
393
+ ]
394
+
395
+ @staticmethod
396
+ def backend_from_string(value: str) -> BackendType:
397
+ """Convert string to BackendType."""
398
+ for bt in BackendType:
399
+ if bt.value == value:
400
+ return bt
401
+ raise ValueError(f"Unknown backend: {value}")
src/character_service.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Character Sheet Service
3
+ =======================
4
+
5
+ 9-stage pipeline for generating 7-view character turnaround sheets.
6
+
7
+ Layout:
8
+ +------------------+------------------+------------------+
9
+ | Left Face Profile| Front Face | Right Face Profile| (3:4)
10
+ +------------------+------------------+------------------+
11
+ | Left Side Body | Front Body | Right Side Body | Back Body | (9:16)
12
+ +------------------+------------------+------------------+
13
+ """
14
+
15
+ import time
16
+ import random
17
+ import logging
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple, Dict, Any, Callable, List
20
+ from datetime import datetime
21
+ from PIL import Image
22
+
23
+ from .models import (
24
+ GenerationRequest,
25
+ GenerationResult,
26
+ CharacterSheetConfig,
27
+ CharacterSheetMetadata
28
+ )
29
+ from .gemini_client import GeminiClient
30
+ from .backend_router import BackendRouter, BackendType
31
+ from .utils import ensure_pil_image, save_image, sanitize_filename, preprocess_images_for_backend
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class CharacterSheetService:
38
+ """
39
+ Service for generating 7-view character turnaround sheets.
40
+
41
+ Pipeline (9 stages):
42
+ 0. Input normalization (face→body or body→face+body)
43
+ 1. Front face portrait
44
+ 2. Left face profile (90 degrees)
45
+ 3. Right face profile (90 degrees)
46
+ 4. Front full body (from normalized)
47
+ 5. Back full body
48
+ 6. Left side full body
49
+ 7. Right side full body
50
+ 8. Composite character sheet
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ api_key: Optional[str] = None,
56
+ use_pro_model: bool = False,
57
+ config: Optional[CharacterSheetConfig] = None,
58
+ backend: Optional[BackendType] = None,
59
+ backend_router: Optional[BackendRouter] = None
60
+ ):
61
+ """
62
+ Initialize character sheet service.
63
+
64
+ Args:
65
+ api_key: Gemini API key (for cloud backends)
66
+ use_pro_model: Use Gemini Pro model (legacy, use backend param instead)
67
+ config: Optional configuration
68
+ backend: Specific backend to use
69
+ backend_router: Pre-configured backend router
70
+ """
71
+ self.config = config or CharacterSheetConfig()
72
+
73
+ # Determine backend
74
+ if backend_router is not None:
75
+ self.router = backend_router
76
+ self.backend = backend or backend_router.default_backend
77
+ else:
78
+ # Determine default backend based on params
79
+ if backend is not None:
80
+ self.backend = backend
81
+ elif use_pro_model:
82
+ self.backend = BackendType.GEMINI_PRO
83
+ else:
84
+ self.backend = BackendType.GEMINI_FLASH
85
+
86
+ self.router = BackendRouter(
87
+ gemini_api_key=api_key,
88
+ default_backend=self.backend
89
+ )
90
+
91
+ # For backward compatibility
92
+ self.use_pro_model = use_pro_model
93
+ self.client = self.router.get_client(self.backend)
94
+
95
+ logger.info(f"CharacterSheetService initialized (backend: {self.backend.value})")
96
+
97
+ def generate_character_sheet(
98
+ self,
99
+ initial_image: Optional[Image.Image],
100
+ input_type: str = "Face Only",
101
+ character_name: str = "Character",
102
+ gender_term: str = "character",
103
+ costume_description: str = "",
104
+ costume_image: Optional[Image.Image] = None,
105
+ face_image: Optional[Image.Image] = None,
106
+ body_image: Optional[Image.Image] = None,
107
+ include_costume_in_faces: bool = True,
108
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
109
+ stage_callback: Optional[Callable[[str, Image.Image, Dict[str, Any]], None]] = None,
110
+ output_dir: Optional[Path] = None
111
+ ) -> Tuple[Optional[Image.Image], str, Dict[str, Any]]:
112
+ """
113
+ Generate complete 7-view character turnaround sheet.
114
+
115
+ Args:
116
+ initial_image: Starting image (face or body)
117
+ input_type: "Face Only", "Full Body", or "Face + Body (Separate)"
118
+ character_name: Character name
119
+ gender_term: "character", "man", or "woman"
120
+ costume_description: Text costume description
121
+ costume_image: Optional costume reference
122
+ face_image: Face image (for Face + Body mode)
123
+ body_image: Body image (for Face + Body mode)
124
+ include_costume_in_faces: If True, include costume reference in face views.
125
+ Set False for models like FLUX that confuse costume with framing.
126
+ progress_callback: Optional callback(stage, total_stages, message)
127
+ stage_callback: Optional callback(stage_name, image, stages_dict) called after each
128
+ stage completes with the generated image. Enables streaming preview.
129
+ output_dir: Optional output directory
130
+
131
+ Returns:
132
+ Tuple of (character_sheet, status_message, metadata)
133
+ """
134
+ try:
135
+ total_stages = 9
136
+ stages = {}
137
+
138
+ logger.info("=" * 60)
139
+ logger.info(f"STARTING CHARACTER SHEET: {character_name}")
140
+ logger.info(f"Input type: {input_type}")
141
+ logger.info(f"Costume: {costume_description or '(none)'}")
142
+ logger.info("=" * 60)
143
+
144
+ # Build costume instructions - separate for face and body views
145
+ # For models like FLUX, costume refs confuse face generation
146
+ costume_instruction_body = ""
147
+ if costume_description:
148
+ costume_instruction_body = f" wearing {costume_description}"
149
+ elif costume_image:
150
+ costume_instruction_body = " wearing the costume shown in the reference"
151
+
152
+ # Face views only get costume instruction if flag is set
153
+ if include_costume_in_faces:
154
+ costume_instruction_face = costume_instruction_body
155
+ else:
156
+ costume_instruction_face = ""
157
+ logger.info("Costume excluded from face views (include_costume_in_faces=False)")
158
+
159
+ def update_progress(stage: int, message: str):
160
+ if progress_callback:
161
+ progress_callback(stage, total_stages, message)
162
+ logger.info(f"[Stage {stage}/{total_stages}] {message}")
163
+
164
+ def notify_stage_complete(stage_name: str, image: Image.Image):
165
+ """Notify callback when a stage completes for streaming preview."""
166
+ if stage_callback and image is not None:
167
+ stage_callback(stage_name, image, stages)
168
+
169
+ # =================================================================
170
+ # Stage 0: Normalize input
171
+ # =================================================================
172
+ update_progress(0, "Normalizing input images...")
173
+
174
+ reference_body, reference_face = self._normalize_input(
175
+ initial_image=initial_image,
176
+ input_type=input_type,
177
+ face_image=face_image,
178
+ body_image=body_image,
179
+ costume_instruction=costume_instruction_body, # Body normalization uses full costume
180
+ costume_image=costume_image,
181
+ gender_term=gender_term,
182
+ stages=stages,
183
+ progress_callback=lambda msg: update_progress(0, msg)
184
+ )
185
+
186
+ if reference_body is None or reference_face is None:
187
+ return None, "Failed to normalize input images", {}
188
+
189
+ time.sleep(1)
190
+
191
+ # =================================================================
192
+ # FACE VIEWS (3 portraits)
193
+ # =================================================================
194
+
195
+ # Stage 1: Front face portrait
196
+ update_progress(1, "Generating front face portrait...")
197
+
198
+ if input_type == "Face + Body (Separate)":
199
+ prompt = f"Generate a close-up frontal facial portrait showing the {gender_term} from the first image (body/costume reference), extrapolate and extract exact facial details from the second image (face reference). Do NOT transfer clothing or hair style from the second image. The face should fill the entire vertical space, neutral grey background with professional studio lighting."
200
+ input_images = [reference_body, reference_face]
201
+ else:
202
+ prompt = f"Generate a formal portrait view of this {gender_term}{costume_instruction_face} as depicted in the reference images, in front of a neutral grey background with professional studio lighting. The face should fill the entire vertical space. Maintain exact facial features from the reference."
203
+ input_images = [reference_face, reference_body]
204
+ # Only include costume in face views if flag is set (smarter models)
205
+ if costume_image and include_costume_in_faces:
206
+ input_images.append(costume_image)
207
+
208
+ front_face, status = self._generate_stage(
209
+ prompt=prompt,
210
+ input_images=input_images,
211
+ aspect_ratio=self.config.face_aspect_ratio,
212
+ temperature=self.config.face_temperature
213
+ )
214
+
215
+ if front_face is None:
216
+ return None, f"Stage 1 failed: {status}", {}
217
+
218
+ stages['front_face'] = front_face
219
+ notify_stage_complete('front_face', front_face)
220
+ time.sleep(1)
221
+
222
+ # Stage 2: Left face profile
223
+ update_progress(2, "Generating left face profile...")
224
+
225
+ prompt = f"Create a left side profile view (90 degrees) of this {gender_term}'s face{costume_instruction_face}, showing the left side of the face filling the frame. Professional studio lighting against a neutral grey background. Maintain exact facial features from the reference."
226
+
227
+ input_images = [front_face, reference_body]
228
+ if input_type == "Face + Body (Separate)":
229
+ input_images.append(reference_face)
230
+ elif costume_image and include_costume_in_faces:
231
+ # Only include costume in face views if flag is set (smarter models)
232
+ input_images.append(costume_image)
233
+
234
+ left_face, status = self._generate_stage(
235
+ prompt=prompt,
236
+ input_images=input_images,
237
+ aspect_ratio=self.config.face_aspect_ratio,
238
+ temperature=self.config.face_temperature
239
+ )
240
+
241
+ if left_face is None:
242
+ return None, f"Stage 2 failed: {status}", {}
243
+
244
+ stages['left_face'] = left_face
245
+ notify_stage_complete('left_face', left_face)
246
+ time.sleep(1)
247
+
248
+ # Stage 3: Right face profile
249
+ update_progress(3, "Generating right face profile...")
250
+
251
+ prompt = f"Create a right side profile view (90 degrees) of this {gender_term}'s face{costume_instruction_face}, showing the right side of the face filling the frame. Professional studio lighting against a neutral grey background. Maintain exact facial features from the reference."
252
+
253
+ input_images = [front_face, reference_body]
254
+ if input_type == "Face + Body (Separate)":
255
+ input_images.append(reference_face)
256
+ elif costume_image and include_costume_in_faces:
257
+ # Only include costume in face views if flag is set (smarter models)
258
+ input_images.append(costume_image)
259
+
260
+ right_face, status = self._generate_stage(
261
+ prompt=prompt,
262
+ input_images=input_images,
263
+ aspect_ratio=self.config.face_aspect_ratio,
264
+ temperature=self.config.face_temperature
265
+ )
266
+
267
+ if right_face is None:
268
+ return None, f"Stage 3 failed: {status}", {}
269
+
270
+ stages['right_face'] = right_face
271
+ notify_stage_complete('right_face', right_face)
272
+ time.sleep(1)
273
+
274
+ # =================================================================
275
+ # BODY VIEWS (4 views)
276
+ # =================================================================
277
+
278
+ # Stage 4: Front body (use normalized reference)
279
+ update_progress(4, "Using front body from normalized reference...")
280
+ front_body = reference_body
281
+ stages['front_body'] = front_body
282
+ notify_stage_complete('front_body', front_body)
283
+ time.sleep(1)
284
+
285
+ # Stage 5: Back body
286
+ update_progress(5, "Generating back full body...")
287
+
288
+ prompt = f"Generate a rear view image of this {gender_term}{costume_instruction_body} showing the back in a neutral standing pose against a neutral grey background with professional studio lighting. The full body should fill the vertical space. Maintain consistent appearance from the reference images."
289
+
290
+ input_images = [reference_body, front_face]
291
+ if costume_image:
292
+ input_images.append(costume_image)
293
+
294
+ back_body, status = self._generate_stage(
295
+ prompt=prompt,
296
+ input_images=input_images,
297
+ aspect_ratio=self.config.body_aspect_ratio,
298
+ temperature=self.config.body_temperature
299
+ )
300
+
301
+ if back_body is None:
302
+ return None, f"Stage 5 failed: {status}", {}
303
+
304
+ stages['back_body'] = back_body
305
+ notify_stage_complete('back_body', back_body)
306
+ time.sleep(1)
307
+
308
+ # Stage 6: Left side body
309
+ update_progress(6, "Generating left side full body...")
310
+
311
+ prompt = f"Generate a left side view of the full body of this {gender_term}{costume_instruction_body} in front of a neutral grey background. The {gender_term} should be shown from the left side (90 degree angle) in a neutral standing pose. Full body fills vertical space. Professional studio lighting."
312
+
313
+ input_images = [left_face, front_body, reference_body]
314
+
315
+ left_body, status = self._generate_stage(
316
+ prompt=prompt,
317
+ input_images=input_images,
318
+ aspect_ratio=self.config.body_aspect_ratio,
319
+ temperature=self.config.body_temperature
320
+ )
321
+
322
+ if left_body is None:
323
+ return None, f"Stage 6 failed: {status}", {}
324
+
325
+ stages['left_body'] = left_body
326
+ notify_stage_complete('left_body', left_body)
327
+ time.sleep(1)
328
+
329
+ # Stage 7: Right side body
330
+ update_progress(7, "Generating right side full body...")
331
+
332
+ prompt = f"Generate a right side view of the full body of this {gender_term}{costume_instruction_body} in front of a neutral grey background. The {gender_term} should be shown from the right side (90 degree angle) in a neutral standing pose. Full body fills vertical space. Professional studio lighting."
333
+
334
+ input_images = [right_face, front_body, reference_body]
335
+
336
+ right_body, status = self._generate_stage(
337
+ prompt=prompt,
338
+ input_images=input_images,
339
+ aspect_ratio=self.config.body_aspect_ratio,
340
+ temperature=self.config.body_temperature
341
+ )
342
+
343
+ if right_body is None:
344
+ return None, f"Stage 7 failed: {status}", {}
345
+
346
+ stages['right_body'] = right_body
347
+ notify_stage_complete('right_body', right_body)
348
+ time.sleep(1)
349
+
350
+ # =================================================================
351
+ # Stage 8: Composite character sheet
352
+ # =================================================================
353
+ update_progress(8, "Compositing character sheet...")
354
+
355
+ character_sheet = self.composite_character_sheet(
356
+ left_face=left_face,
357
+ front_face=front_face,
358
+ right_face=right_face,
359
+ left_body=left_body,
360
+ front_body=front_body,
361
+ right_body=right_body,
362
+ back_body=back_body
363
+ )
364
+
365
+ stages['character_sheet'] = character_sheet
366
+
367
+ # Build metadata
368
+ metadata = CharacterSheetMetadata(
369
+ character_name=character_name,
370
+ input_type=input_type,
371
+ costume_description=costume_description,
372
+ backend=self.router.get_active_backend_name(),
373
+ stages={
374
+ "left_face": {"size": left_face.size},
375
+ "front_face": {"size": front_face.size},
376
+ "right_face": {"size": right_face.size},
377
+ "left_body": {"size": left_body.size},
378
+ "front_body": {"size": front_body.size},
379
+ "right_body": {"size": right_body.size},
380
+ "back_body": {"size": back_body.size},
381
+ }
382
+ )
383
+
384
+ success_msg = f"Character sheet generated! 7 views of {character_name}"
385
+
386
+ # Save to disk if requested
387
+ if output_dir:
388
+ save_dir = self._save_outputs(
389
+ character_name=character_name,
390
+ stages=stages,
391
+ output_dir=output_dir
392
+ )
393
+ success_msg += f"\nSaved to: {save_dir}"
394
+
395
+ update_progress(9, "Complete!")
396
+ return character_sheet, success_msg, {"metadata": metadata, "stages": stages}
397
+
398
+ except Exception as e:
399
+ logger.exception(f"Character sheet generation failed: {e}")
400
+ return None, f"Error: {str(e)}", {}
401
+
402
+ def _normalize_input(
403
+ self,
404
+ initial_image: Optional[Image.Image],
405
+ input_type: str,
406
+ face_image: Optional[Image.Image],
407
+ body_image: Optional[Image.Image],
408
+ costume_instruction: str,
409
+ costume_image: Optional[Image.Image],
410
+ gender_term: str,
411
+ stages: dict,
412
+ progress_callback: Optional[Callable]
413
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
414
+ """Normalize input images to create reference body and face."""
415
+
416
+ if input_type == "Face + Body (Separate)":
417
+ if face_image is None or body_image is None:
418
+ return None, None
419
+
420
+ if progress_callback:
421
+ progress_callback("Normalizing body image...")
422
+
423
+ prompt = f"Front view full body portrait of this person{costume_instruction}, standing, neutral background"
424
+ input_images = [body_image, face_image]
425
+ if costume_image:
426
+ input_images.append(costume_image)
427
+
428
+ normalized_body, _ = self._generate_stage(
429
+ prompt=prompt,
430
+ input_images=input_images,
431
+ aspect_ratio=self.config.body_aspect_ratio,
432
+ temperature=self.config.normalize_temperature
433
+ )
434
+
435
+ if normalized_body is None:
436
+ return None, None
437
+
438
+ stages['normalized_body'] = normalized_body
439
+ return normalized_body, face_image
440
+
441
+ elif input_type == "Face Only":
442
+ if initial_image is None:
443
+ return None, None
444
+
445
+ if progress_callback:
446
+ progress_callback("Generating full body from face...")
447
+
448
+ prompt = f"Create a full body image of the {gender_term}{costume_instruction} standing in a neutral pose in front of a grey background with professional studio lighting. The {gender_term}'s face and features should match the reference image exactly."
449
+
450
+ input_images = [initial_image]
451
+ if costume_image:
452
+ input_images.append(costume_image)
453
+
454
+ full_body, _ = self._generate_stage(
455
+ prompt=prompt,
456
+ input_images=input_images,
457
+ aspect_ratio=self.config.body_aspect_ratio,
458
+ temperature=self.config.normalize_temperature
459
+ )
460
+
461
+ if full_body is None:
462
+ return None, None
463
+
464
+ stages['generated_body'] = full_body
465
+ return full_body, initial_image
466
+
467
+ else: # Full Body
468
+ if initial_image is None:
469
+ return None, None
470
+
471
+ # Normalize body
472
+ if progress_callback:
473
+ progress_callback("Normalizing full body...")
474
+
475
+ prompt = f"Front view full body portrait of this person{costume_instruction}, standing, neutral background"
476
+
477
+ input_images = [initial_image]
478
+ if costume_image:
479
+ input_images.append(costume_image)
480
+
481
+ normalized_body, _ = self._generate_stage(
482
+ prompt=prompt,
483
+ input_images=input_images,
484
+ aspect_ratio=self.config.body_aspect_ratio,
485
+ temperature=self.config.normalize_temperature
486
+ )
487
+
488
+ if normalized_body is None:
489
+ return None, None
490
+
491
+ stages['normalized_body'] = normalized_body
492
+ time.sleep(1)
493
+
494
+ # Extract face
495
+ if progress_callback:
496
+ progress_callback("Generating face closeup...")
497
+
498
+ prompt = f"Create a frontal closeup portrait of this {gender_term}'s face{costume_instruction}, focusing only on the face and head. Professional studio lighting against a neutral grey background. The face should fill the entire vertical space. Maintain exact facial features from the reference."
499
+
500
+ input_images = [normalized_body, initial_image]
501
+ if costume_image:
502
+ input_images.append(costume_image)
503
+
504
+ face_closeup, _ = self._generate_stage(
505
+ prompt=prompt,
506
+ input_images=input_images,
507
+ aspect_ratio=self.config.face_aspect_ratio,
508
+ temperature=self.config.face_temperature
509
+ )
510
+
511
+ if face_closeup is None:
512
+ return None, None
513
+
514
+ stages['extracted_face'] = face_closeup
515
+ return normalized_body, face_closeup
516
+
517
+ def _generate_stage(
518
+ self,
519
+ prompt: str,
520
+ input_images: List[Image.Image],
521
+ aspect_ratio: str,
522
+ temperature: float,
523
+ max_retries: int = 3
524
+ ) -> Tuple[Optional[Image.Image], str]:
525
+ """Generate single stage with retry logic."""
526
+
527
+ modified_prompt = prompt
528
+ cfg = self.config
529
+
530
+ # Preprocess images for the current backend
531
+ backend_type = self.backend.value if self.backend else "unknown"
532
+ processed_images = preprocess_images_for_backend(
533
+ input_images, backend_type, aspect_ratio
534
+ )
535
+ logger.info(f"Preprocessed {len(processed_images)} images for {backend_type}")
536
+
537
+ for attempt in range(max_retries):
538
+ try:
539
+ if attempt > 0:
540
+ wait_time = cfg.retry_delay
541
+ logger.info(f"Retry {attempt + 1}/{max_retries}, waiting {wait_time}s...")
542
+ time.sleep(wait_time)
543
+
544
+ request = GenerationRequest(
545
+ prompt=modified_prompt,
546
+ input_images=processed_images,
547
+ aspect_ratio=aspect_ratio,
548
+ temperature=temperature
549
+ )
550
+
551
+ result = self.client.generate(request)
552
+
553
+ if result.success:
554
+ delay = random.uniform(cfg.rate_limit_delay_min, cfg.rate_limit_delay_max)
555
+ time.sleep(delay)
556
+ return result.image, result.message
557
+
558
+ # Check for safety block
559
+ error_upper = result.message.upper()
560
+ if any(kw in error_upper for kw in ['SAFETY', 'BLOCKED', 'PROHIBITED', 'IMAGE_OTHER']):
561
+ if 'wearing' not in modified_prompt.lower():
562
+ if 'body' in modified_prompt.lower():
563
+ modified_prompt = prompt + ", fully clothed in casual wear"
564
+ else:
565
+ modified_prompt = prompt + ", wearing appropriate clothing"
566
+ logger.info("Modified prompt to avoid safety filters")
567
+
568
+ logger.warning(f"Attempt {attempt + 1} failed: {result.message}")
569
+
570
+ except Exception as e:
571
+ logger.error(f"Attempt {attempt + 1} exception: {e}")
572
+ if attempt == max_retries - 1:
573
+ return None, str(e)
574
+
575
+ return None, f"All {max_retries} attempts failed"
576
+
577
+ def composite_character_sheet(
578
+ self,
579
+ left_face: Image.Image,
580
+ front_face: Image.Image,
581
+ right_face: Image.Image,
582
+ left_body: Image.Image,
583
+ front_body: Image.Image,
584
+ right_body: Image.Image,
585
+ back_body: Image.Image
586
+ ) -> Image.Image:
587
+ """
588
+ Composite all 7 views into character sheet.
589
+
590
+ Layout:
591
+ +------------------+------------------+------------------+
592
+ | Left Face Profile| Front Face | Right Face Profile|
593
+ +------------------+------------------+------------------+
594
+ | Left Side Body | Front Body | Right Side Body | Back Body |
595
+ +------------------+------------------+------------------+
596
+ """
597
+ # Normalize all inputs
598
+ left_face = ensure_pil_image(left_face, "left_face")
599
+ front_face = ensure_pil_image(front_face, "front_face")
600
+ right_face = ensure_pil_image(right_face, "right_face")
601
+ left_body = ensure_pil_image(left_body, "left_body")
602
+ front_body = ensure_pil_image(front_body, "front_body")
603
+ right_body = ensure_pil_image(right_body, "right_body")
604
+ back_body = ensure_pil_image(back_body, "back_body")
605
+
606
+ spacing = self.config.spacing
607
+
608
+ # Calculate dimensions
609
+ face_row_width = left_face.width + front_face.width + right_face.width
610
+ body_row_width = left_body.width + front_body.width + right_body.width + back_body.width
611
+ canvas_width = max(face_row_width, body_row_width)
612
+ canvas_height = front_face.height + spacing + front_body.height
613
+
614
+ # Create canvas
615
+ canvas = Image.new('RGB', (canvas_width, canvas_height), color=self.config.background_color)
616
+
617
+ # Upper row: 3 face portraits
618
+ x = 0
619
+ canvas.paste(left_face, (x, 0))
620
+ x += left_face.width
621
+ canvas.paste(front_face, (x, 0))
622
+ x += front_face.width
623
+ canvas.paste(right_face, (x, 0))
624
+
625
+ # Lower row: 4 body views
626
+ x = 0
627
+ y = front_face.height + spacing
628
+ canvas.paste(left_body, (x, y))
629
+ x += left_body.width
630
+ canvas.paste(front_body, (x, y))
631
+ x += front_body.width
632
+ canvas.paste(right_body, (x, y))
633
+ x += right_body.width
634
+ canvas.paste(back_body, (x, y))
635
+
636
+ return canvas
637
+
638
+ def extract_views_from_sheet(
639
+ self,
640
+ character_sheet: Image.Image
641
+ ) -> Dict[str, Image.Image]:
642
+ """
643
+ Extract individual views from character sheet.
644
+
645
+ Returns:
646
+ Dictionary with 7 extracted views
647
+ """
648
+ sheet_width, sheet_height = character_sheet.size
649
+ spacing = self.config.spacing
650
+
651
+ # Find separator by scanning for dark bar
652
+ scan_start = sheet_height // 3
653
+ scan_end = (2 * sheet_height) // 3
654
+
655
+ min_brightness = 255
656
+ separator_y = scan_start
657
+
658
+ for y in range(scan_start, scan_end):
659
+ line = character_sheet.crop((0, y, min(200, sheet_width), y + 1))
660
+ pixels = list(line.getdata())
661
+ avg_brightness = sum(
662
+ sum(p[:3]) / 3 if isinstance(p, tuple) else p
663
+ for p in pixels
664
+ ) / len(pixels)
665
+
666
+ if avg_brightness < min_brightness:
667
+ min_brightness = avg_brightness
668
+ separator_y = y
669
+
670
+ face_height = separator_y
671
+ body_start_y = separator_y + spacing
672
+ body_height = sheet_height - body_start_y
673
+
674
+ # Calculate widths from aspect ratios
675
+ face_width = (face_height * 3) // 4
676
+ body_width = (body_height * 9) // 16
677
+
678
+ # Extract views
679
+ views = {
680
+ 'left_face': character_sheet.crop((0, 0, face_width, face_height)),
681
+ 'front_face': character_sheet.crop((face_width, 0, 2 * face_width, face_height)),
682
+ 'right_face': character_sheet.crop((2 * face_width, 0, 3 * face_width, face_height)),
683
+ 'left_body': character_sheet.crop((0, body_start_y, body_width, body_start_y + body_height)),
684
+ 'front_body': character_sheet.crop((body_width, body_start_y, 2 * body_width, body_start_y + body_height)),
685
+ 'right_body': character_sheet.crop((2 * body_width, body_start_y, 3 * body_width, body_start_y + body_height)),
686
+ 'back_body': character_sheet.crop((3 * body_width, body_start_y, 4 * body_width, body_start_y + body_height)),
687
+ }
688
+
689
+ return views
690
+
691
+ def _save_outputs(
692
+ self,
693
+ character_name: str,
694
+ stages: dict,
695
+ output_dir: Path
696
+ ) -> Path:
697
+ """Save all outputs to directory."""
698
+ output_dir = Path(output_dir)
699
+ safe_name = sanitize_filename(character_name)
700
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
701
+ char_dir = output_dir / f"{safe_name}_{timestamp}"
702
+ char_dir.mkdir(parents=True, exist_ok=True)
703
+
704
+ for name, image in stages.items():
705
+ if isinstance(image, Image.Image):
706
+ save_image(image, char_dir, f"{safe_name}_{name}")
707
+
708
+ logger.info(f"Saved outputs to: {char_dir}")
709
+ return char_dir
src/comfyui_client.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ComfyUI Client for Qwen-Image-Edit-2511
3
+ ========================================
4
+
5
+ Client to interact with ComfyUI API for running Qwen-Image-Edit-2511.
6
+
7
+ Model setup (download from HuggingFace):
8
+
9
+ Lightning (default, 4-step):
10
+ diffusion_models/ qwen_image_edit_2511_fp8_e4m3fn_scaled_lightning_comfyui_4steps_v1.0.safetensors
11
+ (lightx2v/Qwen-Image-Edit-2511-Lightning)
12
+
13
+ Standard (20-step, optional):
14
+ diffusion_models/ qwen_image_edit_2511_fp8mixed.safetensors
15
+ (Comfy-Org/Qwen-Image-Edit_ComfyUI)
16
+
17
+ Shared:
18
+ text_encoders/ qwen_2.5_vl_7b_fp8_scaled.safetensors (Comfy-Org/Qwen-Image_ComfyUI)
19
+ vae/ qwen_image_vae.safetensors (Comfy-Org/Qwen-Image_ComfyUI)
20
+
21
+ Required custom nodes:
22
+ - Comfyui-QwenEditUtils (lrzjason) for TextEncodeQwenImageEditPlus
23
+ """
24
+
25
+ import logging
26
+ import time
27
+ import uuid
28
+ import json
29
+ import io
30
+ import base64
31
+ from typing import Optional, List, Tuple
32
+ from PIL import Image
33
+ import websocket
34
+ import urllib.request
35
+ import urllib.parse
36
+
37
+ from .models import GenerationRequest, GenerationResult
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class ComfyUIClient:
44
+ """
45
+ Client for ComfyUI API to run Qwen-Image-Edit-2511.
46
+
47
+ Requires ComfyUI running with:
48
+ - Qwen-Image-Edit-2511 model in models/diffusion_models/
49
+ - Qwen 2.5 VL 7B text encoder in models/text_encoders/
50
+ - Qwen Image VAE in models/vae/
51
+ - Comfyui-QwenEditUtils custom node installed
52
+ """
53
+
54
+ # Default ComfyUI settings
55
+ DEFAULT_HOST = "127.0.0.1"
56
+ DEFAULT_PORT = 8188
57
+
58
+ # Model file names (expected in ComfyUI models/ subfolders)
59
+ # Lightning: baked model (LoRA pre-merged, ComfyUI-specific format)
60
+ UNET_MODEL_LIGHTNING = "qwen_image_edit_2511_fp8_e4m3fn_scaled_lightning_comfyui_4steps_v1.0.safetensors"
61
+ # Standard: base fp8mixed model (20-step, higher quality)
62
+ UNET_MODEL_STANDARD = "qwen_image_edit_2511_fp8mixed.safetensors"
63
+ TEXT_ENCODER = "qwen_2.5_vl_7b_fp8_scaled.safetensors"
64
+ VAE_MODEL = "qwen_image_vae.safetensors"
65
+
66
+ # Target output dimensions per aspect ratio.
67
+ # Generation happens at 1024x1024, then crop+resize to these.
68
+ ASPECT_RATIOS = {
69
+ "1:1": (1024, 1024),
70
+ "16:9": (1344, 768),
71
+ "9:16": (768, 1344),
72
+ "21:9": (1680, 720),
73
+ "3:2": (1248, 832),
74
+ "2:3": (832, 1248),
75
+ "3:4": (896, 1152),
76
+ "4:3": (1152, 896),
77
+ "4:5": (1024, 1280),
78
+ "5:4": (1280, 1024),
79
+ }
80
+
81
+ # Generate at 1024x1024 (proven safe for Qwen's VAE), then crop+resize
82
+ NATIVE_RESOLUTION = (1024, 1024)
83
+
84
+ # With Lightning LoRA: 4 steps, CFG 1.0 (fast, ~seconds per view)
85
+ # Without LoRA: 20 steps, CFG 4.0
86
+ DEFAULT_STEPS_LIGHTNING = 4
87
+ DEFAULT_STEPS_STANDARD = 20
88
+ DEFAULT_CFG_LIGHTNING = 1.0
89
+ DEFAULT_CFG_STANDARD = 4.0
90
+
91
+ def __init__(
92
+ self,
93
+ host: str = DEFAULT_HOST,
94
+ port: int = DEFAULT_PORT,
95
+ use_lightning: bool = True,
96
+ ):
97
+ """
98
+ Initialize ComfyUI client.
99
+
100
+ Args:
101
+ host: ComfyUI server host
102
+ port: ComfyUI server port
103
+ use_lightning: Use Lightning LoRA for 4-step generation (much faster)
104
+ """
105
+ self.host = host
106
+ self.port = port
107
+ self.use_lightning = use_lightning
108
+ self.client_id = str(uuid.uuid4())
109
+ self.server_address = f"{host}:{port}"
110
+
111
+ if use_lightning:
112
+ self.num_inference_steps = self.DEFAULT_STEPS_LIGHTNING
113
+ self.cfg_scale = self.DEFAULT_CFG_LIGHTNING
114
+ else:
115
+ self.num_inference_steps = self.DEFAULT_STEPS_STANDARD
116
+ self.cfg_scale = self.DEFAULT_CFG_STANDARD
117
+
118
+ logger.info(
119
+ f"ComfyUIClient initialized for {self.server_address} "
120
+ f"(lightning={use_lightning}, steps={self.num_inference_steps})"
121
+ )
122
+
123
+ def is_healthy(self) -> bool:
124
+ """Check if ComfyUI server is running and accessible."""
125
+ try:
126
+ url = f"http://{self.server_address}/system_stats"
127
+ with urllib.request.urlopen(url, timeout=5) as response:
128
+ return response.status == 200
129
+ except Exception:
130
+ return False
131
+
132
+ def _upload_image(self, image: Image.Image, name: str = "input.png") -> Optional[str]:
133
+ """
134
+ Upload an image to ComfyUI, pre-resized to fit within 1024x1024.
135
+
136
+ Args:
137
+ image: PIL Image to upload
138
+ name: Filename for the uploaded image
139
+
140
+ Returns:
141
+ Filename on server, or None if failed
142
+ """
143
+ try:
144
+ # Pre-resize to keep total pixels around 1024x1024 (matching reference workflow)
145
+ max_pixels = 1024 * 1024
146
+ w, h = image.size
147
+ if w * h > max_pixels:
148
+ scale = (max_pixels / (w * h)) ** 0.5
149
+ new_w = int(w * scale)
150
+ new_h = int(h * scale)
151
+ image = image.resize((new_w, new_h), Image.LANCZOS)
152
+ logger.debug(f"Pre-resized input from {w}x{h} to {new_w}x{new_h}")
153
+
154
+ # Convert image to bytes
155
+ img_bytes = io.BytesIO()
156
+ image.save(img_bytes, format='PNG')
157
+ img_bytes.seek(0)
158
+
159
+ # Create multipart form data
160
+ boundary = uuid.uuid4().hex
161
+
162
+ body = b''
163
+ body += f'--{boundary}\r\n'.encode()
164
+ body += f'Content-Disposition: form-data; name="image"; filename="{name}"\r\n'.encode()
165
+ body += b'Content-Type: image/png\r\n\r\n'
166
+ body += img_bytes.read()
167
+ body += f'\r\n--{boundary}--\r\n'.encode()
168
+
169
+ url = f"http://{self.server_address}/upload/image"
170
+ req = urllib.request.Request(
171
+ url,
172
+ data=body,
173
+ headers={
174
+ 'Content-Type': f'multipart/form-data; boundary={boundary}'
175
+ }
176
+ )
177
+
178
+ with urllib.request.urlopen(req) as response:
179
+ result = json.loads(response.read())
180
+ return result.get('name')
181
+
182
+ except Exception as e:
183
+ logger.error(f"Failed to upload image: {e}")
184
+ return None
185
+
186
+ def _queue_prompt(self, prompt: dict) -> str:
187
+ """
188
+ Queue a prompt for execution.
189
+
190
+ Args:
191
+ prompt: Workflow prompt dict
192
+
193
+ Returns:
194
+ Prompt ID
195
+ """
196
+ prompt_id = str(uuid.uuid4())
197
+ p = {"prompt": prompt, "client_id": self.client_id, "prompt_id": prompt_id}
198
+ data = json.dumps(p).encode('utf-8')
199
+
200
+ url = f"http://{self.server_address}/prompt"
201
+ req = urllib.request.Request(url, data=data)
202
+ urllib.request.urlopen(req)
203
+
204
+ return prompt_id
205
+
206
+ def _get_history(self, prompt_id: str) -> dict:
207
+ """Get execution history for a prompt."""
208
+ url = f"http://{self.server_address}/history/{prompt_id}"
209
+ with urllib.request.urlopen(url) as response:
210
+ return json.loads(response.read())
211
+
212
+ def _get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
213
+ """Get an image from ComfyUI."""
214
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
215
+ url_values = urllib.parse.urlencode(data)
216
+ url = f"http://{self.server_address}/view?{url_values}"
217
+ with urllib.request.urlopen(url) as response:
218
+ return response.read()
219
+
220
+ def _wait_for_completion(self, prompt_id: str, timeout: float = 900.0) -> bool:
221
+ """
222
+ Wait for prompt execution to complete using websocket.
223
+
224
+ Args:
225
+ prompt_id: The prompt ID to wait for
226
+ timeout: Maximum time to wait in seconds (default 15 min for image editing)
227
+
228
+ Returns:
229
+ True if completed successfully, False if timeout/error
230
+ """
231
+ ws = None
232
+ try:
233
+ ws_url = f"ws://{self.server_address}/ws?clientId={self.client_id}"
234
+ ws = websocket.WebSocket()
235
+ ws.settimeout(timeout)
236
+ ws.connect(ws_url)
237
+
238
+ start_time = time.time()
239
+ while time.time() - start_time < timeout:
240
+ try:
241
+ out = ws.recv()
242
+ if isinstance(out, str):
243
+ message = json.loads(out)
244
+ if message['type'] == 'executing':
245
+ data = message['data']
246
+ if data['node'] is None and data['prompt_id'] == prompt_id:
247
+ return True # Execution complete
248
+ elif message['type'] == 'execution_error':
249
+ logger.error(f"Execution error: {message}")
250
+ return False
251
+ except websocket.WebSocketTimeoutException:
252
+ continue
253
+
254
+ logger.error("Timeout waiting for completion")
255
+ return False
256
+
257
+ except Exception as e:
258
+ logger.error(f"WebSocket error: {e}")
259
+ return False
260
+ finally:
261
+ if ws:
262
+ try:
263
+ ws.close()
264
+ except:
265
+ pass
266
+
267
+ def _get_dimensions(self, aspect_ratio: str) -> Tuple[int, int]:
268
+ """Get pixel dimensions for aspect ratio."""
269
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
270
+ return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
271
+
272
+ @staticmethod
273
+ def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
274
+ """Crop to target aspect ratio, then resize. Centers the crop."""
275
+ src_w, src_h = image.size
276
+ target_ratio = target_w / target_h
277
+ src_ratio = src_w / src_h
278
+
279
+ if abs(target_ratio - src_ratio) < 0.01:
280
+ return image.resize((target_w, target_h), Image.LANCZOS)
281
+
282
+ if target_ratio < src_ratio:
283
+ crop_w = int(src_h * target_ratio)
284
+ offset = (src_w - crop_w) // 2
285
+ image = image.crop((offset, 0, offset + crop_w, src_h))
286
+ else:
287
+ crop_h = int(src_w / target_ratio)
288
+ offset = (src_h - crop_h) // 2
289
+ image = image.crop((0, offset, src_w, offset + crop_h))
290
+
291
+ return image.resize((target_w, target_h), Image.LANCZOS)
292
+
293
+ def _build_workflow(
294
+ self,
295
+ prompt: str,
296
+ width: int,
297
+ height: int,
298
+ input_images: List[str] = None,
299
+ negative_prompt: str = ""
300
+ ) -> dict:
301
+ """
302
+ Build the ComfyUI workflow for Qwen-Image-Edit-2511.
303
+
304
+ Workflow graph:
305
+ UNETLoader → KSampler
306
+ CLIPLoader → TextEncodeQwenImageEditPlus (pos/neg)
307
+ VAELoader → TextEncode + VAEDecode
308
+ LoadImage(s) → TextEncodeQwenImageEditPlus
309
+ EmptyQwenImageLayeredLatentImage → KSampler
310
+ KSampler → VAEDecode → PreviewImage
311
+
312
+ Lightning mode uses a baked model (LoRA pre-merged), no separate
313
+ LoRA or ModelSamplingAuraFlow nodes needed.
314
+ """
315
+ workflow = {}
316
+ node_id = 1
317
+
318
+ # --- Model loading ---
319
+
320
+ # Select model based on lightning mode
321
+ unet_name = (self.UNET_MODEL_LIGHTNING if self.use_lightning
322
+ else self.UNET_MODEL_STANDARD)
323
+
324
+ # UNETLoader - weight_dtype "default" lets ComfyUI auto-detect fp8
325
+ unet_id = str(node_id)
326
+ workflow[unet_id] = {
327
+ "class_type": "UNETLoader",
328
+ "inputs": {
329
+ "unet_name": unet_name,
330
+ "weight_dtype": "default"
331
+ }
332
+ }
333
+ node_id += 1
334
+
335
+ # CLIPLoader
336
+ clip_id = str(node_id)
337
+ workflow[clip_id] = {
338
+ "class_type": "CLIPLoader",
339
+ "inputs": {
340
+ "clip_name": self.TEXT_ENCODER,
341
+ "type": "qwen_image"
342
+ }
343
+ }
344
+ node_id += 1
345
+
346
+ # VAELoader
347
+ vae_id = str(node_id)
348
+ workflow[vae_id] = {
349
+ "class_type": "VAELoader",
350
+ "inputs": {
351
+ "vae_name": self.VAE_MODEL
352
+ }
353
+ }
354
+ node_id += 1
355
+
356
+ model_out_id = unet_id
357
+
358
+ # --- Input images ---
359
+
360
+ image_loader_ids = []
361
+ if input_images:
362
+ for img_name in input_images[:3]: # Max 3 reference images
363
+ img_loader_id = str(node_id)
364
+ workflow[img_loader_id] = {
365
+ "class_type": "LoadImage",
366
+ "inputs": {
367
+ "image": img_name
368
+ }
369
+ }
370
+ image_loader_ids.append(img_loader_id)
371
+ node_id += 1
372
+
373
+ # --- Text encoding ---
374
+
375
+ # Positive: prompt + vision references + VAE
376
+ pos_encode_id = str(node_id)
377
+ pos_inputs = {
378
+ "clip": [clip_id, 0],
379
+ "prompt": prompt,
380
+ "vae": [vae_id, 0]
381
+ }
382
+ for i, loader_id in enumerate(image_loader_ids):
383
+ pos_inputs[f"image{i+1}"] = [loader_id, 0]
384
+
385
+ workflow[pos_encode_id] = {
386
+ "class_type": "TextEncodeQwenImageEditPlus",
387
+ "inputs": pos_inputs
388
+ }
389
+ node_id += 1
390
+
391
+ # Negative: text only, no images
392
+ neg_encode_id = str(node_id)
393
+ workflow[neg_encode_id] = {
394
+ "class_type": "TextEncodeQwenImageEditPlus",
395
+ "inputs": {
396
+ "clip": [clip_id, 0],
397
+ "prompt": negative_prompt or " ",
398
+ "vae": [vae_id, 0]
399
+ }
400
+ }
401
+ node_id += 1
402
+
403
+ # --- Latent + sampling ---
404
+
405
+ latent_id = str(node_id)
406
+ workflow[latent_id] = {
407
+ "class_type": "EmptySD3LatentImage",
408
+ "inputs": {
409
+ "width": width,
410
+ "height": height,
411
+ "batch_size": 1
412
+ }
413
+ }
414
+ node_id += 1
415
+
416
+ sampler_id = str(node_id)
417
+ workflow[sampler_id] = {
418
+ "class_type": "KSampler",
419
+ "inputs": {
420
+ "model": [model_out_id, 0],
421
+ "positive": [pos_encode_id, 0],
422
+ "negative": [neg_encode_id, 0],
423
+ "latent_image": [latent_id, 0],
424
+ "seed": int(time.time()) % 2**32,
425
+ "steps": self.num_inference_steps,
426
+ "cfg": self.cfg_scale,
427
+ "sampler_name": "euler",
428
+ "scheduler": "simple",
429
+ "denoise": 1.0
430
+ }
431
+ }
432
+ node_id += 1
433
+
434
+ # --- Decode + output ---
435
+
436
+ decode_id = str(node_id)
437
+ workflow[decode_id] = {
438
+ "class_type": "VAEDecode",
439
+ "inputs": {
440
+ "samples": [sampler_id, 0],
441
+ "vae": [vae_id, 0]
442
+ }
443
+ }
444
+ node_id += 1
445
+
446
+ preview_id = str(node_id)
447
+ workflow[preview_id] = {
448
+ "class_type": "PreviewImage",
449
+ "inputs": {
450
+ "images": [decode_id, 0]
451
+ }
452
+ }
453
+
454
+ return workflow
455
+
456
+ def generate(
457
+ self,
458
+ request: GenerationRequest,
459
+ num_inference_steps: Optional[int] = None,
460
+ cfg_scale: Optional[float] = None
461
+ ) -> GenerationResult:
462
+ """
463
+ Generate/edit image using Qwen-Image-Edit-2511 via ComfyUI.
464
+
465
+ Generates at native 1024x1024, then crop+resize to requested
466
+ aspect ratio for clean VAE output.
467
+ """
468
+ if not self.is_healthy():
469
+ return GenerationResult.error_result(
470
+ "ComfyUI server is not accessible. Make sure ComfyUI is running on "
471
+ f"{self.server_address}"
472
+ )
473
+
474
+ try:
475
+ start_time = time.time()
476
+
477
+ # Target dimensions for post-processing
478
+ target_w, target_h = self._get_dimensions(request.aspect_ratio)
479
+ # Generate at native resolution (VAE-safe)
480
+ native_w, native_h = self.NATIVE_RESOLUTION
481
+
482
+ # Upload input images (max 3)
483
+ uploaded_images = []
484
+ if request.has_input_images:
485
+ for i, img in enumerate(request.input_images):
486
+ if img is not None:
487
+ name = f"input_{i}_{uuid.uuid4().hex[:8]}.png"
488
+ uploaded_name = self._upload_image(img, name)
489
+ if uploaded_name:
490
+ uploaded_images.append(uploaded_name)
491
+ else:
492
+ logger.warning(f"Failed to upload image {i}")
493
+
494
+ steps = num_inference_steps or self.num_inference_steps
495
+ cfg = cfg_scale or self.cfg_scale
496
+
497
+ # Temporarily set for workflow build
498
+ old_steps, old_cfg = self.num_inference_steps, self.cfg_scale
499
+ self.num_inference_steps, self.cfg_scale = steps, cfg
500
+
501
+ workflow = self._build_workflow(
502
+ prompt=request.prompt,
503
+ width=native_w,
504
+ height=native_h,
505
+ input_images=uploaded_images or None,
506
+ negative_prompt=request.negative_prompt or ""
507
+ )
508
+
509
+ self.num_inference_steps, self.cfg_scale = old_steps, old_cfg
510
+
511
+ logger.info(f"Generating with ComfyUI/Qwen: {request.prompt[:80]}...")
512
+ logger.info(
513
+ f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}, "
514
+ f"steps: {steps}, cfg: {cfg}, images: {len(uploaded_images)}, "
515
+ f"lightning: {self.use_lightning}"
516
+ )
517
+
518
+ # Queue and wait
519
+ prompt_id = self._queue_prompt(workflow)
520
+ logger.info(f"Queued prompt: {prompt_id}")
521
+
522
+ if not self._wait_for_completion(prompt_id):
523
+ return GenerationResult.error_result("Generation failed or timed out")
524
+
525
+ # Retrieve output
526
+ history = self._get_history(prompt_id)
527
+ if prompt_id not in history:
528
+ return GenerationResult.error_result("No history found for prompt")
529
+
530
+ outputs = history[prompt_id].get('outputs', {})
531
+ for nid, node_output in outputs.items():
532
+ if 'images' in node_output:
533
+ for img_info in node_output['images']:
534
+ img_data = self._get_image(
535
+ img_info['filename'],
536
+ img_info.get('subfolder', ''),
537
+ img_info.get('type', 'temp')
538
+ )
539
+ image = Image.open(io.BytesIO(img_data))
540
+ generation_time = time.time() - start_time
541
+ logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
542
+
543
+ # Crop+resize to target aspect ratio
544
+ if (target_w, target_h) != (native_w, native_h):
545
+ image = self._crop_and_resize(image, target_w, target_h)
546
+ logger.info(f"Post-processed to: {image.size}")
547
+
548
+ return GenerationResult.success_result(
549
+ image=image,
550
+ message=f"Generated with ComfyUI/Qwen in {generation_time:.2f}s",
551
+ generation_time=generation_time
552
+ )
553
+
554
+ return GenerationResult.error_result("No output images found")
555
+
556
+ except Exception as e:
557
+ logger.error(f"ComfyUI generation failed: {e}", exc_info=True)
558
+ return GenerationResult.error_result(f"ComfyUI error: {str(e)}")
559
+
560
+ def unload_model(self):
561
+ """
562
+ Request ComfyUI to free memory.
563
+ Note: ComfyUI manages models automatically, but we can request cleanup.
564
+ """
565
+ try:
566
+ url = f"http://{self.server_address}/free"
567
+ data = json.dumps({"unload_models": True}).encode('utf-8')
568
+ req = urllib.request.Request(url, data=data, method='POST')
569
+ urllib.request.urlopen(req)
570
+ logger.info("Requested ComfyUI to free memory")
571
+ except Exception as e:
572
+ logger.warning(f"Failed to request memory cleanup: {e}")
573
+
574
+ @classmethod
575
+ def get_dimensions(cls, aspect_ratio: str) -> Tuple[int, int]:
576
+ """Get pixel dimensions for aspect ratio."""
577
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
578
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
src/core/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Character Sheet Pro - Core API
3
+ ==============================
4
+
5
+ Exposes core functionality for plugins.
6
+ """
7
+
8
+ from .plugin_base import Plugin, PluginMetadata, PluginAPI
9
+ from .plugin_manager import PluginManager
10
+
11
+ __all__ = [
12
+ 'Plugin',
13
+ 'PluginMetadata',
14
+ 'PluginAPI',
15
+ 'PluginManager'
16
+ ]
src/flux_klein_client.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FLUX.2 Klein Client
3
+ ===================
4
+
5
+ Client for FLUX.2 klein 4B local image generation.
6
+ Supports text-to-image and multi-reference editing.
7
+ """
8
+
9
+ import logging
10
+ import time
11
+ from typing import Optional, List
12
+ from PIL import Image
13
+
14
+ import torch
15
+
16
+ from .models import GenerationRequest, GenerationResult
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class FluxKleinClient:
23
+ """
24
+ Client for FLUX.2 klein models.
25
+
26
+ Supports:
27
+ - Text-to-image generation
28
+ - Single and multi-reference image editing
29
+ - Multiple model sizes (4B, 9B) and variants (distilled, base)
30
+ """
31
+
32
+ # Model variants - choose based on quality/speed tradeoff
33
+ MODELS = {
34
+ # 4B models (~13GB VRAM)
35
+ "4b": "black-forest-labs/FLUX.2-klein-4B", # Distilled, 4 steps
36
+ "4b-base": "black-forest-labs/FLUX.2-klein-base-4B", # Base, configurable steps
37
+ # 9B models (~29GB VRAM, better quality)
38
+ "9b": "black-forest-labs/FLUX.2-klein-9B", # Distilled, 4 steps
39
+ "9b-base": "black-forest-labs/FLUX.2-klein-base-9B", # Base, 50 steps - BEST QUALITY
40
+ "9b-fp8": "black-forest-labs/FLUX.2-klein-9b-fp8", # FP8 quantized (~20GB)
41
+ }
42
+
43
+ # Legacy compatibility
44
+ MODEL_ID = MODELS["4b"]
45
+ MODEL_ID_BASE = MODELS["4b-base"]
46
+
47
+ # Aspect ratio to dimensions mapping
48
+ ASPECT_RATIOS = {
49
+ "1:1": (1024, 1024),
50
+ "16:9": (1344, 768),
51
+ "9:16": (768, 1344),
52
+ "21:9": (1536, 640), # Cinematic ultra-wide
53
+ "3:2": (1248, 832),
54
+ "2:3": (832, 1248),
55
+ "3:4": (896, 1152),
56
+ "4:3": (1152, 896),
57
+ "4:5": (896, 1120),
58
+ "5:4": (1120, 896),
59
+ }
60
+
61
+ # Default settings for each model variant
62
+ MODEL_DEFAULTS = {
63
+ "4b": {"steps": 4, "guidance": 1.0},
64
+ "4b-base": {"steps": 28, "guidance": 3.5},
65
+ "9b": {"steps": 4, "guidance": 1.0},
66
+ "9b-base": {"steps": 50, "guidance": 4.0}, # Best quality
67
+ "9b-fp8": {"steps": 4, "guidance": 4.0},
68
+ }
69
+
70
+ def __init__(
71
+ self,
72
+ model_variant: str = "9b-base", # Default to highest quality
73
+ device: str = "cuda",
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ enable_cpu_offload: bool = True,
76
+ # Legacy params
77
+ use_base_model: bool = False,
78
+ ):
79
+ """
80
+ Initialize FLUX.2 klein client.
81
+
82
+ Args:
83
+ model_variant: Model variant to use:
84
+ - "4b": Fast, 4 steps, ~13GB VRAM
85
+ - "4b-base": Configurable steps, ~13GB VRAM
86
+ - "9b": Better quality, 4 steps, ~29GB VRAM
87
+ - "9b-base": BEST quality, 50 steps, ~29GB VRAM
88
+ - "9b-fp8": FP8 quantized, ~20GB VRAM
89
+ device: Device to use (cuda or cpu)
90
+ dtype: Data type for model weights
91
+ enable_cpu_offload: Enable CPU offload to save VRAM
92
+ """
93
+ # Handle legacy use_base_model parameter
94
+ if use_base_model and model_variant == "9b-base":
95
+ model_variant = "4b-base"
96
+
97
+ self.model_variant = model_variant
98
+ self.device = device
99
+ self.dtype = dtype
100
+ self.enable_cpu_offload = enable_cpu_offload
101
+ self.pipe = None
102
+ self._loaded = False
103
+
104
+ # Get default settings for this variant
105
+ defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 4, "guidance": 1.0})
106
+ self.default_steps = defaults["steps"]
107
+ self.default_guidance = defaults["guidance"]
108
+
109
+ logger.info(f"FluxKleinClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
110
+
111
+ def load_model(self) -> bool:
112
+ """Load the model into memory."""
113
+ if self._loaded:
114
+ return True
115
+
116
+ try:
117
+ # Get model ID for selected variant
118
+ model_id = self.MODELS.get(self.model_variant, self.MODELS["4b"])
119
+ logger.info(f"Loading FLUX.2 klein ({self.model_variant}) from {model_id}...")
120
+
121
+ start_time = time.time()
122
+
123
+ # FLUX.2 klein requires Flux2KleinPipeline (specific to klein models)
124
+ # Requires diffusers from git: pip install git+https://github.com/huggingface/diffusers.git
125
+ from diffusers import Flux2KleinPipeline
126
+
127
+ self.pipe = Flux2KleinPipeline.from_pretrained(
128
+ model_id,
129
+ torch_dtype=self.dtype,
130
+ )
131
+
132
+ # Use enable_model_cpu_offload() for VRAM management (documented approach)
133
+ if self.enable_cpu_offload:
134
+ self.pipe.enable_model_cpu_offload()
135
+ logger.info("CPU offload enabled")
136
+ else:
137
+ self.pipe.to(self.device)
138
+ logger.info(f"Model moved to {self.device}")
139
+
140
+ load_time = time.time() - start_time
141
+ logger.info(f"FLUX.2 klein ({self.model_variant}) loaded in {load_time:.1f}s")
142
+
143
+ # Validate by running a test generation
144
+ logger.info("Validating model with test generation...")
145
+ try:
146
+ test_result = self.pipe(
147
+ prompt="A simple test image",
148
+ height=256,
149
+ width=256,
150
+ guidance_scale=1.0,
151
+ num_inference_steps=1,
152
+ generator=torch.Generator(device="cpu").manual_seed(42),
153
+ )
154
+ if test_result.images[0] is not None:
155
+ logger.info("Model validation successful")
156
+ else:
157
+ logger.error("Model validation failed: no output image")
158
+ return False
159
+ except Exception as e:
160
+ logger.error(f"Model validation failed: {e}", exc_info=True)
161
+ return False
162
+
163
+ self._loaded = True
164
+ return True
165
+
166
+ except Exception as e:
167
+ logger.error(f"Failed to load FLUX.2 klein: {e}", exc_info=True)
168
+ return False
169
+
170
+ def unload_model(self):
171
+ """Unload model from memory."""
172
+ if self.pipe is not None:
173
+ del self.pipe
174
+ self.pipe = None
175
+ self._loaded = False
176
+
177
+ if torch.cuda.is_available():
178
+ torch.cuda.empty_cache()
179
+
180
+ logger.info("FLUX.2 klein unloaded")
181
+
182
+ def generate(
183
+ self,
184
+ request: GenerationRequest,
185
+ num_inference_steps: int = None,
186
+ guidance_scale: float = None
187
+ ) -> GenerationResult:
188
+ """
189
+ Generate image using FLUX.2 klein.
190
+
191
+ Args:
192
+ request: GenerationRequest object
193
+ num_inference_steps: Number of denoising steps (4 for klein distilled)
194
+ guidance_scale: Classifier-free guidance scale
195
+
196
+ Returns:
197
+ GenerationResult object
198
+ """
199
+ if not self._loaded:
200
+ if not self.load_model():
201
+ return GenerationResult.error_result("Failed to load FLUX.2 klein model")
202
+
203
+ # Use model defaults if not specified
204
+ if num_inference_steps is None:
205
+ num_inference_steps = self.default_steps
206
+ if guidance_scale is None:
207
+ guidance_scale = self.default_guidance
208
+
209
+ try:
210
+ start_time = time.time()
211
+
212
+ # Get dimensions from aspect ratio
213
+ width, height = self._get_dimensions(request.aspect_ratio)
214
+
215
+ logger.info(f"Generating with {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
216
+
217
+ # Build generation kwargs
218
+ gen_kwargs = {
219
+ "prompt": request.prompt,
220
+ "height": height,
221
+ "width": width,
222
+ "guidance_scale": guidance_scale,
223
+ "num_inference_steps": num_inference_steps,
224
+ "generator": torch.Generator(device="cpu").manual_seed(42),
225
+ }
226
+
227
+ # Add input images if present (for editing)
228
+ if request.has_input_images:
229
+ # FLUX.2 klein supports multi-reference editing
230
+ # Pass images as 'image' parameter
231
+ valid_images = [img for img in request.input_images if img is not None]
232
+ if len(valid_images) == 1:
233
+ gen_kwargs["image"] = valid_images[0]
234
+ elif len(valid_images) > 1:
235
+ gen_kwargs["image"] = valid_images
236
+
237
+ logger.info(f"Generating with FLUX.2 klein: {request.prompt[:80]}...")
238
+
239
+ # Generate
240
+ with torch.inference_mode():
241
+ output = self.pipe(**gen_kwargs)
242
+ image = output.images[0]
243
+
244
+ generation_time = time.time() - start_time
245
+ logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
246
+
247
+ return GenerationResult.success_result(
248
+ image=image,
249
+ message=f"Generated with FLUX.2 klein in {generation_time:.2f}s",
250
+ generation_time=generation_time
251
+ )
252
+
253
+ except Exception as e:
254
+ logger.error(f"FLUX.2 klein generation failed: {e}", exc_info=True)
255
+ return GenerationResult.error_result(f"FLUX.2 klein error: {str(e)}")
256
+
257
+ def _get_dimensions(self, aspect_ratio: str) -> tuple:
258
+ """Get pixel dimensions for aspect ratio."""
259
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
260
+ return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
261
+
262
+ def is_healthy(self) -> bool:
263
+ """Check if model is loaded and ready."""
264
+ return self._loaded and self.pipe is not None
265
+
266
+ @classmethod
267
+ def get_dimensions(cls, aspect_ratio: str) -> tuple:
268
+ """Get pixel dimensions for aspect ratio."""
269
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
270
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
src/gemini_client.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini API Client
3
+ =================
4
+
5
+ Client for Google Gemini Image APIs (Flash and Pro models).
6
+ Handles API communication and response parsing.
7
+ """
8
+
9
+ import base64
10
+ import logging
11
+ from io import BytesIO
12
+ from typing import Optional
13
+ from PIL import Image
14
+
15
+ from google import genai
16
+ from google.genai import types
17
+
18
+ from .models import GenerationRequest, GenerationResult
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class GeminiClient:
25
+ """
26
+ Client for Gemini Image APIs.
27
+
28
+ Supports:
29
+ - Gemini 2.5 Flash Image (up to ~3 reference images)
30
+ - Gemini 3 Pro Image Preview (up to 14 reference images, 1K/2K/4K)
31
+ """
32
+
33
+ # Model names (updated January 2026)
34
+ # See: https://ai.google.dev/gemini-api/docs/image-generation
35
+ MODEL_FLASH = "gemini-2.5-flash-image" # Fast, efficient image generation
36
+ MODEL_PRO = "gemini-3-pro-image-preview" # Pro quality, advanced text rendering
37
+
38
+ # Valid resolutions for Pro model
39
+ VALID_RESOLUTIONS = ["1K", "2K", "4K"]
40
+
41
+ # Aspect ratio to dimensions mapping
42
+ ASPECT_RATIOS = {
43
+ "1:1": (1024, 1024),
44
+ "16:9": (1344, 768),
45
+ "9:16": (768, 1344),
46
+ "21:9": (1536, 640), # Cinematic ultra-wide
47
+ "3:2": (1248, 832),
48
+ "2:3": (832, 1248),
49
+ "3:4": (864, 1184),
50
+ "4:3": (1344, 1008),
51
+ "4:5": (1024, 1280),
52
+ "5:4": (1280, 1024),
53
+ }
54
+
55
+ def __init__(self, api_key: str, use_pro_model: bool = False):
56
+ """
57
+ Initialize Gemini client.
58
+
59
+ Args:
60
+ api_key: Google Gemini API key
61
+ use_pro_model: If True, use Pro model with enhanced capabilities
62
+ """
63
+ if not api_key:
64
+ raise ValueError("API key is required for Gemini client")
65
+
66
+ self.api_key = api_key
67
+ self.use_pro_model = use_pro_model
68
+ self.client = genai.Client(api_key=api_key)
69
+
70
+ model_name = self.MODEL_PRO if use_pro_model else self.MODEL_FLASH
71
+ logger.info(f"GeminiClient initialized with model: {model_name}")
72
+
73
+ def generate(
74
+ self,
75
+ request: GenerationRequest,
76
+ resolution: str = "1K"
77
+ ) -> GenerationResult:
78
+ """
79
+ Generate image using Gemini API.
80
+
81
+ Args:
82
+ request: GenerationRequest object
83
+ resolution: Resolution for Pro model ("1K", "2K", "4K")
84
+
85
+ Returns:
86
+ GenerationResult object
87
+ """
88
+ try:
89
+ model_name = self.MODEL_PRO if self.use_pro_model else self.MODEL_FLASH
90
+ logger.info(f"Generating with {model_name}: {request.prompt[:100]}...")
91
+
92
+ # Build contents list
93
+ contents = self._build_contents(request)
94
+
95
+ # Build config
96
+ config = self._build_config(
97
+ request,
98
+ resolution if self.use_pro_model else None
99
+ )
100
+
101
+ # Call API
102
+ response = self.client.models.generate_content(
103
+ model=model_name,
104
+ contents=contents,
105
+ config=config
106
+ )
107
+
108
+ # Parse response
109
+ return self._parse_response(response)
110
+
111
+ except Exception as e:
112
+ logger.error(f"Gemini generation failed: {e}", exc_info=True)
113
+ return GenerationResult.error_result(f"Gemini API error: {str(e)}")
114
+
115
+ def _build_contents(self, request: GenerationRequest) -> list:
116
+ """Build contents list for API request."""
117
+ contents = []
118
+
119
+ # Add input images if present
120
+ if request.has_input_images:
121
+ valid_images = [img for img in request.input_images if img is not None]
122
+ contents.extend(valid_images)
123
+
124
+ # Add prompt
125
+ contents.append(request.prompt)
126
+
127
+ return contents
128
+
129
+ def _build_config(
130
+ self,
131
+ request: GenerationRequest,
132
+ resolution: Optional[str] = None
133
+ ) -> types.GenerateContentConfig:
134
+ """Build generation config for API request."""
135
+ # Parse aspect ratio
136
+ aspect_ratio = request.aspect_ratio
137
+ if " " in aspect_ratio:
138
+ aspect_ratio = aspect_ratio.split()[0]
139
+
140
+ # Build image config
141
+ image_config_kwargs = {"aspect_ratio": aspect_ratio}
142
+
143
+ # Add resolution for Pro model
144
+ if resolution and self.use_pro_model:
145
+ if resolution not in self.VALID_RESOLUTIONS:
146
+ logger.warning(f"Invalid resolution '{resolution}', defaulting to '1K'")
147
+ resolution = "1K"
148
+ image_config_kwargs["output_image_resolution"] = resolution
149
+ logger.info(f"Pro model resolution: {resolution}")
150
+
151
+ config = types.GenerateContentConfig(
152
+ temperature=request.temperature,
153
+ response_modalities=["image", "text"],
154
+ image_config=types.ImageConfig(**image_config_kwargs)
155
+ )
156
+
157
+ return config
158
+
159
+ def _parse_response(self, response) -> GenerationResult:
160
+ """Parse API response and extract image."""
161
+ if response is None:
162
+ return GenerationResult.error_result("No response from API")
163
+
164
+ if not hasattr(response, 'candidates') or not response.candidates:
165
+ return GenerationResult.error_result("No candidates in response")
166
+
167
+ candidate = response.candidates[0]
168
+
169
+ # Check finish reason
170
+ if hasattr(candidate, 'finish_reason'):
171
+ finish_reason = str(candidate.finish_reason)
172
+ logger.info(f"Finish reason: {finish_reason}")
173
+
174
+ if 'SAFETY' in finish_reason or 'PROHIBITED' in finish_reason:
175
+ return GenerationResult.error_result(
176
+ f"Content blocked by safety filters: {finish_reason}"
177
+ )
178
+
179
+ # Check for content
180
+ if not hasattr(candidate, 'content') or candidate.content is None:
181
+ finish_reason = getattr(candidate, 'finish_reason', 'UNKNOWN')
182
+ return GenerationResult.error_result(
183
+ f"No content in response (finish_reason: {finish_reason})"
184
+ )
185
+
186
+ # Extract image from parts
187
+ if hasattr(candidate.content, 'parts') and candidate.content.parts:
188
+ for part in candidate.content.parts:
189
+ if hasattr(part, 'inline_data') and part.inline_data:
190
+ try:
191
+ image_data = part.inline_data.data
192
+
193
+ # Handle both bytes and base64 string
194
+ if isinstance(image_data, str):
195
+ image_data = base64.b64decode(image_data)
196
+
197
+ # Convert to PIL Image
198
+ image_buffer = BytesIO(image_data)
199
+ image = Image.open(image_buffer)
200
+ image.load()
201
+
202
+ logger.info(f"Image generated: {image.size}, {image.mode}")
203
+ return GenerationResult.success_result(
204
+ image=image,
205
+ message="Generated successfully"
206
+ )
207
+
208
+ except Exception as e:
209
+ logger.error(f"Failed to decode image: {e}")
210
+ return GenerationResult.error_result(
211
+ f"Image decoding error: {str(e)}"
212
+ )
213
+
214
+ return GenerationResult.error_result("No image data in response")
215
+
216
+ def is_healthy(self) -> bool:
217
+ """Check if API is accessible."""
218
+ return self.api_key is not None and len(self.api_key) > 0
219
+
220
+ @classmethod
221
+ def get_dimensions(cls, aspect_ratio: str) -> tuple:
222
+ """Get pixel dimensions for aspect ratio."""
223
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
224
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
src/longcat_edit_client.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LongCat-Image-Edit Client
3
+ =========================
4
+
5
+ Client for Meituan's LongCat-Image-Edit model.
6
+ Supports instruction-following image editing with bilingual (Chinese-English) support.
7
+
8
+ This is a SOTA open-source image editing model with excellent:
9
+ - Global editing, local editing, text modification
10
+ - Reference-guided editing
11
+ - Consistency preservation (layout, texture, color tone, identity)
12
+ - Multi-turn editing capabilities
13
+ """
14
+
15
+ import logging
16
+ import time
17
+ from typing import Optional, List
18
+ from PIL import Image
19
+
20
+ import torch
21
+
22
+ from .models import GenerationRequest, GenerationResult
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class LongCatEditClient:
29
+ """
30
+ Client for LongCat-Image-Edit model from Meituan.
31
+
32
+ Features:
33
+ - Instruction-following image editing
34
+ - Bilingual support (Chinese-English)
35
+ - Excellent consistency preservation
36
+ - Multi-turn editing
37
+
38
+ Requires ~18GB VRAM with CPU offload.
39
+ """
40
+
41
+ MODEL_ID = "meituan-longcat/LongCat-Image-Edit"
42
+
43
+ # Aspect ratio to dimensions mapping
44
+ ASPECT_RATIOS = {
45
+ "1:1": (1024, 1024),
46
+ "16:9": (1344, 768),
47
+ "9:16": (768, 1344),
48
+ "21:9": (1536, 640), # Cinematic ultra-wide
49
+ "3:2": (1248, 832),
50
+ "2:3": (832, 1248),
51
+ "3:4": (896, 1152),
52
+ "4:3": (1152, 896),
53
+ "4:5": (896, 1120),
54
+ "5:4": (1120, 896),
55
+ }
56
+
57
+ # Default generation settings
58
+ DEFAULT_STEPS = 50
59
+ DEFAULT_GUIDANCE = 4.5
60
+
61
+ def __init__(
62
+ self,
63
+ device: str = "cuda",
64
+ dtype: torch.dtype = torch.bfloat16,
65
+ enable_cpu_offload: bool = True,
66
+ ):
67
+ """
68
+ Initialize LongCat-Image-Edit client.
69
+
70
+ Args:
71
+ device: Device to use (cuda or cpu)
72
+ dtype: Data type for model weights (bfloat16 recommended)
73
+ enable_cpu_offload: Enable CPU offload to save VRAM (~18GB required)
74
+ """
75
+ self.device = device
76
+ self.dtype = dtype
77
+ self.enable_cpu_offload = enable_cpu_offload
78
+ self.pipe = None
79
+ self._loaded = False
80
+
81
+ logger.info(f"LongCatEditClient initialized (cpu_offload: {enable_cpu_offload})")
82
+
83
+ def load_model(self) -> bool:
84
+ """Load the model into memory."""
85
+ if self._loaded:
86
+ return True
87
+
88
+ try:
89
+ logger.info(f"Loading LongCat-Image-Edit from {self.MODEL_ID}...")
90
+
91
+ start_time = time.time()
92
+
93
+ # Import LongCat pipeline
94
+ # Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
95
+ from diffusers import LongCatImageEditPipeline
96
+
97
+ self.pipe = LongCatImageEditPipeline.from_pretrained(
98
+ self.MODEL_ID,
99
+ torch_dtype=self.dtype,
100
+ )
101
+
102
+ # Apply memory optimization
103
+ if self.enable_cpu_offload:
104
+ self.pipe.enable_model_cpu_offload()
105
+ logger.info("CPU offload enabled (~18GB VRAM)")
106
+ else:
107
+ self.pipe.to(self.device, self.dtype)
108
+ logger.info(f"Model moved to {self.device} (high VRAM mode)")
109
+
110
+ load_time = time.time() - start_time
111
+ logger.info(f"LongCat-Image-Edit loaded in {load_time:.1f}s")
112
+
113
+ self._loaded = True
114
+ return True
115
+
116
+ except Exception as e:
117
+ logger.error(f"Failed to load LongCat-Image-Edit: {e}", exc_info=True)
118
+ return False
119
+
120
+ def unload_model(self):
121
+ """Unload model from memory."""
122
+ if self.pipe is not None:
123
+ del self.pipe
124
+ self.pipe = None
125
+
126
+ self._loaded = False
127
+
128
+ if torch.cuda.is_available():
129
+ torch.cuda.empty_cache()
130
+
131
+ logger.info("LongCat-Image-Edit unloaded")
132
+
133
+ def generate(
134
+ self,
135
+ request: GenerationRequest,
136
+ num_inference_steps: int = None,
137
+ guidance_scale: float = None
138
+ ) -> GenerationResult:
139
+ """
140
+ Edit image using LongCat-Image-Edit.
141
+
142
+ Args:
143
+ request: GenerationRequest object with:
144
+ - prompt: The editing instruction (e.g., "Change the background to a forest")
145
+ - input_images: List with the source image to edit
146
+ - aspect_ratio: Output aspect ratio
147
+ num_inference_steps: Number of denoising steps (default: 50)
148
+ guidance_scale: Classifier-free guidance scale (default: 4.5)
149
+
150
+ Returns:
151
+ GenerationResult object
152
+ """
153
+ if not self._loaded:
154
+ if not self.load_model():
155
+ return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")
156
+
157
+ # Use defaults if not specified
158
+ if num_inference_steps is None:
159
+ num_inference_steps = self.DEFAULT_STEPS
160
+ if guidance_scale is None:
161
+ guidance_scale = self.DEFAULT_GUIDANCE
162
+
163
+ try:
164
+ start_time = time.time()
165
+
166
+ # Get input image
167
+ if not request.has_input_images:
168
+ return GenerationResult.error_result("LongCat-Image-Edit requires an input image to edit")
169
+
170
+ input_image = None
171
+ for img in request.input_images:
172
+ if img is not None:
173
+ input_image = img
174
+ break
175
+
176
+ if input_image is None:
177
+ return GenerationResult.error_result("No valid input image provided")
178
+
179
+ # Get dimensions from aspect ratio
180
+ width, height = self._get_dimensions(request.aspect_ratio)
181
+
182
+ # Resize input image to target dimensions
183
+ input_image = input_image.convert('RGB')
184
+ input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
185
+
186
+ logger.info(f"Editing with LongCat: steps={num_inference_steps}, guidance={guidance_scale}")
187
+ logger.info(f"Edit instruction: {request.prompt[:100]}...")
188
+
189
+ # Build generation kwargs
190
+ gen_kwargs = {
191
+ "image": input_image,
192
+ "prompt": request.prompt,
193
+ "negative_prompt": request.negative_prompt or "",
194
+ "guidance_scale": guidance_scale,
195
+ "num_inference_steps": num_inference_steps,
196
+ "num_images_per_prompt": 1,
197
+ "generator": torch.Generator("cpu").manual_seed(42),
198
+ }
199
+
200
+ # Generate
201
+ with torch.inference_mode():
202
+ output = self.pipe(**gen_kwargs)
203
+ image = output.images[0]
204
+
205
+ generation_time = time.time() - start_time
206
+ logger.info(f"Edited in {generation_time:.2f}s: {image.size}")
207
+
208
+ return GenerationResult.success_result(
209
+ image=image,
210
+ message=f"Edited with LongCat-Image-Edit in {generation_time:.2f}s",
211
+ generation_time=generation_time
212
+ )
213
+
214
+ except Exception as e:
215
+ logger.error(f"LongCat-Image-Edit generation failed: {e}", exc_info=True)
216
+ return GenerationResult.error_result(f"LongCat-Image-Edit error: {str(e)}")
217
+
218
+ def edit_with_instruction(
219
+ self,
220
+ source_image: Image.Image,
221
+ instruction: str,
222
+ negative_prompt: str = "",
223
+ num_inference_steps: int = None,
224
+ guidance_scale: float = None,
225
+ seed: int = 42
226
+ ) -> GenerationResult:
227
+ """
228
+ Simplified method for instruction-based image editing.
229
+
230
+ Args:
231
+ source_image: The image to edit
232
+ instruction: Natural language editing instruction
233
+ Examples:
234
+ - "Change the background to a sunset beach"
235
+ - "Make the person wear a red dress"
236
+ - "Add snow to the scene"
237
+ - "Change the cat to a dog"
238
+ negative_prompt: What to avoid in the output
239
+ num_inference_steps: Denoising steps (default: 50)
240
+ guidance_scale: CFG scale (default: 4.5)
241
+ seed: Random seed for reproducibility
242
+
243
+ Returns:
244
+ GenerationResult with the edited image
245
+ """
246
+ if not self._loaded:
247
+ if not self.load_model():
248
+ return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")
249
+
250
+ if num_inference_steps is None:
251
+ num_inference_steps = self.DEFAULT_STEPS
252
+ if guidance_scale is None:
253
+ guidance_scale = self.DEFAULT_GUIDANCE
254
+
255
+ try:
256
+ start_time = time.time()
257
+
258
+ # Ensure RGB
259
+ source_image = source_image.convert('RGB')
260
+
261
+ logger.info(f"Editing image with instruction: {instruction[:100]}...")
262
+
263
+ with torch.inference_mode():
264
+ output = self.pipe(
265
+ image=source_image,
266
+ prompt=instruction,
267
+ negative_prompt=negative_prompt,
268
+ guidance_scale=guidance_scale,
269
+ num_inference_steps=num_inference_steps,
270
+ num_images_per_prompt=1,
271
+ generator=torch.Generator("cpu").manual_seed(seed),
272
+ )
273
+ image = output.images[0]
274
+
275
+ generation_time = time.time() - start_time
276
+ logger.info(f"Edit completed in {generation_time:.2f}s")
277
+
278
+ return GenerationResult.success_result(
279
+ image=image,
280
+ message=f"Edited with instruction in {generation_time:.2f}s",
281
+ generation_time=generation_time
282
+ )
283
+
284
+ except Exception as e:
285
+ logger.error(f"Instruction-based edit failed: {e}", exc_info=True)
286
+ return GenerationResult.error_result(f"Edit error: {str(e)}")
287
+
288
+ def _get_dimensions(self, aspect_ratio: str) -> tuple:
289
+ """Get pixel dimensions for aspect ratio."""
290
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
291
+ return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
292
+
293
+ def is_healthy(self) -> bool:
294
+ """Check if model is loaded and ready."""
295
+ return self._loaded and self.pipe is not None
296
+
297
+ @classmethod
298
+ def get_dimensions(cls, aspect_ratio: str) -> tuple:
299
+ """Get pixel dimensions for aspect ratio."""
300
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
301
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
src/model_manager.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Manager
3
+ =============
4
+
5
+ Manages model loading states and provides a robust interface for
6
+ ensuring models are loaded and validated before generation.
7
+
8
+ States:
9
+ - UNLOADED: No model loaded
10
+ - LOADING: Model is being loaded
11
+ - READY: Model loaded and validated
12
+ - ERROR: Model failed to load
13
+ """
14
+
15
+ import logging
16
+ import threading
17
+ import time
18
+ from typing import Optional, Callable, Tuple
19
+ from enum import Enum
20
+ from PIL import Image
21
+
22
+ from .backend_router import BackendRouter, BackendType
23
+ from .character_service import CharacterSheetService
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class ModelState(Enum):
30
+ """Model loading states."""
31
+ UNLOADED = "unloaded"
32
+ LOADING = "loading"
33
+ READY = "ready"
34
+ ERROR = "error"
35
+
36
+
37
+ class ModelManager:
38
+ """
39
+ Manages model loading lifecycle with state tracking.
40
+
41
+ Ensures models are fully loaded and validated before allowing generation.
42
+ Provides progress callbacks for UI updates during loading.
43
+ """
44
+
45
+ def __init__(self):
46
+ self._state = ModelState.UNLOADED
47
+ self._current_backend: Optional[BackendType] = None
48
+ self._service: Optional[CharacterSheetService] = None
49
+ self._error_message: Optional[str] = None
50
+ self._loading_progress: float = 0.0
51
+ self._loading_message: str = ""
52
+ self._lock = threading.Lock()
53
+ self._cancel_requested = False
54
+
55
+ @property
56
+ def state(self) -> ModelState:
57
+ """Current model state."""
58
+ return self._state
59
+
60
+ @property
61
+ def is_ready(self) -> bool:
62
+ """Check if model is ready for generation."""
63
+ return self._state == ModelState.READY
64
+
65
+ @property
66
+ def is_loading(self) -> bool:
67
+ """Check if model is currently loading."""
68
+ return self._state == ModelState.LOADING
69
+
70
+ @property
71
+ def error_message(self) -> Optional[str]:
72
+ """Get error message if in error state."""
73
+ return self._error_message
74
+
75
+ @property
76
+ def loading_progress(self) -> float:
77
+ """Get loading progress (0.0 to 1.0)."""
78
+ return self._loading_progress
79
+
80
+ @property
81
+ def loading_message(self) -> str:
82
+ """Get current loading status message."""
83
+ return self._loading_message
84
+
85
+ @property
86
+ def current_backend(self) -> Optional[BackendType]:
87
+ """Get currently loaded backend."""
88
+ return self._current_backend
89
+
90
+ @property
91
+ def service(self) -> Optional[CharacterSheetService]:
92
+ """Get the character sheet service (only valid when ready)."""
93
+ if self._state != ModelState.READY:
94
+ return None
95
+ return self._service
96
+
97
+ def get_status_display(self) -> Tuple[str, str]:
98
+ """
99
+ Get status message and color for UI display.
100
+
101
+ Returns:
102
+ Tuple of (message, color) where color is a CSS color string
103
+ """
104
+ if self._state == ModelState.UNLOADED:
105
+ return "No model loaded", "#888888"
106
+ elif self._state == ModelState.LOADING:
107
+ pct = int(self._loading_progress * 100)
108
+ return f"Loading... {pct}% - {self._loading_message}", "#FFA500"
109
+ elif self._state == ModelState.READY:
110
+ backend_name = BackendRouter.BACKEND_NAMES.get(
111
+ self._current_backend,
112
+ str(self._current_backend)
113
+ )
114
+ return f"Ready: {backend_name}", "#00AA00"
115
+ elif self._state == ModelState.ERROR:
116
+ return f"Error: {self._error_message}", "#FF0000"
117
+ return "Unknown state", "#888888"
118
+
119
+ def request_cancel(self):
120
+ """Request cancellation of current loading operation."""
121
+ self._cancel_requested = True
122
+ logger.info("Model loading cancellation requested")
123
+
124
+ def load_model(
125
+ self,
126
+ backend: BackendType,
127
+ api_key: Optional[str] = None,
128
+ steps: int = 4,
129
+ guidance: float = 1.0,
130
+ progress_callback: Optional[Callable[[float, str], None]] = None
131
+ ) -> bool:
132
+ """
133
+ Load a model with progress tracking.
134
+
135
+ Args:
136
+ backend: Backend type to load
137
+ api_key: API key for cloud backends
138
+ steps: Default steps for generation
139
+ guidance: Default guidance scale
140
+ progress_callback: Callback for progress updates (progress, message)
141
+
142
+ Returns:
143
+ True if model loaded successfully
144
+ """
145
+ with self._lock:
146
+ if self._state == ModelState.LOADING:
147
+ logger.warning("Model is already loading, ignoring request")
148
+ return False
149
+
150
+ self._state = ModelState.LOADING
151
+ self._loading_progress = 0.0
152
+ self._loading_message = "Initializing..."
153
+ self._error_message = None
154
+ self._cancel_requested = False
155
+
156
+ def update_progress(progress: float, message: str):
157
+ self._loading_progress = progress
158
+ self._loading_message = message
159
+ if progress_callback:
160
+ progress_callback(progress, message)
161
+
162
+ try:
163
+ # Step 1: Unload previous model if different backend
164
+ update_progress(0.05, "Checking current model...")
165
+
166
+ if self._service and self._current_backend != backend:
167
+ update_progress(0.1, "Unloading previous model...")
168
+ try:
169
+ if hasattr(self._service, 'router'):
170
+ self._service.router.unload_local_models()
171
+ except Exception as e:
172
+ logger.warning(f"Error unloading previous model: {e}")
173
+ self._service = None
174
+
175
+ if self._cancel_requested:
176
+ self._state = ModelState.UNLOADED
177
+ return False
178
+
179
+ # Step 2: Create service and load model
180
+ backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend))
181
+ update_progress(0.15, f"Loading {backend_name}...")
182
+
183
+ logger.info(f"Creating CharacterSheetService for {backend.value}")
184
+
185
+ # For local models, this will load the model
186
+ # For cloud backends, this just validates the API key
187
+ self._service = CharacterSheetService(
188
+ api_key=api_key,
189
+ backend=backend
190
+ )
191
+
192
+ if self._cancel_requested:
193
+ self._state = ModelState.UNLOADED
194
+ self._service = None
195
+ return False
196
+
197
+ update_progress(0.7, "Model loaded, configuring...")
198
+
199
+ # Step 3: Configure default parameters
200
+ if hasattr(self._service.client, 'default_steps'):
201
+ self._service.client.default_steps = steps
202
+ if hasattr(self._service.client, 'default_guidance'):
203
+ self._service.client.default_guidance = guidance
204
+
205
+ update_progress(0.8, "Validating model...")
206
+
207
+ # Step 4: Validate model is actually working
208
+ is_valid, error = self._validate_model()
209
+
210
+ if not is_valid:
211
+ raise RuntimeError(f"Model validation failed: {error}")
212
+
213
+ update_progress(1.0, "Ready!")
214
+
215
+ # Success!
216
+ with self._lock:
217
+ self._current_backend = backend
218
+ self._state = ModelState.READY
219
+ self._loading_progress = 1.0
220
+ self._loading_message = "Ready"
221
+
222
+ logger.info(f"Model {backend.value} loaded and validated successfully")
223
+ return True
224
+
225
+ except Exception as e:
226
+ error_msg = str(e)
227
+ logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True)
228
+
229
+ with self._lock:
230
+ self._state = ModelState.ERROR
231
+ self._error_message = self._simplify_error(error_msg)
232
+ self._service = None
233
+
234
+ if progress_callback:
235
+ progress_callback(0.0, f"Error: {self._error_message}")
236
+
237
+ return False
238
+
239
+ def _validate_model(self) -> Tuple[bool, Optional[str]]:
240
+ """
241
+ Validate that the model is actually working.
242
+
243
+ For local models, checks that the pipeline is loaded.
244
+ For cloud backends, does a minimal health check.
245
+
246
+ Returns:
247
+ Tuple of (is_valid, error_message)
248
+ """
249
+ if self._service is None:
250
+ return False, "Service not initialized"
251
+
252
+ try:
253
+ client = self._service.client
254
+
255
+ # Check if client has health check method
256
+ if hasattr(client, 'is_healthy'):
257
+ if not client.is_healthy():
258
+ return False, "Client health check failed"
259
+
260
+ # For local models, check pipeline is loaded
261
+ if hasattr(client, '_loaded'):
262
+ if not client._loaded:
263
+ return False, "Model pipeline not loaded"
264
+
265
+ # For FLUX models, verify the pipe exists
266
+ if hasattr(client, 'pipe'):
267
+ if client.pipe is None:
268
+ return False, "Model pipeline is None"
269
+
270
+ return True, None
271
+
272
+ except Exception as e:
273
+ return False, str(e)
274
+
275
+ def _simplify_error(self, error: str) -> str:
276
+ """Simplify technical error messages for user display."""
277
+ error_lower = error.lower()
278
+
279
+ if "cuda out of memory" in error_lower or "out of memory" in error_lower:
280
+ return "Not enough GPU memory. Try a smaller model or close other applications."
281
+
282
+ if "api key" in error_lower:
283
+ return "Invalid or missing API key."
284
+
285
+ if "connection" in error_lower or "network" in error_lower:
286
+ return "Network connection error. Check your internet connection."
287
+
288
+ if "not found" in error_lower and "model" in error_lower:
289
+ return "Model files not found. The model may need to be downloaded."
290
+
291
+ if "import" in error_lower:
292
+ return "Missing dependencies. Some required packages are not installed."
293
+
294
+ if "meta tensor" in error_lower:
295
+ return "Model loading failed (meta tensor error). Try restarting the application."
296
+
297
+ # Truncate long errors
298
+ if len(error) > 100:
299
+ return error[:97] + "..."
300
+
301
+ return error
302
+
303
+ def unload(self):
304
+ """Unload the current model."""
305
+ with self._lock:
306
+ if self._service:
307
+ try:
308
+ if hasattr(self._service, 'router'):
309
+ self._service.router.unload_local_models()
310
+ except Exception as e:
311
+ logger.warning(f"Error during unload: {e}")
312
+ self._service = None
313
+
314
+ self._state = ModelState.UNLOADED
315
+ self._current_backend = None
316
+ self._error_message = None
317
+ self._loading_progress = 0.0
318
+ self._loading_message = ""
319
+
320
+ logger.info("Model unloaded")
321
+
322
+
323
+ # Global singleton for model management
324
+ _model_manager: Optional[ModelManager] = None
325
+
326
+
327
+ def get_model_manager() -> ModelManager:
328
+ """Get the global ModelManager instance."""
329
+ global _model_manager
330
+ if _model_manager is None:
331
+ _model_manager = ModelManager()
332
+ return _model_manager
src/models.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Models for Character Sheet Pro
3
+ ====================================
4
+
5
+ Dataclasses for generation requests and results.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional, List, Dict, Any
10
+ from PIL import Image
11
+ from datetime import datetime
12
+
13
+
14
+ @dataclass
15
+ class GenerationRequest:
16
+ """Request for image generation."""
17
+
18
+ prompt: str
19
+ input_images: List[Image.Image] = field(default_factory=list)
20
+ aspect_ratio: str = "1:1"
21
+ temperature: float = 0.4
22
+ negative_prompt: Optional[str] = None
23
+
24
+ @property
25
+ def has_input_images(self) -> bool:
26
+ """Check if request has input images."""
27
+ return len(self.input_images) > 0
28
+
29
+
30
+ @dataclass
31
+ class GenerationResult:
32
+ """Result from image generation."""
33
+
34
+ success: bool
35
+ image: Optional[Image.Image] = None
36
+ message: str = ""
37
+ generation_time: Optional[float] = None
38
+
39
+ @classmethod
40
+ def success_result(
41
+ cls,
42
+ image: Image.Image,
43
+ message: str = "Generated successfully",
44
+ generation_time: Optional[float] = None
45
+ ) -> "GenerationResult":
46
+ """Create successful result."""
47
+ return cls(
48
+ success=True,
49
+ image=image,
50
+ message=message,
51
+ generation_time=generation_time
52
+ )
53
+
54
+ @classmethod
55
+ def error_result(cls, message: str) -> "GenerationResult":
56
+ """Create error result."""
57
+ return cls(success=False, message=message)
58
+
59
+
60
+ @dataclass
61
+ class CharacterSheetConfig:
62
+ """Configuration for character sheet generation."""
63
+
64
+ # Aspect ratios
65
+ face_aspect_ratio: str = "3:4" # 864x1184
66
+ body_aspect_ratio: str = "9:16" # 768x1344
67
+
68
+ # Generation temperatures
69
+ face_temperature: float = 0.35
70
+ body_temperature: float = 0.35
71
+ normalize_temperature: float = 0.5
72
+
73
+ # Layout
74
+ spacing: int = 20
75
+ background_color: str = "#2C2C2C"
76
+
77
+ # Retry settings
78
+ max_retries: int = 3
79
+ retry_delay: float = 30.0
80
+ rate_limit_delay_min: float = 2.0
81
+ rate_limit_delay_max: float = 3.0
82
+
83
+
84
+ @dataclass
85
+ class CharacterSheetMetadata:
86
+ """Metadata for generated character sheet."""
87
+
88
+ character_name: str
89
+ input_type: str # "Face Only", "Full Body", "Face + Body"
90
+ costume_description: str
91
+ backend: str
92
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
93
+ views: int = 7
94
+ stages: Dict[str, Any] = field(default_factory=dict)
src/qwen_image_edit_client.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen-Image-Edit Client
3
+ ======================
4
+
5
+ Client for Qwen-Image-Edit-2511 local image editing.
6
+ Supports multi-image editing with improved consistency.
7
+
8
+ GPU loading strategies (benchmarked on A6000 + A5000):
9
+ Pinned 2-GPU: 169.9s (4.25s/step) - 1.36x vs baseline
10
+ Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline
11
+ CPU offload: 231.5s (5.79s/step) - baseline
12
+ """
13
+
14
+ import logging
15
+ import time
16
+ import types
17
+ from typing import Optional, List
18
+ from PIL import Image
19
+
20
+ import torch
21
+
22
+ from .models import GenerationRequest, GenerationResult
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class QwenImageEditClient:
29
+ """
30
+ Client for Qwen-Image-Edit-2511 model.
31
+
32
+ Supports:
33
+ - Multi-image editing (up to multiple reference images)
34
+ - Precise text editing
35
+ - Improved character consistency
36
+ - LoRA integration
37
+ """
38
+
39
+ # Model variants
40
+ MODELS = {
41
+ "full": "Qwen/Qwen-Image-Edit", # Official Qwen model
42
+ }
43
+
44
+ # Legacy compatibility
45
+ MODEL_ID = MODELS["full"]
46
+
47
+ # Aspect ratio to dimensions mapping (target output sizes)
48
+ ASPECT_RATIOS = {
49
+ "1:1": (1328, 1328),
50
+ "16:9": (1664, 928),
51
+ "9:16": (928, 1664),
52
+ "21:9": (1680, 720), # Cinematic ultra-wide
53
+ "3:2": (1584, 1056),
54
+ "2:3": (1056, 1584),
55
+ "3:4": (1104, 1472),
56
+ "4:3": (1472, 1104),
57
+ "4:5": (1056, 1320),
58
+ "5:4": (1320, 1056),
59
+ }
60
+
61
+ # Proven native generation resolution. Tested resolutions:
62
+ # 1104x1472 (3:4) → CLEAN output (face views in v1 test)
63
+ # 928x1664 (9:16) → VAE tiling noise / garbage
64
+ # 1328x1328 (1:1) → VAE tiling noise / garbage
65
+ # 896x1184 (auto) → garbage
66
+ # Always generate at 1104x1472, then crop+resize to target.
67
+ NATIVE_RESOLUTION = (1104, 1472)
68
+
69
+ # VRAM thresholds for loading strategies
70
+ # Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB
71
+ BALANCED_VRAM_THRESHOLD_GB = 45 # Single GPU balanced (needs ~42GB + headroom)
72
+ MAIN_GPU_MIN_VRAM_GB = 42 # Transformer + VAE minimum
73
+ ENCODER_GPU_MIN_VRAM_GB = 17 # Text encoder minimum
74
+
75
+ def __init__(
76
+ self,
77
+ model_variant: str = "full", # Use full model (~50GB)
78
+ device: str = "cuda",
79
+ dtype: torch.dtype = torch.bfloat16,
80
+ enable_cpu_offload: bool = True,
81
+ encoder_device: Optional[str] = None,
82
+ ):
83
+ """
84
+ Initialize Qwen-Image-Edit client.
85
+
86
+ Args:
87
+ model_variant: Model variant ("full" for ~50GB)
88
+ device: Device to use for transformer+VAE (cuda or cuda:N)
89
+ dtype: Data type for model weights
90
+ enable_cpu_offload: Enable CPU offload to save VRAM
91
+ encoder_device: Explicit device for text_encoder (e.g. "cuda:3").
92
+ If None, auto-detected from available GPUs.
93
+ """
94
+ self.model_variant = model_variant
95
+ self.device = device
96
+ self.dtype = dtype
97
+ self.enable_cpu_offload = enable_cpu_offload
98
+ self.encoder_device = encoder_device
99
+ self.pipe = None
100
+ self._loaded = False
101
+ self._loading_strategy = None
102
+
103
+ logger.info(f"QwenImageEditClient initialized (variant: {model_variant})")
104
+
105
+ @staticmethod
106
+ def _get_gpu_vram_gb(device_idx: int) -> float:
107
+ """Get total VRAM in GB for a specific GPU."""
108
+ if not torch.cuda.is_available():
109
+ return 0.0
110
+ if device_idx >= torch.cuda.device_count():
111
+ return 0.0
112
+ return torch.cuda.get_device_properties(device_idx).total_memory / 1e9
113
+
114
+ def _get_vram_gb(self) -> float:
115
+ """Get available VRAM in GB for the main target device."""
116
+ device_idx = self._parse_device_idx(self.device)
117
+ return self._get_gpu_vram_gb(device_idx)
118
+
119
+ @staticmethod
120
+ def _parse_device_idx(device: str) -> int:
121
+ """Parse CUDA device index from device string."""
122
+ if device.startswith("cuda:"):
123
+ try:
124
+ return int(device.split(":")[1])
125
+ except (ValueError, IndexError):
126
+ pass
127
+ return 0
128
+
129
+ def _find_encoder_gpu(self, main_idx: int) -> Optional[int]:
130
+ """Find a secondary GPU suitable for text_encoder (>= 17GB VRAM).
131
+
132
+ Prefers GPUs with more VRAM. Skips the main GPU.
133
+ """
134
+ if not torch.cuda.is_available():
135
+ return None
136
+
137
+ candidates = []
138
+ for i in range(torch.cuda.device_count()):
139
+ if i == main_idx:
140
+ continue
141
+ vram = self._get_gpu_vram_gb(i)
142
+ if vram >= self.ENCODER_GPU_MIN_VRAM_GB:
143
+ name = torch.cuda.get_device_name(i)
144
+ candidates.append((i, vram, name))
145
+
146
+ if not candidates:
147
+ return None
148
+
149
+ # Pick the GPU with the most VRAM
150
+ candidates.sort(key=lambda x: x[1], reverse=True)
151
+ best = candidates[0]
152
+ logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)")
153
+ return best[0]
154
+
155
+ @staticmethod
156
+ def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None):
157
+ """Patched prompt encoding that routes inputs to text_encoder's device.
158
+
159
+ The original _get_qwen_prompt_embeds sends model_inputs to
160
+ execution_device (main GPU), then calls text_encoder on a different
161
+ GPU, causing a device mismatch. This patch:
162
+ 1. Sends model_inputs to text_encoder's device for encoding
163
+ 2. Moves outputs back to execution_device for the transformer
164
+ """
165
+ te_device = next(self.text_encoder.parameters()).device
166
+ execution_device = device or self._execution_device
167
+ dtype = dtype or self.text_encoder.dtype
168
+
169
+ prompt = [prompt] if isinstance(prompt, str) else prompt
170
+
171
+ template = self.prompt_template_encode
172
+ drop_idx = self.prompt_template_encode_start_idx
173
+ txt = [template.format(e) for e in prompt]
174
+
175
+ # Route to text_encoder's device, NOT execution_device
176
+ model_inputs = self.processor(
177
+ text=txt, images=image, padding=True, return_tensors="pt"
178
+ ).to(te_device)
179
+
180
+ outputs = self.text_encoder(
181
+ input_ids=model_inputs.input_ids,
182
+ attention_mask=model_inputs.attention_mask,
183
+ pixel_values=model_inputs.pixel_values,
184
+ image_grid_thw=model_inputs.image_grid_thw,
185
+ output_hidden_states=True,
186
+ )
187
+
188
+ hidden_states = outputs.hidden_states[-1]
189
+ split_hidden_states = self._extract_masked_hidden(
190
+ hidden_states, model_inputs.attention_mask)
191
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
192
+ attn_mask_list = [
193
+ torch.ones(e.size(0), dtype=torch.long, device=e.device)
194
+ for e in split_hidden_states
195
+ ]
196
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
197
+ prompt_embeds = torch.stack([
198
+ torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])
199
+ for u in split_hidden_states
200
+ ])
201
+ encoder_attention_mask = torch.stack([
202
+ torch.cat([u, u.new_zeros(max_seq_len - u.size(0))])
203
+ for u in attn_mask_list
204
+ ])
205
+
206
+ # Move outputs to execution_device for transformer
207
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device)
208
+ encoder_attention_mask = encoder_attention_mask.to(device=execution_device)
209
+
210
+ return prompt_embeds, encoder_attention_mask
211
+
212
+ def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool:
213
+ """Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary.
214
+
215
+ Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline.
216
+ """
217
+ from diffusers import QwenImageEditPipeline
218
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
219
+ from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
220
+ from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
221
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
222
+
223
+ main_dev = f"cuda:{main_idx}"
224
+ enc_dev = f"cuda:{encoder_idx}"
225
+
226
+ logger.info(f"Loading pinned 2-GPU: transformer+VAE → {main_dev}, text_encoder → {enc_dev}")
227
+
228
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
229
+ model_id, subfolder="scheduler")
230
+ tokenizer = Qwen2Tokenizer.from_pretrained(
231
+ model_id, subfolder="tokenizer")
232
+ processor = Qwen2VLProcessor.from_pretrained(
233
+ model_id, subfolder="processor")
234
+
235
+ text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
236
+ model_id, subfolder="text_encoder", torch_dtype=self.dtype,
237
+ ).to(enc_dev)
238
+ logger.info(f" text_encoder loaded on {enc_dev}")
239
+
240
+ transformer = QwenImageTransformer2DModel.from_pretrained(
241
+ model_id, subfolder="transformer", torch_dtype=self.dtype,
242
+ ).to(main_dev)
243
+ logger.info(f" transformer loaded on {main_dev}")
244
+
245
+ vae = AutoencoderKLQwenImage.from_pretrained(
246
+ model_id, subfolder="vae", torch_dtype=self.dtype,
247
+ ).to(main_dev)
248
+ vae.enable_tiling()
249
+ logger.info(f" VAE loaded on {main_dev}")
250
+
251
+ self.pipe = QwenImageEditPipeline(
252
+ scheduler=scheduler, vae=vae, text_encoder=text_encoder,
253
+ tokenizer=tokenizer, processor=processor, transformer=transformer,
254
+ )
255
+
256
+ # Fix 1: Override _execution_device to force main GPU
257
+ # Without this, pipeline returns text_encoder's device, causing VAE
258
+ # to receive tensors on the wrong GPU
259
+ main_device = torch.device(main_dev)
260
+ QwenImageEditPipeline._execution_device = property(lambda self: main_device)
261
+
262
+ # Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device
263
+ self.pipe._get_qwen_prompt_embeds = types.MethodType(
264
+ self._patched_get_qwen_prompt_embeds, self.pipe)
265
+
266
+ self._loading_strategy = "pinned_multi_gpu"
267
+ logger.info(f"Pinned 2-GPU pipeline ready")
268
+ return True
269
+
270
+ def load_model(self) -> bool:
271
+ """Load the model with the best available strategy.
272
+
273
+ Strategy priority (GPU strategies always attempted first):
274
+ 1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary
275
+ (requires main GPU >= 42GB, secondary >= 17GB)
276
+ Benchmark: 169.9s (4.25s/step) - 1.36x
277
+ 2. Balanced single-GPU: device_map="balanced" on single large GPU
278
+ (requires GPU >= 45GB)
279
+ Benchmark: 184.4s (4.61s/step) - 1.25x
280
+ 3. CPU offload: model components shuttle between CPU and GPU
281
+ (requires enable_cpu_offload=True)
282
+ Benchmark: 231.5s (5.79s/step) - 1.0x baseline
283
+ 4. Direct load: entire model on single GPU (may OOM)
284
+ """
285
+ if self._loaded:
286
+ return True
287
+
288
+ try:
289
+ from diffusers import QwenImageEditPipeline
290
+
291
+ model_id = self.MODELS.get(self.model_variant, self.MODELS["full"])
292
+ main_idx = self._parse_device_idx(self.device)
293
+ main_vram = self._get_gpu_vram_gb(main_idx)
294
+ logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...")
295
+ logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM")
296
+
297
+ start_time = time.time()
298
+ loaded = False
299
+
300
+ # Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough)
301
+ if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB:
302
+ encoder_idx = None
303
+ if self.encoder_device:
304
+ encoder_idx = self._parse_device_idx(self.encoder_device)
305
+ enc_vram = self._get_gpu_vram_gb(encoder_idx)
306
+ if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB:
307
+ logger.warning(
308
+ f"Specified encoder device cuda:{encoder_idx} has "
309
+ f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. "
310
+ f"Falling back to auto-detect.")
311
+ encoder_idx = None
312
+
313
+ if encoder_idx is None:
314
+ encoder_idx = self._find_encoder_gpu(main_idx)
315
+
316
+ if encoder_idx is not None:
317
+ self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx)
318
+ loaded = True
319
+
320
+ # Strategy 2: Balanced single-GPU
321
+ if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB:
322
+ max_mem_gb = int(main_vram - 4)
323
+ self.pipe = QwenImageEditPipeline.from_pretrained(
324
+ model_id, torch_dtype=self.dtype,
325
+ device_map="balanced",
326
+ max_memory={main_idx: f"{max_mem_gb}GiB"},
327
+ )
328
+ self._loading_strategy = "balanced_single"
329
+ logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB")
330
+ loaded = True
331
+
332
+ # Strategy 3: CPU offload (only if allowed)
333
+ if not loaded and self.enable_cpu_offload:
334
+ self.pipe = QwenImageEditPipeline.from_pretrained(
335
+ model_id, torch_dtype=self.dtype)
336
+ self.pipe.enable_model_cpu_offload()
337
+ self._loading_strategy = "cpu_offload"
338
+ logger.info("Loaded with enable_model_cpu_offload()")
339
+ loaded = True
340
+
341
+ # Strategy 4: Direct load (last resort, may OOM)
342
+ if not loaded:
343
+ self.pipe = QwenImageEditPipeline.from_pretrained(
344
+ model_id, torch_dtype=self.dtype)
345
+ self.pipe.to(self.device)
346
+ self._loading_strategy = "direct"
347
+ logger.info(f"Loaded directly to {self.device}")
348
+
349
+ self.pipe.set_progress_bar_config(disable=None)
350
+
351
+ load_time = time.time() - start_time
352
+ logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})")
353
+
354
+ self._loaded = True
355
+ return True
356
+
357
+ except Exception as e:
358
+ logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True)
359
+ return False
360
+
361
+ def unload_model(self):
362
+ """Unload model from memory."""
363
+ if self.pipe is not None:
364
+ del self.pipe
365
+ self.pipe = None
366
+ self._loaded = False
367
+
368
+ if torch.cuda.is_available():
369
+ torch.cuda.empty_cache()
370
+
371
+ logger.info("Qwen-Image-Edit-2511 unloaded")
372
+
373
+ def generate(
374
+ self,
375
+ request: GenerationRequest,
376
+ num_inference_steps: int = 40,
377
+ guidance_scale: float = 1.0,
378
+ true_cfg_scale: float = 4.0
379
+ ) -> GenerationResult:
380
+ """
381
+ Generate/edit image using Qwen-Image-Edit-2511.
382
+
383
+ Args:
384
+ request: GenerationRequest object
385
+ num_inference_steps: Number of denoising steps
386
+ guidance_scale: Classifier-free guidance scale
387
+ true_cfg_scale: True CFG scale for better control
388
+
389
+ Returns:
390
+ GenerationResult object
391
+ """
392
+ if not self._loaded:
393
+ if not self.load_model():
394
+ return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model")
395
+
396
+ try:
397
+ start_time = time.time()
398
+
399
+ # Target dimensions for post-processing crop+resize
400
+ target_w, target_h = self._get_dimensions(request.aspect_ratio)
401
+
402
+ # Build input images list
403
+ input_images = []
404
+ if request.has_input_images:
405
+ input_images = [img for img in request.input_images if img is not None]
406
+
407
+ # Always generate at the proven native resolution (1104x1472).
408
+ # Other resolutions cause VAE tiling artifacts.
409
+ native_w, native_h = self.NATIVE_RESOLUTION
410
+ gen_kwargs = {
411
+ "prompt": request.prompt,
412
+ "negative_prompt": request.negative_prompt or " ",
413
+ "height": native_h,
414
+ "width": native_w,
415
+ "num_inference_steps": num_inference_steps,
416
+ "guidance_scale": guidance_scale,
417
+ "true_cfg_scale": true_cfg_scale,
418
+ "num_images_per_prompt": 1,
419
+ "generator": torch.manual_seed(42),
420
+ }
421
+
422
+ # Qwen-Image-Edit is a single-image editor: use only the first image.
423
+ # The character service passes multiple references (face, body, costume)
424
+ # but the costume/view info is already encoded in the text prompt.
425
+ if input_images:
426
+ gen_kwargs["image"] = input_images[0]
427
+
428
+ logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...")
429
+ logger.info(f"Input images: {len(input_images)} (using first)")
430
+ logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}")
431
+
432
+ # Generate at proven native resolution
433
+ with torch.inference_mode():
434
+ output = self.pipe(**gen_kwargs)
435
+ image = output.images[0]
436
+
437
+ generation_time = time.time() - start_time
438
+ logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
439
+
440
+ # Crop + resize to requested aspect ratio
441
+ image = self._crop_and_resize(image, target_w, target_h)
442
+ logger.info(f"Post-processed to: {image.size}")
443
+
444
+ return GenerationResult.success_result(
445
+ image=image,
446
+ message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s",
447
+ generation_time=generation_time
448
+ )
449
+
450
+ except Exception as e:
451
+ logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True)
452
+ return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}")
453
+
454
+ @staticmethod
455
+ def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
456
+ """Crop image to target aspect ratio, then resize to target dimensions.
457
+
458
+ Centers the crop on the image so equal amounts are trimmed from
459
+ each side. Uses LANCZOS for high-quality downscaling.
460
+ """
461
+ src_w, src_h = image.size
462
+ target_ratio = target_w / target_h
463
+ src_ratio = src_w / src_h
464
+
465
+ if abs(target_ratio - src_ratio) < 0.01:
466
+ # Already the right aspect ratio, just resize
467
+ return image.resize((target_w, target_h), Image.LANCZOS)
468
+
469
+ if target_ratio < src_ratio:
470
+ # Target is taller/narrower than source → crop sides
471
+ crop_w = int(src_h * target_ratio)
472
+ offset = (src_w - crop_w) // 2
473
+ image = image.crop((offset, 0, offset + crop_w, src_h))
474
+ else:
475
+ # Target is wider than source → crop top/bottom
476
+ crop_h = int(src_w / target_ratio)
477
+ offset = (src_h - crop_h) // 2
478
+ image = image.crop((0, offset, src_w, offset + crop_h))
479
+
480
+ return image.resize((target_w, target_h), Image.LANCZOS)
481
+
482
+ def _get_dimensions(self, aspect_ratio: str) -> tuple:
483
+ """Get pixel dimensions for aspect ratio."""
484
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
485
+ return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
486
+
487
+ def is_healthy(self) -> bool:
488
+ """Check if model is loaded and ready."""
489
+ return self._loaded and self.pipe is not None
490
+
491
+ @classmethod
492
+ def get_dimensions(cls, aspect_ratio: str) -> tuple:
493
+ """Get pixel dimensions for aspect ratio."""
494
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
495
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
src/utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility Functions
3
+ =================
4
+
5
+ Helper functions for image processing and file operations.
6
+ """
7
+
8
+ import re
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, Union
12
+ from datetime import datetime
13
+ from PIL import Image
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def ensure_pil_image(
20
+ obj: Union[Image.Image, str, Path, None],
21
+ context: str = ""
22
+ ) -> Image.Image:
23
+ """
24
+ Ensure object is a PIL Image.
25
+
26
+ Args:
27
+ obj: Image, path, or None
28
+ context: Context for error messages
29
+
30
+ Returns:
31
+ PIL Image
32
+
33
+ Raises:
34
+ ValueError: If object cannot be converted to Image
35
+ """
36
+ if obj is None:
37
+ raise ValueError(f"[{context}] Image is None")
38
+
39
+ if isinstance(obj, Image.Image):
40
+ return obj
41
+
42
+ if isinstance(obj, (str, Path)):
43
+ try:
44
+ return Image.open(obj)
45
+ except Exception as e:
46
+ raise ValueError(f"[{context}] Failed to load image from path: {e}")
47
+
48
+ raise ValueError(f"[{context}] Unsupported image type: {type(obj)}")
49
+
50
+
51
+ def sanitize_filename(name: str) -> str:
52
+ """
53
+ Sanitize string for use as filename.
54
+
55
+ Args:
56
+ name: Original name
57
+
58
+ Returns:
59
+ Safe filename string
60
+ """
61
+ # Replace problematic characters
62
+ safe_name = re.sub(r'[<>:"/\\|?*]', '_', name)
63
+ # Remove leading/trailing spaces and dots
64
+ safe_name = safe_name.strip('. ')
65
+ # Limit length
66
+ if len(safe_name) > 100:
67
+ safe_name = safe_name[:100]
68
+ return safe_name or "unnamed"
69
+
70
+
71
+ def save_image(
72
+ image: Image.Image,
73
+ directory: Path,
74
+ base_name: str,
75
+ format: str = "PNG"
76
+ ) -> Path:
77
+ """
78
+ Save image to directory.
79
+
80
+ Args:
81
+ image: PIL Image to save
82
+ directory: Output directory
83
+ base_name: Base filename (without extension)
84
+ format: Image format
85
+
86
+ Returns:
87
+ Path to saved file
88
+ """
89
+ directory = Path(directory)
90
+ directory.mkdir(parents=True, exist_ok=True)
91
+
92
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
93
+ safe_name = sanitize_filename(base_name)
94
+ ext = format.lower()
95
+
96
+ filename = f"{safe_name}_{timestamp}.{ext}"
97
+ filepath = directory / filename
98
+
99
+ image.save(filepath, format=format)
100
+ logger.info(f"Saved: {filepath}")
101
+
102
+ return filepath
103
+
104
+
105
+ def resize_for_display(
106
+ image: Image.Image,
107
+ max_size: int = 1024
108
+ ) -> Image.Image:
109
+ """
110
+ Resize image for display while maintaining aspect ratio.
111
+
112
+ Args:
113
+ image: PIL Image
114
+ max_size: Maximum dimension
115
+
116
+ Returns:
117
+ Resized image
118
+ """
119
+ width, height = image.size
120
+
121
+ if width <= max_size and height <= max_size:
122
+ return image
123
+
124
+ if width > height:
125
+ new_width = max_size
126
+ new_height = int(height * max_size / width)
127
+ else:
128
+ new_height = max_size
129
+ new_width = int(width * max_size / height)
130
+
131
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
132
+
133
+
134
+ def get_image_info(image: Image.Image) -> str:
135
+ """Get human-readable image info string."""
136
+ return f"{image.size[0]}x{image.size[1]} {image.mode}"
137
+
138
+
139
+ def preprocess_input_image(
140
+ image: Image.Image,
141
+ max_size: int = 1024,
142
+ target_size: tuple = None,
143
+ ensure_rgb: bool = True
144
+ ) -> Image.Image:
145
+ """
146
+ Preprocess input image for model consumption.
147
+
148
+ Handles various formats (JFIF, TIFF, WebP, etc.) by converting to RGB PNG-compatible format.
149
+
150
+ Args:
151
+ image: PIL Image to preprocess
152
+ max_size: Maximum dimension (used if target_size not specified)
153
+ target_size: Specific (width, height) to resize to
154
+ ensure_rgb: Convert to RGB mode
155
+
156
+ Returns:
157
+ Preprocessed PIL Image in RGB format
158
+ """
159
+ # Ensure we have a copy to avoid modifying original
160
+ img = image.copy()
161
+
162
+ # Force re-encode as PNG-compatible by saving to memory and reloading
163
+ # This handles weird formats like JFIF, TIFF, etc.
164
+ import io
165
+ buf = io.BytesIO()
166
+
167
+ # Convert to RGB first if needed
168
+ if img.mode not in ('RGB', 'RGBA'):
169
+ img = img.convert('RGB')
170
+
171
+ # Save as PNG to buffer and reload - this normalizes the format
172
+ img.save(buf, format='PNG')
173
+ buf.seek(0)
174
+ img = Image.open(buf)
175
+ img.load() # Force load into memory
176
+
177
+ # Convert to RGB if needed (handle RGBA)
178
+ if ensure_rgb and img.mode != 'RGB':
179
+ if img.mode == 'RGBA':
180
+ # Handle transparency by compositing on white background
181
+ background = Image.new('RGB', img.size, (255, 255, 255))
182
+ background.paste(img, mask=img.split()[3])
183
+ img = background
184
+ else:
185
+ img = img.convert('RGB')
186
+
187
+ # Resize to target size or max_size
188
+ if target_size:
189
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
190
+ else:
191
+ width, height = img.size
192
+ if width > max_size or height > max_size:
193
+ if width > height:
194
+ new_width = max_size
195
+ new_height = int(height * max_size / width)
196
+ else:
197
+ new_height = max_size
198
+ new_width = int(width * max_size / height)
199
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
200
+
201
+ return img
202
+
203
+
204
+ def preprocess_images_for_backend(
205
+ images: list,
206
+ backend_type: str,
207
+ aspect_ratio: str = "1:1"
208
+ ) -> list:
209
+ """
210
+ Preprocess a list of images for a specific backend.
211
+
212
+ Args:
213
+ images: List of PIL Images
214
+ backend_type: Backend type string (e.g., 'flux_klein', 'qwen_comfyui')
215
+ aspect_ratio: Target aspect ratio
216
+
217
+ Returns:
218
+ List of preprocessed PIL Images
219
+ """
220
+ if not images:
221
+ return images
222
+
223
+ # Backend-specific settings
224
+ # FLUX models work best with smaller input images (512-768px)
225
+ backend_configs = {
226
+ 'flux_klein': {'max_size': 768}, # 4B - faster with smaller inputs
227
+ 'flux_klein_9b_fp8': {'max_size': 768}, # 9B - same, quality comes from model not input size
228
+ 'qwen_image_edit': {'max_size': 1024},
229
+ 'qwen_comfyui': {'max_size': 1024},
230
+ 'zimage_turbo': {'max_size': 768},
231
+ 'zimage_base': {'max_size': 768},
232
+ 'longcat_edit': {'max_size': 768},
233
+ 'gemini_flash': {'max_size': 1024}, # Gemini handles larger but 1024 is fine
234
+ 'gemini_pro': {'max_size': 1024},
235
+ }
236
+
237
+ config = backend_configs.get(backend_type, {'max_size': 1024})
238
+ max_size = config['max_size']
239
+
240
+ processed = []
241
+ for img in images:
242
+ if img is not None:
243
+ processed.append(preprocess_input_image(img, max_size=max_size))
244
+ else:
245
+ processed.append(None)
246
+
247
+ return processed
src/zimage_client.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Z-Image Client
3
+ ==============
4
+
5
+ Client for Z-Image (Tongyi-MAI) local image generation.
6
+ Supports text-to-image and image-to-image editing.
7
+
8
+ Z-Image is a 6B parameter model that achieves state-of-the-art quality
9
+ with only 8-9 inference steps, fitting in 16GB VRAM.
10
+ """
11
+
12
+ import logging
13
+ import time
14
+ from typing import Optional, List
15
+ from PIL import Image
16
+
17
+ import torch
18
+
19
+ from .models import GenerationRequest, GenerationResult
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ZImageClient:
26
+ """
27
+ Client for Z-Image models from Tongyi-MAI.
28
+
29
+ Supports:
30
+ - Text-to-image generation (ZImagePipeline)
31
+ - Image-to-image editing (ZImageImg2ImgPipeline)
32
+ - Multiple model variants (Turbo, Base, Edit, Omni)
33
+ """
34
+
35
+ # Model variants
36
+ MODELS = {
37
+ # Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM
38
+ "turbo": "Tongyi-MAI/Z-Image-Turbo",
39
+ # Base - Quality-focused, more steps
40
+ "base": "Tongyi-MAI/Z-Image",
41
+ # Edit - Fine-tuned for instruction-following image editing
42
+ "edit": "Tongyi-MAI/Z-Image-Edit",
43
+ # Omni - Versatile, supports both generation and editing
44
+ "omni": "Tongyi-MAI/Z-Image-Omni-Base",
45
+ }
46
+
47
+ # Aspect ratio to dimensions mapping
48
+ # Z-Image supports 512x512 to 2048x2048
49
+ ASPECT_RATIOS = {
50
+ "1:1": (1024, 1024),
51
+ "16:9": (1344, 768),
52
+ "9:16": (768, 1344),
53
+ "21:9": (1536, 640), # Cinematic ultra-wide
54
+ "3:2": (1248, 832),
55
+ "2:3": (832, 1248),
56
+ "3:4": (896, 1152),
57
+ "4:3": (1152, 896),
58
+ "4:5": (896, 1120),
59
+ "5:4": (1120, 896),
60
+ }
61
+
62
+ # Default settings for each model variant
63
+ MODEL_DEFAULTS = {
64
+ "turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed
65
+ "base": {"steps": 50, "guidance": 4.0}, # Quality-focused
66
+ "edit": {"steps": 28, "guidance": 3.5}, # Editing
67
+ "omni": {"steps": 28, "guidance": 3.5}, # Versatile
68
+ }
69
+
70
+ def __init__(
71
+ self,
72
+ model_variant: str = "turbo",
73
+ device: str = "cuda",
74
+ dtype: torch.dtype = torch.bfloat16,
75
+ enable_cpu_offload: bool = True,
76
+ ):
77
+ """
78
+ Initialize Z-Image client.
79
+
80
+ Args:
81
+ model_variant: Model variant to use:
82
+ - "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED)
83
+ - "base": Quality-focused, 50 steps
84
+ - "edit": Instruction-following image editing
85
+ - "omni": Versatile generation + editing
86
+ device: Device to use (cuda or cpu)
87
+ dtype: Data type for model weights (bfloat16 recommended)
88
+ enable_cpu_offload: Enable CPU offload to save VRAM
89
+ """
90
+ self.model_variant = model_variant
91
+ self.device = device
92
+ self.dtype = dtype
93
+ self.enable_cpu_offload = enable_cpu_offload
94
+ self.pipe = None
95
+ self.pipe_img2img = None
96
+ self._loaded = False
97
+
98
+ # Get default settings for this variant
99
+ defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0})
100
+ self.default_steps = defaults["steps"]
101
+ self.default_guidance = defaults["guidance"]
102
+
103
+ logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
104
+
105
+ def load_model(self) -> bool:
106
+ """Load the model into memory."""
107
+ if self._loaded:
108
+ return True
109
+
110
+ try:
111
+ # Get model ID for selected variant
112
+ model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"])
113
+ logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...")
114
+
115
+ start_time = time.time()
116
+
117
+ # Import diffusers pipelines for Z-Image
118
+ # Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
119
+ from diffusers import ZImagePipeline, ZImageImg2ImgPipeline
120
+
121
+ # Load text-to-image pipeline
122
+ self.pipe = ZImagePipeline.from_pretrained(
123
+ model_id,
124
+ torch_dtype=self.dtype,
125
+ )
126
+
127
+ # Load img2img pipeline (shares components)
128
+ self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained(
129
+ model_id,
130
+ torch_dtype=self.dtype,
131
+ # Share components to save memory
132
+ text_encoder=self.pipe.text_encoder,
133
+ tokenizer=self.pipe.tokenizer,
134
+ vae=self.pipe.vae,
135
+ transformer=self.pipe.transformer,
136
+ scheduler=self.pipe.scheduler,
137
+ )
138
+
139
+ # Apply memory optimization
140
+ if self.enable_cpu_offload:
141
+ self.pipe.enable_model_cpu_offload()
142
+ self.pipe_img2img.enable_model_cpu_offload()
143
+ logger.info("CPU offload enabled")
144
+ else:
145
+ self.pipe.to(self.device)
146
+ self.pipe_img2img.to(self.device)
147
+ logger.info(f"Model moved to {self.device}")
148
+
149
+ # Optional: Enable flash attention if available
150
+ try:
151
+ self.pipe.transformer.set_attention_backend("flash")
152
+ self.pipe_img2img.transformer.set_attention_backend("flash")
153
+ logger.info("Flash Attention enabled")
154
+ except Exception:
155
+ logger.info("Flash Attention not available, using default SDPA")
156
+
157
+ load_time = time.time() - start_time
158
+ logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s")
159
+
160
+ # Validate by running a test generation
161
+ logger.info("Validating model with test generation...")
162
+ try:
163
+ test_result = self.pipe(
164
+ prompt="A simple test image",
165
+ height=256,
166
+ width=256,
167
+ guidance_scale=0.0,
168
+ num_inference_steps=2,
169
+ generator=torch.Generator(device="cpu").manual_seed(42),
170
+ )
171
+ if test_result.images[0] is not None:
172
+ logger.info("Model validation successful")
173
+ else:
174
+ logger.error("Model validation failed: no output image")
175
+ return False
176
+ except Exception as e:
177
+ logger.error(f"Model validation failed: {e}", exc_info=True)
178
+ return False
179
+
180
+ self._loaded = True
181
+ return True
182
+
183
+ except Exception as e:
184
+ logger.error(f"Failed to load Z-Image: {e}", exc_info=True)
185
+ return False
186
+
187
+ def unload_model(self):
188
+ """Unload model from memory."""
189
+ if self.pipe is not None:
190
+ del self.pipe
191
+ self.pipe = None
192
+ if self.pipe_img2img is not None:
193
+ del self.pipe_img2img
194
+ self.pipe_img2img = None
195
+
196
+ self._loaded = False
197
+
198
+ if torch.cuda.is_available():
199
+ torch.cuda.empty_cache()
200
+
201
+ logger.info("Z-Image unloaded")
202
+
203
+ def generate(
204
+ self,
205
+ request: GenerationRequest,
206
+ num_inference_steps: int = None,
207
+ guidance_scale: float = None
208
+ ) -> GenerationResult:
209
+ """
210
+ Generate image using Z-Image.
211
+
212
+ Args:
213
+ request: GenerationRequest object
214
+ num_inference_steps: Number of denoising steps (9 for turbo)
215
+ guidance_scale: Classifier-free guidance scale (0.0 for turbo)
216
+
217
+ Returns:
218
+ GenerationResult object
219
+ """
220
+ if not self._loaded:
221
+ if not self.load_model():
222
+ return GenerationResult.error_result("Failed to load Z-Image model")
223
+
224
+ # Use model defaults if not specified
225
+ if num_inference_steps is None:
226
+ num_inference_steps = self.default_steps
227
+ if guidance_scale is None:
228
+ guidance_scale = self.default_guidance
229
+
230
+ try:
231
+ start_time = time.time()
232
+
233
+ # Get dimensions from aspect ratio
234
+ width, height = self._get_dimensions(request.aspect_ratio)
235
+
236
+ logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
237
+
238
+ # Check if we have input images (use img2img pipeline)
239
+ if request.has_input_images:
240
+ return self._generate_img2img(
241
+ request, width, height, num_inference_steps, guidance_scale, start_time
242
+ )
243
+
244
+ # Text-to-image generation
245
+ gen_kwargs = {
246
+ "prompt": request.prompt,
247
+ "height": height,
248
+ "width": width,
249
+ "guidance_scale": guidance_scale,
250
+ "num_inference_steps": num_inference_steps,
251
+ "generator": torch.Generator(device="cpu").manual_seed(42),
252
+ }
253
+
254
+ # Add negative prompt if present
255
+ if request.negative_prompt:
256
+ gen_kwargs["negative_prompt"] = request.negative_prompt
257
+
258
+ logger.info(f"Generating with Z-Image: {request.prompt[:80]}...")
259
+
260
+ # Generate
261
+ with torch.inference_mode():
262
+ output = self.pipe(**gen_kwargs)
263
+ image = output.images[0]
264
+
265
+ generation_time = time.time() - start_time
266
+ logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
267
+
268
+ return GenerationResult.success_result(
269
+ image=image,
270
+ message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s",
271
+ generation_time=generation_time
272
+ )
273
+
274
+ except Exception as e:
275
+ logger.error(f"Z-Image generation failed: {e}", exc_info=True)
276
+ return GenerationResult.error_result(f"Z-Image error: {str(e)}")
277
+
278
+ def _generate_img2img(
279
+ self,
280
+ request: GenerationRequest,
281
+ width: int,
282
+ height: int,
283
+ num_inference_steps: int,
284
+ guidance_scale: float,
285
+ start_time: float
286
+ ) -> GenerationResult:
287
+ """Generate using img2img pipeline with input images."""
288
+ try:
289
+ # Get the first valid input image
290
+ input_image = None
291
+ for img in request.input_images:
292
+ if img is not None:
293
+ input_image = img
294
+ break
295
+
296
+ if input_image is None:
297
+ return GenerationResult.error_result("No valid input image provided")
298
+
299
+ # Resize input image to target dimensions
300
+ input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
301
+
302
+ # Build generation kwargs for img2img
303
+ gen_kwargs = {
304
+ "prompt": request.prompt,
305
+ "image": input_image,
306
+ "strength": 0.6, # How much to transform the image
307
+ "height": height,
308
+ "width": width,
309
+ "guidance_scale": guidance_scale,
310
+ "num_inference_steps": num_inference_steps,
311
+ "generator": torch.Generator(device="cpu").manual_seed(42),
312
+ }
313
+
314
+ # Add negative prompt if present
315
+ if request.negative_prompt:
316
+ gen_kwargs["negative_prompt"] = request.negative_prompt
317
+
318
+ logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...")
319
+
320
+ # Generate
321
+ with torch.inference_mode():
322
+ output = self.pipe_img2img(**gen_kwargs)
323
+ image = output.images[0]
324
+
325
+ generation_time = time.time() - start_time
326
+ logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}")
327
+
328
+ return GenerationResult.success_result(
329
+ image=image,
330
+ message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s",
331
+ generation_time=generation_time
332
+ )
333
+
334
+ except Exception as e:
335
+ logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True)
336
+ return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}")
337
+
338
+ def _get_dimensions(self, aspect_ratio: str) -> tuple:
339
+ """Get pixel dimensions for aspect ratio."""
340
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
341
+ return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
342
+
343
+ def is_healthy(self) -> bool:
344
+ """Check if model is loaded and ready."""
345
+ return self._loaded and self.pipe is not None
346
+
347
+ @classmethod
348
+ def get_dimensions(cls, aspect_ratio: str) -> tuple:
349
+ """Get pixel dimensions for aspect ratio."""
350
+ ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
351
+ return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))