AI_Try_ON / app.py
Akshajzclap's picture
Update app.py
1c6ece1 verified
import os
import io
from pathlib import Path
from PIL import Image
import gradio as gr
from google import genai
from google.genai import types
from dotenv import load_dotenv
from huggingface_hub import HfApi, Repository
import shutil
load_dotenv()
# ---------------- CONFIG ---------------- #
MODEL_NAME = "gemini-2.5-flash-image"
API_KEY_ENV = "GEMINI_API_KEY"
PORT = 5467
HF_TOKEN_SECRET = "HF_TOKEN"
PRODUCT_DIR = Path("inputs/input_product")
REFERENCE_DIR = Path("inputs/input_reference")
THUMB_W = 400
# ---------------- HELPERS ---------------- #
def list_files(prefix: str, directory: Path, suffix=".jpg"):
return sorted(directory.glob(f"{prefix}*{suffix}"))
def load_image_safe(path: Path):
if path.exists():
return Image.open(path).convert("RGB")
return None
def pil_to_bytes(img: Image.Image):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# ---------------- GEMINI ---------------- #
def generate_with_gemini(instruction_prompt, product_img, reference_img):
"""Call Gemini image model using the new google-genai client API."""
api_key = os.getenv(API_KEY_ENV)
if not api_key:
raise gr.Error("Missing GEMINI_API_KEY environment variable.")
# βœ… Use new client interface
client = genai.Client(api_key=api_key)
# βœ… Build an interleaved content list with both images and text
parts = [
types.Part(text="You are an expert at photorealistic product compositing."),
types.Part(text="Here is the product image:"),
types.Part(inline_data=types.Blob(mime_type="image/png", data=pil_to_bytes(product_img))),
types.Part(text="Here is the reference image where the product should be placed:"),
types.Part(inline_data=types.Blob(mime_type="image/png", data=pil_to_bytes(reference_img))),
types.Part(text=f"Follow these detailed instructions carefully: {instruction_prompt}")
]
contents = [types.Content(role="user", parts=parts)]
# βœ… Configure to request both image and text (though we expect an image)
config = types.GenerateContentConfig(response_modalities=["IMAGE"])
try:
# βœ… Stream responses to handle large images efficiently
generated_image = None
for chunk in client.models.generate_content_stream(
model=MODEL_NAME, contents=contents, config=config
):
if not chunk.candidates:
continue
candidate = chunk.candidates[0]
if not candidate.content or not candidate.content.parts:
continue
part = candidate.content.parts[0]
inline = getattr(part, "inline_data", None)
if inline and inline.data:
generated_image = Image.open(io.BytesIO(inline.data)).convert("RGB")
break
if generated_image is None:
raise gr.Error("Model returned no image (possibly blocked or empty output).")
return generated_image
except Exception as e:
print(f"⚠️ Error generating image: {e}")
if "safety" in str(e).lower():
raise gr.Error("⚠️ Blocked by safety filter. Try using a more neutral reference image or prompt.")
raise gr.Error(f"Gemini API error: {e}")
import hashlib
CACHE_DIR = Path("generated_cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
def get_cache_key(category: str, product_id: str, reference_id: str):
"""Generate a readable cache key for this combination."""
return f"{category}_{product_id}_{reference_id}"
def get_cache_path(category: str, product_id: str, reference_id: str):
"""Return the file path for this cached combo."""
cache_folder = CACHE_DIR / category
cache_folder.mkdir(parents=True, exist_ok=True)
file_name = f"{product_id}_{reference_id}.png"
return cache_folder / file_name
def generate_action(prompt, selected_product, uploaded_product, selected_reference, uploaded_reference):
"""Main generation logic with robust checks and corrected API calls."""
category = "faucet" if "faucet" in prompt.lower() else "pantry"
# --- ROBUST IMAGE SELECTION (Fix for TypeError) ---
if uploaded_product is not None:
product_img = uploaded_product
product_id = "uploaded_product"
elif selected_product is not None:
product_img = load_image_safe(Path(selected_product))
product_id = Path(selected_product).stem
else:
raise gr.Error("Please select or upload a product image.")
if uploaded_reference is not None:
reference_img = uploaded_reference
reference_id = "uploaded_reference"
elif selected_reference is not None:
reference_img = load_image_safe(Path(selected_reference))
reference_id = Path(selected_reference).stem
else:
raise gr.Error("Please select or upload a reference image.")
# --- Caching Logic ---
cache_path = get_cache_path(category, product_id, reference_id)
if cache_path.exists():
print(f"βœ… Loading from permanent cache: {cache_path}")
cached_output = Image.open(cache_path).convert("RGB")
return [reference_img, cached_output], cached_output
print(f"βš™οΈ Generating new image for: {category} | {product_id} + {reference_id}")
generated = generate_with_gemini(prompt, product_img, reference_img)
if generated is None:
raise gr.Error("Model returned no image.")
generated.save(cache_path)
print(f"πŸ’Ύ Saved new cached result: {cache_path}")
# --- CORRECTED COMMIT AND PUSH BLOCK (Fix for 'SpaceInfo' Error) ---
try:
token = os.getenv("HF_TOKEN")
repo_id = os.getenv("SPACE_ID") # HF Spaces automatically provides this
if token and repo_id:
api = HfApi()
# The incorrect api.space_info call is removed.
api.upload_file(
path_or_fileobj=str(cache_path),
path_in_repo=str(cache_path),
repo_id=repo_id, # Use repo_id directly
repo_type="space",
token=token,
commit_message=f"Cache: Add {cache_path.name}"
)
print(f"βœ… Pushed {cache_path.name} to permanent cache.")
else:
print("⚠️ HF_TOKEN or SPACE_ID not found. Skipping push to repo.")
except Exception as e:
print(f"⚠️ Failed to push to repo: {e}")
return [reference_img, generated], generated
# ---------------- TAB BUILDER ---------------- #
def make_tab(prefix: str, title: str):
product_files = list_files(prefix, PRODUCT_DIR)
reference_files = list_files(prefix, REFERENCE_DIR)
# Prompt per category
if prefix == "faucet":
prompt = "replace the faucet in the kitchen sink with the faucet given in the product image."
elif prefix == "pantry":
prompt = "Put the pantry unit in the kitchen on the ground and fill it with items"
else:
prompt = "Replace the product in the reference image with the product in the second image."
with gr.Tab(title):
selected_product = gr.State(None)
selected_reference = gr.State(None)
# 🧱 PRODUCT SECTION
gr.Markdown("### Product Selection")
with gr.Row():
# Selected preview (shows either clicked or uploaded)
selected_product_img = gr.Image(
label="Selected Product",
interactive=False,
visible=False,
width=150
)
# Product thumbnails
for f in product_files:
img = load_image_safe(f)
with gr.Column():
gr.Image(value=img, interactive=False, width=150)
def select_product(path=str(f)):
img = load_image_safe(Path(path))
return gr.update(value=img, visible=True), path
btn = gr.Button(f"Select {f.name}", size="sm")
btn.click(fn=select_product, inputs=[], outputs=[selected_product_img, selected_product])
# Upload product
upload_product = gr.Image(label="πŸ“€ Upload Product", type="pil", width=150)
# βœ… When upload changes, override selected image and state
def on_product_upload(upload_img):
if upload_img is not None:
return gr.update(value=upload_img, visible=True), "uploaded_product"
return gr.update(), None
upload_product.change(
fn=on_product_upload,
inputs=[upload_product],
outputs=[selected_product_img, selected_product],
)
# 🏠 REFERENCE SECTION
gr.Markdown("### Reference Selection")
with gr.Row():
selected_reference_img = gr.Image(
label="Selected Reference",
interactive=False,
visible=False,
width=150
)
for f in reference_files:
img = load_image_safe(f)
with gr.Column():
gr.Image(value=img, interactive=False, width=150)
def select_reference(path=str(f)):
img = load_image_safe(Path(path))
return gr.update(value=img, visible=True), path
btn = gr.Button(f"Select {f.name}", size="sm")
btn.click(fn=select_reference, inputs=[], outputs=[selected_reference_img, selected_reference])
upload_reference = gr.Image(label="πŸ“€ Upload Reference", type="pil", width=150)
# βœ… When upload changes, override selected reference
def on_reference_upload(upload_img):
if upload_img is not None:
return gr.update(value=upload_img, visible=True), "uploaded_reference"
return gr.update(), None
upload_reference.change(
fn=on_reference_upload,
inputs=[upload_reference],
outputs=[selected_reference_img, selected_reference],
)
# πŸš€ GENERATE SECTION
gr.Markdown("### Generate Composite")
generate_btn = gr.Button("πŸš€ Generate", variant="primary")
result_slider = gr.ImageSlider(label="Before / After Comparison", height=300)
output_img = gr.Image(label="Generated Image", interactive=False, visible=True, height=300)
generate_btn.click(
fn=generate_action,
inputs=[gr.State(prompt), selected_product, upload_product, selected_reference, upload_reference],
outputs=[result_slider, output_img],
)
# ---------------- MAIN APP ---------------- #
with gr.Blocks(
title="AI Product Try-On",
css="""
img {border-radius: 8px;}
.gr-button {margin-top: 4px;}
""",
) as demo:
gr.Markdown("""
# 🏠 AI Product Try-On
Select a product and a reference image (or upload your own).
Click **Generate** to visualize how the product appears in the reference environment.
""")
with gr.Tabs():
make_tab("faucet", "Faucet")
make_tab("pantry", "Pantry Unit")
# ---------------- RUN ---------------- #
if __name__ == "__main__":
demo.launch()