Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import io | |
| from PIL import Image | |
| import os | |
| import sys | |
| import base64 # <-- Import for Base64 decoding | |
| from dotenv import load_dotenv | |
| # --- Configuration --- | |
| # Load environment variables from a .env file (if it exists) | |
| load_dotenv() | |
| # Get the API URL from the environment variable | |
| MODAL_API_URL = os.environ.get("MODAL_API_URL") | |
| # --- Startup Check --- | |
| if not MODAL_API_URL: | |
| print("Error: MODAL_API_URL environment variable not set.") | |
| print("Please create a .env file and add: MODAL_API_URL='your_api_url_here'") | |
| sys.exit(1) | |
| # --- End Configuration --- | |
| def generate_image(prompt, steps): | |
| """ | |
| Calls the Modal API to generate an image. | |
| Returns a tuple: (PIL.Image or None, status_message) | |
| """ | |
| print(f"Requesting image for prompt: '{prompt}' with {steps} steps.") | |
| headers = {"Content-Type": "application/json"} | |
| data = { | |
| "prompt": prompt, | |
| "num_inference_steps": int(steps) | |
| } | |
| # 2. Make the POST request | |
| try: | |
| response = requests.post(MODAL_API_URL, headers=headers, json=data, timeout=120) | |
| # 3. Handle the response | |
| if response.status_code == 200: | |
| content_type = response.headers.get('Content-Type', '').lower() | |
| # --- NEW LOGIC: Handle JSON response with Base64 image --- | |
| if 'application/json' in content_type: | |
| try: | |
| json_data = response.json() | |
| # Check for the expected response format: {"success":true, "image":"..."} | |
| if json_data.get("success") and "image" in json_data: | |
| # 1. Get the base64 string | |
| base64_string = json_data["image"] | |
| # 2. Decode the base64 string into bytes | |
| image_data_bytes = base64.b64decode(base64_string) | |
| # 3. Create a bytes buffer | |
| image_bytes_io = io.BytesIO(image_data_bytes) | |
| # 4. Open with PIL | |
| img = Image.open(image_bytes_io) | |
| print("Success: JSON response with Base64 image received and processed.") | |
| return img, "Generation successful!" | |
| else: | |
| # JSON response, but not in the expected format | |
| print(f"Error: Received JSON, but 'success' or 'image' key is missing or false.") | |
| error_msg = f"Error: API returned unexpected JSON.\nResponse: {response.text[:500]}" | |
| return None, error_msg | |
| except requests.exceptions.JSONDecodeError: | |
| print("Error: API said it was JSON but failed to decode.") | |
| return None, "Error: API returned invalid JSON." | |
| except (base64.binascii.Error, TypeError) as e: | |
| print(f"Error: Failed to decode Base64 string. {e}") | |
| return None, "Error: API returned a corrupt image (Base64 decode failed)." | |
| except (IOError, OSError) as e: | |
| print(f"Error: Could not open image from decoded Base64. {e}") | |
| return None, "Error: API returned valid Base64, but it could not be processed as an image." | |
| # --- FALLBACK: Handle raw image response (just in case) --- | |
| elif 'image' in content_type: | |
| try: | |
| image_bytes_io = io.BytesIO(response.content) | |
| img = Image.open(image_bytes_io) | |
| print("Success: Raw image received and processed.") | |
| return img, "Generation successful!" | |
| except (IOError, OSError) as e: | |
| print(f"Error: Could not open raw image from response. {e}") | |
| return None, f"Error: API returned raw image data, but it could not be processed.\n{e}" | |
| # --- ERROR: Other response types --- | |
| else: | |
| print(f"Error: Received unexpected Content-Type: {content_type}") | |
| error_msg = f"Error: Received unexpected response type.\nContent-Type: {content_type}\nResponse: {response.text[:500]}" | |
| return None, error_msg | |
| else: | |
| # The API returned an error (e.g., 404, 500, 503) | |
| print(f"Error: API request failed. Status Code: {response.status_code}") | |
| error_msg = f"Error: API request failed.\nStatus Code: {response.status_code}\nResponse: {response.text[:500]}" | |
| return None, error_msg | |
| except requests.exceptions.Timeout: | |
| print("Error: Request timed out.") | |
| return None, "Error: The request timed out (120 seconds). The server might be busy or cold-starting. Please try again." | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error: Request failed. {e}") | |
| return None, f"Error: Request failed.\n{e}" | |
| # --- Build the Gradio UI (No changes needed here) --- | |
| with gr.Blocks(title="FLUX.1 Image Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| ## ⚡️ Super-Fast FLUX.1 Image Generation (via Modal H100s) | |
| Enter a prompt and select the number of steps to generate an image. Please note that our cloud GPUs take about 30 seconds to cold-start. The more the app is used, the faster it will be because the GPUs will be active. The app can also autoscale, so image generation will always take around 8 seconds for 30 steps. *The app's access to our API is currently turned off remotely due to NSFW abuse. We are working on a moderation system, and then the app will be turned back on.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="A high-quality photo of a red panda wearing a tiny chef hat", | |
| lines=3 | |
| ) | |
| steps_input = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=28 | |
| ) | |
| generate_btn = gr.Button("Generate Image", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image", type="pil", interactive=False) | |
| output_status = gr.Textbox(label="Status", interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| ["A cat holding a sign that says hello world", 28], | |
| ["A cinematic photo of a robot reading a book in a library, 4k", 30], | |
| ["Logo for a coffee shop named 'The Daily Grind', minimalist, vector art", 20], | |
| ["A vibrant watercolor painting of a bustling European city street", 35], | |
| ], | |
| inputs=[prompt_input, steps_input] | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[prompt_input, steps_input], | |
| outputs=[output_image, output_status] | |
| ) | |
| # --- Run the App --- | |
| if __name__ == "__main__": | |
| print(f"Starting Gradio app, API endpoint loaded.") | |
| demo.launch() |