KalsusEvening's picture
Update app.py
aaa7f14 verified
import os
import numpy as np
import pandas as pd
import torch
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import pickle
import io
# =============================================================================
# SETUP
# =============================================================================
print("Loading model and data...")
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load CLIP model
MODEL_NAME = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
model.eval()
print("βœ“ CLIP model loaded")
# Load embeddings and metadata
embeddings = np.load("artwork_embeddings.npy")
df = pd.read_csv("artwork_metadata.csv")
EMBEDDINGS_TENSOR = torch.tensor(embeddings).to(device)
print(f"βœ“ Loaded {len(embeddings)} embeddings")
# Load pre-saved images
print("Loading images...")
with open('images_data.pkl', 'rb') as f:
images_data = pickle.load(f)
print(f"βœ“ Loaded {len(images_data)} images")
def get_image(idx):
"""Get PIL image from saved data."""
img_bytes = images_data[idx]
return Image.open(io.BytesIO(img_bytes))
# =============================================================================
# CORE FUNCTIONS
# =============================================================================
def get_image_embedding(image):
"""Convert PIL image to CLIP embedding."""
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
features = model.get_image_features(**inputs)
features = features / features.norm(dim=-1, keepdim=True)
return features
def get_text_embedding(text):
"""Convert text to CLIP embedding."""
inputs = processor(text=text, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
features = model.get_text_features(**inputs)
features = features / features.norm(dim=-1, keepdim=True)
return features
def get_recommendations(query_embedding, top_k=5):
"""Get top-k similar artworks."""
query_embedding = query_embedding.to(device)
similarities = torch.mm(query_embedding, EMBEDDINGS_TENSOR.T)[0]
top_scores, top_indices = torch.topk(similarities, top_k)
results = []
for score, idx in zip(top_scores.cpu().numpy(), top_indices.cpu().numpy()):
artwork_info = df.iloc[idx]
results.append({
"index": int(idx),
"similarity": float(score),
"artist": artwork_info["artist"],
"genre": artwork_info["genre"],
"style": artwork_info["style"],
"image": get_image(int(idx))
})
return results
# =============================================================================
# GRADIO FUNCTIONS
# =============================================================================
def recommend_from_text(text_query, num_results=5):
if not text_query.strip():
return [], "Please enter a description"
query_emb = get_text_embedding(text_query)
recommendations = get_recommendations(query_emb, top_k=int(num_results))
gallery_images = []
info_text = f"Results for: \"{text_query}\"\n\n"
for i, rec in enumerate(recommendations):
gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}"))
info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n"
return gallery_images, info_text
def recommend_from_image(image, num_results=5):
if image is None:
return [], "Please upload an image"
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
query_emb = get_image_embedding(image)
recommendations = get_recommendations(query_emb, top_k=int(num_results))
gallery_images = []
info_text = "Similar artworks found:\n\n"
for i, rec in enumerate(recommendations):
gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}"))
info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n"
return gallery_images, info_text
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
with gr.Blocks(title="WikiArt Recommendation System", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 WikiArt Artwork Recommendation System
Find similar artworks using AI! You can either:
- **Describe** what you're looking for in text
- **Upload** an image to find similar artworks
*Powered by CLIP embeddings on 15,000 artworks from WikiArt*
""")
with gr.Tabs():
with gr.TabItem("πŸ”€ Search by Description"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Describe the artwork you're looking for",
placeholder="e.g., 'impressionist painting of a garden with flowers'",
lines=3
)
text_num_results = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Number of results"
)
text_btn = gr.Button("πŸ” Find Artworks", variant="primary")
with gr.Column(scale=2):
text_gallery = gr.Gallery(
label="Recommended Artworks",
columns=5,
height=400,
object_fit="contain"
)
text_info = gr.Textbox(label="Details", lines=6)
text_btn.click(
fn=recommend_from_text,
inputs=[text_input, text_num_results],
outputs=[text_gallery, text_info]
)
gr.Examples(
examples=[
["impressionist landscape with water and trees"],
["dark moody portrait with dramatic lighting"],
["abstract colorful geometric shapes"],
["religious painting with angels"],
["Japanese style artwork with nature"],
],
inputs=text_input
)
with gr.TabItem("πŸ–ΌοΈ Search by Image"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload an artwork image",
type="pil"
)
image_num_results = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Number of results"
)
image_btn = gr.Button("πŸ” Find Similar", variant="primary")
with gr.Column(scale=2):
image_gallery = gr.Gallery(
label="Similar Artworks",
columns=5,
height=400,
object_fit="contain"
)
image_info = gr.Textbox(label="Details", lines=6)
image_btn.click(
fn=recommend_from_image,
inputs=[image_input, image_num_results],
outputs=[image_gallery, image_info]
)
gr.Markdown("""
---
### πŸ“Ή Project Presentation
""")
gr.HTML("""
<iframe width="560" height="315"
src="https://www.youtube.com/embed/0vXrQyuLWsA"
title="YouTube video player"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
""")
gr.Markdown("""
---
**Dataset:** WikiArt (15,000 artworks) | **Model:** CLIP ViT-B/32 | **Assignment 3 - ML Course**
""")
if __name__ == "__main__":
demo.launch()