KalsusEvening commited on
Commit
2a99ba2
·
verified ·
1 Parent(s): f7da9dc

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, CLIPModel
9
+ from datasets import load_dataset
10
+ import random
11
+
12
+ # =============================================================================
13
+ # SETUP
14
+ # =============================================================================
15
+
16
+ print("Loading model and data...")
17
+
18
+ # Device
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {device}")
21
+
22
+ # Load CLIP model
23
+ MODEL_NAME = "openai/clip-vit-base-patch32"
24
+ model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
25
+ processor = CLIPProcessor.from_pretrained(MODEL_NAME)
26
+ model.eval()
27
+ print("✓ CLIP model loaded")
28
+
29
+ # Load embeddings and metadata
30
+ embeddings = np.load("artwork_embeddings.npy")
31
+ df = pd.read_csv("artwork_metadata.csv")
32
+ EMBEDDINGS_TENSOR = torch.tensor(embeddings).to(device)
33
+ print(f"✓ Loaded {len(embeddings)} embeddings")
34
+
35
+ # Load dataset for images
36
+ print("Loading WikiArt dataset (this may take a moment)...")
37
+ full_dataset = load_dataset("huggan/wikiart", split="train")
38
+ sample_indices = np.load("sample_indices.npy")
39
+ dataset = full_dataset.select(sample_indices.tolist())
40
+ print(f"✓ Dataset loaded: {len(dataset)} artworks")
41
+
42
+ # =============================================================================
43
+ # CORE FUNCTIONS
44
+ # =============================================================================
45
+
46
+ def get_image_embedding(image):
47
+ """Convert PIL image to CLIP embedding."""
48
+ image = image.convert("RGB")
49
+ inputs = processor(images=image, return_tensors="pt", padding=True)
50
+ inputs = {k: v.to(device) for k, v in inputs.items()}
51
+ with torch.no_grad():
52
+ features = model.get_image_features(**inputs)
53
+ features = features / features.norm(dim=-1, keepdim=True)
54
+ return features
55
+
56
+ def get_text_embedding(text):
57
+ """Convert text to CLIP embedding."""
58
+ inputs = processor(text=text, return_tensors="pt", padding=True)
59
+ inputs = {k: v.to(device) for k, v in inputs.items()}
60
+ with torch.no_grad():
61
+ features = model.get_text_features(**inputs)
62
+ features = features / features.norm(dim=-1, keepdim=True)
63
+ return features
64
+
65
+ def get_recommendations(query_embedding, top_k=5):
66
+ """Get top-k similar artworks."""
67
+ query_embedding = query_embedding.to(device)
68
+ similarities = torch.mm(query_embedding, EMBEDDINGS_TENSOR.T)[0]
69
+ top_scores, top_indices = torch.topk(similarities, top_k)
70
+
71
+ results = []
72
+ for score, idx in zip(top_scores.cpu().numpy(), top_indices.cpu().numpy()):
73
+ artwork_info = df.iloc[idx]
74
+ results.append({
75
+ "index": int(idx),
76
+ "similarity": float(score),
77
+ "artist": artwork_info["artist"],
78
+ "genre": artwork_info["genre"],
79
+ "style": artwork_info["style"],
80
+ "image": dataset[int(idx)]["image"]
81
+ })
82
+ return results
83
+
84
+ # =============================================================================
85
+ # GRADIO FUNCTIONS
86
+ # =============================================================================
87
+
88
+ def recommend_from_text(text_query, num_results=5):
89
+ if not text_query.strip():
90
+ return [], "Please enter a description"
91
+
92
+ query_emb = get_text_embedding(text_query)
93
+ recommendations = get_recommendations(query_emb, top_k=int(num_results))
94
+
95
+ gallery_images = []
96
+ info_text = f"Results for: \"{text_query}\"\n\n"
97
+
98
+ for i, rec in enumerate(recommendations):
99
+ gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}"))
100
+ info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n"
101
+
102
+ return gallery_images, info_text
103
+
104
+ def recommend_from_image(image, num_results=5):
105
+ if image is None:
106
+ return [], "Please upload an image"
107
+
108
+ if not isinstance(image, Image.Image):
109
+ image = Image.fromarray(image)
110
+
111
+ query_emb = get_image_embedding(image)
112
+ recommendations = get_recommendations(query_emb, top_k=int(num_results))
113
+
114
+ gallery_images = []
115
+ info_text = "Similar artworks found:\n\n"
116
+
117
+ for i, rec in enumerate(recommendations):
118
+ gallery_images.append((rec["image"], f"{rec['style']} | {rec['artist'][:20]}"))
119
+ info_text += f"{i+1}. {rec['style']} by {rec['artist']} (Score: {rec['similarity']:.3f})\n"
120
+
121
+ return gallery_images, info_text
122
+
123
+ # =============================================================================
124
+ # GRADIO INTERFACE
125
+ # =============================================================================
126
+
127
+ with gr.Blocks(title="WikiArt Recommendation System", theme=gr.themes.Soft()) as demo:
128
+
129
+ gr.Markdown("""
130
+ # 🎨 WikiArt Artwork Recommendation System
131
+
132
+ Find similar artworks using AI! You can either:
133
+ - **Describe** what you're looking for in text
134
+ - **Upload** an image to find similar artworks
135
+
136
+ *Powered by CLIP embeddings on 15,000 artworks from WikiArt*
137
+ """)
138
+
139
+ with gr.Tabs():
140
+ with gr.TabItem("🔤 Search by Description"):
141
+ with gr.Row():
142
+ with gr.Column(scale=1):
143
+ text_input = gr.Textbox(
144
+ label="Describe the artwork you're looking for",
145
+ placeholder="e.g., 'impressionist painting of a garden with flowers'",
146
+ lines=3
147
+ )
148
+ text_num_results = gr.Slider(
149
+ minimum=1, maximum=10, value=5, step=1,
150
+ label="Number of results"
151
+ )
152
+ text_btn = gr.Button("🔍 Find Artworks", variant="primary")
153
+
154
+ with gr.Column(scale=2):
155
+ text_gallery = gr.Gallery(
156
+ label="Recommended Artworks",
157
+ columns=5,
158
+ height=400,
159
+ object_fit="contain"
160
+ )
161
+
162
+ text_info = gr.Textbox(label="Details", lines=6)
163
+
164
+ text_btn.click(
165
+ fn=recommend_from_text,
166
+ inputs=[text_input, text_num_results],
167
+ outputs=[text_gallery, text_info]
168
+ )
169
+
170
+ gr.Examples(
171
+ examples=[
172
+ ["impressionist landscape with water and trees"],
173
+ ["dark moody portrait with dramatic lighting"],
174
+ ["abstract colorful geometric shapes"],
175
+ ["religious painting with angels"],
176
+ ["Japanese style artwork with nature"],
177
+ ],
178
+ inputs=text_input
179
+ )
180
+
181
+ with gr.TabItem("🖼️ Search by Image"):
182
+ with gr.Row():
183
+ with gr.Column(scale=1):
184
+ image_input = gr.Image(
185
+ label="Upload an artwork image",
186
+ type="pil"
187
+ )
188
+ image_num_results = gr.Slider(
189
+ minimum=1, maximum=10, value=5, step=1,
190
+ label="Number of results"
191
+ )
192
+ image_btn = gr.Button("🔍 Find Similar", variant="primary")
193
+
194
+ with gr.Column(scale=2):
195
+ image_gallery = gr.Gallery(
196
+ label="Similar Artworks",
197
+ columns=5,
198
+ height=400,
199
+ object_fit="contain"
200
+ )
201
+
202
+ image_info = gr.Textbox(label="Details", lines=6)
203
+
204
+ image_btn.click(
205
+ fn=recommend_from_image,
206
+ inputs=[image_input, image_num_results],
207
+ outputs=[image_gallery, image_info]
208
+ )
209
+
210
+ gr.Markdown("""
211
+ ---
212
+ **Dataset:** WikiArt (15,000 artworks) | **Model:** CLIP ViT-B/32 | **Assignment 3 - ML Course**
213
+ """)
214
+
215
+ if __name__ == "__main__":
216
+ demo.launch()
artwork_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:599d408174194866e68dda6775c012c7360e4fd39f35a79d52a45869f94d0c72
3
+ size 30720128
artwork_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch
3
+ transformers
4
+ datasets
5
+ numpy
6
+ pandas
7
+ Pillow
sample_indices.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:927c21f714b4d8807380dcf7b9ca1b1d919859d15b5ed1274607a337a64f9153
3
+ size 120128