sshenai commited on
Commit
9632f31
·
verified ·
1 Parent(s): 1589fa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -72
app.py CHANGED
@@ -1,74 +1,57 @@
1
- # 导入必要的库
2
- import gradio as gr
3
- from datasets import load_dataset
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer, util
6
  from transformers import pipeline
7
-
8
- # 安装依赖(在Hugging Face Spaces中可省略,若空间环境未预装相关库可保留)
9
- #!pip install datasets sentence-transformers transformers torch
10
-
11
- # 加载数据集
12
- dataset = load_dataset("Pradeep016/career-guidance-qa-dataset", split="train")
13
- # 过滤无效数据(确保question和answer非空)
14
- dataset = dataset.filter(lambda x: x["question"] and x["answer"])
15
-
16
- # 构建职位知识库(职位名称 + 问题-答案对)
17
- def build_knowledge_base(dataset):
18
- knowledge_base = []
19
- for item in dataset:
20
- role = item["role"]
21
- question = item["question"]
22
- answer = item["answer"]
23
- # 合并职位名称与问题,增强语义关联
24
- entry = f"{role} | {question}: {answer}"
25
- knowledge_base.append(entry)
26
- return knowledge_base
27
-
28
- knowledge_base = build_knowledge_base(dataset)
29
-
30
- # 初始化语义搜索模型
31
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
32
- # 预计算知识库嵌入向量
33
- knowledge_embeddings = embedder.encode(knowledge_base, convert_to_tensor=True)
34
-
35
- # 智能问答函数
36
- def career_qa(user_input):
37
- # 1. 语义搜索匹配相关职位
38
- input_embedding = embedder.encode(user_input, convert_to_tensor=True)
39
- # 计算余弦相似度
40
- cos_scores = util.cos_sim(input_embedding, knowledge_embeddings)[0]
41
- # 取前3个最相关条目
42
- top_indices = np.argsort(cos_scores)[-3:][::-1]
43
- top_matches = [knowledge_base[idx] for idx in top_indices]
44
 
45
- # 2. 从匹配条目中提取答案
46
- qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-finetuned-squad2")
47
- results = []
48
- for match in top_matches:
49
- role = match.split(" | ")[0]
50
- context = match.split(" | ")[1]
51
- # 固定问题为“请介绍这个职位”
52
- result = qa_pipeline(question="请介绍这个职位", context=context)
53
- results.append({
54
- "职位名称": role,
55
- "简介": result["answer"],
56
- "置信度": result["score"]
57
- })
58
- return results
59
-
60
- # Gradio界面定义
61
- def demo(user_input):
62
- results = career_qa(user_input)
63
- output = "\n".join([f"📌 {res['职位名称']}\n{res['简介']}\n" for res in results])
64
- return output
65
-
66
- iface = gr.Interface(
67
- fn=demo,
68
- inputs=gr.Textbox(label="输入职业关键词(如:零售经理)"),
69
- outputs=gr.Textbox(label="职位介绍"),
70
- title="职业咨询智能问答",
71
- )
72
-
73
- if __name__ == "__main__":
74
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入库
2
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
3
+ from PIL import Image
4
+ import torch
 
5
  from transformers import pipeline
6
+ import requests
7
+ from io import BytesIO
8
+
9
+ # 1. 图像标题生成(使用指定模型)
10
+ def generate_caption(image_url):
11
+ model_name = "bipin/image-caption-generator"
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
14
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # 下载并预处理图像
19
+ response = requests.get(image_url)
20
+ img = Image.open(BytesIO(response.content)).convert("RGB")
21
+ pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values.to(device)
22
+
23
+ # 生成标题(限制50字内)
24
+ output_ids = model.generate(pixel_values, num_beams=4, max_length=50)
25
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
26
+ return caption
27
+
28
+ # 2. 标题扩写为宣传文案(使用文本生成模型)
29
+ def expand_to_copy(caption):
30
+ generator = pipeline("text-generation", model="gpt2", max_length=200)
31
+ prompt = f"根据以下图片标题生成宣传文案:{caption}\n要求:生动形象,突出产品优势,适合社交媒体传播。"
32
+ copy = generator(prompt, num_return_sequences=1)[0]['generated_text']
33
+ return copy.strip()
34
+
35
+ # 3. 文本转语音(使用TTS模型)
36
+ def text_to_speech(text, output_file="output.mp3"):
37
+ tts = pipeline("text-to-speech", model="facebook/t5-small")
38
+ speech = tts(text)
39
+ with open(output_file, "wb") as f:
40
+ f.write(speech["audio"])
41
+ return output_file
42
+
43
+ # 主函数
44
+ def marketing_pipeline(image_url):
45
+ # 生成标题
46
+ caption = generate_caption(image_url)
47
+ print(f"生成标题:{caption}")
48
+
49
+ # 扩写文案
50
+ copy = expand_to_copy(caption)
51
+ print(f"宣传文案:\n{copy}")
52
+
53
+ # 生成语音
54
+ audio_file = text_to_speech(copy)
55
+ print(f"语音文件已保存:{audio_file}")
56
+
57
+ return caption, copy, audio_file