sshenai commited on
Commit
4505c1f
·
verified ·
1 Parent(s): 575f4b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -83
app.py CHANGED
@@ -1,91 +1,65 @@
1
- # 导入必要库
2
- from datasets import load_dataset
3
- import pandas as pd
4
- import torch
5
- from sentence_transformers import SentenceTransformer, util
6
- from transformers import pipeline
7
 
8
- # ----------------------
9
- # 1. 加载数据集
10
- # ----------------------
11
- def load_book_data():
12
- # 加载 bookcorpus 数据集(仅保留标题和摘要)
13
- dataset = load_dataset("bookcorpus", split="train")
14
- books = pd.DataFrame(dataset)[["title", "text"]].rename(columns={"text": "description"})
15
-
16
- # 过滤空值并截断长文本(可选)
17
- books = books.dropna().head(1000) # 取前1000条数据便于演示
18
- books["description"] = books["description"].apply(lambda x: x[:5000]) # 截断至5000字以内
19
- return books
20
 
21
- # ----------------------
22
- # 2. 初始化模型
23
- # ----------------------
24
- def initialize_models():
25
- # 语义搜索模型
26
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
27
-
28
- # 摘要生成模型
29
- summarizer = pipeline(
30
- "summarization",
31
- model="facebook/bart-large-cnn",
32
- max_length=150,
33
- min_length=30,
34
- do_sample=False
35
- )
36
- return embedder, summarizer
37
 
38
- # ----------------------
39
- # 3. 关键词搜索与推荐
40
- # ----------------------
41
- def search_similar_books(keywords, books, embedder, top_k=5):
42
- # 生成关键词嵌入
43
- keyword_embedding = embedder.encode(keywords, convert_to_tensor=True)
44
-
45
- # 生成书籍嵌入(批量处理)
46
- book_embeddings = torch.stack([
47
- embedder.encode(title + " " + desc, convert_to_tensor=True)
48
- for title, desc in zip(books["title"], books["description"])
49
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # 计算余弦相似度
52
- cos_scores = util.cos_sim(keyword_embedding, book_embeddings)[0]
53
 
54
- # 获取 top-k 结果
55
- top_results = torch.topk(cos_scores, k=top_k).indices.tolist()
56
- return books.iloc[top_results]
 
 
57
 
58
- # ----------------------
59
- # 4. 生成摘要并输出
60
- # ----------------------
61
- def generate_book_summaries(books, summarizer):
62
- results = []
63
- for idx, row in books.iterrows():
64
- summary = summarizer(row["description"], max_length=150)[0]["summary_text"]
65
- results.append({
66
- "title": row["title"],
67
- "summary": summary,
68
- "similarity": "{:.2f}".format(float(cos_scores[idx])) # 可选:添加相似度分数
69
- })
70
- return results
71
 
72
- # ----------------------
73
- # 5. 主函数与交互
74
- # ----------------------
75
  if __name__ == "__main__":
76
- # 加载数据与模型
77
- books = load_book_data()
78
- embedder, summarizer = initialize_models()
79
-
80
- # 用户输入关键词
81
- user_keywords = "fantasy adventure magic" # 示例关键词,可替换为用户输入
82
-
83
- # 执行搜索与摘要生成
84
- similar_books = search_similar_books(user_keywords, books, embedder)
85
- summaries = generate_book_summaries(similar_books, summarizer)
86
-
87
- # 打印结果
88
- for i, book in enumerate(summaries, 1):
89
- print(f"📚 Book {i}: {book['title']}")
90
- print(f"🌟 Similarity: {book['similarity']}")
91
- print(f"📝 Summary: {book['summary']}\n")
 
1
+ # 加载模型(全局加载提升性能)
2
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
3
+ SUMMARIZER = pipeline("summarization", model="facebook/bart-large-cnn")
4
+ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
5
+ MODEL = AutoModel.from_pretrained(MODEL_NAME)
 
6
 
7
+ # 加载bookcorpus数据集
8
+ def load_data():
9
+ dataset = load_dataset("bookcorpus", streaming=True) # 启用流式读取
10
+ books = dataset["train"].take(100_000) # 取前10万条数据
11
+ return [{"text": x["text"]} for x in books if len(x["text"]) > 100] # 过滤短文本
 
 
 
 
 
 
 
12
 
13
+ # 文本嵌入生成(复用原始代码逻辑)
14
+ def mean_pooling(model_output, attention_mask):
15
+ token_embeddings = model_output[0]
16
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
17
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def get_embeddings(texts):
20
+ encoded_input = TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt')
21
+ with torch.no_grad():
22
+ model_output = MODEL(**encoded_input)
23
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
24
+ return F.normalize(embeddings, p=2, dim=1)
25
+
26
+ # 相似度计算适配新数据集
27
+ def find_similar_books(keywords, books, top_k=5):
28
+ keyword_embedding = get_embeddings(keywords).mean(0).unsqueeze(0)
29
+ book_embeddings = get_embeddings([book["text"] for book in books])
30
+ similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
31
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
32
+ return [books[i] for i in top_indices]
33
+
34
+ # 摘要生成适配长文本
35
+ def summarize_description(text):
36
+ if len(text.split()) > 500:
37
+ return SUMMARIZER(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
38
+ return text
39
+
40
+ # 主逻辑流程
41
+ def recommend_books(keywords):
42
+ keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
43
+ if len(keywords) < 3:
44
+ return "Please enter at least 3 keywords separated by commas or spaces."
45
 
46
+ books = load_data()
47
+ similar_books = find_similar_books(keywords, books)
48
 
49
+ output = []
50
+ for i, book in enumerate(similar_books, 1):
51
+ summary = summarize_description(book["text"])
52
+ output.append(f"{i}. {summary}\n")
53
+ return "\n".join(output)
54
 
55
+ # Gradio界面保持相同
56
+ iface = gr.Interface(
57
+ fn=recommend_books,
58
+ inputs=gr.Textbox(label="Enter 3+ keywords (comma/space separated)"),
59
+ outputs=gr.Textbox(label="Recommended Book Passages"),
60
+ title="Book Corpus Semantic Search",
61
+ description="Search through 100,000 book passages from bookcorpus dataset"
62
+ )
 
 
 
 
 
63
 
 
 
 
64
  if __name__ == "__main__":
65
+ iface.launch()