sshenai commited on
Commit
8fe9808
·
verified ·
1 Parent(s): 1c0fa61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -35
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from datasets import load_dataset
2
  import numpy as np
3
  from transformers import AutoTokenizer, AutoModel, pipeline
4
  import torch
@@ -7,67 +7,96 @@ from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import gradio as gr
9
 
10
- # 加载模型(全局加载提升性能)
11
- MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
- SUMMARIZER = pipeline("summarization", model="facebook/bart-large-cnn")
13
- TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
14
- MODEL = AutoModel.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
15
 
16
- # 加载bookcorpus数据集
17
  def load_data():
18
- dataset = load_dataset("bookcorpus", streaming=True) # 启用流式读取
19
- books = dataset["train"].take(100_000) # 取前10万条数据
20
- return [{"text": x["text"]} for x in books if len(x["text"]) > 100] # 过滤短文本
 
 
21
 
22
- # 文本嵌入生成(复用原始代码逻辑)
23
  def mean_pooling(model_output, attention_mask):
24
  token_embeddings = model_output[0]
25
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
26
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
 
28
- def get_embeddings(texts):
29
- encoded_input = TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt')
 
30
  with torch.no_grad():
31
- model_output = MODEL(**encoded_input)
32
  embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
33
- return F.normalize(embeddings, p=2, dim=1)
 
34
 
35
- # 相似度计算适配新数据集
36
- def find_similar_books(keywords, books, top_k=5):
37
- keyword_embedding = get_embeddings(keywords).mean(0).unsqueeze(0)
38
- book_embeddings = get_embeddings([book["text"] for book in books])
 
 
 
 
 
 
39
  similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
 
 
40
  top_indices = np.argsort(similarities)[-top_k:][::-1]
41
- return [books[i] for i in top_indices]
 
 
 
42
 
43
- # 摘要生成适配长文本
44
- def summarize_description(text):
45
- if len(text.split()) > 500:
46
- return SUMMARIZER(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
47
- return text
 
48
 
49
- # 主逻辑流程
50
  def recommend_books(keywords):
 
51
  keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
52
  if len(keywords) < 3:
53
  return "Please enter at least 3 keywords separated by commas or spaces."
54
 
 
 
55
  books = load_data()
56
- similar_books = find_similar_books(keywords, books)
57
 
 
 
 
 
58
  output = []
59
- for i, book in enumerate(similar_books, 1):
60
- summary = summarize_description(book["text"])
61
- output.append(f"{i}. {summary}\n")
 
62
  return "\n".join(output)
63
 
64
- # Gradio界面保持相同
65
  iface = gr.Interface(
66
  fn=recommend_books,
67
- inputs=gr.Textbox(label="Enter 3+ keywords (comma/space separated)"),
68
- outputs=gr.Textbox(label="Recommended Book Passages"),
69
- title="Book Corpus Semantic Search",
70
- description="Search through 100,000 book passages from bookcorpus dataset"
71
  )
72
 
73
  if __name__ == "__main__":
 
1
+ import pandas as pd
2
  import numpy as np
3
  from transformers import AutoTokenizer, AutoModel, pipeline
4
  import torch
 
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import gradio as gr
9
 
10
+ # Load models
11
+ def load_models():
12
+ # For semantic search
13
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModel.from_pretrained(model_name)
16
+
17
+ # For summarization
18
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
+
20
+ return tokenizer, model, summarizer
21
 
22
+ # Load book data
23
  def load_data():
24
+ # Load the Goodreads dataset (adjust path as needed)
25
+ books = pd.read_csv("bookcorpus.csv")
26
+ # Keep only relevant columns and drop rows with missing descriptions
27
+ books = books[['title', 'author']].dropna()
28
+ return books
29
 
30
+ # Mean pooling for sentence embeddings
31
  def mean_pooling(model_output, attention_mask):
32
  token_embeddings = model_output[0]
33
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
 
36
+ # Get embeddings for text
37
+ def get_embeddings(texts, tokenizer, model):
38
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
39
  with torch.no_grad():
40
+ model_output = model(**encoded_input)
41
  embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
42
+ embeddings = F.normalize(embeddings, p=2, dim=1)
43
+ return embeddings
44
 
45
+ # Find most similar books
46
+ def find_similar_books(keywords, books, tokenizer, model, top_k=5):
47
+ # Get embeddings for keywords
48
+ keyword_embedding = get_embeddings(keywords, tokenizer, model).mean(0).unsqueeze(0)
49
+
50
+ # Get embeddings for book titles and descriptions
51
+ book_texts = books['title'] + " " + books['author']
52
+ book_embeddings = get_embeddings(book_texts.tolist(), tokenizer, model)
53
+
54
+ # Calculate similarity
55
  similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
56
+
57
+ # Get top matches
58
  top_indices = np.argsort(similarities)[-top_k:][::-1]
59
+ results = books.iloc[top_indices].copy()
60
+ results['similarity'] = similarities[top_indices]
61
+
62
+ return results
63
 
64
+ # Summarize book description
65
+ def summarize_description(description, summarizer):
66
+ if len(description.split()) > 100: # Only summarize long descriptions
67
+ summary = summarizer(description, max_length=130, min_length=30, do_sample=False)
68
+ return summary[0]['summary_text']
69
+ return description
70
 
71
+ # Main function
72
  def recommend_books(keywords):
73
+ # Split keywords by comma or space
74
  keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
75
  if len(keywords) < 3:
76
  return "Please enter at least 3 keywords separated by commas or spaces."
77
 
78
+ # Load models and data
79
+ tokenizer, model, summarizer = load_models()
80
  books = load_data()
 
81
 
82
+ # Find similar books
83
+ similar_books = find_similar_books(keywords, books, tokenizer, model)
84
+
85
+ # Generate output
86
  output = []
87
+ for i, (_, row) in enumerate(similar_books.iterrows(), 1):
88
+ summary = summarize_description(row['description'], summarizer)
89
+ output.append(f"{i}. {row['title']}\n Summary: {summary}\n")
90
+
91
  return "\n".join(output)
92
 
93
+ # Gradio interface
94
  iface = gr.Interface(
95
  fn=recommend_books,
96
+ inputs=gr.Textbox(label="Enter at least 3 keywords (comma or space separated)"),
97
+ outputs=gr.Textbox(label="Recommended Books"),
98
+ title="Book Recommendation Engine",
99
+ description="Enter 3 or more keywords to find relevant books and get summaries of their plots."
100
  )
101
 
102
  if __name__ == "__main__":