File size: 6,050 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Dict, Optional, List
from ...tool import Tool
from ...storage_handler import FileStorageHandler, LocalStorageHandler
from .openai_utils import (
    create_openai_client,
    build_validation_params,
    validate_parameters,
    handle_validation_result,
)


class OpenAIImageGenerationTool(Tool):
    name: str = "openai_image_generation"
    description: str = "OpenAI image generation supporting dall-e-2, dall-e-3, gpt-image-1 (with validation)."

    inputs: Dict[str, Dict[str, str]] = {
        "prompt": {"type": "string", "description": "Prompt text. Required."},
        "image_name": {"type": "string", "description": "Optional save name."},
        "model": {"type": "string", "description": "dall-e-2 | dall-e-3 | gpt-image-1"},
        "size": {"type": "string", "description": "Model-specific size."},
        "quality": {"type": "string", "description": "quality for gpt-image-1/dall-e-3"},
        "n": {"type": "integer", "description": "1-10 (1 for dalle-3)"},
        "background": {"type": "string", "description": "gpt-image-1 only"},
        "moderation": {"type": "string", "description": "gpt-image-1 only"},
        "output_compression": {"type": "integer", "description": "gpt-image-1 jpeg/webp"},
        "output_format": {"type": "string", "description": "gpt-image-1 png/jpeg/webp"},
        "partial_images": {"type": "integer", "description": "gpt-image-1 streaming partials"},
        "response_format": {"type": "string", "description": "url | b64_json for dalle-2/3"},
        "stream": {"type": "boolean", "description": "gpt-image-1 streaming"},
        "style": {"type": "string", "description": "dall-e-3 vivid|natural"},
    }
    required: Optional[List[str]] = ["prompt"]

    def __init__(self, api_key: str, organization_id: str = None, model: str = "dall-e-3", 
                 save_path: str = "./generated_images", storage_handler: Optional[FileStorageHandler] = None):
        super().__init__()
        self.api_key = api_key
        self.organization_id = organization_id
        self.model = model
        self.save_path = save_path
        self.storage_handler = storage_handler or LocalStorageHandler(base_path=save_path)

    def __call__(
        self,
        prompt: str,
        image_name: str = None,
        model: str = None,
        size: str = None,
        quality: str = None,
        n: int = None,
        background: str = None,
        moderation: str = None,
        output_compression: int = None,
        output_format: str = None,
        partial_images: int = None,
        response_format: str = None,
        stream: bool = None,
        style: str = None,
    ):
        try:
            client = create_openai_client(self.api_key, self.organization_id)
            actual_model = model if model else self.model

            params_to_validate = build_validation_params(
                model=actual_model,
                prompt=prompt,
                size=size,
                quality=quality,
                n=n,
                background=background,
                moderation=moderation,
                output_compression=output_compression,
                output_format=output_format,
                partial_images=partial_images,
                response_format=response_format,
                stream=stream,
                style=style,
            )

            validation_result = validate_parameters(actual_model, params_to_validate, "generation")
            error = handle_validation_result(validation_result)
            if error:
                return error

            api_params = validation_result["validated_params"].copy()
            api_params.pop("image_name", None)

            response = client.images.generate(**api_params)

            # Save results using storage handler
            import base64
            results = []
            for i, image_data in enumerate(response.data):
                try:
                    if hasattr(image_data, "b64_json") and image_data.b64_json:
                        image_bytes = base64.b64decode(image_data.b64_json)
                    elif hasattr(image_data, "url") and image_data.url:
                        import requests
                        r = requests.get(image_data.url)
                        r.raise_for_status()
                        image_bytes = r.content
                    else:
                        raise Exception("No valid image data in response")

                    # Generate unique filename
                    filename = self._get_unique_filename(image_name, i)
                    
                    # Save using storage handler
                    result = self.storage_handler.save(filename, image_bytes)
                    
                    if result["success"]:
                        results.append(filename)
                    else:
                        results.append(f"Error saving image {i+1}: {result.get('error', 'Unknown error')}")
                except Exception as e:
                    results.append(f"Error saving image {i+1}: {e}")

            return {"results": results, "count": len(results)}
        except Exception as e:
            return {"error": f"Image generation failed: {e}"}
    
    def _get_unique_filename(self, image_name: str, index: int) -> str:
        """Generate a unique filename for the image"""
        import time
        
        if image_name:
            base = image_name.rsplit(".", 1)[0]
            filename = f"{base}_{index+1}.png"
        else:
            ts = int(time.time())
            filename = f"generated_{ts}_{index+1}.png"
        
        # Check if file exists and generate unique name
        counter = 1
        while self.storage_handler.exists(filename):
            if image_name:
                base = image_name.rsplit(".", 1)[0]
                filename = f"{base}_{index+1}_{counter}.png"
            else:
                filename = f"generated_{ts}_{index+1}_{counter}.png"
            counter += 1
            
        return filename