File size: 5,171 Bytes
02e14ff
 
 
 
 
 
 
8d015f7
 
 
 
 
 
 
 
 
 
02e14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d015f7
 
 
 
02e14ff
 
8d015f7
 
02e14ff
 
8d015f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02e14ff
8d015f7
 
 
02e14ff
 
 
 
 
 
 
 
 
8d015f7
02e14ff
 
 
 
8d015f7
 
 
 
 
02e14ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d015f7
02e14ff
 
 
 
 
 
 
8d015f7
02e14ff
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import base64
import time

import requests
import streamlit as st


def encode_image(uploaded_file=None, image_url=None):
    if uploaded_file:
        # Handle file upload
        bytes_data = uploaded_file.getvalue()
    elif image_url:
        # Handle image URL
        response = requests.get(image_url)
        bytes_data = response.content
    else:
        raise ValueError("Either uploaded_file or image_url must be provided.")
    encoded = base64.b64encode(bytes_data).decode()
    return f"data:image/jpeg;base64,{encoded}"


class Output:
    def __init__(self, obj: dict):
        self.data = obj["data"]
        self.output_image = False
        if self.is_image():
            self.get_result()

    def is_image(self):
        return "images" in self.data

    def get_result(self):
        self.images = [ImageOutput(**img) for img in self.data["images"]]
        for img in self.images:
            if img.is_output():
                self.output_image = True
                self.output_url = img.url
                break


class ImageOutput:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    def is_output(self):
        return self.type == "output"


def main():
    st.title("Image Generator")

    ## add sidebar with a text field for the token
    st.sidebar.title("Settings")
    st.sidebar.subheader("Enter your API token")
    token = st.sidebar.text_input("Token")

    # File uploader
    col1, col2 = st.columns(2)

    with col1:
        uploaded_file = st.file_uploader(
        "Choose an image", type=["png", "jpg", "jpeg"]
    )
    with col2:
        image_url = st.text_input("Or enter an image URL", help="Remove the uploaded image if you want to use an image URL")
    if uploaded_file:
        st.image(uploaded_file, caption="Uploaded Image")
    elif image_url:
        st.image(image_url, caption="Image URL")
    # Text inputs
    input_category = st.text_input("Enter category (Optional)")
    input_variant = st.text_input("Enter variant")

    # Method input
    option_labels = [
    "Value 1: For replacement of element in foreground",
    "Value 2: For replacement of background",
    "Value 3: To generate a new image with the same dimensions based on the prompt",
    "Value 4: Workflow stop",
    ]

    option_values = [1, 2, 3, 4]

    selected_option = st.selectbox(
        "Select an option:",
        index=None,
        options=option_labels,
    )
    if selected_option is None:
        input_number = None
    else:
        input_number = option_values[option_labels.index(selected_option)]

    # Generation
    if st.button("Generate") and (uploaded_file or image_url) and input_variant:
        encoded_image = encode_image(uploaded_file=uploaded_file, image_url=image_url)

        # Initial API call
        headers = {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json",
        }

        data = {
            "execution_mode": "async",
            "inputs": {"input_image": encoded_image, "input_text": input_variant},
            "workflow_version_id": "c5249acc-2cda-4734-b4f8-7823cecbce3d",
            "machine_id": "353b27c7-6bce-4472-b5eb-18f22d2373fc",
        }

        if input_category:
            data["inputs"]["input_category"] = input_category
        if input_number:
            data["inputs"]["input_number"] = input_number

        with st.spinner("Generating image..."):
            try:
                response = requests.post(
                    "https://api.comfydeploy.com/api/run",
                    headers=headers,
                    json=data,
                )
                response.raise_for_status()
                run_id = response.json()["run_id"]
                st.write(f"Run ID: {run_id}")
                counter = 0
                retrieving = True
                while retrieving:
                    time.sleep(2)
                    get_response = requests.get(
                        f"https://www.comfydeploy.com/api/run?run_id={run_id}",
                        headers=headers,
                    )
                    get_response.raise_for_status()
                    result = get_response.json()

                    if result["status"] == "success":
                        st.success("Image generated successfully!")
                        retrieving = False
                        break
                    elif result["status"] == "running":
                        continue
                    elif result["status"] == "failed":
                        st.error("Image generation failed")
                        break

                    counter += 1
                outputs = [Output(x) for x in result["outputs"]]

                final_img = [
                    output.output_url
                    for output in outputs
                    if output.output_image
                ][0]
                st.image(final_img, caption="Generated Image")


            except requests.exceptions.RequestException as e:
                st.error(f"Error occurred: {str(e)}")


if __name__ == "__main__":
    main()