import os import numpy as np import pandas as pd import torch import gradio as gr from PIL import Image from transformers import CLIPProcessor, CLIPModel import pickle import io # ============================================================================= # SETUP # ============================================================================= print("Loading model and data...") # Device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load CLIP model MODEL_NAME = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(MODEL_NAME).to(device) processor = CLIPProcessor.from_pretrained(MODEL_NAME) model.eval() print("✓ CLIP model loaded") # Load embeddings and metadata embeddings = np.load("artwork_embeddings.npy") df = pd.read_csv("artwork_metadata.csv") EMBEDDINGS_TENSOR = torch.tensor(embeddings).to(device) print(f"✓ Loaded {len(embeddings)} embeddings") # Load pre-saved images print("Loading images...") with open('images_data.pkl', 'rb') as f: images_data = pickle.load(f) print(f"✓ Loaded {len(images_data)} images") def get_image(idx): """Get PIL image from saved data.""" img_bytes = images_data[idx] return Image.open(io.BytesIO(img_bytes)) # ============================================================================= # CORE FUNCTIONS # ============================================================================= def get_image_embedding(image): """Convert PIL image to CLIP embedding.""" image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt", padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): features = model.get_image_features(**inputs) features = features / features.norm(dim=-1, keepdim=True) return features def get_text_embedding(text): """Convert text to CLIP embedding.""" inputs = processor(text=text, return_tensors="pt", padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): features = model.get_text_features(**inputs) features = features / features.norm(dim=-1, keepdim=True) return features def get_recommendations(query_embedding, top_k=5): """Get top-k similar artworks.""" query_embedding = query_embedding.to(device) similarities = torch.mm(query_embedding, EMBEDDINGS_TENSOR.T)[0] top_scores, top_indices = torch.topk(similarities, top_k) results = [] for score, idx in zip(top_scores.cpu().numpy(), top_indices.cpu().numpy()): artwork_info = df.iloc[idx] results.append({ "index": int(idx), "similarity": float(score), "artist": artwork_info["artist"], "genre": artwork_info["genre"], "style": artwork_info["style"], "image": get_image(int(idx)) }) return results # ============================================================================= # GRADIO FUNCTIONS # ============================================================================= def recommend_from_text(text_query, num_results=5): if not text_query.strip(): return [], "Please enter a description" query_emb = get_text_embedding(text_query) recommendations = get_recommendations(query_emb, top_k=int(num_results)) gallery_images = [] info_text = f"Results for: \"{text_query}\"\n\n" for i, rec in enumerate(recommendations): gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}")) info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n" return gallery_images, info_text def recommend_from_image(image, num_results=5): if image is None: return [], "Please upload an image" if not isinstance(image, Image.Image): image = Image.fromarray(image) query_emb = get_image_embedding(image) recommendations = get_recommendations(query_emb, top_k=int(num_results)) gallery_images = [] info_text = "Similar artworks found:\n\n" for i, rec in enumerate(recommendations): gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}")) info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n" return gallery_images, info_text # ============================================================================= # GRADIO INTERFACE # ============================================================================= with gr.Blocks(title="WikiArt Recommendation System", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎨 WikiArt Artwork Recommendation System Find similar artworks using AI! You can either: - **Describe** what you're looking for in text - **Upload** an image to find similar artworks *Powered by CLIP embeddings on 15,000 artworks from WikiArt* """) with gr.Tabs(): with gr.TabItem("🔤 Search by Description"): with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="Describe the artwork you're looking for", placeholder="e.g., 'impressionist painting of a garden with flowers'", lines=3 ) text_num_results = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of results" ) text_btn = gr.Button("🔍 Find Artworks", variant="primary") with gr.Column(scale=2): text_gallery = gr.Gallery( label="Recommended Artworks", columns=5, height=400, object_fit="contain" ) text_info = gr.Textbox(label="Details", lines=6) text_btn.click( fn=recommend_from_text, inputs=[text_input, text_num_results], outputs=[text_gallery, text_info] ) gr.Examples( examples=[ ["impressionist landscape with water and trees"], ["dark moody portrait with dramatic lighting"], ["abstract colorful geometric shapes"], ["religious painting with angels"], ["Japanese style artwork with nature"], ], inputs=text_input ) with gr.TabItem("🖼️ Search by Image"): with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="Upload an artwork image", type="pil" ) image_num_results = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of results" ) image_btn = gr.Button("🔍 Find Similar", variant="primary") with gr.Column(scale=2): image_gallery = gr.Gallery( label="Similar Artworks", columns=5, height=400, object_fit="contain" ) image_info = gr.Textbox(label="Details", lines=6) image_btn.click( fn=recommend_from_image, inputs=[image_input, image_num_results], outputs=[image_gallery, image_info] ) gr.Markdown(""" --- ### 📹 Project Presentation """) gr.HTML(""" """) gr.Markdown(""" --- **Dataset:** WikiArt (15,000 artworks) | **Model:** CLIP ViT-B/32 | **Assignment 3 - ML Course** """) if __name__ == "__main__": demo.launch()