import gradio as gr from PIL import Image import os import torch import numpy as np from transformers import ViTFeatureExtractor, ViTModel from sklearn.neighbors import NearestNeighbors # --- Load model and extractor --- model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") # --- Helper: extract embedding from image --- def get_embedding(img): inputs = extractor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) return outputs.last_hidden_state[:, 0, :].squeeze().numpy() # --- Load image paths --- dataset_image_paths = sorted([f for f in os.listdir() if f.startswith("image") and f.endswith(".jpg")]) generated_image_paths = sorted([f for f in os.listdir() if f.startswith("generated_custom") and f.endswith(".jpg")]) # --- Compute embeddings --- def compute_embeddings(paths): embeddings = [] for path in paths: img = Image.open(path).convert("RGB") emb = get_embedding(img) embeddings.append(emb) return np.vstack(embeddings) dataset_embeddings = compute_embeddings(dataset_image_paths) generated_embeddings = compute_embeddings(generated_image_paths) # --- Fit Nearest Neighbors models --- nn_dataset = NearestNeighbors(n_neighbors=3, metric='cosine').fit(dataset_embeddings) nn_generated = NearestNeighbors(n_neighbors=1, metric='cosine').fit(generated_embeddings) # --- Recommendation function --- def recommend_image(user_image): emb = get_embedding(user_image) # Find 3 most similar real images (Van Gogh) _, dataset_indices = nn_dataset.kneighbors([emb]) recommended_paths = [dataset_image_paths[i] for i in dataset_indices[0]] # Find 1 most similar generated image _, gen_indices = nn_generated.kneighbors([emb]) generated_path = generated_image_paths[gen_indices[0][0]] generated_image = Image.open(generated_path) return user_image, [Image.open(p) for p in recommended_paths], generated_image # --- UI --- example_images = [ "image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg", "image5.jpg" ] with gr.Blocks(title="Van Gogh Style Image Recommendation") as demo: gr.Markdown("## 🎨 Van Gogh Style Image Recommendation") gr.Markdown("""Upload a painting-style image and get: - 🖼️ The uploaded image preview - 🎯 Top 3 similar real Van Gogh paintings - 🧠 1 AI-generated painting that best matches your input""") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="🖼️ Upload an Image") examples = gr.Dataset(components=[input_image], samples=[[Image.open(p)] for p in example_images], label="✨ 1-Click Example Images") submit_btn = gr.Button("🔍 Recommend", variant="primary") clear_btn = gr.Button("❌ Clear") with gr.Column(): output_user_image = gr.Image(label="📥 Your Image") output_similar = gr.Gallery(label="🎯 Top 3 Real Paintings", columns=3) output_generated = gr.Image(label="🧠 Generated AI Painting") submit_btn.click(fn=recommend_image, inputs=input_image, outputs=[output_user_image, output_similar, output_generated]) clear_btn.click(fn=lambda: (None, None, None), inputs=[], outputs=[output_user_image, output_similar, output_generated]) examples.click(fn=recommend_image, inputs=input_image, outputs=[output_user_image, output_similar, output_generated]) demo.launch()