File size: 3,458 Bytes
5b6e956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
OmniGen2 Backend Plugin

Plugin adapter for OmniGen2 local backend.
"""

import sys
from pathlib import Path
from typing import Any, Dict, Optional, List
from PIL import Image

# Add parent directories to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from core.omnigen2_client import OmniGen2Client
from models.generation_request import GenerationRequest
from config.settings import Settings

# Import from shared plugin system
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'shared'))
from plugin_system.base_plugin import BaseBackendPlugin


class OmniGen2Plugin(BaseBackendPlugin):
    """Plugin adapter for OmniGen2 local backend."""

    def __init__(self, config_path: Path):
        """Initialize OmniGen2 plugin."""
        super().__init__(config_path)

        # Get settings
        settings = Settings()
        base_url = settings.omnigen2_base_url

        try:
            self.client = OmniGen2Client(base_url=base_url)
            # Test connection
            self.available = self.client.health_check()
        except Exception as e:
            print(f"Warning: OmniGen2 backend not available: {e}")
            self.client = None
            self.available = False

    def health_check(self) -> bool:
        """Check if OmniGen2 backend is available."""
        if not self.available or self.client is None:
            return False

        try:
            return self.client.health_check()
        except:
            return False

    def generate_image(
        self,
        prompt: str,
        input_images: Optional[List[Image.Image]] = None,
        **kwargs
    ) -> Image.Image:
        """
        Generate image using OmniGen2 backend.

        Args:
            prompt: Text prompt for generation
            input_images: Optional list of input images
            **kwargs: Additional generation parameters

        Returns:
            Generated PIL Image
        """
        if not self.health_check():
            raise RuntimeError("OmniGen2 backend not available")

        # Create generation request
        request = GenerationRequest(
            prompt=prompt,
            input_images=input_images or [],
            aspect_ratio=kwargs.get('aspect_ratio', '1:1'),
            number_of_images=kwargs.get('number_of_images', 1),
            guidance_scale=kwargs.get('guidance_scale', 3.0),
            num_inference_steps=kwargs.get('num_inference_steps', 50),
            seed=kwargs.get('seed', -1)
        )

        # Generate image
        result = self.client.generate(request)

        if result.images:
            return result.images[0]
        else:
            raise RuntimeError(f"OmniGen2 generation failed: {result.error}")

    def get_capabilities(self) -> Dict[str, Any]:
        """Report OmniGen2 backend capabilities."""
        return {
            'name': 'OmniGen2 Local',
            'type': 'local',
            'supports_input_images': True,
            'supports_multi_image': True,
            'max_input_images': 8,
            'supports_aspect_ratios': True,
            'available_aspect_ratios': ['1:1', '3:4', '4:3', '9:16', '16:9', '3:2', '2:3', '4:5', '5:4', '21:9'],
            'supports_guidance_scale': True,
            'supports_inference_steps': True,
            'supports_seed': True,
            'estimated_time_per_image': 8.0,  # seconds (depends on GPU)
            'cost_per_image': 0.0,  # Free, local
        }