Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| import pickle | |
| import io | |
| # ============================================================================= | |
| # SETUP | |
| # ============================================================================= | |
| print("Loading model and data...") | |
| # Device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load CLIP model | |
| MODEL_NAME = "openai/clip-vit-base-patch32" | |
| model = CLIPModel.from_pretrained(MODEL_NAME).to(device) | |
| processor = CLIPProcessor.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| print("β CLIP model loaded") | |
| # Load embeddings and metadata | |
| embeddings = np.load("artwork_embeddings.npy") | |
| df = pd.read_csv("artwork_metadata.csv") | |
| EMBEDDINGS_TENSOR = torch.tensor(embeddings).to(device) | |
| print(f"β Loaded {len(embeddings)} embeddings") | |
| # Load pre-saved images | |
| print("Loading images...") | |
| with open('images_data.pkl', 'rb') as f: | |
| images_data = pickle.load(f) | |
| print(f"β Loaded {len(images_data)} images") | |
| def get_image(idx): | |
| """Get PIL image from saved data.""" | |
| img_bytes = images_data[idx] | |
| return Image.open(io.BytesIO(img_bytes)) | |
| # ============================================================================= | |
| # CORE FUNCTIONS | |
| # ============================================================================= | |
| def get_image_embedding(image): | |
| """Convert PIL image to CLIP embedding.""" | |
| image = image.convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = model.get_image_features(**inputs) | |
| features = features / features.norm(dim=-1, keepdim=True) | |
| return features | |
| def get_text_embedding(text): | |
| """Convert text to CLIP embedding.""" | |
| inputs = processor(text=text, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = model.get_text_features(**inputs) | |
| features = features / features.norm(dim=-1, keepdim=True) | |
| return features | |
| def get_recommendations(query_embedding, top_k=5): | |
| """Get top-k similar artworks.""" | |
| query_embedding = query_embedding.to(device) | |
| similarities = torch.mm(query_embedding, EMBEDDINGS_TENSOR.T)[0] | |
| top_scores, top_indices = torch.topk(similarities, top_k) | |
| results = [] | |
| for score, idx in zip(top_scores.cpu().numpy(), top_indices.cpu().numpy()): | |
| artwork_info = df.iloc[idx] | |
| results.append({ | |
| "index": int(idx), | |
| "similarity": float(score), | |
| "artist": artwork_info["artist"], | |
| "genre": artwork_info["genre"], | |
| "style": artwork_info["style"], | |
| "image": get_image(int(idx)) | |
| }) | |
| return results | |
| # ============================================================================= | |
| # GRADIO FUNCTIONS | |
| # ============================================================================= | |
| def recommend_from_text(text_query, num_results=5): | |
| if not text_query.strip(): | |
| return [], "Please enter a description" | |
| query_emb = get_text_embedding(text_query) | |
| recommendations = get_recommendations(query_emb, top_k=int(num_results)) | |
| gallery_images = [] | |
| info_text = f"Results for: \"{text_query}\"\n\n" | |
| for i, rec in enumerate(recommendations): | |
| gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}")) | |
| info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n" | |
| return gallery_images, info_text | |
| def recommend_from_image(image, num_results=5): | |
| if image is None: | |
| return [], "Please upload an image" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| query_emb = get_image_embedding(image) | |
| recommendations = get_recommendations(query_emb, top_k=int(num_results)) | |
| gallery_images = [] | |
| info_text = "Similar artworks found:\n\n" | |
| for i, rec in enumerate(recommendations): | |
| gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}")) | |
| info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n" | |
| return gallery_images, info_text | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| with gr.Blocks(title="WikiArt Recommendation System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¨ WikiArt Artwork Recommendation System | |
| Find similar artworks using AI! You can either: | |
| - **Describe** what you're looking for in text | |
| - **Upload** an image to find similar artworks | |
| *Powered by CLIP embeddings on 15,000 artworks from WikiArt* | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("π€ Search by Description"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Describe the artwork you're looking for", | |
| placeholder="e.g., 'impressionist painting of a garden with flowers'", | |
| lines=3 | |
| ) | |
| text_num_results = gr.Slider( | |
| minimum=1, maximum=10, value=5, step=1, | |
| label="Number of results" | |
| ) | |
| text_btn = gr.Button("π Find Artworks", variant="primary") | |
| with gr.Column(scale=2): | |
| text_gallery = gr.Gallery( | |
| label="Recommended Artworks", | |
| columns=5, | |
| height=400, | |
| object_fit="contain" | |
| ) | |
| text_info = gr.Textbox(label="Details", lines=6) | |
| text_btn.click( | |
| fn=recommend_from_text, | |
| inputs=[text_input, text_num_results], | |
| outputs=[text_gallery, text_info] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["impressionist landscape with water and trees"], | |
| ["dark moody portrait with dramatic lighting"], | |
| ["abstract colorful geometric shapes"], | |
| ["religious painting with angels"], | |
| ["Japanese style artwork with nature"], | |
| ], | |
| inputs=text_input | |
| ) | |
| with gr.TabItem("πΌοΈ Search by Image"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Upload an artwork image", | |
| type="pil" | |
| ) | |
| image_num_results = gr.Slider( | |
| minimum=1, maximum=10, value=5, step=1, | |
| label="Number of results" | |
| ) | |
| image_btn = gr.Button("π Find Similar", variant="primary") | |
| with gr.Column(scale=2): | |
| image_gallery = gr.Gallery( | |
| label="Similar Artworks", | |
| columns=5, | |
| height=400, | |
| object_fit="contain" | |
| ) | |
| image_info = gr.Textbox(label="Details", lines=6) | |
| image_btn.click( | |
| fn=recommend_from_image, | |
| inputs=[image_input, image_num_results], | |
| outputs=[image_gallery, image_info] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### πΉ Project Presentation | |
| """) | |
| gr.HTML(""" | |
| <iframe width="560" height="315" | |
| src="https://www.youtube.com/embed/0vXrQyuLWsA" | |
| title="YouTube video player" | |
| frameborder="0" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
| allowfullscreen> | |
| </iframe> | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| **Dataset:** WikiArt (15,000 artworks) | **Model:** CLIP ViT-B/32 | **Assignment 3 - ML Course** | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |