File size: 4,071 Bytes
7bdc0ee
 
c6da11a
7bdc0ee
 
 
 
 
 
 
8185baa
 
049a9d0
8185baa
 
 
be9f8ce
8185baa
11d179f
 
 
 
 
 
 
c6da11a
 
7bdc0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6da11a
7bdc0ee
 
c6da11a
11d179f
7bdc0ee
 
11d179f
f11eee8
7bdc0ee
 
 
 
 
 
 
 
 
c6da11a
7bdc0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
c6da11a
7bdc0ee
 
c6da11a
7bdc0ee
 
c6da11a
 
 
 
 
7bdc0ee
711ba8c
7bdc0ee
be9f8ce
711ba8c
 
c6da11a
2e7f9aa
11d179f
711ba8c
ca805f2
5423dae
7bdc0ee
11d179f
f4537e8
94229e7
 
c6da11a
 
 
7bdc0ee
c6da11a
7bdc0ee
a68c643
7bdc0ee
a68c643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import time

import requests
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os


# Ссылка на файл CSS
css_url = "https://neurixyufi-aihub.static.hf.space/style.css"

# Получение CSS по ссылке
response = requests.get(css_url)
css = response.text + ".gradio-container{max-width: 700px !important} h1{text-align:center}"

styles = {
    "Свой стиль": "DEFAULT",
    "Аниме": "ANIME",
    "Детальное фото": "UHD",
    "Кандинский": "KANDINSKY"
}

timeout = 125

api_key = os.getenv("api_key")
secret_key = os.getenv("secret_key")

class Text2ImageAPI:

    def __init__(self, url, api_key, secret_key):
        self.URL = url
        self.AUTH_HEADERS = {
            'X-Key': f'Key {api_key}',
            'X-Secret': f'Secret {secret_key}',
        }

    def get_model(self):
        response = requests.get(self.URL + 'key/api/v1/models', headers=self.AUTH_HEADERS)
        data = response.json()
        return data[0]['id']

    def generate(self, prompt, negative, style, width, height, model):
        params = {
            "type": "GENERATE",
            "numImages": 1,
            "style": f"{style}",
            "width": width,
            "height": height,
            "negativePromptUnclip": negative,
            "censored": False,
            "generateParams": {
                "query": f"{prompt}"
            }
        }

        data = {
            'model_id': (None, model),
            'params': (None, json.dumps(params), 'application/json')
        }
        response = requests.post(self.URL + 'key/api/v1/text2image/run', headers=self.AUTH_HEADERS, files=data, timeout=timeout)
        data = response.json()
        return data['uuid']

    def check_generation(self, request_id, attempts=10, delay=10):
        while attempts > 0:
            response = requests.get(self.URL + 'key/api/v1/text2image/status/' + request_id, headers=self.AUTH_HEADERS)
            data = response.json()
            if data['status'] == 'DONE':
                return data['images']

            attempts -= 1
            time.sleep(delay)


def api_gradio(prompt, negative, style, width, height):
    api = Text2ImageAPI('https://api-key.fusionbrain.ai/', api_key, secret_key)
    model_id = api.get_model()
    uuid = api.generate(prompt, negative, styles[style], width, height, model_id)
    images = api.check_generation(uuid)

    decoded_data = base64.b64decode(images[0])
    image = Image.open(BytesIO(decoded_data))

    return image



with gr.Blocks(css=css) as demo:
    gr.Markdown("# Kandinsky")
    with gr.Column():
        with gr.Row():
            prompt = gr.Textbox(show_label=False, placeholder="Описание изображения", max_lines=3, lines=1, interactive=True, scale=20)
        with gr.Row():
            style = gr.Radio(show_label=False, value="Свой стиль", choices=list(styles.keys()))
        with gr.Row():
            button = gr.Button(value="Создать")
    with gr.Accordion("Дополнительные настройки", open=False):
        with gr.Row():
            negative = gr.Textbox(label="Отрицательная подсказка", placeholder="Исключения, чего не должно быть на фото", max_lines=3, lines=1, interactive=True, scale=20)
        with gr.Row():
            width = gr.Slider(label="Ширина", minimum=128, maximum=1024, step=1, value=1024, interactive=True)
            height = gr.Slider(label="Высота", minimum=128, maximum=1024, step=1, value=1024, interactive=True)
            
      #  with gr.Row():
         #   images = gr.Slider(label="Количество изображений", minimum=1, maximum=4, step=1, value=1, interactive=True)
    with gr.Row():
        gallery = gr.Image(show_label=False)

    button.click(api_gradio, inputs=[prompt, negative, style, width, height], outputs=gallery, queue=True, concurrency_limit=250)

demo.queue(max_size=250).launch(show_api=False, share=False)