| import logging |
| import os |
| import random |
|
|
| import gradio as gr |
| from hautech import HautechRequest, OperationData, OperationInput, Poller |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def inference(file, prompt, quality, seed, version): |
| if not file: |
| logger.debug("No file provided") |
| raise Exception("No file provided") |
|
|
| hautech_request = HautechRequest() |
| img_id = hautech_request.upload_image(file) |
| operations = hautech_request.post( |
| "/operations", |
| json=OperationData( |
| input=OperationInput( |
| aspectRatio="1:1", |
| productImageId=img_id, |
| quality=quality, |
| prompt=prompt, |
| seed=seed, |
| version=int(version), |
| ), |
| type="generate", |
| ).model_dump(), |
| ) |
| data = operations.json() |
| generation_id = data.get("id", "") |
|
|
| if len(generation_id) == 0: |
| logger.debug( |
| "Operations returned no data", operations.status_code, operations.text |
| ) |
| raise Exception("Generation ID is empty") |
|
|
| poller = Poller(hautech_request) |
| data = poller.poll(generation_id=generation_id, interval_sec=4) |
|
|
| response = data.get("output", {}).get("imageIds", []) |
| if len(response) == 0: |
| logger.debug( |
| "Empty array for `imageIds`", operations.status_code, operations.text |
| ) |
| raise Exception("Failed to generate images") |
|
|
| images_urls = hautech_request.post("/images/urls", json={"ids": response}) |
| image_data = images_urls.json() |
|
|
| res = image_data.values() |
| if len(res) == 0: |
| logger.debug( |
| f"Fetching from: 'images/urls' {response} returned empty array", |
| images_urls.status_code, |
| images_urls.text, |
| image_data, |
| ) |
| raise Exception("Failed to get images") |
|
|
| return res |
|
|
|
|
| interface = gr.Interface( |
| fn=inference, |
| inputs=[ |
| gr.File(label="Upload Garment Image"), |
| gr.Textbox(label="Enter Prompt", placeholder="Enter your description here"), |
| gr.Dropdown( |
| choices=["low", "high"], |
| label="Quality", |
| value="low", |
| info="Select image quality", |
| ), |
| gr.Number(label="Seed", value=random.randint(1, 2**64 - 1)), |
| gr.Dropdown( |
| choices=["1"], label="Version", value="1", info="Select model version" |
| ), |
| ], |
| outputs=gr.Gallery(label="Generated Images"), |
| title="Hautech", |
| description="Upload a garment image and provide a prompt to generate related content.", |
| theme="huggingface", |
| ) |
|
|
| if __name__ == "__main__": |
| log_level = os.getenv("LOG") |
| if log_level is not None and log_level.lower() == "debug": |
| logging.basicConfig(level=logging.DEBUG) |
|
|
| token = os.getenv("TOKEN") |
| if not token: |
| raise Exception("Token environment variable is required") |
|
|
| interface.launch() |
|
|