Shrutikp70 commited on
Commit
0abaaea
Β·
verified Β·
1 Parent(s): a863605

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ from transformers import pipeline, CLIPProcessor, CLIPModel
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from PIL import Image
7
+ import pickle
8
+ import gradio as gr
9
+
10
+ # -------------------------------
11
+ # BOOK RECOMMENDATION SYSTEM CLASS
12
+ # -------------------------------
13
+ class BookRecommendationSystem:
14
+ def __init__(self, csv_path='cleaned_complete_book_dataset.csv',
15
+ image_embeddings_path='image_embeddings.pkl'):
16
+ self.df = None
17
+ self.text_model = None
18
+ self.text_embeddings = None
19
+ self.image_model = None
20
+ self.image_processor = None
21
+ self.image_embeddings = None
22
+ self.image_post_ids = None
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ self.load_text_data(csv_path)
25
+ self.load_image_embeddings(image_embeddings_path)
26
+ self.initialize_text_model()
27
+ self.initialize_image_model()
28
+
29
+ def load_text_data(self, filepath):
30
+ try:
31
+ self.df = pd.read_csv(filepath)
32
+ print(f"Dataset loaded successfully. Shape: {self.df.shape}")
33
+ except Exception as e:
34
+ print(f"Error loading dataset: {e}")
35
+ self.df = pd.DataFrame()
36
+
37
+ def load_image_embeddings(self, embeddings_path):
38
+ try:
39
+ with open(embeddings_path, 'rb') as f:
40
+ data = pickle.load(f)
41
+ self.image_embeddings = data['embeddings']
42
+ self.image_post_ids = data['post_ids']
43
+ print(f"Image embeddings loaded: {len(self.image_post_ids)} posts")
44
+ except Exception as e:
45
+ print(f"Error loading image embeddings: {e}")
46
+ self.image_embeddings = None
47
+ self.image_post_ids = None
48
+
49
+ def initialize_text_model(self):
50
+ if self.text_model is None:
51
+ try:
52
+ self.text_model = pipeline(
53
+ "feature-extraction",
54
+ model="sentence-transformers/all-MiniLM-L6-v2",
55
+ device=self.device
56
+ )
57
+ self._compute_text_embeddings()
58
+ except Exception as e:
59
+ print(f"Error initializing text model: {e}")
60
+
61
+ def initialize_image_model(self):
62
+ if self.image_model is None and self.image_embeddings is not None:
63
+ try:
64
+ self.image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
65
+ self.image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
66
+ except Exception as e:
67
+ print(f"Error initializing image model: {e}")
68
+
69
+ def _compute_text_embeddings(self):
70
+ if self.df.empty:
71
+ return
72
+ self.df['text_for_embedding'] = (
73
+ self.df['description'].fillna('').astype(str) + ' ' +
74
+ self.df['title'].fillna('').astype(str)
75
+ ).str.strip()
76
+ embeddings_list = [
77
+ self.text_model(text, truncation=True, max_length=512)[0][0]
78
+ if text and not text.isspace()
79
+ else np.zeros(384)
80
+ for text in self.df['text_for_embedding']
81
+ ]
82
+ self.text_embeddings = np.array(embeddings_list)
83
+
84
+ def get_text_similarity(self, text_query):
85
+ if self.text_model is None or self.text_embeddings is None:
86
+ return np.zeros(len(self.df))
87
+ try:
88
+ query_out = self.text_model(text_query, truncation=True, max_length=512)
89
+ query_emb = np.array(query_out[0][0]).reshape(1, -1)
90
+ return cosine_similarity(query_emb, self.text_embeddings)[0]
91
+ except:
92
+ return np.zeros(len(self.df))
93
+
94
+ def get_image_similarity(self, user_image):
95
+ if self.image_model is None or self.image_embeddings is None:
96
+ return np.zeros(len(self.df))
97
+ try:
98
+ img = user_image.convert("RGB")
99
+ inputs = self.image_processor(images=img, return_tensors="pt").to(self.device)
100
+ with torch.no_grad():
101
+ user_emb = self.image_model.get_image_features(**inputs)
102
+ user_emb /= user_emb.norm(p=2, dim=-1, keepdim=True)
103
+ user_emb = user_emb.cpu().numpy()
104
+ image_sims = cosine_similarity(user_emb, self.image_embeddings)[0]
105
+
106
+ df_similarities = np.zeros(len(self.df))
107
+ id_to_idx = {post_id: i for i, post_id in enumerate(self.image_post_ids)}
108
+ mask = self.df['id'].isin(id_to_idx)
109
+ indices = self.df.index[mask]
110
+ map_ids = self.df['id'][mask].map(id_to_idx)
111
+ df_similarities[indices] = image_sims[map_ids.values]
112
+ return df_similarities
113
+ except:
114
+ return np.zeros(len(self.df))
115
+
116
+ def recommend_multimodal(self, text_query=None, user_image=None,
117
+ weights=(0.6, 0.4), top_k=5, genre=None):
118
+ if self.df.empty:
119
+ return ["Dataset not loaded."]
120
+ df = self.df.copy()
121
+ if genre:
122
+ df = df[df["genre"].str.lower() == genre.lower()]
123
+ if df.empty:
124
+ return ["No books found for this genre."]
125
+
126
+ text_sim = self.get_text_similarity(text_query) if text_query else np.zeros(len(self.df))
127
+ image_sim = self.get_image_similarity(user_image) if user_image is not None else np.zeros(len(self.df))
128
+ combined_sim = weights[0] * text_sim + weights[1] * image_sim
129
+ df['similarity'] = combined_sim
130
+
131
+ df = df.sort_values("similarity", ascending=False).head(top_k)
132
+ recommendations = []
133
+ for _, row in df.iterrows():
134
+ if pd.notna(row['top_one_book_title']):
135
+ first_title = str(row['top_one_book_title']).split(" and ")[0].split("\n")[0].strip()
136
+ recommendations.append((first_title, row.get("genre", "")))
137
+ return recommendations[:top_k]
138
+
139
+ # -------------------------------
140
+ # INITIALIZE SYSTEM
141
+ # -------------------------------
142
+ recommender = BookRecommendationSystem()
143
+
144
+ # -------------------------------
145
+ # GRADIO UI
146
+ # -------------------------------
147
+ def get_recommendations(text_query, image_input, weight, selected_genre):
148
+ if not text_query.strip():
149
+ text_query = None
150
+ user_image = Image.fromarray(image_input) if image_input is not None else None
151
+ recommendations = recommender.recommend_multimodal(
152
+ text_query=text_query,
153
+ user_image=user_image,
154
+ weights=(weight, 1-weight),
155
+ top_k=5,
156
+ genre=selected_genre
157
+ )
158
+ if not recommendations:
159
+ return "<p style='color:red'>❌ No matching books found. Try a different query or image.</p>"
160
+
161
+ # Create HTML cards
162
+ html = "<div style='display:grid; gap:12px;'>"
163
+ for i, (title, genre) in enumerate(recommendations, start=1):
164
+ genre_html = f"<p style='color:#555; font-size:0.9em; margin:0;'>🎭 Genre: {genre}</p>" if genre else ""
165
+ html += f"""
166
+ <div style="background:#f9fafb; border-radius:10px; padding:12px; box-shadow:0 1px 4px rgba(0,0,0,0.1)">
167
+ <h3 style="margin:0;">πŸ“– {i}. {title}</h3>
168
+ {genre_html}
169
+ </div>
170
+ """
171
+ html += "</div>"
172
+ return html
173
+
174
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
175
+ gr.Markdown(
176
+ "# πŸ“š **BookMatch.AI**\n_Discover your next favorite read using text + image search_"
177
+ )
178
+ with gr.Row():
179
+ with gr.Column(scale=1):
180
+ text_input = gr.Textbox(
181
+ lines=3,
182
+ placeholder="Describe the book vibe (e.g. 'dark fantasy with magic and dragons')",
183
+ label="πŸ”Ž Describe Your Ideal Book"
184
+ )
185
+ image_input = gr.Image(type="numpy", label="πŸ–ΌοΈ Upload an Image for Inspiration (Optional)")
186
+ weight_slider = gr.Slider(0, 1, value=0.6, step=0.05, label="βš–οΈ Text vs Image Weight")
187
+ genre_dropdown = gr.Dropdown(
188
+ choices=sorted(recommender.df['genre'].dropna().unique()) if 'genre' in recommender.df.columns else [],
189
+ label="🎭 Filter by Genre (Optional)",
190
+ value=None
191
+ )
192
+ submit_btn = gr.Button("✨ Get Recommendations", variant="primary")
193
+ with gr.Column(scale=1):
194
+ output_html = gr.HTML(label="🎯 Your Top Matches")
195
+
196
+ gr.Examples(
197
+ examples=[
198
+ ["Dark fantasy adventure with mythical creatures", "https://images.unsplash.com/photo-1528372444006-1bfc81acab02", 0.6, None],
199
+ ["Cozy romance set in a small town cafΓ©", "https://images.unsplash.com/photo-1519681393784-d120267933ba", 0.6, None],
200
+ ["Space opera with political intrigue", "https://images.unsplash.com/photo-1462331940025-496dfbfc7564", 0.6, None],
201
+ ],
202
+ inputs=[text_input, image_input, weight_slider, genre_dropdown]
203
+ )
204
+
205
+ submit_btn.click(
206
+ fn=get_recommendations,
207
+ inputs=[text_input, image_input, weight_slider, genre_dropdown],
208
+ outputs=output_html
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ iface.launch()