import os import json import gradio as gr from huggingface_hub import InferenceClient from huggingface_hub import get_token as hf_get_token from gradio.context import LocalContext import contextvars workflow_token = contextvars.ContextVar("workflow_token", default=None) def get_hf_token() -> str | None: """ Retrieves the HF API token from either the workflow context, the user's Gradio OAuth session, or falls back to the system environment. """ w_token = workflow_token.get() if w_token: return w_token request = LocalContext.request.get(None) if request is not None: session = getattr(request, "session", {}) oauth_info = session.get("oauth_info", {}) if oauth_info: token = oauth_info.get("access_token") if token and token != "mock-oauth-token-for-local-dev": return token try: return hf_get_token() except Exception: return None def generate_prompt(concept: str) -> str: """ Expands a simple concept into a detailed image prompt using the NVIDIA Nemotron model. """ if not concept: return "a ginger cat wearing a tiny wizard hat reading a spellbook" try: token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") client = InferenceClient( provider="together", api_key=token, bill_to="huggingface", ) system_instruction = ( "You are an expert prompt engineer for text-to-image models. " "Your task is to take a simple concept and expand it into a detailed, " "vivid, and high-quality image prompt for FLUX.1-dev. " "Describe the scene, lighting, materials, and aesthetic in detail. " "Provide ONLY the final prompt text. Do not include any introductory or concluding text, " "do not provide multiple options, and do not wrap the prompt in quotes." ) messages = [ {"role": "system", "content": system_instruction}, {"role": "user", "content": f"Concept: {concept}"} ] response = client.chat_completion( model="nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4", messages=messages, temperature=0.7, max_tokens=256 ) result = response.choices[0].message.content clean_result = str(result).strip() if clean_result.startswith('"') and clean_result.endswith('"'): clean_result = clean_result[1:-1] elif clean_result.startswith("'") and clean_result.endswith("'"): clean_result = clean_result[1:-1] return clean_result except Exception as e: print(f"Error calling Nemotron model: {e}") return f"A detailed, high-quality, professional commercial product photograph of {concept}" def generate_image(prompt: str) -> dict: """ Generates an image from a prompt using the FLUX.1-dev model. Returns a dictionary structure compatible with Gradio's image viewer. """ if not prompt: prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" try: token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") client = InferenceClient( provider="auto", api_key=token, bill_to="huggingface", ) image = client.text_to_image( prompt, model="black-forest-labs/FLUX.1-dev", ) import tempfile import uuid temp_dir = tempfile.gettempdir() filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") image.save(filepath) return { "path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True } except Exception as e: print(f"Error calling FLUX.1-dev model: {e}") raise e def generate_z_image(prompt: str) -> dict: """ Generates an image from a prompt using the Tongyi-MAI/Z-Image-Turbo model. Returns a dictionary structure compatible with Gradio's image viewer. """ if not prompt: prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" try: token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") client = InferenceClient( provider="auto", api_key=token, bill_to="huggingface", ) image = client.text_to_image( prompt, model="Tongyi-MAI/Z-Image-Turbo", ) import tempfile import uuid temp_dir = tempfile.gettempdir() filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") image.save(filepath) return { "path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True } except Exception as e: print(f"Error calling Z-Image-Turbo model: {e}") raise e def edit_image(image_input: dict | str, prompt: str) -> dict | None: """ Edits a base image using the FLUX.2-klein-9B model. Returns a dictionary structure compatible with Gradio's image viewer. """ print(f"DEBUG: edit_image called with image_input={image_input}, prompt={prompt}") if not image_input or image_input == "None": return None if not prompt: prompt = "Turn the cat into a tiger" try: # Extract file path from Gradio image dictionary or string if isinstance(image_input, dict): image_path = image_input.get("path") if not image_path: url = image_input.get("url") if url and "/gradio_api/file=" in url: image_path = url.split("/gradio_api/file=")[-1] else: image_path = image_input if not image_path or image_path == "None" or not os.path.exists(image_path): print(f"Workflow: Base image not generated/ready yet (path: {image_path})") return None with open(image_path, "rb") as f: input_image_bytes = f.read() token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") client = InferenceClient( provider="auto", api_key=token, bill_to="huggingface", ) image = client.image_to_image( input_image_bytes, prompt=prompt, model="black-forest-labs/FLUX.2-klein-9B", ) import tempfile import uuid temp_dir = tempfile.gettempdir() filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png") image.save(filepath) return { "path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True } except Exception as e: print(f"Error calling FLUX.2-klein-9B model: {e}") raise e def generate_ideogram_image(prompt: str) -> dict | None: """ Generates an image from a prompt using the ideogram-ai/ideogram4 Space. Returns a dictionary structure compatible with Gradio's image viewer. """ if not prompt: prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook" try: from gradio_client import Client client = Client("ideogram-ai/ideogram4") result = client.predict( prompt=prompt, mode="Default ยท 20 steps", upsampler="Ideogram (remote)", width=1024, height=1024, seed=0, randomize_seed=True, api_name="/generate", ) filepath = result[0] return { "path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True } except Exception as e: print(f"Error calling ideogram-ai/ideogram4 Space: {e}") raise e demo = gr.Workflow(bind=[generate_prompt, generate_image, generate_z_image, edit_image, generate_ideogram_image]) if __name__ == "__main__": demo.launch()