Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from diffusers import FluxPipeline | |
| import io | |
| import os | |
| # Page configuration | |
| st.set_page_config(page_title="Flux Image Generator", layout="centered") | |
| st.title("🎨 AI Image Generator") | |
| st.caption("Powered by FLUX.1 [schnell]") | |
| # 1. THE LOAD FUNCTION | |
| def load_pipeline(): | |
| # This pulls the secret you named 'HF_TOKEN' from your Space Settings | |
| token = os.getenv("HF_TOKEN") | |
| # Loading the model with the access token | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| torch_dtype=torch.bfloat16, | |
| token=token | |
| ) | |
| # Enables memory saving for the Hugging Face free tier | |
| pipe.enable_model_cpu_offload() | |
| return pipe | |
| # Initialize the pipeline | |
| try: | |
| pipeline = load_pipeline() | |
| except Exception as e: | |
| st.error("Could not load the model. Make sure you accepted the terms on the model page and added your HF_TOKEN to secrets.") | |
| st.stop() | |
| # 2. SIDEBAR SETTINGS | |
| with st.sidebar: | |
| st.header("Settings") | |
| width = st.slider("Width", 512, 1024, 1024, step=128) | |
| height = st.slider("Height", 512, 1024, 1024, step=128) | |
| num_steps = st.slider("Inference Steps", 1, 4, 4) | |
| st.info("Tip: Use 4 steps for the best quality.") | |
| # 3. USER INTERFACE | |
| prompt = st.text_area("Enter your prompt:", "A futuristic robotic arm building a circuit board, cinematic lighting, 8k") | |
| if st.button("Generate Image"): | |
| if prompt: | |
| with st.spinner("Generating... this may take a minute."): | |
| try: | |
| # Generate the image | |
| image = pipeline( | |
| prompt, | |
| num_inference_steps=num_steps, | |
| guidance_scale=0.0, | |
| width=width, | |
| height=height, | |
| max_sequence_length=256 | |
| ).images[0] | |
| # Display image | |
| st.image(image, caption="Generated Result", use_column_width=True) | |
| # Download button | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| byte_im = buf.getvalue() | |
| st.download_button(label="Download Image", data=byte_im, file_name="generated.png", mime="image/png") | |
| except Exception as e: | |
| st.error(f"Error during generation: {e}") | |
| else: | |
| st.warning("Please enter a prompt first!") |