File size: 2,925 Bytes
b0f5e54 f2f2537 b0f5e54 f2f2537 b0f5e54 f2f2537 b0f5e54 f2f2537 b0f5e54 f2f2537 b0f5e54 f2f2537 b0f5e54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import logging
import os
import random
import gradio as gr
from hautech import HautechRequest, OperationData, OperationInput, Poller
logger = logging.getLogger(__name__)
def inference(file, prompt, quality, seed, version):
if not file:
logger.debug("No file provided")
raise Exception("No file provided")
hautech_request = HautechRequest()
img_id = hautech_request.upload_image(file)
operations = hautech_request.post(
"/operations",
json=OperationData(
input=OperationInput(
aspectRatio="1:1",
productImageId=img_id,
quality=quality,
prompt=prompt,
seed=seed,
version=int(version),
),
type="generate",
).model_dump(),
)
data = operations.json()
generation_id = data.get("id", "")
if len(generation_id) == 0:
logger.debug(
"Operations returned no data", operations.status_code, operations.text
)
raise Exception("Generation ID is empty")
poller = Poller(hautech_request)
data = poller.poll(generation_id=generation_id, interval_sec=4)
response = data.get("output", {}).get("imageIds", [])
if len(response) == 0:
logger.debug(
"Empty array for `imageIds`", operations.status_code, operations.text
)
raise Exception("Failed to generate images")
images_urls = hautech_request.post("/images/urls", json={"ids": response})
image_data = images_urls.json()
res = image_data.values()
if len(res) == 0:
logger.debug(
f"Fetching from: 'images/urls' {response} returned empty array",
images_urls.status_code,
images_urls.text,
image_data,
)
raise Exception("Failed to get images")
return res
interface = gr.Interface(
fn=inference,
inputs=[
gr.File(label="Upload Garment Image"),
gr.Textbox(label="Enter Prompt", placeholder="Enter your description here"),
gr.Dropdown(
choices=["low", "high"],
label="Quality",
value="low",
info="Select image quality",
),
gr.Number(label="Seed", value=random.randint(1, 2**64 - 1)),
gr.Dropdown(
choices=["1"], label="Version", value="1", info="Select model version"
),
],
outputs=gr.Gallery(label="Generated Images"),
title="Hautech",
description="Upload a garment image and provide a prompt to generate related content.",
theme="huggingface",
)
if __name__ == "__main__":
log_level = os.getenv("LOG")
if log_level is not None and log_level.lower() == "debug":
logging.basicConfig(level=logging.DEBUG)
token = os.getenv("TOKEN")
if not token:
raise Exception("Token environment variable is required")
interface.launch()
|