Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import random | |
| import os | |
| import time | |
| import torch | |
| from diffusers import FluxPipeline | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| DEFAULT_HEIGHT = 1024 | |
| DEFAULT_WIDTH = 1024 | |
| DEFAULT_GUIDANCE_SCALE = 3.5 | |
| DEFAULT_NUM_INFERENCE_STEPS = 15 | |
| DEFAULT_MAX_SEQUENCE_LENGTH = 512 | |
| HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") | |
| # Cache for the pipeline | |
| CACHED_PIPE = None | |
| def load_bnb_4bit_pipeline(): | |
| """Load the 4-bit quantized pipeline""" | |
| global CACHED_PIPE | |
| if CACHED_PIPE is not None: | |
| return CACHED_PIPE | |
| print("Loading 4-bit BNB pipeline...") | |
| MODEL_ID = "derekl35/FLUX.1-dev-nf4" | |
| start_time = time.time() | |
| try: | |
| pipe = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload() | |
| end_time = time.time() | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") | |
| CACHED_PIPE = pipe | |
| return pipe | |
| except Exception as e: | |
| print(f"Error loading 4-bit BNB pipeline: {e}") | |
| raise | |
| def generate_image(prompt, progress=gr.Progress(track_tqdm=True)): | |
| """Generate image using 4-bit quantized model""" | |
| if not prompt: | |
| return None, "Please enter a prompt." | |
| progress(0.2, desc="Loading 4-bit quantized model...") | |
| try: | |
| # Load the 4-bit pipeline | |
| pipe = load_bnb_4bit_pipeline() | |
| # Set up generation parameters | |
| pipe_kwargs = { | |
| "prompt": prompt, | |
| "height": DEFAULT_HEIGHT, | |
| "width": DEFAULT_WIDTH, | |
| "guidance_scale": DEFAULT_GUIDANCE_SCALE, | |
| "num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, | |
| "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, | |
| } | |
| # Generate seed | |
| seed = random.getrandbits(64) | |
| print(f"Using seed: {seed}") | |
| progress(0.5, desc="Generating image...") | |
| # Generate image | |
| gen_start_time = time.time() | |
| image = pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images[0] | |
| gen_end_time = time.time() | |
| print(f"Image generated in {gen_end_time - gen_start_time:.2f} seconds") | |
| mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 | |
| print(f"Memory reserved: {mem_reserved:.2f} GB") | |
| return image, f"Generation complete! (Seed: {seed})" | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| return None, f"Error: {e}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo: | |
| gr.HTML( | |
| """ | |
| <div style='text-align: center; margin-bottom: 20px;'> | |
| <h1>FLUXllama</h1> | |
| <p>FLUX.1-dev 4-bit Quantized Version</p> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class='container' style='display:flex; justify-content:center; gap:12px; margin-bottom: 20px;'> | |
| <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank"> | |
| <img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge"> | |
| </a> | |
| <a href="https://discord.gg/openfreeai" target="_blank"> | |
| <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="e.g., A photorealistic portrait of an astronaut on Mars", | |
| lines=2, | |
| scale=4 | |
| ) | |
| generate_button = gr.Button("Generate", variant="primary", scale=1) | |
| output_image = gr.Image( | |
| label="Generated Image (4-bit Quantized)", | |
| type="pil", | |
| height=600 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=1 | |
| ) | |
| # Connect components | |
| generate_button.click( | |
| fn=generate_image, | |
| inputs=[prompt_input], | |
| outputs=[output_image, status_text] | |
| ) | |
| # Enter key to submit | |
| prompt_input.submit( | |
| fn=generate_image, | |
| inputs=[prompt_input], | |
| outputs=[output_image, status_text] | |
| ) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| "A photorealistic portrait of an astronaut on Mars", | |
| "Water-color painting of a cat wearing sunglasses", | |
| "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8K", | |
| "A majestic dragon flying over a medieval castle at sunset", | |
| "Abstract art representing the concept of time and space", | |
| "Detailed oil painting of a steampunk clockwork city", | |
| "Underwater scene with bioluminescent creatures in deep ocean", | |
| "Japanese garden in autumn with falling maple leaves" | |
| ], | |
| inputs=prompt_input | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |