File size: 4,872 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Dict, List, Optional
from ...tool import Tool
from ...storage_handler import FileStorageHandler, LocalStorageHandler
import requests
import time


class FluxImageGenerationEditTool(Tool):
    name: str = "flux_image_generation_edit"
    description: str = (
        "Text-to-image and image-editing using the bfl.ai flux-kontext-max API. "
        "Without input_image: generate from prompt. With input_image (base64): edit/transform."
    )

    inputs: Dict[str, Dict] = {
        "prompt": {"type": "string", "description": "The prompt describing the image to generate."},
        "input_image": {"type": "string", "description": "Base64 encoded input image for editing, optional."},
        "seed": {"type": "integer", "description": "Random seed, default is 42.", "default": 42},
        "aspect_ratio": {"type": "string", "description": "Aspect ratio, e.g. '1:1', optional."},
        "output_format": {"type": "string", "description": "Image format, default is jpeg.", "default": "jpeg"},
        "prompt_upsampling": {"type": "boolean", "description": "Enable prompt upsampling, default is false.", "default": False},
        "safety_tolerance": {"type": "integer", "description": "Safety tolerance level, default is 2.", "default": 2},
    }
    required: List[str] = ["prompt"]

    def __init__(self, api_key: str, storage_handler: Optional[FileStorageHandler] = None, 
                 base_path: str = "./imgs", save_path: str = None):
        super().__init__()
        self.api_key = api_key
        
        # Handle backward compatibility: if save_path is provided, use it as base_path
        if save_path is not None:
            base_path = save_path
            
        # Initialize storage handler
        if storage_handler is None:
            self.storage_handler = LocalStorageHandler(base_path=base_path)
        else:
            self.storage_handler = storage_handler

    def __call__(
        self,
        prompt: str,
        input_image: str = None,
        seed: int = 42,
        aspect_ratio: str = None,
        output_format: str = "jpeg",
        prompt_upsampling: bool = False,
        safety_tolerance: int = 2,
    ):
        payload = {
            "prompt": prompt,
            "seed": seed,
            "output_format": output_format,
            "prompt_upsampling": prompt_upsampling,
            "safety_tolerance": safety_tolerance,
        }
        if aspect_ratio:
            payload["aspect_ratio"] = aspect_ratio
        if input_image:
            payload["input_image"] = input_image

        headers = {
            "accept": "application/json",
            "x-key": self.api_key,
            "Content-Type": "application/json",
        }

        response = requests.post("https://api.bfl.ai/v1/flux-kontext-max", json=payload, headers=headers)
        response.raise_for_status()
        request_data = response.json()

        request_id = request_data["id"]
        polling_url = request_data["polling_url"]

        while True:
            time.sleep(2)
            result = requests.get(
                polling_url,
                headers={
                    "accept": "application/json",
                    "x-key": self.api_key,
                },
                params={"id": request_id},
            ).json()

            if result["status"] == "Ready":
                image_url = result["result"]["sample"]
                break
            elif result["status"] in ["Error", "Failed"]:
                raise ValueError(f"Generation failed: {result}")

        image_response = requests.get(image_url)
        image_response.raise_for_status()
        image_content = image_response.content

        # Generate unique filename using storage handler
        filename = self._get_unique_filename(seed, output_format)
        
        # Save image using storage handler
        result = self.storage_handler.save(filename, image_content)
        
        if result["success"]:
            return {
                "success": True,
                "file_path": filename,
                "full_path": result.get("full_path", filename),
                "message": f"Image saved successfully as {filename}"
            }
        else:
            return {
                "success": False,
                "error": f"Failed to save image: {result.get('error', 'Unknown error')}"
            }
    
    def _get_unique_filename(self, seed: int, output_format: str) -> str:
        """Generate a unique filename for the image"""
        base_filename = f"flux_{seed}.{output_format}"
        filename = base_filename
        counter = 1
        
        # Check if file exists and generate unique name
        while self.storage_handler.exists(filename):
            filename = f"flux_{seed}_{counter}.{output_format}"
            counter += 1
            
        return filename