FIBO-Mashup / app.py
davidi-bria's picture
Refactor app.py and update schema.py
adac165
import gradio as gr
import tempfile
import concurrent.futures
import dotenv
import os
from api_utils import get_prompt_api, generate_image
from schema import claude_structured_output
import time
import torch
from diffusers import BriaFiboPipeline
# pipe = BriaFiboPipeline.from_pretrained(
# "briaai/FIBO",
# torch_dtype=torch.bfloat16,
# )
# pipe.to("cpu")
dotenv.load_dotenv()
def get_image_suffix(image):
if hasattr(image, 'format') and image.format:
return '.' + image.format.lower()
if hasattr(image, 'filename') and image.filename:
_, ext = os.path.splitext(image.filename)
if ext:
return ext.lower()
return '.jpg'
def process_images(subject_image, scene_image, style_image):
"""
Process three images and generate a combined image.
Args:
subject_image: PIL Image for the main subject
scene_image: PIL Image for the scene/background
style_image: PIL Image for the artistic style
Returns:
PIL Image: The generated combined image
"""
if subject_image is None or scene_image is None or style_image is None:
raise gr.Error("Please upload all three images (subject, scene, and style)")
try:
# Save images temporarily to pass to the API
with tempfile.TemporaryDirectory() as temp_dir:
subject_path = os.path.join(temp_dir, "subject" + get_image_suffix(subject_image))
scene_path = os.path.join(temp_dir, "scene" + get_image_suffix(scene_image))
style_path = os.path.join(temp_dir, "style" + get_image_suffix(style_image))
subject_image.save(subject_path)
scene_image.save(scene_path)
style_image.save(style_path)
# Get descriptions for each image
time_start = time.time()
with concurrent.futures.ThreadPoolExecutor() as executor:
future_subject = executor.submit(get_prompt_api, subject_path, "subject")
future_scene = executor.submit(get_prompt_api, scene_path, "scene")
future_style = executor.submit(get_prompt_api, style_path, "style")
subject = future_subject.result()
scene = future_scene.result()
style = future_style.result()
time_end = time.time()
print(f"Time taken to get descriptions: {time_end - time_start} seconds")
# Create combined prompt
prompt = f"""
place the main subject from the first image description and place it in the scene from the second image description with a style taken from the third image description.
first (subject) image description:
{subject}
second (scene) image description:
{scene}
third (style) image description:
{style}
create a new image description that incorporates all of the descriptions.
put the subject in the scene with the style.
"""
# Generate structured output using Claude API
time_start = time.time()
response = claude_structured_output(prompt)
time_end = time.time()
print(f"Time taken to generate structured output: {time_end - time_start} seconds")
# Generate the final image
time_start = time.time()
result_image = generate_image(response)
time_end = time.time()
print(f"Time taken to generate image: {time_end - time_start} seconds")
return result_image
except Exception as e:
# Clean up temporary files on error
if "subject_path" in locals():
os.unlink(subject_path)
if "scene_path" in locals():
os.unlink(scene_path)
if "style_path" in locals():
os.unlink(style_path)
raise gr.Error(f"Error processing images: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="2IM - Image Combination Generator") as demo:
gr.Markdown("""
# 🎨 FIBO Mashup - Image Combination Generator
## For best performance and results, use the [FAL app](https://fal.ai/models/bria/fibo-mashup).
Combine three images into one:
1. **Subject Image**: The main object or person you want in the final image
2. **Scene Image**: The background/environment for the final image
3. **Style Image**: The artistic style to apply to the final image
""")
with gr.Row():
with gr.Column():
subject_input = gr.Image(
label="Subject Image",
type="pil",
height=300,
value="assets/subject.jpg",
)
gr.Markdown("*Upload the main subject/object*")
with gr.Column():
scene_input = gr.Image(
label="Scene Image", type="pil", height=300, value="assets/scene.jpg"
)
gr.Markdown("*Upload the scene/background*")
with gr.Column():
style_input = gr.Image(
label="Style Image", type="pil", height=300, value="assets/style.png"
)
gr.Markdown("*Upload the style reference*")
generate_btn = gr.Button("🎨 Generate Combined Image", variant="primary", size="lg")
gr.Markdown("---")
output_image = gr.Image(label="Generated Image", type="pil", height=500)
# Set up the event handler
generate_btn.click(
fn=process_images,
inputs=[subject_input, scene_input, style_input],
outputs=output_image,
)
gr.Markdown("""
### How it works:
1. Upload three images using the fields above
2. Click "Generate Combined Image"
3. The AI will analyze each image and create a new image that combines the subject from the first image, places it in the scene from the second image, and applies the style from the third image
*Note: Generation may take a minute or two depending on API response times.*
""")
if __name__ == "__main__":
demo.launch(share=True)