Spaces:
Sleeping
Sleeping
| """ | |
| Prompt Transformation Layer | |
| Transforms standard internal prompts to backend-specific formats. | |
| Each backend may have different: | |
| - Prompt structure (text, JSON, special tokens) | |
| - Parameter names | |
| - Value formats | |
| - Special requirements | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List, Optional | |
| from dataclasses import dataclass | |
| from PIL import Image | |
| class StandardGenerationRequest: | |
| """ | |
| Standard internal format for generation requests. | |
| This is the ONE format the application uses. | |
| Backend adapters transform this to backend-specific formats. | |
| """ | |
| # Core request | |
| prompt: str | |
| negative_prompt: Optional[str] = None | |
| # Input images (for img2img, controlnet, etc.) | |
| input_images: List[Image.Image] = None | |
| # Generation parameters | |
| width: int = 1024 | |
| height: int = 1024 | |
| num_images: int = 1 | |
| # Quality controls | |
| guidance_scale: float = 7.5 | |
| num_inference_steps: int = 50 | |
| seed: Optional[int] = None | |
| # Advanced options | |
| control_mode: Optional[str] = None # "canny", "depth", "pose", etc. | |
| strength: float = 0.8 # For img2img | |
| # Backend hints (preferences, not requirements) | |
| preferred_model: Optional[str] = None | |
| quality_preset: str = "balanced" # "fast", "balanced", "quality" | |
| def __post_init__(self): | |
| """Initialize mutable defaults.""" | |
| if self.input_images is None: | |
| self.input_images = [] | |
| class PromptTransformer(ABC): | |
| """ | |
| Abstract base class for prompt transformers. | |
| Each backend type has a transformer that converts | |
| StandardGenerationRequest to backend-specific format. | |
| """ | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """ | |
| Transform standard request to backend-specific format. | |
| Args: | |
| request: Standard internal format | |
| Returns: | |
| Backend-specific request dict | |
| """ | |
| pass | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """ | |
| Transform backend response to standard format. | |
| Args: | |
| response: Backend-specific response | |
| Returns: | |
| List of generated images | |
| """ | |
| pass | |
| class GeminiPromptTransformer(PromptTransformer): | |
| """Transformer for Gemini API format.""" | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """Transform to Gemini API format.""" | |
| # Gemini uses aspect ratios instead of width/height | |
| aspect_ratio = self._calculate_aspect_ratio(request.width, request.height) | |
| return { | |
| 'prompt': request.prompt, | |
| 'aspect_ratio': aspect_ratio, | |
| 'number_of_images': request.num_images, | |
| 'safety_filter_level': 'block_some', | |
| 'person_generation': 'allow_all', | |
| # Gemini doesn't support negative prompts directly | |
| # Could append to prompt: "... (avoid: {negative_prompt})" | |
| } | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """Transform Gemini response.""" | |
| # Gemini returns GenerationResult with .images list | |
| if hasattr(response, 'images'): | |
| return response.images | |
| return [] | |
| def _calculate_aspect_ratio(self, width: int, height: int) -> str: | |
| """Calculate aspect ratio string from dimensions.""" | |
| ratios = { | |
| (1, 1): "1:1", | |
| (16, 9): "16:9", | |
| (9, 16): "9:16", | |
| (4, 3): "4:3", | |
| (3, 4): "3:4", | |
| } | |
| # Find closest ratio | |
| ratio = width / height | |
| for (w, h), name in ratios.items(): | |
| if abs(ratio - (w/h)) < 0.1: | |
| return name | |
| return "1:1" # Default | |
| class OmniGen2PromptTransformer(PromptTransformer): | |
| """Transformer for OmniGen2 format.""" | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """Transform to OmniGen2 format.""" | |
| # OmniGen2 uses direct width/height | |
| transformed = { | |
| 'prompt': request.prompt, | |
| 'width': request.width, | |
| 'height': request.height, | |
| 'num_inference_steps': request.num_inference_steps, | |
| 'guidance_scale': request.guidance_scale, | |
| } | |
| # Add negative prompt if provided | |
| if request.negative_prompt: | |
| transformed['negative_prompt'] = request.negative_prompt | |
| # Add seed if provided | |
| if request.seed is not None: | |
| transformed['seed'] = request.seed | |
| else: | |
| transformed['seed'] = -1 # Random | |
| # Handle input images | |
| if request.input_images: | |
| transformed['input_images'] = request.input_images | |
| transformed['strength'] = request.strength | |
| return transformed | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """Transform OmniGen2 response.""" | |
| if hasattr(response, 'images'): | |
| return response.images | |
| return [] | |
| class ComfyUIPromptTransformer(PromptTransformer): | |
| """Transformer for ComfyUI workflow format.""" | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """Transform to ComfyUI workflow format.""" | |
| # ComfyUI uses workflow JSON with nodes | |
| # This is a simplified example - actual workflows are complex | |
| workflow = { | |
| 'nodes': { | |
| # Text encoder | |
| 'prompt_positive': { | |
| 'class_type': 'CLIPTextEncode', | |
| 'inputs': { | |
| 'text': request.prompt | |
| } | |
| }, | |
| # Negative prompt | |
| 'prompt_negative': { | |
| 'class_type': 'CLIPTextEncode', | |
| 'inputs': { | |
| 'text': request.negative_prompt or '' | |
| } | |
| }, | |
| # KSampler | |
| 'sampler': { | |
| 'class_type': 'KSampler', | |
| 'inputs': { | |
| 'seed': request.seed if request.seed else -1, | |
| 'steps': request.num_inference_steps, | |
| 'cfg': request.guidance_scale, | |
| 'width': request.width, | |
| 'height': request.height, | |
| } | |
| }, | |
| } | |
| } | |
| return workflow | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """Transform ComfyUI response.""" | |
| # ComfyUI returns images in specific format | |
| if isinstance(response, dict) and 'images' in response: | |
| return response['images'] | |
| return [] | |
| class FluxPromptTransformer(PromptTransformer): | |
| """Transformer for Flux.1 Kontext AI format.""" | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """Transform to Flux format.""" | |
| transformed = { | |
| 'prompt': request.prompt, | |
| 'width': request.width, | |
| 'height': request.height, | |
| 'num_inference_steps': request.num_inference_steps, | |
| 'guidance_scale': request.guidance_scale, | |
| } | |
| # Flux supports context images | |
| if request.input_images: | |
| transformed['context_images'] = request.input_images | |
| transformed['context_strength'] = request.strength | |
| return transformed | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """Transform Flux response.""" | |
| if hasattr(response, 'images'): | |
| return response.images | |
| return [] | |
| class QwenPromptTransformer(PromptTransformer): | |
| """Transformer for qwen_image_edit_2509 format.""" | |
| def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]: | |
| """Transform to qwen format.""" | |
| # qwen is specifically for image editing | |
| if not request.input_images: | |
| raise ValueError("qwen requires input image for editing") | |
| transformed = { | |
| 'instruction': request.prompt, # qwen uses 'instruction' not 'prompt' | |
| 'input_image': request.input_images[0], # First image | |
| 'guidance_scale': request.guidance_scale, | |
| 'num_inference_steps': request.num_inference_steps, | |
| } | |
| if request.seed is not None: | |
| transformed['seed'] = request.seed | |
| return transformed | |
| def transform_response(self, response: Any) -> List[Image.Image]: | |
| """Transform qwen response.""" | |
| if hasattr(response, 'edited_image'): | |
| return [response.edited_image] | |
| return [] | |
| # Registry of transformers | |
| TRANSFORMER_REGISTRY = { | |
| 'gemini': GeminiPromptTransformer, | |
| 'omnigen2': OmniGen2PromptTransformer, | |
| 'comfyui': ComfyUIPromptTransformer, | |
| 'flux': FluxPromptTransformer, | |
| 'qwen': QwenPromptTransformer, | |
| } | |
| def get_transformer(backend_type: str) -> PromptTransformer: | |
| """ | |
| Get transformer for backend type. | |
| Args: | |
| backend_type: Backend type (e.g., 'gemini', 'omnigen2') | |
| Returns: | |
| PromptTransformer instance | |
| """ | |
| transformer_class = TRANSFORMER_REGISTRY.get(backend_type) | |
| if not transformer_class: | |
| raise ValueError(f"No transformer found for backend type: {backend_type}") | |
| return transformer_class() | |