Spaces:
Running
Running
| import gradio as gr | |
| import requests | |
| import fal_client | |
| import os | |
| import base64 | |
| import io | |
| import time | |
| from PIL import Image | |
| from google import genai | |
| from google.genai.types import GenerateContentConfig, ImageConfig, Part | |
| RUNPOD_API_KEY = os.getenv("RUNPOD_API_KEY") | |
| FAL_KEY = os.getenv("FAL_KEY") | |
| GEMINI_API_KEY = os.getenv("GEMINI_VLM_KEY") | |
| QWEN_ENDPOINT_ID = "jzpm1xin5cprff" | |
| os.environ["FAL_KEY"] = FAL_KEY if FAL_KEY else "" | |
| gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None | |
| DEFAULT_PROMPT = "Add furnishings and accessories to this room as an interior designer would do for a real estate staging. The generated image shall have the exact same dimensions as the original image and architectural details. Respect doorways and windows and make sure they are consistent with the source image and not blocked by furniture. Use cute accessories and with appropriate wall space, add smart simple graphic paintings. Use neutral colors with light colored accents to match the colors of the room. Give the area an attractive glow." | |
| def get_closest_ratio(pil_img): | |
| w, h = pil_img.size | |
| ratio = w / h | |
| ratios = {"9:16": 0.56, "2:3": 0.66, "3:4": 0.75, "4:5": 0.8, "1:1": 1.0, "5:4": 1.25, "4:3": 1.33, "3:2": 1.5, "16:9": 1.77, "21:9": 2.33} | |
| return min(ratios, key=lambda x: abs(ratios[x] - ratio)) | |
| def b64_to_pil(b64_str): | |
| if not b64_str: return None | |
| if "base64," in b64_str: b64_str = b64_str.split("base64,")[1] | |
| return Image.open(io.BytesIO(base64.b64decode(b64_str))) | |
| def bytes_to_pil(img_bytes): | |
| return Image.open(io.BytesIO(img_bytes)) | |
| def get_image_inputs(image_file, image_url): | |
| if image_file: | |
| with open(image_file, "rb") as f: raw_bytes = f.read() | |
| raw_b64 = base64.b64encode(raw_bytes).decode('utf-8') | |
| fal_url = fal_client.upload_file(image_file) | |
| return raw_bytes, raw_b64, fal_url | |
| elif image_url: | |
| resp = requests.get(image_url) | |
| raw_bytes = resp.content | |
| raw_b64 = base64.b64encode(raw_bytes).decode('utf-8') | |
| return raw_bytes, raw_b64, image_url | |
| return None, None, None | |
| def run_qwen(raw_b64, prompt): | |
| url = f"https://api.runpod.ai/v2/{QWEN_ENDPOINT_ID}/runsync" | |
| headers = {"Content-Type": "application/json", "Authorization": f"Bearer {RUNPOD_API_KEY}"} | |
| payload = {"input": {"image": raw_b64, "prompt": prompt, "seed": 42, "use_lightning": True, "true_guidance_scale": 1.0, "num_inference_steps": 4}} | |
| try: | |
| response = requests.post(url, headers=headers, json=payload, timeout=60) | |
| return b64_to_pil(response.json()["output"]["images"][0]) | |
| except: return None | |
| def run_fal_flux(image_url, prompt): | |
| try: | |
| handler = fal_client.submit("fal-ai/flux-2/edit", arguments={"prompt": prompt, "image_urls": [image_url]}) | |
| result = handler.get() | |
| return bytes_to_pil(requests.get(result['images'][0]['url']).content) | |
| except: return None | |
| def run_gemini(image_bytes, prompt, ratio_str): | |
| if not gemini_client: return None | |
| try: | |
| response = gemini_client.models.generate_content( | |
| model="gemini-2.5-flash-image", | |
| contents=[Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), prompt], | |
| config=GenerateContentConfig(response_modalities=["IMAGE"], image_config=ImageConfig(aspect_ratio=ratio_str), candidate_count=1) | |
| ) | |
| for part in response.candidates[0].content.parts: | |
| if part.inline_data: return bytes_to_pil(part.inline_data.data) | |
| except: return None | |
| def compare_all(image_file, image_url, prompt): | |
| raw_bytes, raw_b64, web_url = get_image_inputs(image_file, image_url) | |
| if not raw_bytes: | |
| yield None, None, None, None, "", "", "" | |
| return | |
| og_pil = bytes_to_pil(raw_bytes) | |
| ratio_str = get_closest_ratio(og_pil) | |
| q_img, f_img, g_img = None, None, None | |
| yield og_pil, q_img, f_img, g_img, "β³ Processing (~5s)...", "π Pending...", "π Pending..." | |
| q_img = run_qwen(raw_b64, prompt) | |
| yield og_pil, q_img, f_img, g_img, "β Complete", "β³ Processing (~12s)...", "π Pending..." | |
| f_img = run_fal_flux(web_url, prompt) | |
| yield og_pil, q_img, f_img, g_img, "β Complete", "β Complete", "β³ Processing (~15s)..." | |
| g_img = run_gemini(raw_bytes, prompt, ratio_str) | |
| yield og_pil, q_img, f_img, g_img, "β Complete", "β Complete", "β Complete" | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h2 style='text-align: center; margin: 10px 0;'>ποΈ Interior Design Model Arena</h2>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<center><b>ORIGINAL REFERENCE</b></center>") | |
| out_og = gr.Image(show_label=False, type="pil", height=320) | |
| with gr.Column(scale=1): | |
| input_prompt = gr.Textbox(label="Edit Prompt", value=DEFAULT_PROMPT, lines=4) | |
| with gr.Row(): | |
| input_file = gr.Image(label="Upload", type="filepath", height=100) | |
| input_url = gr.Textbox(label="OR: Image URL", placeholder="Paste URL...") | |
| run_btn = gr.Button("π Generate Comparison", variant="primary") | |
| gr.HTML("<hr style='margin: 15px 0;'>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML("<center><b>QWEN-EDIT</b><br><small>$0.004 - $0.008</small></center>") | |
| stat_q = gr.Markdown("π Waiting...", elem_classes="status-msg") | |
| out_qwen = gr.Image(show_label=False, type="pil", height=350) | |
| with gr.Column(): | |
| gr.HTML("<center><b>FLUX-2 EDIT</b><br><small>$0.03</small></center>") | |
| stat_f = gr.Markdown("π Waiting...", elem_classes="status-msg") | |
| out_fal = gr.Image(show_label=False, type="pil", height=350) | |
| with gr.Column(): | |
| gr.HTML("<center><b>GEMINI 2.5 FLASH</b><br><small>$0.039</small></center>") | |
| stat_g = gr.Markdown("π Waiting...", elem_classes="status-msg") | |
| out_gemini = gr.Image(show_label=False, type="pil", height=350) | |
| run_btn.click( | |
| fn=compare_all, | |
| inputs=[input_file, input_url, input_prompt], | |
| outputs=[out_og, out_qwen, out_fal, out_gemini, stat_q, stat_f, stat_g] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container {max-width: 98% !important} | |
| img {object-fit: contain !important;} | |
| .status-msg {text-align: center; margin-top: -10px; font-size: 0.85em; color: #666;} | |
| """ | |
| ) |