| import os |
| import io |
| from pathlib import Path |
| from PIL import Image |
| import gradio as gr |
| from google import genai |
| from google.genai import types |
| from dotenv import load_dotenv |
| from huggingface_hub import HfApi, Repository |
| import shutil |
| load_dotenv() |
|
|
| |
| MODEL_NAME = "gemini-2.5-flash-image" |
| API_KEY_ENV = "GEMINI_API_KEY" |
| PORT = 5467 |
| HF_TOKEN_SECRET = "HF_TOKEN" |
| PRODUCT_DIR = Path("inputs/input_product") |
| REFERENCE_DIR = Path("inputs/input_reference") |
|
|
| THUMB_W = 400 |
|
|
| |
| def list_files(prefix: str, directory: Path, suffix=".jpg"): |
| return sorted(directory.glob(f"{prefix}*{suffix}")) |
|
|
| def load_image_safe(path: Path): |
| if path.exists(): |
| return Image.open(path).convert("RGB") |
| return None |
|
|
| def pil_to_bytes(img: Image.Image): |
| buf = io.BytesIO() |
| img.save(buf, format="PNG") |
| return buf.getvalue() |
|
|
| |
|
|
| def generate_with_gemini(instruction_prompt, product_img, reference_img): |
| """Call Gemini image model using the new google-genai client API.""" |
| api_key = os.getenv(API_KEY_ENV) |
| if not api_key: |
| raise gr.Error("Missing GEMINI_API_KEY environment variable.") |
|
|
| |
| client = genai.Client(api_key=api_key) |
|
|
| |
| parts = [ |
| types.Part(text="You are an expert at photorealistic product compositing."), |
| types.Part(text="Here is the product image:"), |
| types.Part(inline_data=types.Blob(mime_type="image/png", data=pil_to_bytes(product_img))), |
| types.Part(text="Here is the reference image where the product should be placed:"), |
| types.Part(inline_data=types.Blob(mime_type="image/png", data=pil_to_bytes(reference_img))), |
| types.Part(text=f"Follow these detailed instructions carefully: {instruction_prompt}") |
| ] |
|
|
| contents = [types.Content(role="user", parts=parts)] |
|
|
| |
| config = types.GenerateContentConfig(response_modalities=["IMAGE"]) |
|
|
| try: |
| |
| generated_image = None |
| for chunk in client.models.generate_content_stream( |
| model=MODEL_NAME, contents=contents, config=config |
| ): |
| if not chunk.candidates: |
| continue |
| candidate = chunk.candidates[0] |
| if not candidate.content or not candidate.content.parts: |
| continue |
| part = candidate.content.parts[0] |
| inline = getattr(part, "inline_data", None) |
| if inline and inline.data: |
| generated_image = Image.open(io.BytesIO(inline.data)).convert("RGB") |
| break |
|
|
| if generated_image is None: |
| raise gr.Error("Model returned no image (possibly blocked or empty output).") |
|
|
| return generated_image |
|
|
| except Exception as e: |
| print(f"β οΈ Error generating image: {e}") |
| if "safety" in str(e).lower(): |
| raise gr.Error("β οΈ Blocked by safety filter. Try using a more neutral reference image or prompt.") |
| raise gr.Error(f"Gemini API error: {e}") |
|
|
|
|
|
|
|
|
| import hashlib |
|
|
| CACHE_DIR = Path("generated_cache") |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| def get_cache_key(category: str, product_id: str, reference_id: str): |
| """Generate a readable cache key for this combination.""" |
| return f"{category}_{product_id}_{reference_id}" |
|
|
| def get_cache_path(category: str, product_id: str, reference_id: str): |
| """Return the file path for this cached combo.""" |
| cache_folder = CACHE_DIR / category |
| cache_folder.mkdir(parents=True, exist_ok=True) |
| file_name = f"{product_id}_{reference_id}.png" |
| return cache_folder / file_name |
|
|
|
|
|
|
|
|
| def generate_action(prompt, selected_product, uploaded_product, selected_reference, uploaded_reference): |
| """Main generation logic with robust checks and corrected API calls.""" |
| category = "faucet" if "faucet" in prompt.lower() else "pantry" |
|
|
| |
| if uploaded_product is not None: |
| product_img = uploaded_product |
| product_id = "uploaded_product" |
| elif selected_product is not None: |
| product_img = load_image_safe(Path(selected_product)) |
| product_id = Path(selected_product).stem |
| else: |
| raise gr.Error("Please select or upload a product image.") |
|
|
| if uploaded_reference is not None: |
| reference_img = uploaded_reference |
| reference_id = "uploaded_reference" |
| elif selected_reference is not None: |
| reference_img = load_image_safe(Path(selected_reference)) |
| reference_id = Path(selected_reference).stem |
| else: |
| raise gr.Error("Please select or upload a reference image.") |
|
|
| |
| cache_path = get_cache_path(category, product_id, reference_id) |
| if cache_path.exists(): |
| print(f"β
Loading from permanent cache: {cache_path}") |
| cached_output = Image.open(cache_path).convert("RGB") |
| return [reference_img, cached_output], cached_output |
|
|
| print(f"βοΈ Generating new image for: {category} | {product_id} + {reference_id}") |
| |
| generated = generate_with_gemini(prompt, product_img, reference_img) |
| if generated is None: |
| raise gr.Error("Model returned no image.") |
|
|
| generated.save(cache_path) |
| print(f"πΎ Saved new cached result: {cache_path}") |
|
|
| |
| try: |
| token = os.getenv("HF_TOKEN") |
| repo_id = os.getenv("SPACE_ID") |
| if token and repo_id: |
| api = HfApi() |
| |
| |
| api.upload_file( |
| path_or_fileobj=str(cache_path), |
| path_in_repo=str(cache_path), |
| repo_id=repo_id, |
| repo_type="space", |
| token=token, |
| commit_message=f"Cache: Add {cache_path.name}" |
| ) |
| print(f"β
Pushed {cache_path.name} to permanent cache.") |
| else: |
| print("β οΈ HF_TOKEN or SPACE_ID not found. Skipping push to repo.") |
| |
| except Exception as e: |
| print(f"β οΈ Failed to push to repo: {e}") |
| |
| return [reference_img, generated], generated |
|
|
|
|
| |
| def make_tab(prefix: str, title: str): |
| product_files = list_files(prefix, PRODUCT_DIR) |
| reference_files = list_files(prefix, REFERENCE_DIR) |
|
|
| |
| if prefix == "faucet": |
| prompt = "replace the faucet in the kitchen sink with the faucet given in the product image." |
| elif prefix == "pantry": |
| prompt = "Put the pantry unit in the kitchen on the ground and fill it with items" |
| else: |
| prompt = "Replace the product in the reference image with the product in the second image." |
|
|
| with gr.Tab(title): |
| selected_product = gr.State(None) |
| selected_reference = gr.State(None) |
|
|
| |
| gr.Markdown("### Product Selection") |
| with gr.Row(): |
| |
| selected_product_img = gr.Image( |
| label="Selected Product", |
| interactive=False, |
| visible=False, |
| width=150 |
| ) |
|
|
| |
| for f in product_files: |
| img = load_image_safe(f) |
| with gr.Column(): |
| gr.Image(value=img, interactive=False, width=150) |
| def select_product(path=str(f)): |
| img = load_image_safe(Path(path)) |
| return gr.update(value=img, visible=True), path |
| btn = gr.Button(f"Select {f.name}", size="sm") |
| btn.click(fn=select_product, inputs=[], outputs=[selected_product_img, selected_product]) |
|
|
| |
| upload_product = gr.Image(label="π€ Upload Product", type="pil", width=150) |
|
|
| |
| def on_product_upload(upload_img): |
| if upload_img is not None: |
| return gr.update(value=upload_img, visible=True), "uploaded_product" |
| return gr.update(), None |
|
|
| upload_product.change( |
| fn=on_product_upload, |
| inputs=[upload_product], |
| outputs=[selected_product_img, selected_product], |
| ) |
|
|
| |
| gr.Markdown("### Reference Selection") |
| with gr.Row(): |
| selected_reference_img = gr.Image( |
| label="Selected Reference", |
| interactive=False, |
| visible=False, |
| width=150 |
| ) |
|
|
| for f in reference_files: |
| img = load_image_safe(f) |
| with gr.Column(): |
| gr.Image(value=img, interactive=False, width=150) |
| def select_reference(path=str(f)): |
| img = load_image_safe(Path(path)) |
| return gr.update(value=img, visible=True), path |
| btn = gr.Button(f"Select {f.name}", size="sm") |
| btn.click(fn=select_reference, inputs=[], outputs=[selected_reference_img, selected_reference]) |
|
|
| upload_reference = gr.Image(label="π€ Upload Reference", type="pil", width=150) |
|
|
| |
| def on_reference_upload(upload_img): |
| if upload_img is not None: |
| return gr.update(value=upload_img, visible=True), "uploaded_reference" |
| return gr.update(), None |
|
|
| upload_reference.change( |
| fn=on_reference_upload, |
| inputs=[upload_reference], |
| outputs=[selected_reference_img, selected_reference], |
| ) |
|
|
| |
| gr.Markdown("### Generate Composite") |
| generate_btn = gr.Button("π Generate", variant="primary") |
| result_slider = gr.ImageSlider(label="Before / After Comparison", height=300) |
| output_img = gr.Image(label="Generated Image", interactive=False, visible=True, height=300) |
|
|
| generate_btn.click( |
| fn=generate_action, |
| inputs=[gr.State(prompt), selected_product, upload_product, selected_reference, upload_reference], |
| outputs=[result_slider, output_img], |
| ) |
|
|
|
|
|
|
| |
| with gr.Blocks( |
| title="AI Product Try-On", |
| css=""" |
| img {border-radius: 8px;} |
| .gr-button {margin-top: 4px;} |
| """, |
| ) as demo: |
| gr.Markdown(""" |
| # π AI Product Try-On |
| Select a product and a reference image (or upload your own). |
| Click **Generate** to visualize how the product appears in the reference environment. |
| """) |
|
|
| with gr.Tabs(): |
| make_tab("faucet", "Faucet") |
| make_tab("pantry", "Pantry Unit") |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|