character_forge / shared /plugin_system /prompt_transformer.py
ghmk's picture
Initial deployment of Character Forge
5b6e956
"""
Prompt Transformation Layer
Transforms standard internal prompts to backend-specific formats.
Each backend may have different:
- Prompt structure (text, JSON, special tokens)
- Parameter names
- Value formats
- Special requirements
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from PIL import Image
@dataclass
class StandardGenerationRequest:
"""
Standard internal format for generation requests.
This is the ONE format the application uses.
Backend adapters transform this to backend-specific formats.
"""
# Core request
prompt: str
negative_prompt: Optional[str] = None
# Input images (for img2img, controlnet, etc.)
input_images: List[Image.Image] = None
# Generation parameters
width: int = 1024
height: int = 1024
num_images: int = 1
# Quality controls
guidance_scale: float = 7.5
num_inference_steps: int = 50
seed: Optional[int] = None
# Advanced options
control_mode: Optional[str] = None # "canny", "depth", "pose", etc.
strength: float = 0.8 # For img2img
# Backend hints (preferences, not requirements)
preferred_model: Optional[str] = None
quality_preset: str = "balanced" # "fast", "balanced", "quality"
def __post_init__(self):
"""Initialize mutable defaults."""
if self.input_images is None:
self.input_images = []
class PromptTransformer(ABC):
"""
Abstract base class for prompt transformers.
Each backend type has a transformer that converts
StandardGenerationRequest to backend-specific format.
"""
@abstractmethod
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""
Transform standard request to backend-specific format.
Args:
request: Standard internal format
Returns:
Backend-specific request dict
"""
pass
@abstractmethod
def transform_response(self, response: Any) -> List[Image.Image]:
"""
Transform backend response to standard format.
Args:
response: Backend-specific response
Returns:
List of generated images
"""
pass
class GeminiPromptTransformer(PromptTransformer):
"""Transformer for Gemini API format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to Gemini API format."""
# Gemini uses aspect ratios instead of width/height
aspect_ratio = self._calculate_aspect_ratio(request.width, request.height)
return {
'prompt': request.prompt,
'aspect_ratio': aspect_ratio,
'number_of_images': request.num_images,
'safety_filter_level': 'block_some',
'person_generation': 'allow_all',
# Gemini doesn't support negative prompts directly
# Could append to prompt: "... (avoid: {negative_prompt})"
}
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform Gemini response."""
# Gemini returns GenerationResult with .images list
if hasattr(response, 'images'):
return response.images
return []
def _calculate_aspect_ratio(self, width: int, height: int) -> str:
"""Calculate aspect ratio string from dimensions."""
ratios = {
(1, 1): "1:1",
(16, 9): "16:9",
(9, 16): "9:16",
(4, 3): "4:3",
(3, 4): "3:4",
}
# Find closest ratio
ratio = width / height
for (w, h), name in ratios.items():
if abs(ratio - (w/h)) < 0.1:
return name
return "1:1" # Default
class OmniGen2PromptTransformer(PromptTransformer):
"""Transformer for OmniGen2 format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to OmniGen2 format."""
# OmniGen2 uses direct width/height
transformed = {
'prompt': request.prompt,
'width': request.width,
'height': request.height,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
}
# Add negative prompt if provided
if request.negative_prompt:
transformed['negative_prompt'] = request.negative_prompt
# Add seed if provided
if request.seed is not None:
transformed['seed'] = request.seed
else:
transformed['seed'] = -1 # Random
# Handle input images
if request.input_images:
transformed['input_images'] = request.input_images
transformed['strength'] = request.strength
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform OmniGen2 response."""
if hasattr(response, 'images'):
return response.images
return []
class ComfyUIPromptTransformer(PromptTransformer):
"""Transformer for ComfyUI workflow format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to ComfyUI workflow format."""
# ComfyUI uses workflow JSON with nodes
# This is a simplified example - actual workflows are complex
workflow = {
'nodes': {
# Text encoder
'prompt_positive': {
'class_type': 'CLIPTextEncode',
'inputs': {
'text': request.prompt
}
},
# Negative prompt
'prompt_negative': {
'class_type': 'CLIPTextEncode',
'inputs': {
'text': request.negative_prompt or ''
}
},
# KSampler
'sampler': {
'class_type': 'KSampler',
'inputs': {
'seed': request.seed if request.seed else -1,
'steps': request.num_inference_steps,
'cfg': request.guidance_scale,
'width': request.width,
'height': request.height,
}
},
}
}
return workflow
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform ComfyUI response."""
# ComfyUI returns images in specific format
if isinstance(response, dict) and 'images' in response:
return response['images']
return []
class FluxPromptTransformer(PromptTransformer):
"""Transformer for Flux.1 Kontext AI format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to Flux format."""
transformed = {
'prompt': request.prompt,
'width': request.width,
'height': request.height,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
}
# Flux supports context images
if request.input_images:
transformed['context_images'] = request.input_images
transformed['context_strength'] = request.strength
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform Flux response."""
if hasattr(response, 'images'):
return response.images
return []
class QwenPromptTransformer(PromptTransformer):
"""Transformer for qwen_image_edit_2509 format."""
def transform_request(self, request: StandardGenerationRequest) -> Dict[str, Any]:
"""Transform to qwen format."""
# qwen is specifically for image editing
if not request.input_images:
raise ValueError("qwen requires input image for editing")
transformed = {
'instruction': request.prompt, # qwen uses 'instruction' not 'prompt'
'input_image': request.input_images[0], # First image
'guidance_scale': request.guidance_scale,
'num_inference_steps': request.num_inference_steps,
}
if request.seed is not None:
transformed['seed'] = request.seed
return transformed
def transform_response(self, response: Any) -> List[Image.Image]:
"""Transform qwen response."""
if hasattr(response, 'edited_image'):
return [response.edited_image]
return []
# Registry of transformers
TRANSFORMER_REGISTRY = {
'gemini': GeminiPromptTransformer,
'omnigen2': OmniGen2PromptTransformer,
'comfyui': ComfyUIPromptTransformer,
'flux': FluxPromptTransformer,
'qwen': QwenPromptTransformer,
}
def get_transformer(backend_type: str) -> PromptTransformer:
"""
Get transformer for backend type.
Args:
backend_type: Backend type (e.g., 'gemini', 'omnigen2')
Returns:
PromptTransformer instance
"""
transformer_class = TRANSFORMER_REGISTRY.get(backend_type)
if not transformer_class:
raise ValueError(f"No transformer found for backend type: {backend_type}")
return transformer_class()