StagingArena / app.py
Nightfury16's picture
update
d4df241
import gradio as gr
import requests
import fal_client
import os
import base64
import io
import time
from PIL import Image
from google import genai
from google.genai.types import GenerateContentConfig, ImageConfig, Part
RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY")
FAL_KEY = os.getenv("FAL_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_VLM_KEY")
QWEN_ENDPOINT_ID = "jzpm1xin5cprff"
os.environ["FAL_KEY"] = FAL_KEY if FAL_KEY else ""
gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
DEFAULT_PROMPT = "Add furnishings and accessories to this room as an interior designer would do for a real estate staging. The generated image shall have the exact same dimensions as the original image and architectural details. Respect doorways and windows and make sure they are consistent with the source image and not blocked by furniture. Use cute accessories and with appropriate wall space, add smart simple graphic paintings. Use neutral colors with light colored accents to match the colors of the room. Give the area an attractive glow."
def get_closest_ratio(pil_img):
w, h = pil_img.size
ratio = w / h
ratios = {"9:16": 0.56, "2:3": 0.66, "3:4": 0.75, "4:5": 0.8, "1:1": 1.0, "5:4": 1.25, "4:3": 1.33, "3:2": 1.5, "16:9": 1.77, "21:9": 2.33}
return min(ratios, key=lambda x: abs(ratios[x] - ratio))
def b64_to_pil(b64_str):
if not b64_str: return None
if "base64," in b64_str: b64_str = b64_str.split("base64,")[1]
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
def bytes_to_pil(img_bytes):
return Image.open(io.BytesIO(img_bytes))
def get_image_inputs(image_file, image_url):
if image_file:
with open(image_file, "rb") as f: raw_bytes = f.read()
raw_b64 = base64.b64encode(raw_bytes).decode('utf-8')
fal_url = fal_client.upload_file(image_file)
return raw_bytes, raw_b64, fal_url
elif image_url:
resp = requests.get(image_url)
raw_bytes = resp.content
raw_b64 = base64.b64encode(raw_bytes).decode('utf-8')
return raw_bytes, raw_b64, image_url
return None, None, None
def run_qwen(raw_b64, prompt):
url = f"https://api.runpod.ai/v2/{QWEN_ENDPOINT_ID}/runsync"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {RUNPOD_API_KEY}"}
payload = {"input": {"image": raw_b64, "prompt": prompt, "seed": 42, "use_lightning": True, "true_guidance_scale": 1.0, "num_inference_steps": 4}}
try:
response = requests.post(url, headers=headers, json=payload, timeout=60)
return b64_to_pil(response.json()["output"]["images"][0])
except: return None
def run_fal_flux(image_url, prompt):
try:
handler = fal_client.submit("fal-ai/flux-2/edit", arguments={"prompt": prompt, "image_urls": [image_url]})
result = handler.get()
return bytes_to_pil(requests.get(result['images'][0]['url']).content)
except: return None
def run_gemini(image_bytes, prompt, ratio_str):
if not gemini_client: return None
try:
response = gemini_client.models.generate_content(
model="gemini-2.5-flash-image",
contents=[Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), prompt],
config=GenerateContentConfig(response_modalities=["IMAGE"], image_config=ImageConfig(aspect_ratio=ratio_str), candidate_count=1)
)
for part in response.candidates[0].content.parts:
if part.inline_data: return bytes_to_pil(part.inline_data.data)
except: return None
def compare_all(image_file, image_url, prompt):
raw_bytes, raw_b64, web_url = get_image_inputs(image_file, image_url)
if not raw_bytes:
yield None, None, None, None, "", "", ""
return
og_pil = bytes_to_pil(raw_bytes)
ratio_str = get_closest_ratio(og_pil)
q_img, f_img, g_img = None, None, None
yield og_pil, q_img, f_img, g_img, "⏳ Processing (~5s)...", "πŸ•’ Pending...", "πŸ•’ Pending..."
q_img = run_qwen(raw_b64, prompt)
yield og_pil, q_img, f_img, g_img, "βœ… Complete", "⏳ Processing (~12s)...", "πŸ•’ Pending..."
f_img = run_fal_flux(web_url, prompt)
yield og_pil, q_img, f_img, g_img, "βœ… Complete", "βœ… Complete", "⏳ Processing (~15s)..."
g_img = run_gemini(raw_bytes, prompt, ratio_str)
yield og_pil, q_img, f_img, g_img, "βœ… Complete", "βœ… Complete", "βœ… Complete"
with gr.Blocks() as demo:
gr.HTML("<h2 style='text-align: center; margin: 10px 0;'>πŸ›‹οΈ Interior Design Model Arena</h2>")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<center><b>ORIGINAL REFERENCE</b></center>")
out_og = gr.Image(show_label=False, type="pil", height=320)
with gr.Column(scale=1):
input_prompt = gr.Textbox(label="Edit Prompt", value=DEFAULT_PROMPT, lines=4)
with gr.Row():
input_file = gr.Image(label="Upload", type="filepath", height=100)
input_url = gr.Textbox(label="OR: Image URL", placeholder="Paste URL...")
run_btn = gr.Button("πŸš€ Generate Comparison", variant="primary")
gr.HTML("<hr style='margin: 15px 0;'>")
with gr.Row():
with gr.Column():
gr.HTML("<center><b>QWEN-EDIT</b><br><small>$0.004 - $0.008</small></center>")
stat_q = gr.Markdown("πŸ•’ Waiting...", elem_classes="status-msg")
out_qwen = gr.Image(show_label=False, type="pil", height=350)
with gr.Column():
gr.HTML("<center><b>FLUX-2 EDIT</b><br><small>$0.03</small></center>")
stat_f = gr.Markdown("πŸ•’ Waiting...", elem_classes="status-msg")
out_fal = gr.Image(show_label=False, type="pil", height=350)
with gr.Column():
gr.HTML("<center><b>GEMINI 2.5 FLASH</b><br><small>$0.039</small></center>")
stat_g = gr.Markdown("πŸ•’ Waiting...", elem_classes="status-msg")
out_gemini = gr.Image(show_label=False, type="pil", height=350)
run_btn.click(
fn=compare_all,
inputs=[input_file, input_url, input_prompt],
outputs=[out_og, out_qwen, out_fal, out_gemini, stat_q, stat_f, stat_g]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(),
css="""
.gradio-container {max-width: 98% !important}
img {object-fit: contain !important;}
.status-msg {text-align: center; margin-top: -10px; font-size: 0.85em; color: #666;}
"""
)