import torch import numpy as np import pandas as pd from PIL import Image import gradio as gr from transformers import CLIPModel, CLIPProcessor from datasets import load_dataset # --- Device & model --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(MODEL_NAME).to(DEVICE) processor = CLIPProcessor.from_pretrained(MODEL_NAME) model.eval() # --- Load dataset & embeddings --- dataset = load_dataset("JotDe/birds", split="train") embedding_df = pd.read_parquet("bird_image_embeddings.parquet") embedding_vectors = embedding_df.filter(like="emb_").values # --- Scene detection --- def detect_scene(image: Image.Image): image = image.convert("RGB") arr = np.array(image) / 255.0 mean_r = arr[:, :, 0].mean() mean_g = arr[:, :, 1].mean() mean_b = arr[:, :, 2].mean() if mean_b > 0.45 and mean_b > mean_g and mean_b > mean_r: return "sky" elif mean_g > 0.40 and mean_g > mean_r: return "forest" elif mean_b > 0.35 and mean_g < 0.35: return "water" else: return "mixed" # --- Embed user image --- def embed_user_image(image: Image.Image): image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): features = model.get_image_features(**inputs) features = features / features.norm(dim=-1, keepdim=True) return features.cpu().numpy().flatten() # --- Recommendation function --- def recommend_images(image: Image.Image): user_emb = embed_user_image(image) scene = detect_scene(image) # Add scene column if missing if "scene" not in embedding_df.columns: scenes = [] for i in range(len(embedding_df)): img = dataset[int(embedding_df.iloc[i]["row_in_subset"])]["image"] scenes.append(detect_scene(img)) embedding_df["scene"] = scenes # Filter by scene filtered_df = embedding_df[embedding_df["scene"] == scene] # Fallback if too few images if len(filtered_df) < 10: filtered_df = embedding_df filtered_vectors = filtered_df.filter(like="emb_").values similarities = filtered_vectors @ user_emb top_indices = np.argsort(similarities)[-3:][::-1] results = [] for idx in top_indices: row = filtered_df.iloc[int(idx)] img = dataset[int(row["row_in_subset"])]["image"] results.append(img) return results # --- Gradio interface --- interface = gr.Interface( fn=recommend_images, inputs=gr.Image(type="pil", label="Upload a bird image"), outputs=[ gr.Image(label="Recommendation 1"), gr.Image(label="Recommendation 2"), gr.Image(label="Recommendation 3"), ], title="Bird Image Recommendation System", description="Scene-aware bird image recommendations using CLIP embeddings." ) interface.launch() import gradio as gr with gr.Blocks() as demo: gr.HTML(""" """) demo.launch()