Spaces:
Sleeping
Sleeping
| """ | |
| OmniGen2 Backend Plugin | |
| Plugin adapter for OmniGen2 local backend. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, List | |
| from PIL import Image | |
| # Add parent directories to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from core.omnigen2_client import OmniGen2Client | |
| from models.generation_request import GenerationRequest | |
| from config.settings import Settings | |
| # Import from shared plugin system | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'shared')) | |
| from plugin_system.base_plugin import BaseBackendPlugin | |
| class OmniGen2Plugin(BaseBackendPlugin): | |
| """Plugin adapter for OmniGen2 local backend.""" | |
| def __init__(self, config_path: Path): | |
| """Initialize OmniGen2 plugin.""" | |
| super().__init__(config_path) | |
| # Get settings | |
| settings = Settings() | |
| base_url = settings.omnigen2_base_url | |
| try: | |
| self.client = OmniGen2Client(base_url=base_url) | |
| # Test connection | |
| self.available = self.client.health_check() | |
| except Exception as e: | |
| print(f"Warning: OmniGen2 backend not available: {e}") | |
| self.client = None | |
| self.available = False | |
| def health_check(self) -> bool: | |
| """Check if OmniGen2 backend is available.""" | |
| if not self.available or self.client is None: | |
| return False | |
| try: | |
| return self.client.health_check() | |
| except: | |
| return False | |
| def generate_image( | |
| self, | |
| prompt: str, | |
| input_images: Optional[List[Image.Image]] = None, | |
| **kwargs | |
| ) -> Image.Image: | |
| """ | |
| Generate image using OmniGen2 backend. | |
| Args: | |
| prompt: Text prompt for generation | |
| input_images: Optional list of input images | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Generated PIL Image | |
| """ | |
| if not self.health_check(): | |
| raise RuntimeError("OmniGen2 backend not available") | |
| # Create generation request | |
| request = GenerationRequest( | |
| prompt=prompt, | |
| input_images=input_images or [], | |
| aspect_ratio=kwargs.get('aspect_ratio', '1:1'), | |
| number_of_images=kwargs.get('number_of_images', 1), | |
| guidance_scale=kwargs.get('guidance_scale', 3.0), | |
| num_inference_steps=kwargs.get('num_inference_steps', 50), | |
| seed=kwargs.get('seed', -1) | |
| ) | |
| # Generate image | |
| result = self.client.generate(request) | |
| if result.images: | |
| return result.images[0] | |
| else: | |
| raise RuntimeError(f"OmniGen2 generation failed: {result.error}") | |
| def get_capabilities(self) -> Dict[str, Any]: | |
| """Report OmniGen2 backend capabilities.""" | |
| return { | |
| 'name': 'OmniGen2 Local', | |
| 'type': 'local', | |
| 'supports_input_images': True, | |
| 'supports_multi_image': True, | |
| 'max_input_images': 8, | |
| 'supports_aspect_ratios': True, | |
| 'available_aspect_ratios': ['1:1', '3:4', '4:3', '9:16', '16:9', '3:2', '2:3', '4:5', '5:4', '21:9'], | |
| 'supports_guidance_scale': True, | |
| 'supports_inference_steps': True, | |
| 'supports_seed': True, | |
| 'estimated_time_per_image': 8.0, # seconds (depends on GPU) | |
| 'cost_per_image': 0.0, # Free, local | |
| } | |