Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import torch | |
| import numpy as np | |
| from transformers import ViTFeatureExtractor, ViTModel | |
| from sklearn.neighbors import NearestNeighbors | |
| # --- Load model and extractor --- | |
| model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") | |
| extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| # --- Helper: extract embedding from image --- | |
| def get_embedding(img): | |
| inputs = extractor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.last_hidden_state[:, 0, :].squeeze().numpy() | |
| # --- Load image paths --- | |
| dataset_image_paths = sorted([f for f in os.listdir() if f.startswith("image") and f.endswith(".jpg")]) | |
| generated_image_paths = sorted([f for f in os.listdir() if f.startswith("generated_custom") and f.endswith(".jpg")]) | |
| # --- Compute embeddings --- | |
| def compute_embeddings(paths): | |
| embeddings = [] | |
| for path in paths: | |
| img = Image.open(path).convert("RGB") | |
| emb = get_embedding(img) | |
| embeddings.append(emb) | |
| return np.vstack(embeddings) | |
| dataset_embeddings = compute_embeddings(dataset_image_paths) | |
| generated_embeddings = compute_embeddings(generated_image_paths) | |
| # --- Fit Nearest Neighbors models --- | |
| nn_dataset = NearestNeighbors(n_neighbors=3, metric='cosine').fit(dataset_embeddings) | |
| nn_generated = NearestNeighbors(n_neighbors=1, metric='cosine').fit(generated_embeddings) | |
| # --- Recommendation function --- | |
| def recommend_image(user_image): | |
| emb = get_embedding(user_image) | |
| # Find 3 most similar real images (Van Gogh) | |
| _, dataset_indices = nn_dataset.kneighbors([emb]) | |
| recommended_paths = [dataset_image_paths[i] for i in dataset_indices[0]] | |
| # Find 1 most similar generated image | |
| _, gen_indices = nn_generated.kneighbors([emb]) | |
| generated_path = generated_image_paths[gen_indices[0][0]] | |
| generated_image = Image.open(generated_path) | |
| return user_image, [Image.open(p) for p in recommended_paths], generated_image | |
| # --- UI --- | |
| example_images = [ | |
| "image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg", "image5.jpg" | |
| ] | |
| with gr.Blocks(title="Van Gogh Style Image Recommendation") as demo: | |
| gr.Markdown("## π¨ Van Gogh Style Image Recommendation") | |
| gr.Markdown("""Upload a painting-style image and get: | |
| - πΌοΈ The uploaded image preview | |
| - π― Top 3 similar real Van Gogh paintings | |
| - π§ 1 AI-generated painting that best matches your input""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="πΌοΈ Upload an Image") | |
| examples = gr.Dataset(components=[input_image], | |
| samples=[[Image.open(p)] for p in example_images], | |
| label="β¨ 1-Click Example Images") | |
| submit_btn = gr.Button("π Recommend", variant="primary") | |
| clear_btn = gr.Button("β Clear") | |
| with gr.Column(): | |
| output_user_image = gr.Image(label="π₯ Your Image") | |
| output_similar = gr.Gallery(label="π― Top 3 Real Paintings", columns=3) | |
| output_generated = gr.Image(label="π§ Generated AI Painting") | |
| submit_btn.click(fn=recommend_image, | |
| inputs=input_image, | |
| outputs=[output_user_image, output_similar, output_generated]) | |
| clear_btn.click(fn=lambda: (None, None, None), | |
| inputs=[], | |
| outputs=[output_user_image, output_similar, output_generated]) | |
| examples.click(fn=recommend_image, | |
| inputs=input_image, | |
| outputs=[output_user_image, output_similar, output_generated]) | |
| demo.launch() |