sshenai commited on
Commit
039f26a
·
verified ·
1 Parent(s): d51c0e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -92
app.py CHANGED
@@ -1,103 +1,67 @@
1
- import pandas as pd
 
 
 
 
2
  import numpy as np
3
- from transformers import AutoTokenizer, AutoModel, pipeline
4
- import torch
5
- import torch.nn.functional as F
6
- from sentence_transformers import SentenceTransformer
7
- from sklearn.metrics.pairwise import cosine_similarity
8
- import gradio as gr
9
 
10
- # Load models
11
- def load_models():
12
- # For semantic search
13
- model_name = "sentence-transformers/all-mpnet-base-v2"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModel.from_pretrained(model_name)
16
-
17
- # For summarization
18
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
19
-
20
- return tokenizer, model, summarizer
21
 
22
- # Load book data
23
- def load_data():
24
- # Load the Goodreads dataset (adjust path as needed)
25
- books = pd.read_csv("books.csv")
26
- # Keep only relevant columns and drop rows with missing descriptions
27
- books = books[['title', 'description']].dropna()
28
- return books
 
 
 
 
29
 
30
- # Mean pooling for sentence embeddings
31
- def mean_pooling(model_output, attention_mask):
32
- token_embeddings = model_output[0]
33
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
 
36
- # Get embeddings for text
37
- def get_embeddings(texts, tokenizer, model):
38
- encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
39
- with torch.no_grad():
40
- model_output = model(**encoded_input)
41
- embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
42
- embeddings = F.normalize(embeddings, p=2, dim=1)
43
- return embeddings
44
 
45
- # Find most similar books
46
- def find_similar_books(keywords, books, tokenizer, model, top_k=5):
47
- # Get embeddings for keywords
48
- keyword_embedding = get_embeddings(keywords, tokenizer, model).mean(0).unsqueeze(0)
49
-
50
- # Get embeddings for book titles and descriptions
51
- book_texts = books['title'] + " " + books['description']
52
- book_embeddings = get_embeddings(book_texts.tolist(), tokenizer, model)
53
-
54
- # Calculate similarity
55
- similarities = cosine_similarity(keyword_embedding, book_embeddings)[0]
56
-
57
- # Get top matches
58
- top_indices = np.argsort(similarities)[-top_k:][::-1]
59
- results = books.iloc[top_indices].copy()
60
- results['similarity'] = similarities[top_indices]
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return results
63
 
64
- # Summarize book description
65
- def summarize_description(description, summarizer):
66
- if len(description.split()) > 100: # Only summarize long descriptions
67
- summary = summarizer(description, max_length=130, min_length=30, do_sample=False)
68
- return summary[0]['summary_text']
69
- return description
70
-
71
- # Main function
72
- def recommend_books(keywords):
73
- # Split keywords by comma or space
74
- keywords = [kw.strip() for kw in keywords.replace(',', ' ').split() if kw.strip()]
75
- if len(keywords) < 3:
76
- return "Please enter at least 3 keywords separated by commas or spaces."
77
-
78
- # Load models and data
79
- tokenizer, model, summarizer = load_models()
80
- books = load_data()
81
-
82
- # Find similar books
83
- similar_books = find_similar_books(keywords, books, tokenizer, model)
84
-
85
- # Generate output
86
- output = []
87
- for i, (_, row) in enumerate(similar_books.iterrows(), 1):
88
- summary = summarize_description(row['description'], summarizer)
89
- output.append(f"{i}. {row['title']}\n Summary: {summary}\n")
90
-
91
- return "\n".join(output)
92
 
93
- # Gradio interface
94
- iface = gr.Interface(
95
- fn=recommend_books,
96
- inputs=gr.Textbox(label="Enter at least 3 keywords (comma or space separated)"),
97
- outputs=gr.Textbox(label="Recommended Books"),
98
- title="Book Recommendation Engine",
99
- description="Enter 3 or more keywords to find relevant books and get summaries of their plots."
100
- )
101
 
102
- if __name__ == "__main__":
103
- iface.launch()
 
1
+ # 安装依赖
2
+ !pip install datasets sentence-transformers transformers torch
3
+
4
+ # 导入库
5
+ from datasets import load_dataset
6
  import numpy as np
7
+ from sentence_transformers import SentenceTransformer, util
8
+ from transformers import pipeline
 
 
 
 
9
 
10
+ # 加载数据集
11
+ dataset = load_dataset("Pradeep016/career-guidance-qa-dataset", split="train")
12
+ # 过滤无效数据(确保question和answer非空)
13
+ dataset = dataset.filter(lambda x: x["question"] and x["answer"])
 
 
 
 
 
 
 
14
 
15
+ # 构建职位知识库(职位名称 + 问题-答案对)
16
+ def build_knowledge_base(dataset):
17
+ knowledge_base = []
18
+ for item in dataset:
19
+ role = item["role"]
20
+ question = item["question"]
21
+ answer = item["answer"]
22
+ # 合并职位名称与问题,增强语义关联
23
+ entry = f"{role} | {question}: {answer}"
24
+ knowledge_base.append(entry)
25
+ return knowledge_base
26
 
27
+ knowledge_base = build_knowledge_base(dataset)
 
 
 
 
28
 
29
+ # 初始化语义搜索模型
30
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
+ # 预计算知识库嵌入向量
32
+ knowledge_embeddings = embedder.encode(knowledge_base, convert_to_tensor=True)
 
 
 
 
33
 
34
+ def career_qa(user_input):
35
+ # 1. 语义搜索匹配相关职位
36
+ input_embedding = embedder.encode(user_input, convert_to_tensor=True)
37
+ # 计算余弦相似度
38
+ cos_scores = util.cos_sim(input_embedding, knowledge_embeddings)[0]
39
+ # 取前3个最相关条目
40
+ top_indices = np.argsort(cos_scores)[-3:][::-1]
41
+ top_matches = [knowledge_base[idx] for idx in top_indices]
 
 
 
 
 
 
 
 
42
 
43
+ # 2. 从匹配条目中提取答案
44
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-finetuned-squad2")
45
+ results = []
46
+ for match in top_matches:
47
+ role = match.split(" | ")[0]
48
+ context = match.split(" | ")[1]
49
+ # 固定问题为“请介绍这个职位”
50
+ result = qa_pipeline(question="请介绍这个职位", context=context)
51
+ results.append({
52
+ "职位名称": role,
53
+ "简介": result["answer"],
54
+ "置信度": result["score"]
55
+ })
56
  return results
57
 
58
+ # 用户输入职业关键词
59
+ user_query = "零售经理"
60
+ results = career_qa(user_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # 输出结果
63
+ for res in results:
64
+ print(f"职位:{res['职位名称']}")
65
+ print(f"简介:{res['简介']}")
66
+ print(f"置信度:{res['置信度']:.2f}\n")
 
 
 
67