File size: 2,925 Bytes
b0f5e54
f2f2537
b0f5e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2f2537
b0f5e54
 
f2f2537
 
 
 
b0f5e54
f2f2537
b0f5e54
 
 
 
 
f2f2537
b0f5e54
 
 
 
 
f2f2537
 
 
 
 
 
 
 
b0f5e54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import os
import random

import gradio as gr
from hautech import HautechRequest, OperationData, OperationInput, Poller

logger = logging.getLogger(__name__)


def inference(file, prompt, quality, seed, version):
    if not file:
        logger.debug("No file provided")
        raise Exception("No file provided")

    hautech_request = HautechRequest()
    img_id = hautech_request.upload_image(file)
    operations = hautech_request.post(
        "/operations",
        json=OperationData(
            input=OperationInput(
                aspectRatio="1:1",
                productImageId=img_id,
                quality=quality,
                prompt=prompt,
                seed=seed,
                version=int(version),
            ),
            type="generate",
        ).model_dump(),
    )
    data = operations.json()
    generation_id = data.get("id", "")

    if len(generation_id) == 0:
        logger.debug(
            "Operations returned no data", operations.status_code, operations.text
        )
        raise Exception("Generation ID is empty")

    poller = Poller(hautech_request)
    data = poller.poll(generation_id=generation_id, interval_sec=4)

    response = data.get("output", {}).get("imageIds", [])
    if len(response) == 0:
        logger.debug(
            "Empty array for `imageIds`", operations.status_code, operations.text
        )
        raise Exception("Failed to generate images")

    images_urls = hautech_request.post("/images/urls", json={"ids": response})
    image_data = images_urls.json()

    res = image_data.values()
    if len(res) == 0:
        logger.debug(
            f"Fetching from: 'images/urls' {response} returned empty array",
            images_urls.status_code,
            images_urls.text,
            image_data,
        )
        raise Exception("Failed to get images")

    return res


interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.File(label="Upload Garment Image"),
        gr.Textbox(label="Enter Prompt", placeholder="Enter your description here"),
        gr.Dropdown(
            choices=["low", "high"],
            label="Quality",
            value="low",
            info="Select image quality",
        ),
        gr.Number(label="Seed", value=random.randint(1, 2**64 - 1)),
        gr.Dropdown(
            choices=["1"], label="Version", value="1", info="Select model version"
        ),
    ],
    outputs=gr.Gallery(label="Generated Images"),
    title="Hautech",
    description="Upload a garment image and provide a prompt to generate related content.",
    theme="huggingface",
)

if __name__ == "__main__":
    log_level = os.getenv("LOG")
    if log_level is not None and log_level.lower() == "debug":
        logging.basicConfig(level=logging.DEBUG)

    token = os.getenv("TOKEN")
    if not token:
        raise Exception("Token environment variable is required")

    interface.launch()