| import spaces |
| import gradio as gr |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForCausalLM, pipeline |
| from diffusers import DiffusionPipeline |
| import random |
| import numpy as np |
| import os |
| import subprocess |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.bfloat16 |
|
|
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
| |
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device) |
|
|
| |
| florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval() |
| florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True) |
|
|
| |
| enhancer_long = pipeline("summarization", model="mobenta/M_Prompter", device=device) |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| MAX_IMAGE_SIZE = 2048 |
|
|
| |
| def florence_caption(image): |
| |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(image) |
| |
| inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device) |
| generated_ids = florence_model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| parsed_answer = florence_processor.post_process_generation( |
| generated_text, |
| task="<MORE_DETAILED_CAPTION>", |
| image_size=(image.width, image.height) |
| ) |
| return parsed_answer["<MORE_DETAILED_CAPTION>"] |
|
|
| |
| def enhance_prompt(input_prompt): |
| result = enhancer_long("Enhance the description: " + input_prompt) |
| enhanced_text = result[0]['summary_text'] |
| return enhanced_text |
|
|
| @spaces.GPU(duration=190) |
| def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)): |
| if image is not None: |
| |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(image) |
| |
| prompt = florence_caption(image) |
| else: |
| prompt = text_prompt |
| |
| if use_enhancer: |
| prompt = enhance_prompt(prompt) |
| |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| |
| generator = torch.Generator(device=device).manual_seed(seed) |
| |
| image = pipe( |
| prompt=prompt, |
| generator=generator, |
| num_inference_steps=num_inference_steps, |
| width=width, |
| height=height, |
| guidance_scale=guidance_scale |
| ).images[0] |
| |
| return image, prompt, seed |
|
|
| custom_css = """ |
| .input-group, .output-group { |
| border: 1px solid #e0e0e0; |
| border-radius: 10px; |
| padding: 20px; |
| margin-bottom: 20px; |
| background-color: #f9f9f9; |
| } |
| .submit-btn { |
| background-color: #2980b9 !important; |
| color: white !important; |
| } |
| .submit-btn:hover { |
| background-color: #3498db !important; |
| } |
| """ |
|
|
| title = """<h1 align="center">Mobenta: Creative AI Image Generation</h1> |
| <p align="center"> |
| your creative partner for generating stunning AI-powered images. Whether you start with a photo or a text prompt, it enhances your ideas with advanced captioning and prompt refinement. Dive into the world of creativity and explore endless possibilities! |
| </p> |
| """ |
|
|
|
|
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo: |
| gr.HTML(title) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Group(elem_classes="input-group"): |
| input_image = gr.Image(label="Input Image (Florence-2 Captioner)") |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)") |
| use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False) |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) |
| guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5) |
| num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28) |
| |
| generate_btn = gr.Button("Generate Image", elem_classes="submit-btn") |
| |
| with gr.Column(scale=1): |
| with gr.Group(elem_classes="output-group"): |
| output_image = gr.Image(label="Result", elem_id="gallery", show_label=False) |
| final_prompt = gr.Textbox(label="Final Prompt Used") |
| used_seed = gr.Number(label="Seed Used") |
| |
| generate_btn.click( |
| fn=process_workflow, |
| inputs=[ |
| input_image, text_prompt, use_enhancer, seed, randomize_seed, |
| width, height, guidance_scale, num_inference_steps |
| ], |
| outputs=[output_image, final_prompt, used_seed] |
| ) |
|
|
| demo.launch(debug=True) |
|
|