File size: 4,899 Bytes
42337fe
facc56d
42337fe
7696246
3f15ec4
 
ba44318
42337fe
 
ba44318
7696246
 
 
 
 
 
 
 
ba44318
42337fe
7696246
facc56d
42337fe
facc56d
 
 
42337fe
 
ba44318
42337fe
 
facc56d
 
 
 
9ee2942
 
facc56d
7696246
facc56d
 
 
 
 
 
 
 
 
 
 
42337fe
 
 
 
 
 
 
 
7696246
42337fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7696246
42337fe
 
 
facc56d
7696246
42337fe
 
 
 
7696246
 
ba44318
 
42337fe
7696246
facc56d
ba44318
1b7bc3c
 
facc56d
1b7bc3c
facc56d
42337fe
1b7bc3c
 
 
 
 
 
 
 
 
 
 
 
ba44318
 
facc56d
 
 
 
ba44318
 
facc56d
42337fe
 
facc56d
 
d17404a
facc56d
 
 
 
 
 
 
 
 
 
b6b2d99
facc56d
 
d17404a
 
facc56d
 
1b7bc3c
d17404a
1b7bc3c
ba44318
facc56d
 
 
 
 
ba44318
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import asyncio
import base64
import io
import logging
import os

import gradio as gr
import httpx
from PIL import Image

# Configure once at startup
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)

logger = logging.getLogger(__name__)


async def generate_async(images, prompt, variation, size):
    logger.info("giving it to BytePlus...")
    api_key = os.getenv("BYTEPLUS_API_KEY")
    base_url = os.getenv("BYTEPLUS_URL", "").rstrip("/")
    response_format = "b64_json"
    watermark = False

    images_input = []
    images_output = []

    # Convert uploaded images to base64
    for img in images or []:
        try:
            with open(img.name, "rb") as f:
                image_bytes = f.read()
                encoded = base64.b64encode(image_bytes).decode("utf-8")
                prefixed_b64 = f"data:image/png;base64,{encoded}"
                images_input.append(prefixed_b64)
        except Exception as e:
            logger.error(f"⚠️ Failed to process image {img.name}: {e}")

    try:
        model_name = "seedream-4-0-250828"
        request_data = {
            "model": model_name,
            "prompt": prompt,
            "response_format": response_format,
            "sequential_image_generation": "disabled",
            "size": size,
            "watermark": watermark,
        }

        if images_input:
            request_data["image"] = images_input

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_key}",
        }
        logger.info("Sending request to BytePlus...")
        async with httpx.AsyncClient(timeout=120) as client:
            for _ in range(int(variation)):
                response = await client.post(
                    f"{base_url}/images/generations",
                    json=request_data,
                    headers=headers,
                )
                response.raise_for_status()
                result = response.json()

                # ✅ Fix: data is a list
                for item in result.get("data", []):
                    b64_str = item.get("b64_json")
                    if b64_str:
                        try:
                            image_data = base64.b64decode(b64_str)
                            image = Image.open(io.BytesIO(image_data))
                            images_output.append(image)
                        except Exception as e:
                            logger.warning(f"⚠️ Failed to decode base64 image: {e}")

        return images_output

    except Exception as e:
        logger.error(f"⚠️ Failed to process everything: {e}")
        return []


# Wrapper because Gradio doesn't await async functions
async def generate(images, prompt, variation, size):
    return await generate_async(images, prompt, variation, size)


# ------------------ Gradio UI ------------------ #
with gr.Blocks(theme=gr.themes.Glass()) as demo:
    gr.Markdown("## 🔥 Multi-API Image-to-Image Generator")

    with gr.Row():
        # === Left Column ===
        with gr.Column(scale=1):
            # Upload input
            image_input = gr.File(
                label="Upload your reference images (optional)",
                file_count="multiple",
                file_types=["image"],
            )

            # Preview gallery for uploaded images
            input_preview = gr.Gallery(label="Preview", columns=3, height="auto")

            # Auto-preview when images uploaded
            image_input.change(
                lambda files: [f.name for f in files] if files else [],
                inputs=image_input,
                outputs=input_preview,
            )

            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Describe how you want to modify or generate the image...",
                lines=2,
            )

            variation_choice = gr.Dropdown(
                choices=[1, 2, 3, 4, 5],
                label="Number of Variations",
                value=1,
            )

            size_choice = gr.Dropdown(
                choices=[
                    "2048x2048",
                    "1728x2304",
                    "2304x1728",
                    "2560x1440",
                    "1440x2560",
                    "1664x2496",
                    "2496x1664",
                    "3024x1296",
                ],
                label="Output Size",
                value="2048x2048",
            )

            btn = gr.Button("🚀 Generate", variant="primary", size="lg")

        # === Right Column ===
        with gr.Column(scale=1):
            gallery = gr.Gallery(label="Generated Results", columns=2, height="auto")

    btn.click(
        fn=generate,
        inputs=[image_input, prompt_input, variation_choice, size_choice],
        outputs=gallery,
    )

demo.launch()