Spaces:
Sleeping
Sleeping
File size: 9,458 Bytes
5b6e956 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
"""
Prompt Transformation Layer
Transforms standard internal prompts to backend-specific formats.
Each backend may have different:
- Prompt structure (text, JSON, special tokens)
- Parameter names
- Value formats
- Special requirements
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from PIL import Image
@dataclass
class StandardGenerationRequest:
"""
Standard internal format for generation requests.
This is the ONE format the application uses.
Backend adapters transform this to backend-specific formats.
"""
# Core request
prompt: str
negative_prompt: Optional[str] = None
# Input images (for img2img, controlnet, etc.)
input_images: List[Image.Image] = None
# Generation parameters
width: int = 1024
height: int = 1024
num_images: int = 1
# Quality controls
guidance_scale: float = 7.5
num_inference_steps: int = 50
seed: Optional[int] = None
# Advanced options
control_mode: Optional[str] = None # "canny", "depth", "pose", etc.
strength: float = 0.8 # For img2img
# Backend hints (preferences, not requirements)
preferred_model: Optional[str] = None
quality_preset: str = "balanced" # "fast", "balanced", "quality"
def __post_init__(self):
"""Initialize mutable defaults."""
if self.input_images is None:
self.input_images = []
class PromptTransformer(ABC):
"""
Abstract base class for prompt transformers.
Each backend type has a transformer that converts
StandardGenerationRequest to backend-specific format.
"""
@abstractmethod
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""
Transform standard request to backend-specific format.
Args:
request: Standard internal format
Returns:
Backend-specific request dict
"""
pass
@abstractmethod
def transform_response(self, response: Any) -> List[Image.Image]:
"""
Transform backend response to standard format.
Args:
response: Backend-specific response
Returns:
List of generated images
"""
pass
class GeminiPromptTransformer(PromptTransformer):
"""Transformer for Gemini API format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to Gemini API format."""
# Gemini uses aspect ratios instead of width/height
aspect_ratio = self._calculate_aspect_ratio(request.width, request.height)
return {
'prompt': request.prompt,
'aspect_ratio': aspect_ratio,
'number_of_images': request.num_images,
'safety_filter_level': 'block_some',
'person_generation': 'allow_all',
# Gemini doesn't support negative prompts directly
# Could append to prompt: "... (avoid: {negative_prompt})"
}
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform Gemini response."""
# Gemini returns GenerationResult with .images list
if hasattr(response, 'images'):
return response.images
return []
def _calculate_aspect_ratio(self, width: int, height: int) -> str:
"""Calculate aspect ratio string from dimensions."""
ratios = {
(1, 1): "1:1",
(16, 9): "16:9",
(9, 16): "9:16",
(4, 3): "4:3",
(3, 4): "3:4",
}
# Find closest ratio
ratio = width / height
for (w, h), name in ratios.items():
if abs(ratio - (w/h)) < 0.1:
return name
return "1:1" # Default
class OmniGen2PromptTransformer(PromptTransformer):
"""Transformer for OmniGen2 format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to OmniGen2 format."""
# OmniGen2 uses direct width/height
transformed = {
'prompt': request.prompt,
'width': request.width,
'height': request.height,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
}
# Add negative prompt if provided
if request.negative_prompt:
transformed['negative_prompt'] = request.negative_prompt
# Add seed if provided
if request.seed is not None:
transformed['seed'] = request.seed
else:
transformed['seed'] = -1 # Random
# Handle input images
if request.input_images:
transformed['input_images'] = request.input_images
transformed['strength'] = request.strength
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform OmniGen2 response."""
if hasattr(response, 'images'):
return response.images
return []
class ComfyUIPromptTransformer(PromptTransformer):
"""Transformer for ComfyUI workflow format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to ComfyUI workflow format."""
# ComfyUI uses workflow JSON with nodes
# This is a simplified example - actual workflows are complex
workflow = {
'nodes': {
# Text encoder
'prompt_positive': {
'class_type': 'CLIPTextEncode',
'inputs': {
'text': request.prompt
}
},
# Negative prompt
'prompt_negative': {
'class_type': 'CLIPTextEncode',
'inputs': {
'text': request.negative_prompt or ''
}
},
# KSampler
'sampler': {
'class_type': 'KSampler',
'inputs': {
'seed': request.seed if request.seed else -1,
'steps': request.num_inference_steps,
'cfg': request.guidance_scale,
'width': request.width,
'height': request.height,
}
},
}
}
return workflow
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform ComfyUI response."""
# ComfyUI returns images in specific format
if isinstance(response, dict) and 'images' in response:
return response['images']
return []
class FluxPromptTransformer(PromptTransformer):
"""Transformer for Flux.1 Kontext AI format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to Flux format."""
transformed = {
'prompt': request.prompt,
'width': request.width,
'height': request.height,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
}
# Flux supports context images
if request.input_images:
transformed['context_images'] = request.input_images
transformed['context_strength'] = request.strength
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform Flux response."""
if hasattr(response, 'images'):
return response.images
return []
class QwenPromptTransformer(PromptTransformer):
"""Transformer for qwen_image_edit_2509 format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to qwen format."""
# qwen is specifically for image editing
if not request.input_images:
raise ValueError("qwen requires input image for editing")
transformed = {
'instruction': request.prompt, # qwen uses 'instruction' not 'prompt'
'input_image': request.input_images[0], # First image
'guidance_scale': request.guidance_scale,
'num_inference_steps': request.num_inference_steps,
}
if request.seed is not None:
transformed['seed'] = request.seed
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform qwen response."""
if hasattr(response, 'edited_image'):
return [response.edited_image]
return []
# Registry of transformers
TRANSFORMER_REGISTRY = {
'gemini': GeminiPromptTransformer,
'omnigen2': OmniGen2PromptTransformer,
'comfyui': ComfyUIPromptTransformer,
'flux': FluxPromptTransformer,
'qwen': QwenPromptTransformer,
}
def get_transformer(backend_type: str) -> PromptTransformer:
"""
Get transformer for backend type.
Args:
backend_type: Backend type (e.g., 'gemini', 'omnigen2')
Returns:
PromptTransformer instance
"""
transformer_class = TRANSFORMER_REGISTRY.get(backend_type)
if not transformer_class:
raise ValueError(f"No transformer found for backend type: {backend_type}")
return transformer_class()
|