Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import pipeline, CLIPProcessor, CLIPModel
|
| 5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import pickle
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
# -------------------------------
|
| 11 |
+
# BOOK RECOMMENDATION SYSTEM CLASS
|
| 12 |
+
# -------------------------------
|
| 13 |
+
class BookRecommendationSystem:
|
| 14 |
+
def __init__(self, csv_path='cleaned_complete_book_dataset.csv',
|
| 15 |
+
image_embeddings_path='image_embeddings.pkl'):
|
| 16 |
+
self.df = None
|
| 17 |
+
self.text_model = None
|
| 18 |
+
self.text_embeddings = None
|
| 19 |
+
self.image_model = None
|
| 20 |
+
self.image_processor = None
|
| 21 |
+
self.image_embeddings = None
|
| 22 |
+
self.image_post_ids = None
|
| 23 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
self.load_text_data(csv_path)
|
| 25 |
+
self.load_image_embeddings(image_embeddings_path)
|
| 26 |
+
self.initialize_text_model()
|
| 27 |
+
self.initialize_image_model()
|
| 28 |
+
|
| 29 |
+
def load_text_data(self, filepath):
|
| 30 |
+
try:
|
| 31 |
+
self.df = pd.read_csv(filepath)
|
| 32 |
+
print(f"Dataset loaded successfully. Shape: {self.df.shape}")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error loading dataset: {e}")
|
| 35 |
+
self.df = pd.DataFrame()
|
| 36 |
+
|
| 37 |
+
def load_image_embeddings(self, embeddings_path):
|
| 38 |
+
try:
|
| 39 |
+
with open(embeddings_path, 'rb') as f:
|
| 40 |
+
data = pickle.load(f)
|
| 41 |
+
self.image_embeddings = data['embeddings']
|
| 42 |
+
self.image_post_ids = data['post_ids']
|
| 43 |
+
print(f"Image embeddings loaded: {len(self.image_post_ids)} posts")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error loading image embeddings: {e}")
|
| 46 |
+
self.image_embeddings = None
|
| 47 |
+
self.image_post_ids = None
|
| 48 |
+
|
| 49 |
+
def initialize_text_model(self):
|
| 50 |
+
if self.text_model is None:
|
| 51 |
+
try:
|
| 52 |
+
self.text_model = pipeline(
|
| 53 |
+
"feature-extraction",
|
| 54 |
+
model="sentence-transformers/all-MiniLM-L6-v2",
|
| 55 |
+
device=self.device
|
| 56 |
+
)
|
| 57 |
+
self._compute_text_embeddings()
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Error initializing text model: {e}")
|
| 60 |
+
|
| 61 |
+
def initialize_image_model(self):
|
| 62 |
+
if self.image_model is None and self.image_embeddings is not None:
|
| 63 |
+
try:
|
| 64 |
+
self.image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 65 |
+
self.image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error initializing image model: {e}")
|
| 68 |
+
|
| 69 |
+
def _compute_text_embeddings(self):
|
| 70 |
+
if self.df.empty:
|
| 71 |
+
return
|
| 72 |
+
self.df['text_for_embedding'] = (
|
| 73 |
+
self.df['description'].fillna('').astype(str) + ' ' +
|
| 74 |
+
self.df['title'].fillna('').astype(str)
|
| 75 |
+
).str.strip()
|
| 76 |
+
embeddings_list = [
|
| 77 |
+
self.text_model(text, truncation=True, max_length=512)[0][0]
|
| 78 |
+
if text and not text.isspace()
|
| 79 |
+
else np.zeros(384)
|
| 80 |
+
for text in self.df['text_for_embedding']
|
| 81 |
+
]
|
| 82 |
+
self.text_embeddings = np.array(embeddings_list)
|
| 83 |
+
|
| 84 |
+
def get_text_similarity(self, text_query):
|
| 85 |
+
if self.text_model is None or self.text_embeddings is None:
|
| 86 |
+
return np.zeros(len(self.df))
|
| 87 |
+
try:
|
| 88 |
+
query_out = self.text_model(text_query, truncation=True, max_length=512)
|
| 89 |
+
query_emb = np.array(query_out[0][0]).reshape(1, -1)
|
| 90 |
+
return cosine_similarity(query_emb, self.text_embeddings)[0]
|
| 91 |
+
except:
|
| 92 |
+
return np.zeros(len(self.df))
|
| 93 |
+
|
| 94 |
+
def get_image_similarity(self, user_image):
|
| 95 |
+
if self.image_model is None or self.image_embeddings is None:
|
| 96 |
+
return np.zeros(len(self.df))
|
| 97 |
+
try:
|
| 98 |
+
img = user_image.convert("RGB")
|
| 99 |
+
inputs = self.image_processor(images=img, return_tensors="pt").to(self.device)
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
user_emb = self.image_model.get_image_features(**inputs)
|
| 102 |
+
user_emb /= user_emb.norm(p=2, dim=-1, keepdim=True)
|
| 103 |
+
user_emb = user_emb.cpu().numpy()
|
| 104 |
+
image_sims = cosine_similarity(user_emb, self.image_embeddings)[0]
|
| 105 |
+
|
| 106 |
+
df_similarities = np.zeros(len(self.df))
|
| 107 |
+
id_to_idx = {post_id: i for i, post_id in enumerate(self.image_post_ids)}
|
| 108 |
+
mask = self.df['id'].isin(id_to_idx)
|
| 109 |
+
indices = self.df.index[mask]
|
| 110 |
+
map_ids = self.df['id'][mask].map(id_to_idx)
|
| 111 |
+
df_similarities[indices] = image_sims[map_ids.values]
|
| 112 |
+
return df_similarities
|
| 113 |
+
except:
|
| 114 |
+
return np.zeros(len(self.df))
|
| 115 |
+
|
| 116 |
+
def recommend_multimodal(self, text_query=None, user_image=None,
|
| 117 |
+
weights=(0.6, 0.4), top_k=5, genre=None):
|
| 118 |
+
if self.df.empty:
|
| 119 |
+
return ["Dataset not loaded."]
|
| 120 |
+
df = self.df.copy()
|
| 121 |
+
if genre:
|
| 122 |
+
df = df[df["genre"].str.lower() == genre.lower()]
|
| 123 |
+
if df.empty:
|
| 124 |
+
return ["No books found for this genre."]
|
| 125 |
+
|
| 126 |
+
text_sim = self.get_text_similarity(text_query) if text_query else np.zeros(len(self.df))
|
| 127 |
+
image_sim = self.get_image_similarity(user_image) if user_image is not None else np.zeros(len(self.df))
|
| 128 |
+
combined_sim = weights[0] * text_sim + weights[1] * image_sim
|
| 129 |
+
df['similarity'] = combined_sim
|
| 130 |
+
|
| 131 |
+
df = df.sort_values("similarity", ascending=False).head(top_k)
|
| 132 |
+
recommendations = []
|
| 133 |
+
for _, row in df.iterrows():
|
| 134 |
+
if pd.notna(row['top_one_book_title']):
|
| 135 |
+
first_title = str(row['top_one_book_title']).split(" and ")[0].split("\n")[0].strip()
|
| 136 |
+
recommendations.append((first_title, row.get("genre", "")))
|
| 137 |
+
return recommendations[:top_k]
|
| 138 |
+
|
| 139 |
+
# -------------------------------
|
| 140 |
+
# INITIALIZE SYSTEM
|
| 141 |
+
# -------------------------------
|
| 142 |
+
recommender = BookRecommendationSystem()
|
| 143 |
+
|
| 144 |
+
# -------------------------------
|
| 145 |
+
# GRADIO UI
|
| 146 |
+
# -------------------------------
|
| 147 |
+
def get_recommendations(text_query, image_input, weight, selected_genre):
|
| 148 |
+
if not text_query.strip():
|
| 149 |
+
text_query = None
|
| 150 |
+
user_image = Image.fromarray(image_input) if image_input is not None else None
|
| 151 |
+
recommendations = recommender.recommend_multimodal(
|
| 152 |
+
text_query=text_query,
|
| 153 |
+
user_image=user_image,
|
| 154 |
+
weights=(weight, 1-weight),
|
| 155 |
+
top_k=5,
|
| 156 |
+
genre=selected_genre
|
| 157 |
+
)
|
| 158 |
+
if not recommendations:
|
| 159 |
+
return "<p style='color:red'>β No matching books found. Try a different query or image.</p>"
|
| 160 |
+
|
| 161 |
+
# Create HTML cards
|
| 162 |
+
html = "<div style='display:grid; gap:12px;'>"
|
| 163 |
+
for i, (title, genre) in enumerate(recommendations, start=1):
|
| 164 |
+
genre_html = f"<p style='color:#555; font-size:0.9em; margin:0;'>π Genre: {genre}</p>" if genre else ""
|
| 165 |
+
html += f"""
|
| 166 |
+
<div style="background:#f9fafb; border-radius:10px; padding:12px; box-shadow:0 1px 4px rgba(0,0,0,0.1)">
|
| 167 |
+
<h3 style="margin:0;">π {i}. {title}</h3>
|
| 168 |
+
{genre_html}
|
| 169 |
+
</div>
|
| 170 |
+
"""
|
| 171 |
+
html += "</div>"
|
| 172 |
+
return html
|
| 173 |
+
|
| 174 |
+
with gr.Blocks(theme=gr.themes.Soft()) as iface:
|
| 175 |
+
gr.Markdown(
|
| 176 |
+
"# π **BookMatch.AI**\n_Discover your next favorite read using text + image search_"
|
| 177 |
+
)
|
| 178 |
+
with gr.Row():
|
| 179 |
+
with gr.Column(scale=1):
|
| 180 |
+
text_input = gr.Textbox(
|
| 181 |
+
lines=3,
|
| 182 |
+
placeholder="Describe the book vibe (e.g. 'dark fantasy with magic and dragons')",
|
| 183 |
+
label="π Describe Your Ideal Book"
|
| 184 |
+
)
|
| 185 |
+
image_input = gr.Image(type="numpy", label="πΌοΈ Upload an Image for Inspiration (Optional)")
|
| 186 |
+
weight_slider = gr.Slider(0, 1, value=0.6, step=0.05, label="βοΈ Text vs Image Weight")
|
| 187 |
+
genre_dropdown = gr.Dropdown(
|
| 188 |
+
choices=sorted(recommender.df['genre'].dropna().unique()) if 'genre' in recommender.df.columns else [],
|
| 189 |
+
label="π Filter by Genre (Optional)",
|
| 190 |
+
value=None
|
| 191 |
+
)
|
| 192 |
+
submit_btn = gr.Button("β¨ Get Recommendations", variant="primary")
|
| 193 |
+
with gr.Column(scale=1):
|
| 194 |
+
output_html = gr.HTML(label="π― Your Top Matches")
|
| 195 |
+
|
| 196 |
+
gr.Examples(
|
| 197 |
+
examples=[
|
| 198 |
+
["Dark fantasy adventure with mythical creatures", "https://images.unsplash.com/photo-1528372444006-1bfc81acab02", 0.6, None],
|
| 199 |
+
["Cozy romance set in a small town cafΓ©", "https://images.unsplash.com/photo-1519681393784-d120267933ba", 0.6, None],
|
| 200 |
+
["Space opera with political intrigue", "https://images.unsplash.com/photo-1462331940025-496dfbfc7564", 0.6, None],
|
| 201 |
+
],
|
| 202 |
+
inputs=[text_input, image_input, weight_slider, genre_dropdown]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
submit_btn.click(
|
| 206 |
+
fn=get_recommendations,
|
| 207 |
+
inputs=[text_input, image_input, weight_slider, genre_dropdown],
|
| 208 |
+
outputs=output_html
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
iface.launch()
|