VTO / gradio_app.py
Akshajzclap's picture
Update gradio_app.py
60abc56 verified
import gradio as gr
import requests
from io import BytesIO
from PIL import Image
import tempfile
import os
import mimetypes
import google.generativeai as genai
import io
from dotenv import load_dotenv
# already have: from PIL import Image
load_dotenv()
API_KEY = os.environ.get("GEMINI_API_KEY")
if API_KEY:
genai.configure(api_key=API_KEY)
else:
print("Warning: GEMINI_API_KEY not found in .env")
API_BASE = "http://127.0.0.1:5350/api/v1"
BED_IMAGE_PATHS = [
"app/furniture/beds/bed -1.jpeg",
"app/furniture/beds/bed - 2.jpg",
]
SOFA_IMAGE_PATHS = [
"app/furniture/sofa/sofa-1.jpg",
"app/furniture/sofa/sofa-2.jpg",
]
def pil_to_tempfile(pil_img, suffix=".jpg"):
"""Write PIL image to a temporary file and return the filepath. Caller should delete file when done."""
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
try:
pil_img.save(tmp, format="JPEG")
tmp.flush()
tmp.close()
return tmp.name
except Exception:
try:
tmp.close()
except Exception:
pass
raise
def image_to_bytes_pil(img: Image.Image, fmt="JPEG"):
bio = BytesIO()
img.save(bio, format=fmt)
bio.seek(0)
return bio.getvalue()
def call_api(endpoint: str, data: dict, main_image: Image.Image = None, extra_files: dict = None, timeout=60):
"""
Sends a multipart POST to API_BASE/endpoint.
- data: dict of form fields
- main_image: PIL image for field "garment_images" (optional) — will be added into files under key 'garment_images'
- extra_files: dict mapping fieldname -> (filename, bytes, mime) OR fieldname -> [ (filename, bytes, mime), ... ]
"""
files = []
# If main_image is provided, add it as an entry for 'garment_images'
if main_image is not None:
files.append(("garment_images", ("garment.jpg", image_to_bytes_pil(main_image), "image/jpeg")))
if extra_files:
for fieldname, fileval in extra_files.items():
# fileval may be a single tuple or a list of tuples
if isinstance(fileval, list) or isinstance(fileval, tuple) and isinstance(fileval[0], tuple):
# treat as iterable of tuples
for single in fileval:
files.append((fieldname, single))
else:
# single tuple
files.append((fieldname, fileval))
try:
resp = requests.post(f"{API_BASE}/{endpoint}", data=data, files=files, timeout=timeout)
except Exception as e:
return f"Exception: {e}", None
if resp.status_code == 200:
try:
result_img = Image.open(BytesIO(resp.content)).convert("RGB")
return None, result_img
except Exception:
# If API returns JSON or text on success
return None, resp.text
else:
return f"Error {resp.status_code}: {resp.text}", None
# ---- Shared choices ----
garment_types = ["clothing", "jewellery and watches", "wallet", "shoes", "handbags"]
genders = ["female"]
ages = ["teen", "18-25", "26-35", "36-45", "46-55", "56-65", "66+"]
body_shapes = ["rectangle", "pear", "hourglass", "inverted_triangle"]
ethnicities = ["white", "black", "asian", "latino", "mixed"]
poses = ["standing", "sitting", "lying down", "dancing", "running", "jumping", "walking", "bending", "twisting", "stretching", "flexing", "posing"]
view_angles = ["front", "45deg", "left", "right", "back"]
lighting_conditions = ["studio_softbox", "outdoor_sunny", "indoor_warm", "flat"]
backgrounds = [
"White", "Lifestyle", "Beach Settings", "Cafe Environment", "Spring Garden", "Winter Snow", "Professional Settings"
]
surface_types = ["cotton", "silk", "denim", "leather", "synthetic"]
# Generic shared submit for endpoints that use model settings + single garment image
def submit_shared(endpoint, garment_type, model_gender, model_age_range, model_body_shape, model_race_ethnicity,
model_pose, camera_view_angle, camera_distance_meters, camera_focal_length_mm, camera_aperture_f_number,
camera_lighting_condition, camera_background, garment_image):
data = {
"garment_type": garment_type,
"model_gender": model_gender,
"model_age_range": model_age_range,
"model_body_shape": model_body_shape,
"model_race_ethnicity": model_race_ethnicity,
"model_pose": model_pose,
"camera_view_angle": camera_view_angle,
"camera_distance_meters": camera_distance_meters,
"camera_focal_length_mm": camera_focal_length_mm,
"camera_aperture_f_number": camera_aperture_f_number,
"camera_lighting_condition": camera_lighting_condition,
"camera_background": camera_background,
}
return call_api(endpoint, data, main_image=garment_image)
# ---------------- UI ----------------
with gr.Blocks(title="FashionAI Studio") as demo:
# CSS to visually swap columns while allowing us to define settings first in code
gr.HTML(
"""
<style>
#left_col { order: 1; flex: 2 1 0%; }
#right_col { order: 2; flex: 1 1 0%; padding-left: 18px; border-left: 1px solid rgba(0,0,0,0.06); }
#right_col .gradio-container { padding-left: 0; }
.sidebar-title { font-weight:600; margin-bottom:6px; }
</style>
"""
)
gr.Markdown("# FashionAI Studio")
# Create Model Settings first (so variables exist), CSS makes it render to the right.
with gr.Row():
with gr.Column(scale=1, elem_id="right_col"):
gr.HTML('<div class="sidebar-title">### Model Settings (shared)</div>')
garment_type_sel = gr.Dropdown(garment_types, value=garment_types[0], label="Garment Type")
model_gender_sel = gr.Dropdown(genders, value=genders[0], label="Model Gender")
model_age_sel = gr.Dropdown(ages, value=ages[1], label="Model Age Range")
model_body_sel = gr.Dropdown(body_shapes, value=body_shapes[0], label="Model Body Shape")
model_ethnicity_sel = gr.Dropdown(ethnicities, value=ethnicities[0], label="Model Race & Ethnicity")
model_pose_sel = gr.Dropdown(poses, value=poses[0], label="Model Pose")
view_angle_sel = gr.Dropdown(view_angles, value=view_angles[0], label="View Angle")
distance_num = gr.Number(label="Distance (meters)", value=2)
focal_num = gr.Number(label="Focal Length (mm)", value=50)
aperture_num = gr.Number(label="Aperture (f-number)", value=2.8)
lighting_sel = gr.Dropdown(lighting_conditions, value=lighting_conditions[0], label="Lighting Condition")
background_sel = gr.Dropdown(backgrounds, value=backgrounds[0], label="Background")
# Main workspace (left)
with gr.Column(scale=2, elem_id="left_col"):
# Helper for simple subtab that uses shared settings + garment image
def make_shared_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
img = gr.Image(label="Garment Image", type="pil")
btn = gr.Button(f"Generate — {tab_label}")
with gr.Column(scale=2):
out_img = gr.Image(label="Generated Result")
err = gr.Textbox(label="Error", interactive=False)
btn.click(
fn=lambda *args, ep=endpoint: submit_shared(ep, *args),
inputs=[
garment_type_sel, model_gender_sel, model_age_sel, model_body_sel, model_ethnicity_sel,
model_pose_sel, view_angle_sel, distance_num, focal_num, aperture_num,
lighting_sel, background_sel, img
],
outputs=[err, out_img]
)
# Helper for background_edit (special inputs)
def make_background_edit_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
gtype = gr.Dropdown(garment_types, value=garment_types[0], label="Garment Type")
surface = gr.Dropdown(surface_types, value=surface_types[0], label="Surface Type")
view = gr.Dropdown(view_angles, value=view_angles[0], label="Camera View Angle")
distance = gr.Number(label="Camera Distance (meters)", value=2.0)
focal = gr.Number(label="Camera Focal Length (mm)", value=50)
aperture = gr.Number(label="Camera Aperture (f-number)", value=2.8)
lighting = gr.Dropdown(["natural", "studio", "warm"], value="studio", label="Camera Lighting Condition")
background = gr.Dropdown(backgrounds, value=backgrounds[0], label="Camera Background")
img = gr.Image(label="Garment Image", type="pil")
btn = gr.Button("Generate Background Edit")
with gr.Column(scale=2):
out_img = gr.Image(label="Generated Result")
err = gr.Textbox(label="Error", interactive=False)
def bg_submit(endpoint, gtype, surface, view, distance, focal, aperture, lighting, background, img):
data = {
"garment_type": gtype,
"surface_type": surface,
"camera_view_angle": view,
"camera_distance_meters": distance,
"camera_focal_length_mm": focal,
"camera_aperture_f_number": aperture,
"camera_lighting_condition": lighting,
"camera_background": background,
}
return call_api(endpoint, data, main_image=img)
btn.click(
fn=lambda *args, ep=endpoint: bg_submit(ep, *args),
inputs=[gtype, surface, view, distance, focal, aperture, lighting, background, img],
outputs=[err, out_img]
)
# Helper for bags -> what fits inside (dimensions)
def make_whatfits_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
length = gr.Number(label="Length (cm)", value=20)
width = gr.Number(label="Width (cm)", value=10)
height = gr.Number(label="Height (cm)", value=8)
img = gr.Image(label="Bag Image (optional)", type="pil")
btn = gr.Button("Generate — What Fits Inside")
with gr.Column(scale=2):
out_txt = gr.Textbox(label="Result / Suggestions", interactive=False)
out_img = gr.Image(label="Preview (if any)")
err = gr.Textbox(label="Error", interactive=False)
def wf_submit(endpoint, length, width, height, img):
"""
- Map your UI fields to the expected API field names.
- Always return three values: (err_str_or_None, out_image_or_None, out_text_or_empty_str)
"""
# Map to expected backend keys (change these if your backend uses different names)
data = {
"product_length": length,
"product_width": width,
"product_height": height,
}
extra = {}
# Some backends expect the main image field to be 'garment_images' in files.
if img is not None:
extra["garment_images"] = ("bag.jpg", image_to_bytes_pil(img), "image/jpeg")
err_msg, resp = call_api(endpoint, data, main_image=None, extra_files=extra)
# If call_api reported an error, return that as the error textbox + empty image/text outputs
if err_msg:
return err_msg, None, ""
# resp may be a PIL.Image or text
try:
if isinstance(resp, Image.Image):
return None, resp, ""
except Exception:
pass
# fallback: return textual response into out_txt
return None, None, str(resp)
btn.click(fn=lambda *args, ep=endpoint: wf_submit(ep, *args),
inputs=[length, width, height, img],
outputs=[err, out_img, out_txt])
# Helper for size comparison (dimensions)
# Replace your current make_sizecompare_subtab with this version:
def make_sizecompare_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
ref_img = gr.Image(label="Reference Product Image (optional)", type="pil")
ref_l = gr.Number(label="Reference Length (cm)", value=10)
ref_w = gr.Number(label="Reference Width (cm)", value=5)
ref_h = gr.Number(label="Reference Height (cm)", value=2)
cmp_img = gr.Image(label="Compare Product Image (optional)", type="pil")
cmp_l = gr.Number(label="Compare Length (cm)", value=12)
cmp_w = gr.Number(label="Compare Width (cm)", value=6)
cmp_h = gr.Number(label="Compare Height (cm)", value=3)
btn = gr.Button("Compare Sizes")
with gr.Column(scale=2):
out_img = gr.Image(label="Comparison Visualization / Preview")
out_txt = gr.Textbox(label="Comparison Result", interactive=False)
err = gr.Textbox(label="Error", interactive=False)
def size_submit(ep, garment_type_val, ref_img, ref_l, ref_w, ref_h, cmp_img, cmp_l, cmp_w, cmp_h):
# Map UI label -> backend expected token (edit this mapping to match your backend)
mapping = {
"clothing": "clothing",
"jewellery and watches": "jewellery and watches",
"wallet": "wallet",
"shoes": "shoes",
"handbags": "handbags", # you might need "bag" or "handbag" depending on backend
"bags": "handbags",
}
# Defensive: ensure we have a garment type
if not garment_type_val:
return "Please select a garment type in the Model Settings (right panel).", None, ""
# translate using mapping; fall back to original value
send_garment_type = mapping.get(str(garment_type_val).strip().lower(), garment_type_val)
# The backend requires at least one image as 'garment_images'
if (ref_img is None) and (cmp_img is None):
return "Please upload at least one product image (reference or compare).", None, ""
# Prepare payload
data = {
"garment_type": send_garment_type,
"ref_length_cm": ref_l,
"ref_width_cm": ref_w,
"ref_height_cm": ref_h,
"cmp_length_cm": cmp_l,
"cmp_width_cm": cmp_w,
"cmp_height_cm": cmp_h,
}
# Build files ensuring 'garment_images' exists (required by your backend)
extra_files = {}
if ref_img is not None:
extra_files["garment_images"] = ("ref.jpg", image_to_bytes_pil(ref_img), "image/jpeg")
extra_files["ref_image"] = ("ref.jpg", image_to_bytes_pil(ref_img), "image/jpeg")
if cmp_img is not None:
if "garment_images" not in extra_files:
extra_files["garment_images"] = ("cmp.jpg", image_to_bytes_pil(cmp_img), "image/jpeg")
extra_files["cmp_image"] = ("cmp.jpg", image_to_bytes_pil(cmp_img), "image/jpeg")
# debug log: print what we are sending
print("Calling size endpoint:", ep, "garment_type_sent:", send_garment_type, "data:", data.keys(), "files:", list(extra_files.keys()))
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
# If server returned an error, include the garment_type we sent to help debugging
if err_msg:
# Show server message + what we attempted to send
return f"{err_msg}\n(garment_type sent: {send_garment_type})", None, ""
# resp may be an image or text
if isinstance(resp, Image.Image):
return None, resp, ""
return None, None, str(resp)
# Note: we capture the endpoint via default arg ep=endpoint in the lambda below.
# Inputs list must contain component objects only (no plain strings).
btn.click(
fn=lambda garment_type_val, ref_img, ref_l, ref_w, ref_h, cmp_img, cmp_l, cmp_w, cmp_h, ep=endpoint:
size_submit(ep, garment_type_val, ref_img, ref_l, ref_w, ref_h, cmp_img, cmp_l, cmp_w, cmp_h),
inputs=[garment_type_sel, ref_img, ref_l, ref_w, ref_h, cmp_img, cmp_l, cmp_w, cmp_h],
outputs=[err, out_img, out_txt]
)
# Helper for Outfit & Product Visualization (bag image + additional product image)
def make_outfit_product_viz_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
bag_img = gr.Image(label="Bag Image (primary)", type="pil")
add_img = gr.Image(label="Additional Product Image (secondary)", type="pil")
btn = gr.Button("Generate Outfit & Product Visualization")
with gr.Column(scale=2):
out_img = gr.Image(label="Visualization")
err = gr.Textbox(label="Error", interactive=False)
def opv_submit(
ep,
garment_type_val,
model_gender_val,
model_age_val,
model_body_val,
model_ethnicity_val,
model_pose_val,
camera_view_val,
camera_distance_val,
camera_focal_val,
camera_aperture_val,
camera_lighting_val,
camera_background_val,
bag_img,
add_img,
):
# Required fields check
missing = []
if not garment_type_val:
missing.append("garment_type")
# require exactly two images
if bag_img is None or add_img is None:
missing.append("two garment images (bag_img and add_img)")
required_fields = {
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
for k, v in required_fields.items():
if v is None or (isinstance(v, str) and not str(v).strip()):
missing.append(k)
if missing:
return f"Missing required fields: {', '.join(missing)}", None
data = {
"garment_type": garment_type_val,
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
# Build extra_files: put both images under 'garment_images'
extra_files = {}
# supply a list of tuples so call_api will send two files under the same field
extra_files["garment_images"] = [
("bag_primary.jpg", image_to_bytes_pil(bag_img), "image/jpeg"),
("bag_secondary.jpg", image_to_bytes_pil(add_img), "image/jpeg")
]
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
if err_msg:
return err_msg, None
if isinstance(resp, Image.Image):
return None, resp
return None, None
btn.click(
fn=lambda *args, ep=endpoint: opv_submit(
ep,
args[0], # garment_type_sel
args[1], # model_gender_sel
args[2], # model_age_sel
args[3], # model_body_sel
args[4], # model_ethnicity_sel
args[5], # model_pose_sel
args[6], # view_angle_sel
args[7], # distance_num
args[8], # focal_num
args[9], # aperture_num
args[10], # lighting_sel
args[11], # background_sel
args[12], # bag_img
args[13], # add_img
),
inputs=[
garment_type_sel, model_gender_sel, model_age_sel, model_body_sel, model_ethnicity_sel,
model_pose_sel, view_angle_sel, distance_num, focal_num, aperture_num,
lighting_sel, background_sel, bag_img, add_img
],
outputs=[err, out_img]
)
# Wallets: interactive size guide (two wallet images + dims)
def make_wallet_sizeguide_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
wallet1 = gr.Image(label="Upload wallet image 1", type="pil")
l1 = gr.Number(label="First Image Length (cm)", value=10)
w1 = gr.Number(label="First Image Width (cm)", value=7)
h1 = gr.Number(label="First Image Height (cm)", value=2)
wallet2 = gr.Image(label="Upload wallet image 2", type="pil")
l2 = gr.Number(label="Second Image Length (cm)", value=10)
w2 = gr.Number(label="Second Image Width (cm)", value=7)
h2 = gr.Number(label="Second Image Height (cm)", value=2)
btn = gr.Button("Generate Interactive Size Guide")
with gr.Column(scale=2):
out_img = gr.Image(label="Guide / Visualization")
out_txt = gr.Textbox(label="Notes", interactive=False)
err = gr.Textbox(label="Error", interactive=False)
def walletsize_submit(endpoint, w1_img, l1, w1_, h1_, w2_img, l2, w2_, h2_):
"""
Returns a tuple of (err_text_or_None, out_image_or_None, out_text_or_empty_string)
Always returns 3 values so Gradio won't raise 'didn't return enough outputs'.
"""
# --- Input validation (avoid 422 from backend) ---
missing = []
if l1 is None or w1_ is None or h1_ is None:
missing.append("first image dimensions (product_length1/product_width1/product_height1)")
if l2 is None or w2_ is None or h2_ is None:
missing.append("second image dimensions (product_length2/product_width2/product_height2)")
if missing:
return (f"Missing required dimensions: {', '.join(missing)}", None, "")
# Backend expects field names like product_length1/product_width1/product_height1 etc.
data = {
"product_length1": l1,
"product_width1": w1_,
"product_height1": h1_,
"product_length2": l2,
"product_width2": w2_,
"product_height2": h2_,
}
# Build files payload. Ensure 'garment_images' exists and is a list when multiple files.
extra = {}
images_list = []
if w1_img is not None:
images_list.append(("wallet1.jpg", image_to_bytes_pil(w1_img), "image/jpeg"))
extra["wallet_image_1"] = ("wallet1.jpg", image_to_bytes_pil(w1_img), "image/jpeg")
if w2_img is not None:
images_list.append(("wallet2.jpg", image_to_bytes_pil(w2_img), "image/jpeg"))
extra["wallet_image_2"] = ("wallet2.jpg", image_to_bytes_pil(w2_img), "image/jpeg")
if images_list:
# If only one image, call_api accepts tuple or list; prefer list for consistency
extra["garment_images"] = images_list
# Call API
err_msg, resp = call_api(endpoint, data, main_image=None, extra_files=extra)
# Always return triple: (err, image, text)
if err_msg:
# err_msg is the server error string (e.g. the 422 text). Return it in the first textbox.
return (err_msg, None, "")
# resp might be a PIL.Image or textual response
try:
if isinstance(resp, Image.Image):
return (None, resp, "")
except Exception:
# in case Image isn't imported here or type-checking fails, fallback to str
pass
# fallback: return textual response into the third textbox
return (None, None, str(resp) if resp is not None else "")
btn.click(fn=lambda *args, ep=endpoint: walletsize_submit(ep, *args),
inputs=[wallet1, l1, w1, h1, wallet2, l2, w2, h2],
outputs=[err, out_img, out_txt])
# Wallets: cross-category pairing (wallet image + garment image)
def make_wallet_crosspair_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
wallet_img = gr.Image(label="Wallet Image", type="pil")
garment_img = gr.Image(label="Garment Image", type="pil")
btn = gr.Button("Generate Cross-Category Pairing")
with gr.Column(scale=2):
out_img = gr.Image(label="Visualization")
err = gr.Textbox(label="Error", interactive=False)
def wc_submit(
ep,
garment_type_val,
model_gender_val,
model_age_val,
model_body_val,
model_ethnicity_val,
model_pose_val,
camera_view_val,
camera_distance_val,
camera_focal_val,
camera_aperture_val,
camera_lighting_val,
camera_background_val,
wallet_img,
garment_img,
):
# Validate required fields
missing = []
if not garment_type_val:
missing.append("garment_type")
# at least one image is required (server also expects garment_images)
if wallet_img is None and garment_img is None:
missing.append("at least one image (wallet or garment)")
required_fields = {
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
for k, v in required_fields.items():
if v is None or (isinstance(v, str) and not str(v).strip()):
missing.append(k)
if missing:
return f"Missing required fields: {', '.join(missing)}", None
# Prepare payload with all required model/camera fields
data = {
"garment_type": garment_type_val,
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
# Build extra_files: put available images under 'garment_images' (list)
extra_files = {}
images_list = []
if wallet_img is not None:
images_list.append(("wallet.jpg", image_to_bytes_pil(wallet_img), "image/jpeg"))
if garment_img is not None:
images_list.append(("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg"))
# Must send at least one under 'garment_images' — send as list so call_api sends multiple files if present
if images_list:
extra_files["garment_images"] = images_list
# Optionally also attach them under named keys if backend accepts them
if wallet_img is not None:
extra_files["wallet_image"] = ("wallet.jpg", image_to_bytes_pil(wallet_img), "image/jpeg")
if garment_img is not None:
extra_files["garment_image"] = ("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg")
# Debug log
print("Calling cross-pair endpoint:", ep, "garment_type_sent:", garment_type_val, "files:", list(extra_files.keys()))
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
if err_msg:
return err_msg, None
if isinstance(resp, Image.Image):
return None, resp
return None, None
btn.click(
fn=lambda *args, ep=endpoint: wc_submit(
ep,
args[0], # garment_type_sel
args[1], # model_gender_sel
args[2], # model_age_sel
args[3], # model_body_sel
args[4], # model_ethnicity_sel
args[5], # model_pose_sel
args[6], # view_angle_sel
args[7], # distance_num
args[8], # focal_num
args[9], # aperture_num
args[10], # lighting_sel
args[11], # background_sel
args[12], # wallet_img
args[13], # garment_img
),
inputs=[
garment_type_sel, model_gender_sel, model_age_sel, model_body_sel, model_ethnicity_sel,
model_pose_sel, view_angle_sel, distance_num, focal_num, aperture_num,
lighting_sel, background_sel, wallet_img, garment_img
],
outputs=[err, out_img]
)
# Jewellery: outfit-to-jewelry match (jewelry image + garment image)
def make_jewellery_match_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
jewelry_img = gr.Image(label="Jewelry Image", type="pil")
garment_img = gr.Image(label="Garment Image", type="pil")
btn = gr.Button("Generate Match Visualization")
with gr.Column(scale=2):
out_img = gr.Image(label="Visualization")
err = gr.Textbox(label="Error", interactive=False)
def jm_submit(
ep,
garment_type_val,
model_gender_val,
model_age_val,
model_body_val,
model_ethnicity_val,
model_pose_val,
camera_view_val,
camera_distance_val,
camera_focal_val,
camera_aperture_val,
camera_lighting_val,
camera_background_val,
jewelry_img,
garment_img,
):
# Validate required fields
missing = []
if not garment_type_val:
missing.append("garment_type")
# at least one image must be provided as garment_images
if jewelry_img is None and garment_img is None:
missing.append("at least one image (jewelry or garment)")
required_fields = {
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
for k, v in required_fields.items():
if v is None or (isinstance(v, str) and not str(v).strip()):
missing.append(k)
if missing:
return f"Missing required fields: {', '.join(missing)}", None
# Build payload
data = {
"garment_type": garment_type_val,
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
# Build files: attach both images under 'garment_images' (list)
images_list = []
if jewelry_img is not None:
images_list.append(("jewelry.jpg", image_to_bytes_pil(jewelry_img), "image/jpeg"))
if garment_img is not None:
images_list.append(("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg"))
extra_files = {}
if images_list:
extra_files["garment_images"] = images_list
# Also attach named keys for compatibility
if jewelry_img is not None:
extra_files["jewelry_image"] = ("jewelry.jpg", image_to_bytes_pil(jewelry_img), "image/jpeg")
if garment_img is not None:
extra_files["garment_image"] = ("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg")
# Debug
print("Calling outfit->jewelry endpoint:", ep, "garment_type:", garment_type_val, "files:", list(extra_files.keys()))
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
if err_msg:
return err_msg, None
if isinstance(resp, Image.Image):
return None, resp
# fallback: show textual response in err box
return None, None
btn.click(
fn=lambda *args, ep=endpoint: jm_submit(
ep,
args[0], # garment_type_sel
args[1], # model_gender_sel
args[2], # model_age_sel
args[3], # model_body_sel
args[4], # model_ethnicity_sel
args[5], # model_pose_sel
args[6], # view_angle_sel
args[7], # distance_num
args[8], # focal_num
args[9], # aperture_num
args[10], # lighting_sel
args[11], # background_sel
args[12], # jewelry_img
args[13], # garment_img
),
inputs=[
garment_type_sel, model_gender_sel, model_age_sel, model_body_sel, model_ethnicity_sel,
model_pose_sel, view_angle_sel, distance_num, focal_num, aperture_num,
lighting_sel, background_sel, jewelry_img, garment_img
],
outputs=[err, out_img]
)
# Jewellery: size comparison with dimensions
def make_jewellery_size_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
jewelry_img = gr.Image(label="Jewelry Image", type="pil")
l = gr.Number(label="Length (cm)")
w = gr.Number(label="Width (cm)")
h = gr.Number(label="Height (cm)")
btn = gr.Button("Generate Size Comparison")
with gr.Column(scale=2):
out_img = gr.Image(label="Size Comparison Result")
err = gr.Textbox(label="Error", interactive=False)
def js_submit(ep, garment_type_val, jewelry_img, l, w, h):
# Validate required fields
missing = []
if not garment_type_val:
missing.append("garment_type")
if jewelry_img is None:
missing.append("garment_images (jewelry image)")
if missing:
return f"Missing required fields: {', '.join(missing)}", None
# Build payload using expected keys
data = {
"garment_type": garment_type_val,
"length_cm": l,
"width_cm": w,
"height_cm": h,
}
# Files: send the jewelry image as 'garment_images' (server expects this)
extra_files = {}
extra_files["garment_images"] = [("jewelry.jpg", image_to_bytes_pil(jewelry_img), "image/jpeg")]
# also attach under a named key for compatibility
extra_files["jewelry_image"] = ("jewelry.jpg", image_to_bytes_pil(jewelry_img), "image/jpeg")
# Debug log (optional)
print("Calling jewellery size endpoint:", ep, "garment_type_sent:", garment_type_val, "data_keys:", data.keys())
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
if err_msg:
return err_msg, None
# If the API returns an image, show it in out_img
if isinstance(resp, Image.Image):
return None, resp
# Otherwise show textual response as an error/info message and no image
return None, None
btn.click(
fn=lambda *args, ep=endpoint: js_submit(ep, *args),
inputs=[garment_type_sel, jewelry_img, l, w, h],
outputs=[err, out_img]
)
# Shoes: outfit match preview (shoe image + garment image)
def make_shoes_outfit_subtab(tab_label, endpoint):
with gr.Tab(tab_label):
with gr.Row():
with gr.Column(scale=1):
shoe_img = gr.Image(label="Shoe Image", type="pil")
garment_img = gr.Image(label="Garment Image", type="pil")
btn = gr.Button("Generate Outfit Match Preview")
with gr.Column(scale=2):
out_img = gr.Image(label="Visualization")
err = gr.Textbox(label="Error", interactive=False)
def so_submit(
ep,
garment_type_val,
model_gender_val,
model_age_val,
model_body_val,
model_ethnicity_val,
model_pose_val,
camera_view_val,
camera_distance_val,
camera_focal_val,
camera_aperture_val,
camera_lighting_val,
camera_background_val,
shoe_img,
garment_img,
):
# Validate required fields
missing = []
if not garment_type_val:
missing.append("garment_type")
# ensure at least one image
if shoe_img is None and garment_img is None:
missing.append("at least one image (shoe or garment)")
required_fields = {
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
for k, v in required_fields.items():
if v is None or (isinstance(v, str) and not str(v).strip()):
missing.append(k)
if missing:
return f"Missing required fields: {', '.join(missing)}", None
# Build payload using required keys
data = {
"garment_type": garment_type_val,
"model_gender": model_gender_val,
"model_age_range": model_age_val,
"model_body_shape": model_body_val,
"model_race_ethnicity": model_ethnicity_val,
"model_pose": model_pose_val,
"camera_view_angle": camera_view_val,
"camera_distance_meters": camera_distance_val,
"camera_focal_length_mm": camera_focal_val,
"camera_aperture_f_number": camera_aperture_val,
"camera_lighting_condition": camera_lighting_val,
"camera_background": camera_background_val,
}
# Prepare files: put available images under 'garment_images' as a list
images_list = []
if shoe_img is not None:
images_list.append(("shoe.jpg", image_to_bytes_pil(shoe_img), "image/jpeg"))
if garment_img is not None:
images_list.append(("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg"))
extra_files = {}
if images_list:
extra_files["garment_images"] = images_list
# Also send named keys for compatibility
if shoe_img is not None:
extra_files["shoe_image"] = ("shoe.jpg", image_to_bytes_pil(shoe_img), "image/jpeg")
if garment_img is not None:
extra_files["garment_image"] = ("garment.jpg", image_to_bytes_pil(garment_img), "image/jpeg")
# Debug log
print("Calling shoes outfit-match endpoint:", ep, "garment_type:", garment_type_val, "files:", list(extra_files.keys()))
err_msg, resp = call_api(ep, data, main_image=None, extra_files=extra_files)
if err_msg:
return err_msg, None
if isinstance(resp, Image.Image):
return None, resp
return None, None
btn.click(
fn=lambda *args, ep=endpoint: so_submit(
ep,
args[0], # garment_type_sel
args[1], # model_gender_sel
args[2], # model_age_sel
args[3], # model_body_sel
args[4], # model_ethnicity_sel
args[5], # model_pose_sel
args[6], # view_angle_sel
args[7], # distance_num
args[8], # focal_num
args[9], # aperture_num
args[10], # lighting_sel
args[11], # background_sel
args[12], # shoe_img
args[13], # garment_img
),
inputs=[
garment_type_sel, model_gender_sel, model_age_sel, model_body_sel, model_ethnicity_sel,
model_pose_sel, view_angle_sel, distance_num, focal_num, aperture_num,
lighting_sel, background_sel, shoe_img, garment_img
],
outputs=[err, out_img]
)
# Now build the tabs & subtabs
with gr.Tabs():
# Clothes
with gr.Tab("👕 Clothes"):
with gr.Tabs():
make_shared_subtab("Lifestyle Images", "tryon")
make_shared_subtab("Studio & Minimal", "tryon")
make_shared_subtab("Flat Lay & Ghost", "mannequin")
make_shared_subtab("Editorial & Fashion", "editorial")
make_shared_subtab("Detail Shots", "detail")
# make_background_edit_subtab("Background Edit", "background_edit")
# Bags
with gr.Tab("👜 Bags"):
with gr.Tabs():
make_shared_subtab("Model Try-On Simulation", "tryon")
make_whatfits_subtab("What Fits Inside", "whatfits")
make_sizecompare_subtab("Smart Size Comparison", "size")
make_outfit_product_viz_subtab("Outfit & Product Visualization", "outfit-match")
make_shared_subtab("Multiview", "multi_view")
make_background_edit_subtab("Background Edit", "background_edit")
make_shared_subtab("Detail Shots", "detail")
# Wallets
with gr.Tab("👛 Wallets"):
with gr.Tabs():
make_wallet_sizeguide_subtab("Interactive Size Guide", "walletsize")
make_wallet_crosspair_subtab("Cross-Category Pairing", "outfit-match")
make_shared_subtab("Occasion-Based Styling", "occasion")
make_shared_subtab("Detail Shots", "detail")
make_shared_subtab("Multiview", "multi_view")
make_background_edit_subtab("Background Edit", "background_edit")
# Jewellery & Watches
with gr.Tab("💍 Jewelry & Watches"):
with gr.Tabs():
make_shared_subtab("AI Model Shot Generator", "tryon")
make_jewellery_match_subtab("Outfit-to-Jewelry Match Visualizer", "outfit-match")
make_shared_subtab("Occasion-Based Styling Suggestions", "occasion")
make_jewellery_size_subtab = make_jewellery_size_subtab # alias to avoid redefinition issues
make_jewellery_size_subtab("Visual Size Comparison Tool", "size")
# Shoes
with gr.Tab("👟 Shoes"):
with gr.Tabs():
make_shared_subtab("Model Shot Generator", "tryon")
make_shoes_outfit_subtab("Outfit Match Preview", "outfit-match")
make_shared_subtab("Multiview", "multi_view")
make_background_edit_subtab("Background Edit", "background_edit")
make_shared_subtab("Detail Shots", "detail")
with gr.Tab("🛋 Furniture"):
# --- Reusable helper (make sure it's defined once in the file) ---
def analyze_room(image_path: str):
if not API_KEY:
return "GEMINI_API_KEY not configured.", None, gr.update(visible=False)
if not image_path:
return "Please upload an image first.", None, gr.update(visible=False)
model = genai.GenerativeModel(model_name="gemini-2.5-pro")
prompt = """
Based on the image provided:
In one word, what type of room is this?
In one word, what is the hero object in the room (the object that covers the most area)?
"""
try:
if not os.path.exists(image_path):
return f"Error: Image file not found at '{image_path}'", None, gr.update(visible=False)
image_file = genai.upload_file(path=image_path)
response = model.generate_content([prompt, image_file])
response_text = response.text.strip().lower()
try:
genai.delete_file(image_file.name)
except Exception:
pass
# load example images as PIL images (in-memory) for the gallery
if "bedroom" in response_text or "bed" in response_text:
gallery_images = [Image.open(p).convert("RGB") for p in BED_IMAGE_PATHS if os.path.exists(p)]
if not gallery_images:
return "No example bed images found.", "bed", gr.update(visible=False)
return "Analysis: Bedroom\nHero Object: Bed", "bed", gr.update(visible=True, value=gallery_images)
elif "livingroom" in response_text or "sofa" in response_text or "living room" in response_text:
gallery_images = [Image.open(p).convert("RGB") for p in SOFA_IMAGE_PATHS if os.path.exists(p)]
if not gallery_images:
return "No example sofa images found.", "sofa", gr.update(visible=False)
return "Analysis: Living Room\nHero Object: Sofa", "sofa", gr.update(visible=True, value=gallery_images)
else:
return "Could not determine room type.", None, gr.update(visible=False)
except Exception as e:
try:
if 'image_file' in locals() and hasattr(image_file, 'name'):
genai.delete_file(image_file.name)
except Exception:
pass
return f"An error occurred: {e}", None, gr.update(visible=False)
def replace_object(original_image_arg, example_image_arg, hero_object):
"""
original_image_arg: filepath (str) or PIL.Image
example_image_arg: PIL.Image or filepath (str)
"""
if not API_KEY:
return "GEMINI_API_KEY not configured.", None
if original_image_arg is None or example_image_arg is None:
return "Please upload an image and select an example style.", None
if not hero_object:
return "Analysis must be run first to identify the object.", None
model = genai.GenerativeModel(model_name="gemini-2.5-flash-image")
prompt = f"Replace the {hero_object} in the first image with the style and type of the {hero_object} from the second image. Maintain the original room's background, lighting, and perspective."
orig_tmp = None
ex_tmp = None
original_image_file = None
example_image_file = None
try:
# handle original (can be filepath or PIL)
if isinstance(original_image_arg, Image.Image):
orig_tmp = pil_to_tempfile(original_image_arg)
original_image_file = genai.upload_file(path=orig_tmp)
elif isinstance(original_image_arg, str) and os.path.exists(original_image_arg):
original_image_file = genai.upload_file(path=original_image_arg)
else:
return "Original image invalid.", None
# handle example (can be PIL or filepath)
if isinstance(example_image_arg, Image.Image):
ex_tmp = pil_to_tempfile(example_image_arg)
example_image_file = genai.upload_file(path=ex_tmp)
elif isinstance(example_image_arg, str) and os.path.exists(example_image_arg):
example_image_file = genai.upload_file(path=example_image_arg)
else:
return "Selected example image invalid.", None
response = model.generate_content([
prompt,
original_image_file,
example_image_file
])
# cleanup uploaded files on genai side
try:
if original_image_file and hasattr(original_image_file, 'name'):
genai.delete_file(original_image_file.name)
if example_image_file and hasattr(example_image_file, 'name'):
genai.delete_file(example_image_file.name)
except Exception:
pass
if not getattr(response, "candidates", None):
return "Generation failed: no candidates in response.", None
generated_image_data = response.candidates[0].content.parts[0].inline_data.data
generated_image = Image.open(io.BytesIO(generated_image_data)).convert("RGB")
return "Image generated successfully.", generated_image
except Exception as e:
# debug print
print("replace_object error:", e)
if 'response' in locals():
print("RAW RESPONSE:", response)
return f"An unexpected error occurred: {e}", None
finally:
# cleanup temp files
for p in (orig_tmp, ex_tmp):
if p and os.path.exists(p):
try:
os.remove(p)
except Exception:
pass
# --- UI layout for the Furniture tab ---
gr.Markdown("### Furniture / Room Styler")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Step 1 — Upload & Analyze")
furniture_input_image = gr.Image(type="filepath", label="Upload Room Image")
furniture_analyze_button = gr.Button("Analyze Room", variant="secondary")
furniture_analysis_output = gr.Textbox(label="Analysis Result", interactive=False)
# states
furniture_hero_state = gr.State()
furniture_selected_example_state = gr.State()
with gr.Column(scale=2):
gr.Markdown("#### Step 2 — Choose Style Example")
furniture_example_gallery = gr.Gallery(
label="Example Images (select one)",
visible=False,
columns=2,
height="auto",
object_fit="contain"
)
gr.Markdown("---")
gr.Markdown("#### Step 3 — Generate Replacement")
furniture_replace_button = gr.Button("Generate New Room Image", variant="primary")
with gr.Row():
furniture_output_text = gr.Textbox(label="Generation Status", interactive=False, scale=1)
furniture_output_image = gr.Image(label="Your New Room", interactive=False, scale=2)
# Wire up events
furniture_analyze_button.click(
fn=analyze_room,
inputs=furniture_input_image,
outputs=[furniture_analysis_output, furniture_hero_state, furniture_example_gallery]
)
# gallery selection returns object; we extract path -> store in state
def get_selected_image(evt: gr.SelectData):
"""
Accepts a gallery selection event.
Returns: either a PIL.Image (preferred) or a filepath string.
"""
try:
val = evt.value
# If gradio returned a PIL image directly
if isinstance(val, Image.Image):
return val
# If it's a simple string path
if isinstance(val, str):
return val
# If it's a dict wrapper, try to extract
if isinstance(val, dict):
img = val.get("image")
if isinstance(img, Image.Image):
return img
if isinstance(img, dict) and "path" in img:
return img["path"]
# sometimes gallery returns {"name":..., "data":...}
# best-effort: return the PIL object if present
for v in val.values():
if isinstance(v, Image.Image):
return v
# fallback: stringify
return str(val)
# fallback
return val
except Exception as e:
print("Error extracting selected example:", e)
return None
furniture_example_gallery.select(
fn=get_selected_image,
inputs=None,
outputs=furniture_selected_example_state
)
furniture_replace_button.click(
fn=replace_object,
inputs=[furniture_input_image, furniture_selected_example_state, furniture_hero_state],
outputs=[furniture_output_text, furniture_output_image]
)
if __name__ == "__main__":
demo.launch(allowed_paths=["app/furniture"])