smxxxxxxx commited on
Commit
012e95a
·
verified ·
1 Parent(s): 3754e53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModel, pipeline
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import gradio as gr
9
+
10
+ # Load models
11
+ def load_models():
12
+ # For semantic search
13
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModel.from_pretrained(model_name)
16
+
17
+ # For summarization
18
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
+
20
+ return tokenizer, model, summarizer
21
+
22
+ # Load book data
23
+ def load_data():
24
+ # Load the Goodreads dataset (adjust path as needed)
25
+ books = pd.read_csv("books.csv")
26
+ # Keep only relevant columns and drop rows with missing descriptions
27
+ books = books[['title', 'description']].dropna()
28
+ return books
29
+
30
+ # Mean pooling for sentence embeddings
31
+ def mean_pooling(model_output, attention_mask):
32
+ token_embeddings = model_output[0]
33
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
+
36
+ # Get embeddings for text
37
+ def get_embeddings(texts, tokenizer, model):
38
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
39
+ with torch.no_grad():
40
+ model_output = model(**encoded_input)
41
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
42
+ embeddings = F.normalize(embeddings, p=2, dim=1)
43
+ return embeddings
44
+
45
+ # Find most similar books
46
+ def find_similar_books(keywords, books, tokenizer, model, top_k=5):
47
+ # Get embeddings for keywords
48
+ keyword_embedding = get_embeddings(keywords, tokenizer, model).mean(0).unsqueeze(0)
49
+
50
+ # Get embeddings for book titles and descriptions
51
+ book_texts = books['title'] + " " + books['description']
52
+ book_embeddings = get_embeddings(book_texts.tolist(), tokenizer, model)
53
+
54
+ # Calculate similarity
55
+ similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
56
+
57
+ # Get top matches
58
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
59
+ results = books.iloc[top_indices].copy()
60
+ results['similarity'] = similarities[top_indices]
61
+
62
+ return results
63
+
64
+ # Summarize book description
65
+ def summarize_description(description, summarizer):
66
+ if len(description.split()) > 100: # Only summarize long descriptions
67
+ summary = summarizer(description, max_length=130, min_length=30, do_sample=False)
68
+ return summary[0]['summary_text']
69
+ return description
70
+
71
+ # Main function
72
+ def recommend_books(keywords):
73
+ # Split keywords by comma or space
74
+ keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
75
+ if len(keywords) < 3:
76
+ return "Please enter at least 3 keywords separated by commas or spaces."
77
+
78
+ # Load models and data
79
+ tokenizer, model, summarizer = load_models()
80
+ books = load_data()
81
+
82
+ # Find similar books
83
+ similar_books = find_similar_books(keywords, books, tokenizer, model)
84
+
85
+ # Generate output
86
+ output = []
87
+ for i, (_, row) in enumerate(similar_books.iterrows(), 1):
88
+ summary = summarize_description(row['description'], summarizer)
89
+ output.append(f"{i}. {row['title']}\n Summary: {summary}\n")
90
+
91
+ return "\n".join(output)
92
+
93
+ # Gradio interface
94
+ iface = gr.Interface(
95
+ fn=recommend_books,
96
+ inputs=gr.Textbox(label="Enter at least 3 keywords (comma or space separated)"),
97
+ outputs=gr.Textbox(label="Recommended Books"),
98
+ title="Book Recommendation Engine",
99
+ description="Enter 3 or more keywords to find relevant books and get summaries of their plots."
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ iface.launch()