""" Prompt Transformation Layer Transforms standard internal prompts to backend-specific formats. Each backend may have different: - Prompt structure (text, JSON, special tokens) - Parameter names - Value formats - Special requirements """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from dataclasses import dataclass from PIL import Image @dataclass class StandardGenerationRequest: """ Standard internal format for generation requests. This is the ONE format the application uses. Backend adapters transform this to backend-specific formats. """ # Core request prompt: str negative_prompt: Optional[str] = None # Input images (for img2img, controlnet, etc.) input_images: List[Image.Image] = None # Generation parameters width: int = 1024 height: int = 1024 num_images: int = 1 # Quality controls guidance_scale: float = 7.5 num_inference_steps: int = 50 seed: Optional[int] = None # Advanced options control_mode: Optional[str] = None # "canny", "depth", "pose", etc. strength: float = 0.8 # For img2img # Backend hints (preferences, not requirements) preferred_model: Optional[str] = None quality_preset: str = "balanced" # "fast", "balanced", "quality" def __post_init__(self): """Initialize mutable defaults.""" if self.input_images is None: self.input_images = [] class PromptTransformer(ABC): """ Abstract base class for prompt transformers. Each backend type has a transformer that converts StandardGenerationRequest to backend-specific format. """ @abstractmethod def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """ Transform standard request to backend-specific format. Args: request: Standard internal format Returns: Backend-specific request dict """ pass @abstractmethod def transform_response(self, response: Any) -> List[Image.Image]: """ Transform backend response to standard format. Args: response: Backend-specific response Returns: List of generated images """ pass class GeminiPromptTransformer(PromptTransformer): """Transformer for Gemini API format.""" def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """Transform to Gemini API format.""" # Gemini uses aspect ratios instead of width/height aspect_ratio = self._calculate_aspect_ratio(request.width, request.height) return { 'prompt': request.prompt, 'aspect_ratio': aspect_ratio, 'number_of_images': request.num_images, 'safety_filter_level': 'block_some', 'person_generation': 'allow_all', # Gemini doesn't support negative prompts directly # Could append to prompt: "... (avoid: {negative_prompt})" } def transform_response(self, response: Any) -> List[Image.Image]: """Transform Gemini response.""" # Gemini returns GenerationResult with .images list if hasattr(response, 'images'): return response.images return [] def _calculate_aspect_ratio(self, width: int, height: int) -> str: """Calculate aspect ratio string from dimensions.""" ratios = { (1, 1): "1:1", (16, 9): "16:9", (9, 16): "9:16", (4, 3): "4:3", (3, 4): "3:4", } # Find closest ratio ratio = width / height for (w, h), name in ratios.items(): if abs(ratio - (w/h)) < 0.1: return name return "1:1" # Default class OmniGen2PromptTransformer(PromptTransformer): """Transformer for OmniGen2 format.""" def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """Transform to OmniGen2 format.""" # OmniGen2 uses direct width/height transformed = { 'prompt': request.prompt, 'width': request.width, 'height': request.height, 'num_inference_steps': request.num_inference_steps, 'guidance_scale': request.guidance_scale, } # Add negative prompt if provided if request.negative_prompt: transformed['negative_prompt'] = request.negative_prompt # Add seed if provided if request.seed is not None: transformed['seed'] = request.seed else: transformed['seed'] = -1 # Random # Handle input images if request.input_images: transformed['input_images'] = request.input_images transformed['strength'] = request.strength return transformed def transform_response(self, response: Any) -> List[Image.Image]: """Transform OmniGen2 response.""" if hasattr(response, 'images'): return response.images return [] class ComfyUIPromptTransformer(PromptTransformer): """Transformer for ComfyUI workflow format.""" def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """Transform to ComfyUI workflow format.""" # ComfyUI uses workflow JSON with nodes # This is a simplified example - actual workflows are complex workflow = { 'nodes': { # Text encoder 'prompt_positive': { 'class_type': 'CLIPTextEncode', 'inputs': { 'text': request.prompt } }, # Negative prompt 'prompt_negative': { 'class_type': 'CLIPTextEncode', 'inputs': { 'text': request.negative_prompt or '' } }, # KSampler 'sampler': { 'class_type': 'KSampler', 'inputs': { 'seed': request.seed if request.seed else -1, 'steps': request.num_inference_steps, 'cfg': request.guidance_scale, 'width': request.width, 'height': request.height, } }, } } return workflow def transform_response(self, response: Any) -> List[Image.Image]: """Transform ComfyUI response.""" # ComfyUI returns images in specific format if isinstance(response, dict) and 'images' in response: return response['images'] return [] class FluxPromptTransformer(PromptTransformer): """Transformer for Flux.1 Kontext AI format.""" def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """Transform to Flux format.""" transformed = { 'prompt': request.prompt, 'width': request.width, 'height': request.height, 'num_inference_steps': request.num_inference_steps, 'guidance_scale': request.guidance_scale, } # Flux supports context images if request.input_images: transformed['context_images'] = request.input_images transformed['context_strength'] = request.strength return transformed def transform_response(self, response: Any) -> List[Image.Image]: """Transform Flux response.""" if hasattr(response, 'images'): return response.images return [] class QwenPromptTransformer(PromptTransformer): """Transformer for qwen_image_edit_2509 format.""" def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: """Transform to qwen format.""" # qwen is specifically for image editing if not request.input_images: raise ValueError("qwen requires input image for editing") transformed = { 'instruction': request.prompt, # qwen uses 'instruction' not 'prompt' 'input_image': request.input_images[0], # First image 'guidance_scale': request.guidance_scale, 'num_inference_steps': request.num_inference_steps, } if request.seed is not None: transformed['seed'] = request.seed return transformed def transform_response(self, response: Any) -> List[Image.Image]: """Transform qwen response.""" if hasattr(response, 'edited_image'): return [response.edited_image] return [] # Registry of transformers TRANSFORMER_REGISTRY = { 'gemini': GeminiPromptTransformer, 'omnigen2': OmniGen2PromptTransformer, 'comfyui': ComfyUIPromptTransformer, 'flux': FluxPromptTransformer, 'qwen': QwenPromptTransformer, } def get_transformer(backend_type: str) -> PromptTransformer: """ Get transformer for backend type. Args: backend_type: Backend type (e.g., 'gemini', 'omnigen2') Returns: PromptTransformer instance """ transformer_class = TRANSFORMER_REGISTRY.get(backend_type) if not transformer_class: raise ValueError(f"No transformer found for backend type: {backend_type}") return transformer_class()