| import os |
| import json |
| import gradio as gr |
| from huggingface_hub import InferenceClient |
| from huggingface_hub import get_token as hf_get_token |
| from gradio.context import LocalContext |
| import contextvars |
|
|
| workflow_token = contextvars.ContextVar("workflow_token", default=None) |
|
|
|
|
| def get_hf_token() -> str | None: |
| """ |
| Retrieves the HF API token from either the workflow context, |
| the user's Gradio OAuth session, or falls back to the system environment. |
| """ |
| w_token = workflow_token.get() |
| if w_token: |
| return w_token |
|
|
| request = LocalContext.request.get(None) |
| if request is not None: |
| session = getattr(request, "session", {}) |
| oauth_info = session.get("oauth_info", {}) |
| if oauth_info: |
| token = oauth_info.get("access_token") |
| if token and token != "mock-oauth-token-for-local-dev": |
| return token |
| try: |
| return hf_get_token() |
| except Exception: |
| return None |
|
|
|
|
| def generate_prompt(concept: str) -> str: |
| """ |
| Expands a simple concept into a detailed image prompt using the NVIDIA Nemotron model. |
| """ |
| if not concept: |
| return "a ginger cat wearing a tiny wizard hat reading a spellbook" |
| try: |
| token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") |
| client = InferenceClient( |
| provider="together", |
| api_key=token, |
| bill_to="huggingface", |
| ) |
| system_instruction = ( |
| "You are an expert prompt engineer for text-to-image models. " |
| "Your task is to take a simple concept and expand it into a detailed, " |
| "vivid, and high-quality image prompt for FLUX.1-dev. " |
| "Describe the scene, lighting, materials, and aesthetic in detail. " |
| "Provide ONLY the final prompt text. Do not include any introductory or concluding text, " |
| "do not provide multiple options, and do not wrap the prompt in quotes." |
| ) |
| messages = [ |
| {"role": "system", "content": system_instruction}, |
| {"role": "user", "content": f"Concept: {concept}"} |
| ] |
| response = client.chat_completion( |
| model="nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4", |
| messages=messages, |
| temperature=0.7, |
| max_tokens=256 |
| ) |
| result = response.choices[0].message.content |
| clean_result = str(result).strip() |
| if clean_result.startswith('"') and clean_result.endswith('"'): |
| clean_result = clean_result[1:-1] |
| elif clean_result.startswith("'") and clean_result.endswith("'"): |
| clean_result = clean_result[1:-1] |
| return clean_result |
| except Exception as e: |
| print(f"Error calling Nemotron model: {e}") |
| return f"A detailed, high-quality, professional commercial product photograph of {concept}" |
|
|
|
|
| def generate_image(prompt: str) -> dict: |
| """ |
| Generates an image from a prompt using the FLUX.1-dev model. |
| Returns a dictionary structure compatible with Gradio's image viewer. |
| """ |
| if not prompt: |
| prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" |
| try: |
| token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") |
| client = InferenceClient( |
| provider="auto", |
| api_key=token, |
| bill_to="huggingface", |
| ) |
| image = client.text_to_image( |
| prompt, |
| model="black-forest-labs/FLUX.1-dev", |
| ) |
| |
| import tempfile |
| import uuid |
| |
| temp_dir = tempfile.gettempdir() |
| filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") |
| image.save(filepath) |
| |
| return { |
| "path": filepath, |
| "url": f"/gradio_api/file={filepath}", |
| "is_file": True |
| } |
| except Exception as e: |
| print(f"Error calling FLUX.1-dev model: {e}") |
| raise e |
|
|
|
|
| def generate_z_image(prompt: str) -> dict: |
| """ |
| Generates an image from a prompt using the Tongyi-MAI/Z-Image-Turbo model. |
| Returns a dictionary structure compatible with Gradio's image viewer. |
| """ |
| if not prompt: |
| prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" |
| try: |
| token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") |
| client = InferenceClient( |
| provider="auto", |
| api_key=token, |
| bill_to="huggingface", |
| ) |
| image = client.text_to_image( |
| prompt, |
| model="Tongyi-MAI/Z-Image-Turbo", |
| ) |
| |
| import tempfile |
| import uuid |
| |
| temp_dir = tempfile.gettempdir() |
| filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") |
| image.save(filepath) |
| |
| return { |
| "path": filepath, |
| "url": f"/gradio_api/file={filepath}", |
| "is_file": True |
| } |
| except Exception as e: |
| print(f"Error calling Z-Image-Turbo model: {e}") |
| raise e |
|
|
|
|
| def edit_image(image_input: dict | str, prompt: str) -> dict | None: |
| """ |
| Edits a base image using the FLUX.2-klein-9B model. |
| Returns a dictionary structure compatible with Gradio's image viewer. |
| """ |
| print(f"DEBUG: edit_image called with image_input={image_input}, prompt={prompt}") |
| if not image_input or image_input == "None": |
| return None |
| if not prompt: |
| prompt = "Turn the cat into a tiger" |
| |
| try: |
| |
| if isinstance(image_input, dict): |
| image_path = image_input.get("path") |
| if not image_path: |
| url = image_input.get("url") |
| if url and "/gradio_api/file=" in url: |
| image_path = url.split("/gradio_api/file=")[-1] |
| else: |
| image_path = image_input |
| |
| if not image_path or image_path == "None" or not os.path.exists(image_path): |
| print(f"Workflow: Base image not generated/ready yet (path: {image_path})") |
| return None |
| |
| with open(image_path, "rb") as f: |
| input_image_bytes = f.read() |
| |
| token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") |
| client = InferenceClient( |
| provider="auto", |
| api_key=token, |
| bill_to="huggingface", |
| ) |
| image = client.image_to_image( |
| input_image_bytes, |
| prompt=prompt, |
| model="black-forest-labs/FLUX.2-klein-9B", |
| ) |
| |
| import tempfile |
| import uuid |
| |
| temp_dir = tempfile.gettempdir() |
| filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") |
| image.save(filepath) |
| |
| return { |
| "path": filepath, |
| "url": f"/gradio_api/file={filepath}", |
| "is_file": True |
| } |
| except Exception as e: |
| print(f"Error calling FLUX.2-klein-9B model: {e}") |
| raise e |
|
|
|
|
| def generate_ideogram_image(prompt: str) -> dict | None: |
| """ |
| Generates an image from a prompt using the ideogram-ai/ideogram4 Space. |
| Returns a dictionary structure compatible with Gradio's image viewer. |
| """ |
| if not prompt: |
| prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" |
| try: |
| from gradio_client import Client |
| |
| client = Client("ideogram-ai/ideogram4") |
| result = client.predict( |
| prompt=prompt, |
| mode="Default · 20 steps", |
| upsampler="Ideogram (remote)", |
| width=1024, |
| height=1024, |
| seed=0, |
| randomize_seed=True, |
| api_name="/generate", |
| ) |
| |
| filepath = result[0] |
| |
| return { |
| "path": filepath, |
| "url": f"/gradio_api/file={filepath}", |
| "is_file": True |
| } |
| except Exception as e: |
| print(f"Error calling ideogram-ai/ideogram4 Space: {e}") |
| raise e |
|
|
|
|
| demo = gr.Workflow(bind=[generate_prompt, generate_image, generate_z_image, edit_image, generate_ideogram_image]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|