File size: 9,458 Bytes
5b6e956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
Prompt Transformation Layer

Transforms standard internal prompts to backend-specific formats.
Each backend may have different:
- Prompt structure (text, JSON, special tokens)
- Parameter names
- Value formats
- Special requirements
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from PIL import Image


@dataclass
class StandardGenerationRequest:
    """
    Standard internal format for generation requests.

    This is the ONE format the application uses.
    Backend adapters transform this to backend-specific formats.
    """

    # Core request
    prompt: str
    negative_prompt: Optional[str] = None

    # Input images (for img2img, controlnet, etc.)
    input_images: List[Image.Image] = None

    # Generation parameters
    width: int = 1024
    height: int = 1024
    num_images: int = 1

    # Quality controls
    guidance_scale: float = 7.5
    num_inference_steps: int = 50
    seed: Optional[int] = None

    # Advanced options
    control_mode: Optional[str] = None  # "canny", "depth", "pose", etc.
    strength: float = 0.8  # For img2img

    # Backend hints (preferences, not requirements)
    preferred_model: Optional[str] = None
    quality_preset: str = "balanced"  # "fast", "balanced", "quality"

    def __post_init__(self):
        """Initialize mutable defaults."""
        if self.input_images is None:
            self.input_images = []


class PromptTransformer(ABC):
    """
    Abstract base class for prompt transformers.

    Each backend type has a transformer that converts
    StandardGenerationRequest to backend-specific format.
    """

    @abstractmethod
    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """
        Transform standard request to backend-specific format.

        Args:
            request: Standard internal format

        Returns:
            Backend-specific request dict
        """
        pass

    @abstractmethod
    def transform_response(self, response: Any) -> List[Image.Image]:
        """
        Transform backend response to standard format.

        Args:
            response: Backend-specific response

        Returns:
            List of generated images
        """
        pass


class GeminiPromptTransformer(PromptTransformer):
    """Transformer for Gemini API format."""

    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """Transform to Gemini API format."""

        # Gemini uses aspect ratios instead of width/height
        aspect_ratio = self._calculate_aspect_ratio(request.width, request.height)

        return {
            'prompt': request.prompt,
            'aspect_ratio': aspect_ratio,
            'number_of_images': request.num_images,
            'safety_filter_level': 'block_some',
            'person_generation': 'allow_all',
            # Gemini doesn't support negative prompts directly
            # Could append to prompt: "... (avoid: {negative_prompt})"
        }

    def transform_response(self, response: Any) -> List[Image.Image]:
        """Transform Gemini response."""
        # Gemini returns GenerationResult with .images list
        if hasattr(response, 'images'):
            return response.images
        return []

    def _calculate_aspect_ratio(self, width: int, height: int) -> str:
        """Calculate aspect ratio string from dimensions."""
        ratios = {
            (1, 1): "1:1",
            (16, 9): "16:9",
            (9, 16): "9:16",
            (4, 3): "4:3",
            (3, 4): "3:4",
        }

        # Find closest ratio
        ratio = width / height
        for (w, h), name in ratios.items():
            if abs(ratio - (w/h)) < 0.1:
                return name

        return "1:1"  # Default


class OmniGen2PromptTransformer(PromptTransformer):
    """Transformer for OmniGen2 format."""

    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """Transform to OmniGen2 format."""

        # OmniGen2 uses direct width/height
        transformed = {
            'prompt': request.prompt,
            'width': request.width,
            'height': request.height,
            'num_inference_steps': request.num_inference_steps,
            'guidance_scale': request.guidance_scale,
        }

        # Add negative prompt if provided
        if request.negative_prompt:
            transformed['negative_prompt'] = request.negative_prompt

        # Add seed if provided
        if request.seed is not None:
            transformed['seed'] = request.seed
        else:
            transformed['seed'] = -1  # Random

        # Handle input images
        if request.input_images:
            transformed['input_images'] = request.input_images
            transformed['strength'] = request.strength

        return transformed

    def transform_response(self, response: Any) -> List[Image.Image]:
        """Transform OmniGen2 response."""
        if hasattr(response, 'images'):
            return response.images
        return []


class ComfyUIPromptTransformer(PromptTransformer):
    """Transformer for ComfyUI workflow format."""

    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """Transform to ComfyUI workflow format."""

        # ComfyUI uses workflow JSON with nodes
        # This is a simplified example - actual workflows are complex

        workflow = {
            'nodes': {
                # Text encoder
                'prompt_positive': {
                    'class_type': 'CLIPTextEncode',
                    'inputs': {
                        'text': request.prompt
                    }
                },

                # Negative prompt
                'prompt_negative': {
                    'class_type': 'CLIPTextEncode',
                    'inputs': {
                        'text': request.negative_prompt or ''
                    }
                },

                # KSampler
                'sampler': {
                    'class_type': 'KSampler',
                    'inputs': {
                        'seed': request.seed if request.seed else -1,
                        'steps': request.num_inference_steps,
                        'cfg': request.guidance_scale,
                        'width': request.width,
                        'height': request.height,
                    }
                },
            }
        }

        return workflow

    def transform_response(self, response: Any) -> List[Image.Image]:
        """Transform ComfyUI response."""
        # ComfyUI returns images in specific format
        if isinstance(response, dict) and 'images' in response:
            return response['images']
        return []


class FluxPromptTransformer(PromptTransformer):
    """Transformer for Flux.1 Kontext AI format."""

    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """Transform to Flux format."""

        transformed = {
            'prompt': request.prompt,
            'width': request.width,
            'height': request.height,
            'num_inference_steps': request.num_inference_steps,
            'guidance_scale': request.guidance_scale,
        }

        # Flux supports context images
        if request.input_images:
            transformed['context_images'] = request.input_images
            transformed['context_strength'] = request.strength

        return transformed

    def transform_response(self, response: Any) -> List[Image.Image]:
        """Transform Flux response."""
        if hasattr(response, 'images'):
            return response.images
        return []


class QwenPromptTransformer(PromptTransformer):
    """Transformer for qwen_image_edit_2509 format."""

    def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
        """Transform to qwen format."""

        # qwen is specifically for image editing
        if not request.input_images:
            raise ValueError("qwen requires input image for editing")

        transformed = {
            'instruction': request.prompt,  # qwen uses 'instruction' not 'prompt'
            'input_image': request.input_images[0],  # First image
            'guidance_scale': request.guidance_scale,
            'num_inference_steps': request.num_inference_steps,
        }

        if request.seed is not None:
            transformed['seed'] = request.seed

        return transformed

    def transform_response(self, response: Any) -> List[Image.Image]:
        """Transform qwen response."""
        if hasattr(response, 'edited_image'):
            return [response.edited_image]
        return []


# Registry of transformers
TRANSFORMER_REGISTRY = {
    'gemini': GeminiPromptTransformer,
    'omnigen2': OmniGen2PromptTransformer,
    'comfyui': ComfyUIPromptTransformer,
    'flux': FluxPromptTransformer,
    'qwen': QwenPromptTransformer,
}


def get_transformer(backend_type: str) -> PromptTransformer:
    """
    Get transformer for backend type.

    Args:
        backend_type: Backend type (e.g., 'gemini', 'omnigen2')

    Returns:
        PromptTransformer instance
    """
    transformer_class = TRANSFORMER_REGISTRY.get(backend_type)
    if not transformer_class:
        raise ValueError(f"No transformer found for backend type: {backend_type}")

    return transformer_class()