Update app.py
Browse files
app.py
CHANGED
|
@@ -1,91 +1,65 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
from transformers import pipeline
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
dataset = load_dataset("bookcorpus", split="train")
|
| 14 |
-
books = pd.DataFrame(dataset)[["title", "text"]].rename(columns={"text": "description"})
|
| 15 |
-
|
| 16 |
-
# 过滤空值并截断长文本(可选)
|
| 17 |
-
books = books.dropna().head(1000) # 取前1000条数据便于演示
|
| 18 |
-
books["description"] = books["description"].apply(lambda x: x[:5000]) # 截断至5000字以内
|
| 19 |
-
return books
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 27 |
-
|
| 28 |
-
# 摘要生成模型
|
| 29 |
-
summarizer = pipeline(
|
| 30 |
-
"summarization",
|
| 31 |
-
model="facebook/bart-large-cnn",
|
| 32 |
-
max_length=150,
|
| 33 |
-
min_length=30,
|
| 34 |
-
do_sample=False
|
| 35 |
-
)
|
| 36 |
-
return embedder, summarizer
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
"title": row["title"],
|
| 67 |
-
"summary": summary,
|
| 68 |
-
"similarity": "{:.2f}".format(float(cos_scores[idx])) # 可选:添加相似度分数
|
| 69 |
-
})
|
| 70 |
-
return results
|
| 71 |
|
| 72 |
-
# ----------------------
|
| 73 |
-
# 5. 主函数与交互
|
| 74 |
-
# ----------------------
|
| 75 |
if __name__ == "__main__":
|
| 76 |
-
|
| 77 |
-
books = load_book_data()
|
| 78 |
-
embedder, summarizer = initialize_models()
|
| 79 |
-
|
| 80 |
-
# 用户输入关键词
|
| 81 |
-
user_keywords = "fantasy adventure magic" # 示例关键词,可替换为用户输入
|
| 82 |
-
|
| 83 |
-
# 执行搜索与摘要生成
|
| 84 |
-
similar_books = search_similar_books(user_keywords, books, embedder)
|
| 85 |
-
summaries = generate_book_summaries(similar_books, summarizer)
|
| 86 |
-
|
| 87 |
-
# 打印结果
|
| 88 |
-
for i, book in enumerate(summaries, 1):
|
| 89 |
-
print(f"📚 Book {i}: {book['title']}")
|
| 90 |
-
print(f"🌟 Similarity: {book['similarity']}")
|
| 91 |
-
print(f"📝 Summary: {book['summary']}\n")
|
|
|
|
| 1 |
+
# 加载模型(全局加载提升性能)
|
| 2 |
+
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 3 |
+
SUMMARIZER = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 4 |
+
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 5 |
+
MODEL = AutoModel.from_pretrained(MODEL_NAME)
|
|
|
|
| 6 |
|
| 7 |
+
# 加载bookcorpus数据集
|
| 8 |
+
def load_data():
|
| 9 |
+
dataset = load_dataset("bookcorpus", streaming=True) # 启用流式读取
|
| 10 |
+
books = dataset["train"].take(100_000) # 取前10万条数据
|
| 11 |
+
return [{"text": x["text"]} for x in books if len(x["text"]) > 100] # 过滤短文本
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# 文本嵌入生成(复用原始代码逻辑)
|
| 14 |
+
def mean_pooling(model_output, attention_mask):
|
| 15 |
+
token_embeddings = model_output[0]
|
| 16 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 17 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
def get_embeddings(texts):
|
| 20 |
+
encoded_input = TOKENIZER(texts, padding=True, truncation=True, return_tensors='pt')
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
model_output = MODEL(**encoded_input)
|
| 23 |
+
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 24 |
+
return F.normalize(embeddings, p=2, dim=1)
|
| 25 |
+
|
| 26 |
+
# 相似度计算适配新数据集
|
| 27 |
+
def find_similar_books(keywords, books, top_k=5):
|
| 28 |
+
keyword_embedding = get_embeddings(keywords).mean(0).unsqueeze(0)
|
| 29 |
+
book_embeddings = get_embeddings([book["text"] for book in books])
|
| 30 |
+
similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
|
| 31 |
+
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
| 32 |
+
return [books[i] for i in top_indices]
|
| 33 |
+
|
| 34 |
+
# 摘要生成适配长文本
|
| 35 |
+
def summarize_description(text):
|
| 36 |
+
if len(text.split()) > 500:
|
| 37 |
+
return SUMMARIZER(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
|
| 38 |
+
return text
|
| 39 |
+
|
| 40 |
+
# 主逻辑流程
|
| 41 |
+
def recommend_books(keywords):
|
| 42 |
+
keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
|
| 43 |
+
if len(keywords) < 3:
|
| 44 |
+
return "Please enter at least 3 keywords separated by commas or spaces."
|
| 45 |
|
| 46 |
+
books = load_data()
|
| 47 |
+
similar_books = find_similar_books(keywords, books)
|
| 48 |
|
| 49 |
+
output = []
|
| 50 |
+
for i, book in enumerate(similar_books, 1):
|
| 51 |
+
summary = summarize_description(book["text"])
|
| 52 |
+
output.append(f"{i}. {summary}\n")
|
| 53 |
+
return "\n".join(output)
|
| 54 |
|
| 55 |
+
# Gradio界面保持相同
|
| 56 |
+
iface = gr.Interface(
|
| 57 |
+
fn=recommend_books,
|
| 58 |
+
inputs=gr.Textbox(label="Enter 3+ keywords (comma/space separated)"),
|
| 59 |
+
outputs=gr.Textbox(label="Recommended Book Passages"),
|
| 60 |
+
title="Book Corpus Semantic Search",
|
| 61 |
+
description="Search through 100,000 book passages from bookcorpus dataset"
|
| 62 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
if __name__ == "__main__":
|
| 65 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|