workflow / app.py
akhaliq's picture
akhaliq HF Staff
Remove workflow lock subclass and update Gradio wheel URL
c8a3a45
import os
import json
import gradio as gr
from huggingface_hub import InferenceClient
from huggingface_hub import get_token as hf_get_token
from gradio.context import LocalContext
import contextvars
workflow_token = contextvars.ContextVar("workflow_token", default=None)
def get_hf_token() -> str | None:
"""
Retrieves the HF API token from either the workflow context,
the user's Gradio OAuth session, or falls back to the system environment.
"""
w_token = workflow_token.get()
if w_token:
return w_token
request = LocalContext.request.get(None)
if request is not None:
session = getattr(request, "session", {})
oauth_info = session.get("oauth_info", {})
if oauth_info:
token = oauth_info.get("access_token")
if token and token != "mock-oauth-token-for-local-dev":
return token
try:
return hf_get_token()
except Exception:
return None
def generate_prompt(concept: str) -> str:
"""
Expands a simple concept into a detailed image prompt using the NVIDIA Nemotron model.
"""
if not concept:
return "a ginger cat wearing a tiny wizard hat reading a spellbook"
try:
token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
client = InferenceClient(
provider="together",
api_key=token,
bill_to="huggingface",
)
system_instruction = (
"You are an expert prompt engineer for text-to-image models. "
"Your task is to take a simple concept and expand it into a detailed, "
"vivid, and high-quality image prompt for FLUX.1-dev. "
"Describe the scene, lighting, materials, and aesthetic in detail. "
"Provide ONLY the final prompt text. Do not include any introductory or concluding text, "
"do not provide multiple options, and do not wrap the prompt in quotes."
)
messages = [
{"role": "system", "content": system_instruction},
{"role": "user", "content": f"Concept: {concept}"}
]
response = client.chat_completion(
model="nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-NVFP4",
messages=messages,
temperature=0.7,
max_tokens=256
)
result = response.choices[0].message.content
clean_result = str(result).strip()
if clean_result.startswith('"') and clean_result.endswith('"'):
clean_result = clean_result[1:-1]
elif clean_result.startswith("'") and clean_result.endswith("'"):
clean_result = clean_result[1:-1]
return clean_result
except Exception as e:
print(f"Error calling Nemotron model: {e}")
return f"A detailed, high-quality, professional commercial product photograph of {concept}"
def generate_image(prompt: str) -> dict:
"""
Generates an image from a prompt using the FLUX.1-dev model.
Returns a dictionary structure compatible with Gradio's image viewer.
"""
if not prompt:
prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook"
try:
token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
client = InferenceClient(
provider="auto",
api_key=token,
bill_to="huggingface",
)
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-dev",
)
import tempfile
import uuid
temp_dir = tempfile.gettempdir()
filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png")
image.save(filepath)
return {
"path": filepath,
"url": f"/gradio_api/file={filepath}",
"is_file": True
}
except Exception as e:
print(f"Error calling FLUX.1-dev model: {e}")
raise e
def generate_z_image(prompt: str) -> dict:
"""
Generates an image from a prompt using the Tongyi-MAI/Z-Image-Turbo model.
Returns a dictionary structure compatible with Gradio's image viewer.
"""
if not prompt:
prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook"
try:
token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
client = InferenceClient(
provider="auto",
api_key=token,
bill_to="huggingface",
)
image = client.text_to_image(
prompt,
model="Tongyi-MAI/Z-Image-Turbo",
)
import tempfile
import uuid
temp_dir = tempfile.gettempdir()
filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png")
image.save(filepath)
return {
"path": filepath,
"url": f"/gradio_api/file={filepath}",
"is_file": True
}
except Exception as e:
print(f"Error calling Z-Image-Turbo model: {e}")
raise e
def edit_image(image_input: dict | str, prompt: str) -> dict | None:
"""
Edits a base image using the FLUX.2-klein-9B model.
Returns a dictionary structure compatible with Gradio's image viewer.
"""
print(f"DEBUG: edit_image called with image_input={image_input}, prompt={prompt}")
if not image_input or image_input == "None":
return None
if not prompt:
prompt = "Turn the cat into a tiger"
try:
# Extract file path from Gradio image dictionary or string
if isinstance(image_input, dict):
image_path = image_input.get("path")
if not image_path:
url = image_input.get("url")
if url and "/gradio_api/file=" in url:
image_path = url.split("/gradio_api/file=")[-1]
else:
image_path = image_input
if not image_path or image_path == "None" or not os.path.exists(image_path):
print(f"Workflow: Base image not generated/ready yet (path: {image_path})")
return None
with open(image_path, "rb") as f:
input_image_bytes = f.read()
token = get_hf_token() or os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
client = InferenceClient(
provider="auto",
api_key=token,
bill_to="huggingface",
)
image = client.image_to_image(
input_image_bytes,
prompt=prompt,
model="black-forest-labs/FLUX.2-klein-9B",
)
import tempfile
import uuid
temp_dir = tempfile.gettempdir()
filepath = os.path.join(temp_dir, f"{uuid.uuid4()}.png")
image.save(filepath)
return {
"path": filepath,
"url": f"/gradio_api/file={filepath}",
"is_file": True
}
except Exception as e:
print(f"Error calling FLUX.2-klein-9B model: {e}")
raise e
def generate_ideogram_image(prompt: str) -> dict | None:
"""
Generates an image from a prompt using the ideogram-ai/ideogram4 Space.
Returns a dictionary structure compatible with Gradio's image viewer.
"""
if not prompt:
prompt = "a ginger cat wearing a tiny wizard hat reading a spellbook"
try:
from gradio_client import Client
client = Client("ideogram-ai/ideogram4")
result = client.predict(
prompt=prompt,
mode="Default · 20 steps",
upsampler="Ideogram (remote)",
width=1024,
height=1024,
seed=0,
randomize_seed=True,
api_name="/generate",
)
filepath = result[0]
return {
"path": filepath,
"url": f"/gradio_api/file={filepath}",
"is_file": True
}
except Exception as e:
print(f"Error calling ideogram-ai/ideogram4 Space: {e}")
raise e
demo = gr.Workflow(bind=[generate_prompt, generate_image, generate_z_image, edit_image, generate_ideogram_image])
if __name__ == "__main__":
demo.launch()