sshenai commited on
Commit
bb78694
·
verified ·
1 Parent(s): ee47335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -78
app.py CHANGED
@@ -1,103 +1,126 @@
1
- import pandas as pd
 
2
  import numpy as np
3
- from transformers import AutoTokenizer, AutoModel, pipeline
4
  import torch
5
  import torch.nn.functional as F
6
- from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import gradio as gr
 
9
 
10
- # Load models
11
- def load_models():
12
- # For semantic search
13
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModel.from_pretrained(model_name)
16
-
17
- # For summarization
18
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
-
20
- return tokenizer, model, summarizer
21
 
22
- # Load book data
23
- def load_data():
24
- # Load the Goodreads dataset (adjust path as needed)
25
- books = pd.read_csv("bookcorpus.csv")
26
- # Keep only relevant columns and drop rows with missing descriptions
27
- books = books[['title', 'author']].dropna()
 
 
 
 
 
 
28
  return books
29
 
30
- # Mean pooling for sentence embeddings
 
 
 
 
 
 
 
 
31
  def mean_pooling(model_output, attention_mask):
32
- token_embeddings = model_output[0]
33
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
 
36
- # Get embeddings for text
37
- def get_embeddings(texts, tokenizer, model):
38
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
39
- with torch.no_grad():
40
- model_output = model(**encoded_input)
41
- embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
42
- embeddings = F.normalize(embeddings, p=2, dim=1)
43
- return embeddings
 
 
 
 
 
 
 
 
 
44
 
45
- # Find most similar books
46
- def find_similar_books(keywords, books, tokenizer, model, top_k=5):
47
- # Get embeddings for keywords
48
- keyword_embedding = get_embeddings(keywords, tokenizer, model).mean(0).unsqueeze(0)
 
49
 
50
- # Get embeddings for book titles and descriptions
51
- book_texts = books['title'] + " " + books['author']
52
- book_embeddings = get_embeddings(book_texts.tolist(), tokenizer, model)
53
 
54
- # Calculate similarity
55
- similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
 
56
 
57
- # Get top matches
58
- top_indices = np.argsort(similarities)[-top_k:][::-1]
59
- results = books.iloc[top_indices].copy()
60
- results['similarity'] = similarities[top_indices]
61
 
 
 
 
 
 
 
 
 
 
 
62
  return results
63
 
64
- # Summarize book description
65
- def summarize_description(description, summarizer):
66
- if len(description.split()) > 100: # Only summarize long descriptions
67
- summary = summarizer(description, max_length=130, min_length=30, do_sample=False)
68
- return summary[0]['summary_text']
69
- return description
70
-
71
- # Main function
72
- def recommend_books(keywords):
73
- # Split keywords by comma or space
74
- keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
75
- if len(keywords) < 3:
76
- return "Please enter at least 3 keywords separated by commas or spaces."
77
-
78
- # Load models and data
79
- tokenizer, model, summarizer = load_models()
80
- books = load_data()
81
 
82
- # Find similar books
83
- similar_books = find_similar_books(keywords, books, tokenizer, model)
 
84
 
85
- # Generate output
86
- output = []
87
- for i, (_, row) in enumerate(similar_books.iterrows(), 1):
88
- summary = summarize_description(row['description'], summarizer)
89
- output.append(f"{i}. {row['title']}\n Summary: {summary}\n")
 
 
 
90
 
91
- return "\n".join(output)
 
 
 
 
92
 
93
- # Gradio interface
94
- iface = gr.Interface(
95
- fn=recommend_books,
96
- inputs=gr.Textbox(label="Enter at least 3 keywords (comma or space separated)"),
97
- outputs=gr.Textbox(label="Recommended Books"),
98
- title="Book Recommendation Engine",
99
- description="Enter 3 or more keywords to find relevant books and get summaries of their plots."
100
- )
101
 
102
  if __name__ == "__main__":
103
- iface.launch()
 
1
+ # app.py
2
+ from datasets import load_dataset
3
  import numpy as np
 
4
  import torch
5
  import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModel, pipeline
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import gradio as gr
9
+ import re
10
 
11
+ # 全局配置
12
+ MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # 更强大的语义模型
13
+ SUMMARIZER_NAME = "facebook/bart-large-cnn"
14
+ DATASET_NAME = "bookcorpus"
15
+ CACHE_DIR = "./data-cache"
16
+
17
+ # 预加载资源
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ model = AutoModel.from_pretrained(MODEL_NAME)
20
+ summarizer = pipeline("summarization", SUMMARIZER_NAME)
 
21
 
22
+ # 加载并预处理书籍数据
23
+ def load_books():
24
+ dataset = load_dataset(DATASET_NAME, split='train', streaming=True)
25
+ books = []
26
+ for book in dataset.take(50000): # 取5万本书
27
+ text = book['text'].strip()
28
+ if len(text) > 500: # 过滤短文本
29
+ title = re.findall(r'"([^"]*)"', text[:200]) # 尝试提取标题
30
+ books.append({
31
+ "text": text,
32
+ "title": title[0] if title else "Untitled Book"
33
+ })
34
  return books
35
 
36
+ # 生成语义嵌入
37
+ def get_embeddings(texts):
38
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ embeddings = mean_pooling(outputs, inputs['attention_mask'])
42
+ return F.normalize(embeddings, p=2, dim=1)
43
+
44
+ # 平均池化
45
  def mean_pooling(model_output, attention_mask):
46
+ token_embeddings = model_output.last_hidden_state
47
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
48
+ return torch.sum(token_embedding * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
49
 
50
+ # 智能摘要生成
51
+ def generate_summary(text):
52
+ inputs = tokenizer(
53
+ "summarize: " + text,
54
+ max_length=1024,
55
+ truncation=True,
56
+ return_tensors="pt"
57
+ )
58
+ summary_ids = summarizer.model.generate(
59
+ inputs.input_ids,
60
+ max_length=150,
61
+ min_length=50,
62
+ length_penalty=2.0,
63
+ num_beams=4,
64
+ early_stopping=True
65
+ )
66
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
67
 
68
+ # 核心推荐逻辑
69
+ def recommend_books(keywords, top_k=5):
70
+ # 清洗输入
71
+ keywords = re.sub(r'[^\w\s,]', '', keywords).lower()
72
+ keywords = [k.strip() for k in keywords.split(',') if k.strip()]
73
 
74
+ if len(keywords) < 2:
75
+ return "❗ Please enter at least 2 keywords (e.g. 'fantasy, magic')"
 
76
 
77
+ # 获取嵌入
78
+ keyword_emb = get_embeddings([" ".join(keywords)]).mean(dim=0)
79
+ book_embs = get_embeddings([f"{b['title']} {b['text']}" for b in books])
80
 
81
+ # 计算相似度
82
+ sim_scores = cosine_similarity(keyword_emb.reshape(1,-1), book_embs)[0]
83
+ top_indices = np.argsort(sim_scores)[-top_k:][::-1]
 
84
 
85
+ # 生成结果
86
+ results = []
87
+ for idx in top_indices:
88
+ book = books[idx]
89
+ summary = generate_summary(book['text'])
90
+ results.append({
91
+ "title": book['title'],
92
+ "summary": summary,
93
+ "score": f"{sim_scores[idx]:.2f}"
94
+ })
95
  return results
96
 
97
+ # Gradio界面
98
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
99
+ gr.Markdown("# 📚 智能图书推荐系统")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ with gr.Row():
102
+ inputs = gr.Textbox(label="输入关键词(用逗号分隔)", placeholder="例如:sci-fi, time travel")
103
+ outputs = gr.JSON(label="推荐结果")
104
 
105
+ examples = gr.Examples(
106
+ examples=[
107
+ ["romance, paris"],
108
+ ["mystery, detective"],
109
+ ["science fiction, space opera"]
110
+ ],
111
+ inputs=[inputs]
112
+ )
113
 
114
+ inputs.submit(
115
+ fn=recommend_books,
116
+ inputs=inputs,
117
+ outputs=outputs
118
+ )
119
 
120
+ # 初始化数据
121
+ print("Loading book data...")
122
+ books = load_books()
123
+ print(f"Loaded {len(books)} books")
 
 
 
 
124
 
125
  if __name__ == "__main__":
126
+ demo.launch(server_name="0.0.0.0", server_port=7860)