Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import pipeline, CLIPProcessor, CLIPModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from PIL import Image | |
| import pickle | |
| import gradio as gr | |
| # ------------------------------- | |
| # BOOK RECOMMENDATION SYSTEM CLASS | |
| # ------------------------------- | |
| class BookRecommendationSystem: | |
| def __init__(self, csv_path='cleaned_complete_book_dataset.csv', | |
| image_embeddings_path='image_embeddings.pkl'): | |
| self.df = None | |
| self.text_model = None | |
| self.text_embeddings = None | |
| self.image_model = None | |
| self.image_processor = None | |
| self.image_embeddings = None | |
| self.image_post_ids = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.load_text_data(csv_path) | |
| self.load_image_embeddings(image_embeddings_path) | |
| self.initialize_text_model() | |
| self.initialize_image_model() | |
| def load_text_data(self, filepath): | |
| try: | |
| self.df = pd.read_csv(filepath) | |
| print(f"Dataset loaded successfully. Shape: {self.df.shape}") | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| self.df = pd.DataFrame() | |
| def load_image_embeddings(self, embeddings_path): | |
| try: | |
| with open(embeddings_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.image_embeddings = data['embeddings'] | |
| self.image_post_ids = data['post_ids'] | |
| print(f"Image embeddings loaded: {len(self.image_post_ids)} posts") | |
| except Exception as e: | |
| print(f"Error loading image embeddings: {e}") | |
| self.image_embeddings = None | |
| self.image_post_ids = None | |
| def initialize_text_model(self): | |
| if self.text_model is None: | |
| try: | |
| self.text_model = pipeline( | |
| "feature-extraction", | |
| model="sentence-transformers/all-MiniLM-L6-v2", | |
| device=self.device | |
| ) | |
| self._compute_text_embeddings() | |
| except Exception as e: | |
| print(f"Error initializing text model: {e}") | |
| def initialize_image_model(self): | |
| if self.image_model is None and self.image_embeddings is not None: | |
| try: | |
| self.image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self.image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
| except Exception as e: | |
| print(f"Error initializing image model: {e}") | |
| def _compute_text_embeddings(self): | |
| if self.df.empty: | |
| return | |
| self.df['text_for_embedding'] = ( | |
| self.df['description'].fillna('').astype(str) + ' ' + | |
| self.df['title'].fillna('').astype(str) | |
| ).str.strip() | |
| embeddings_list = [ | |
| self.text_model(text, truncation=True, max_length=512)[0][0] | |
| if text and not text.isspace() | |
| else np.zeros(384) | |
| for text in self.df['text_for_embedding'] | |
| ] | |
| self.text_embeddings = np.array(embeddings_list) | |
| def get_text_similarity(self, text_query): | |
| if self.text_model is None or self.text_embeddings is None: | |
| return np.zeros(len(self.df)) | |
| try: | |
| query_out = self.text_model(text_query, truncation=True, max_length=512) | |
| query_emb = np.array(query_out[0][0]).reshape(1, -1) | |
| return cosine_similarity(query_emb, self.text_embeddings)[0] | |
| except: | |
| return np.zeros(len(self.df)) | |
| def get_image_similarity(self, user_image): | |
| if self.image_model is None or self.image_embeddings is None: | |
| return np.zeros(len(self.df)) | |
| try: | |
| img = user_image.convert("RGB") | |
| inputs = self.image_processor(images=img, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| user_emb = self.image_model.get_image_features(**inputs) | |
| user_emb /= user_emb.norm(p=2, dim=-1, keepdim=True) | |
| user_emb = user_emb.cpu().numpy() | |
| image_sims = cosine_similarity(user_emb, self.image_embeddings)[0] | |
| df_similarities = np.zeros(len(self.df)) | |
| id_to_idx = {post_id: i for i, post_id in enumerate(self.image_post_ids)} | |
| mask = self.df['id'].isin(id_to_idx) | |
| indices = self.df.index[mask] | |
| map_ids = self.df['id'][mask].map(id_to_idx) | |
| df_similarities[indices] = image_sims[map_ids.values] | |
| return df_similarities | |
| except: | |
| return np.zeros(len(self.df)) | |
| def recommend_multimodal(self, text_query=None, user_image=None, | |
| weights=(0.6, 0.4), top_k=5, genre=None): | |
| if self.df.empty: | |
| return ["Dataset not loaded."] | |
| df = self.df.copy() | |
| if genre: | |
| df = df[df["genre"].str.lower() == genre.lower()] | |
| if df.empty: | |
| return ["No books found for this genre."] | |
| text_sim = self.get_text_similarity(text_query) if text_query else np.zeros(len(self.df)) | |
| image_sim = self.get_image_similarity(user_image) if user_image is not None else np.zeros(len(self.df)) | |
| combined_sim = weights[0] * text_sim + weights[1] * image_sim | |
| df['similarity'] = combined_sim | |
| df = df.sort_values("similarity", ascending=False).head(top_k) | |
| recommendations = [] | |
| for _, row in df.iterrows(): | |
| if pd.notna(row['top_one_book_title']): | |
| first_title = str(row['top_one_book_title']).split(" and ")[0].split("\n")[0].strip() | |
| recommendations.append((first_title, row.get("genre", ""))) | |
| return recommendations[:top_k] | |
| # ------------------------------- | |
| # INITIALIZE SYSTEM | |
| # ------------------------------- | |
| recommender = BookRecommendationSystem() | |
| # ------------------------------- | |
| # GRADIO UI | |
| # ------------------------------- | |
| def get_recommendations(text_query, image_input, weight, selected_genre): | |
| if not text_query.strip(): | |
| text_query = None | |
| user_image = Image.fromarray(image_input) if image_input is not None else None | |
| recommendations = recommender.recommend_multimodal( | |
| text_query=text_query, | |
| user_image=user_image, | |
| weights=(weight, 1-weight), | |
| top_k=5, | |
| genre=selected_genre | |
| ) | |
| if not recommendations: | |
| return "<p style='color:red'>β No matching books found. Try a different query or image.</p>" | |
| # Create HTML cards | |
| html = "<div style='display:grid; gap:12px;'>" | |
| for i, (title, genre) in enumerate(recommendations, start=1): | |
| genre_html = f"<p style='color:#555; font-size:0.9em; margin:0;'>π Genre: {genre}</p>" if genre else "" | |
| html += f""" | |
| <div style="background:#f9fafb; border-radius:10px; padding:12px; box-shadow:0 1px 4px rgba(0,0,0,0.1)"> | |
| <h3 style="margin:0;">π {i}. {title}</h3> | |
| {genre_html} | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown( | |
| "# π **BookMatch.AI**\n_Discover your next favorite read using text + image search_" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| lines=3, | |
| placeholder="Describe the book vibe (e.g. 'dark fantasy with magic and dragons')", | |
| label="π Describe Your Ideal Book" | |
| ) | |
| image_input = gr.Image(type="numpy", label="πΌοΈ Upload an Image for Inspiration (Optional)") | |
| weight_slider = gr.Slider(0, 1, value=0.6, step=0.05, label="βοΈ Text vs Image Weight") | |
| genre_dropdown = gr.Dropdown( | |
| choices=sorted(recommender.df['genre'].dropna().unique()) if 'genre' in recommender.df.columns else [], | |
| label="π Filter by Genre (Optional)", | |
| value=None | |
| ) | |
| submit_btn = gr.Button("β¨ Get Recommendations", variant="primary") | |
| with gr.Column(scale=1): | |
| output_html = gr.HTML(label="π― Your Top Matches") | |
| gr.Examples( | |
| examples=[ | |
| ["Dark fantasy adventure with mythical creatures", "https://images.unsplash.com/photo-1528372444006-1bfc81acab02", 0.6, None], | |
| ["Cozy romance set in a small town cafΓ©", "https://images.unsplash.com/photo-1519681393784-d120267933ba", 0.6, None], | |
| ["Space opera with political intrigue", "https://images.unsplash.com/photo-1462331940025-496dfbfc7564", 0.6, None], | |
| ], | |
| inputs=[text_input, image_input, weight_slider, genre_dropdown] | |
| ) | |
| submit_btn.click( | |
| fn=get_recommendations, | |
| inputs=[text_input, image_input, weight_slider, genre_dropdown], | |
| outputs=output_html | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |