guyinbal's picture
Upload app.py
c78482c verified
import gradio as gr
from PIL import Image
import os
import torch
import numpy as np
from transformers import ViTFeatureExtractor, ViTModel
from sklearn.neighbors import NearestNeighbors
# --- Load model and extractor ---
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# --- Helper: extract embedding from image ---
def get_embedding(img):
inputs = extractor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state[:, 0, :].squeeze().numpy()
# --- Load image paths ---
dataset_image_paths = sorted([f for f in os.listdir() if f.startswith("image") and f.endswith(".jpg")])
generated_image_paths = sorted([f for f in os.listdir() if f.startswith("generated_custom") and f.endswith(".jpg")])
# --- Compute embeddings ---
def compute_embeddings(paths):
embeddings = []
for path in paths:
img = Image.open(path).convert("RGB")
emb = get_embedding(img)
embeddings.append(emb)
return np.vstack(embeddings)
dataset_embeddings = compute_embeddings(dataset_image_paths)
generated_embeddings = compute_embeddings(generated_image_paths)
# --- Fit Nearest Neighbors models ---
nn_dataset = NearestNeighbors(n_neighbors=3, metric='cosine').fit(dataset_embeddings)
nn_generated = NearestNeighbors(n_neighbors=1, metric='cosine').fit(generated_embeddings)
# --- Recommendation function ---
def recommend_image(user_image):
emb = get_embedding(user_image)
# Find 3 most similar real images (Van Gogh)
_, dataset_indices = nn_dataset.kneighbors([emb])
recommended_paths = [dataset_image_paths[i] for i in dataset_indices[0]]
# Find 1 most similar generated image
_, gen_indices = nn_generated.kneighbors([emb])
generated_path = generated_image_paths[gen_indices[0][0]]
generated_image = Image.open(generated_path)
return user_image, [Image.open(p) for p in recommended_paths], generated_image
# --- UI ---
example_images = [
"image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg", "image5.jpg"
]
with gr.Blocks(title="Van Gogh Style Image Recommendation") as demo:
gr.Markdown("## 🎨 Van Gogh Style Image Recommendation")
gr.Markdown("""Upload a painting-style image and get:
- πŸ–ΌοΈ The uploaded image preview
- 🎯 Top 3 similar real Van Gogh paintings
- 🧠 1 AI-generated painting that best matches your input""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="πŸ–ΌοΈ Upload an Image")
examples = gr.Dataset(components=[input_image],
samples=[[Image.open(p)] for p in example_images],
label="✨ 1-Click Example Images")
submit_btn = gr.Button("πŸ” Recommend", variant="primary")
clear_btn = gr.Button("❌ Clear")
with gr.Column():
output_user_image = gr.Image(label="πŸ“₯ Your Image")
output_similar = gr.Gallery(label="🎯 Top 3 Real Paintings", columns=3)
output_generated = gr.Image(label="🧠 Generated AI Painting")
submit_btn.click(fn=recommend_image,
inputs=input_image,
outputs=[output_user_image, output_similar, output_generated])
clear_btn.click(fn=lambda: (None, None, None),
inputs=[],
outputs=[output_user_image, output_similar, output_generated])
examples.click(fn=recommend_image,
inputs=input_image,
outputs=[output_user_image, output_similar, output_generated])
demo.launch()