Krea-Lightning / app.py
namelessai's picture
Update app.py
7c4da30 verified
import gradio as gr
import requests
import io
from PIL import Image
import os
import sys
import base64 # <-- Import for Base64 decoding
from dotenv import load_dotenv
# --- Configuration ---
# Load environment variables from a .env file (if it exists)
load_dotenv()
# Get the API URL from the environment variable
MODAL_API_URL = os.environ.get("MODAL_API_URL")
# --- Startup Check ---
if not MODAL_API_URL:
print("Error: MODAL_API_URL environment variable not set.")
print("Please create a .env file and add: MODAL_API_URL='your_api_url_here'")
sys.exit(1)
# --- End Configuration ---
def generate_image(prompt, steps):
"""
Calls the Modal API to generate an image.
Returns a tuple: (PIL.Image or None, status_message)
"""
print(f"Requesting image for prompt: '{prompt}' with {steps} steps.")
headers = {"Content-Type": "application/json"}
data = {
"prompt": prompt,
"num_inference_steps": int(steps)
}
# 2. Make the POST request
try:
response = requests.post(MODAL_API_URL, headers=headers, json=data, timeout=120)
# 3. Handle the response
if response.status_code == 200:
content_type = response.headers.get('Content-Type', '').lower()
# --- NEW LOGIC: Handle JSON response with Base64 image ---
if 'application/json' in content_type:
try:
json_data = response.json()
# Check for the expected response format: {"success":true, "image":"..."}
if json_data.get("success") and "image" in json_data:
# 1. Get the base64 string
base64_string = json_data["image"]
# 2. Decode the base64 string into bytes
image_data_bytes = base64.b64decode(base64_string)
# 3. Create a bytes buffer
image_bytes_io = io.BytesIO(image_data_bytes)
# 4. Open with PIL
img = Image.open(image_bytes_io)
print("Success: JSON response with Base64 image received and processed.")
return img, "Generation successful!"
else:
# JSON response, but not in the expected format
print(f"Error: Received JSON, but 'success' or 'image' key is missing or false.")
error_msg = f"Error: API returned unexpected JSON.\nResponse: {response.text[:500]}"
return None, error_msg
except requests.exceptions.JSONDecodeError:
print("Error: API said it was JSON but failed to decode.")
return None, "Error: API returned invalid JSON."
except (base64.binascii.Error, TypeError) as e:
print(f"Error: Failed to decode Base64 string. {e}")
return None, "Error: API returned a corrupt image (Base64 decode failed)."
except (IOError, OSError) as e:
print(f"Error: Could not open image from decoded Base64. {e}")
return None, "Error: API returned valid Base64, but it could not be processed as an image."
# --- FALLBACK: Handle raw image response (just in case) ---
elif 'image' in content_type:
try:
image_bytes_io = io.BytesIO(response.content)
img = Image.open(image_bytes_io)
print("Success: Raw image received and processed.")
return img, "Generation successful!"
except (IOError, OSError) as e:
print(f"Error: Could not open raw image from response. {e}")
return None, f"Error: API returned raw image data, but it could not be processed.\n{e}"
# --- ERROR: Other response types ---
else:
print(f"Error: Received unexpected Content-Type: {content_type}")
error_msg = f"Error: Received unexpected response type.\nContent-Type: {content_type}\nResponse: {response.text[:500]}"
return None, error_msg
else:
# The API returned an error (e.g., 404, 500, 503)
print(f"Error: API request failed. Status Code: {response.status_code}")
error_msg = f"Error: API request failed.\nStatus Code: {response.status_code}\nResponse: {response.text[:500]}"
return None, error_msg
except requests.exceptions.Timeout:
print("Error: Request timed out.")
return None, "Error: The request timed out (120 seconds). The server might be busy or cold-starting. Please try again."
except requests.exceptions.RequestException as e:
print(f"Error: Request failed. {e}")
return None, f"Error: Request failed.\n{e}"
# --- Build the Gradio UI (No changes needed here) ---
with gr.Blocks(title="FLUX.1 Image Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## ⚡️ Super-Fast FLUX.1 Image Generation (via Modal H100s)
Enter a prompt and select the number of steps to generate an image. Please note that our cloud GPUs take about 30 seconds to cold-start. The more the app is used, the faster it will be because the GPUs will be active. The app can also autoscale, so image generation will always take around 8 seconds for 30 steps. *The app's access to our API is currently turned off remotely due to NSFW abuse. We are working on a moderation system, and then the app will be turned back on.*
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="A high-quality photo of a red panda wearing a tiny chef hat",
lines=3
)
steps_input = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
value=28
)
generate_btn = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", type="pil", interactive=False)
output_status = gr.Textbox(label="Status", interactive=False)
gr.Examples(
examples=[
["A cat holding a sign that says hello world", 28],
["A cinematic photo of a robot reading a book in a library, 4k", 30],
["Logo for a coffee shop named 'The Daily Grind', minimalist, vector art", 20],
["A vibrant watercolor painting of a bustling European city street", 35],
],
inputs=[prompt_input, steps_input]
)
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, steps_input],
outputs=[output_image, output_status]
)
# --- Run the App ---
if __name__ == "__main__":
print(f"Starting Gradio app, API endpoint loaded.")
demo.launch()