File size: 4,872 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Dict, List, Optional
from ...tool import Tool
from ...storage_handler import FileStorageHandler, LocalStorageHandler
import requests
import time
class FluxImageGenerationEditTool(Tool):
name: str = "flux_image_generation_edit"
description: str = (
"Text-to-image and image-editing using the bfl.ai flux-kontext-max API. "
"Without input_image: generate from prompt. With input_image (base64): edit/transform."
)
inputs: Dict[str, Dict] = {
"prompt": {"type": "string", "description": "The prompt describing the image to generate."},
"input_image": {"type": "string", "description": "Base64 encoded input image for editing, optional."},
"seed": {"type": "integer", "description": "Random seed, default is 42.", "default": 42},
"aspect_ratio": {"type": "string", "description": "Aspect ratio, e.g. '1:1', optional."},
"output_format": {"type": "string", "description": "Image format, default is jpeg.", "default": "jpeg"},
"prompt_upsampling": {"type": "boolean", "description": "Enable prompt upsampling, default is false.", "default": False},
"safety_tolerance": {"type": "integer", "description": "Safety tolerance level, default is 2.", "default": 2},
}
required: List[str] = ["prompt"]
def __init__(self, api_key: str, storage_handler: Optional[FileStorageHandler] = None,
base_path: str = "./imgs", save_path: str = None):
super().__init__()
self.api_key = api_key
# Handle backward compatibility: if save_path is provided, use it as base_path
if save_path is not None:
base_path = save_path
# Initialize storage handler
if storage_handler is None:
self.storage_handler = LocalStorageHandler(base_path=base_path)
else:
self.storage_handler = storage_handler
def __call__(
self,
prompt: str,
input_image: str = None,
seed: int = 42,
aspect_ratio: str = None,
output_format: str = "jpeg",
prompt_upsampling: bool = False,
safety_tolerance: int = 2,
):
payload = {
"prompt": prompt,
"seed": seed,
"output_format": output_format,
"prompt_upsampling": prompt_upsampling,
"safety_tolerance": safety_tolerance,
}
if aspect_ratio:
payload["aspect_ratio"] = aspect_ratio
if input_image:
payload["input_image"] = input_image
headers = {
"accept": "application/json",
"x-key": self.api_key,
"Content-Type": "application/json",
}
response = requests.post("https://api.bfl.ai/v1/flux-kontext-max", json=payload, headers=headers)
response.raise_for_status()
request_data = response.json()
request_id = request_data["id"]
polling_url = request_data["polling_url"]
while True:
time.sleep(2)
result = requests.get(
polling_url,
headers={
"accept": "application/json",
"x-key": self.api_key,
},
params={"id": request_id},
).json()
if result["status"] == "Ready":
image_url = result["result"]["sample"]
break
elif result["status"] in ["Error", "Failed"]:
raise ValueError(f"Generation failed: {result}")
image_response = requests.get(image_url)
image_response.raise_for_status()
image_content = image_response.content
# Generate unique filename using storage handler
filename = self._get_unique_filename(seed, output_format)
# Save image using storage handler
result = self.storage_handler.save(filename, image_content)
if result["success"]:
return {
"success": True,
"file_path": filename,
"full_path": result.get("full_path", filename),
"message": f"Image saved successfully as {filename}"
}
else:
return {
"success": False,
"error": f"Failed to save image: {result.get('error', 'Unknown error')}"
}
def _get_unique_filename(self, seed: int, output_format: str) -> str:
"""Generate a unique filename for the image"""
base_filename = f"flux_{seed}.{output_format}"
filename = base_filename
counter = 1
# Check if file exists and generate unique name
while self.storage_handler.exists(filename):
filename = f"flux_{seed}_{counter}.{output_format}"
counter += 1
return filename
|