LCNada commited on
Commit
a6719c7
·
verified ·
1 Parent(s): e8bbfee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -93
app.py CHANGED
@@ -1,101 +1,79 @@
 
1
  import streamlit as st
2
- from PIL import Image
3
  from transformers import pipeline
 
 
 
4
 
5
- # ----------------------------
6
- # 生成图像描述函数
7
- # ----------------------------
8
- def generate_caption(image_file):
9
- """
10
- 使用 Hugging Face pipeline 的 image-to-text 模型生成图片描述
11
- 参数:
12
- image_file: 上传的图片文件(文件对象或文件路径)
13
- 返回:
14
- caption: 生成的图片描述文本
15
- """
16
- # 打开图片(如果上传的是文件流,可以直接传给 pipeline)
17
- image = Image.open(image_file)
18
- # 利用 image-to-text pipeline 加载 Salesforce/blip-image-captioning-base 模型
19
- caption_generator = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
20
- # 直接将图片传入 pipeline,返回结果是一个列表,每个元素是一个字典
21
- caption_results = caption_generator(image)
22
- caption = caption_results[0]['generated_text'] # 取第一个结果
23
- return caption
24
 
25
- # ----------------------------
26
- # 基于图片描述生成完整故事的函数
27
- # ----------------------------
28
- def generate_story(caption):
29
- """
30
- 基于图片描述生成完整故事,确保生成的故事至少包含100个单词。
31
- 参数:
32
- caption: 图片描述文本
33
- 返回:
34
- story: 生成的故事文本
35
- """
36
- # 使用 text-generation pipeline 加载 GPT-2 模型
37
- story_generator = pipeline("text-generation", model="gpt2")
38
- # 构建生成故事的提示语
39
- prompt = f"Based on the following image caption: '{caption}', generate a complete fairy tale story for children with at least 100 words. "
40
-
41
- # 生成故事文本
42
- result = story_generator(prompt, max_length=300, num_return_sequences=1)
43
- story = result[0]['generated_text']
44
-
45
- # 简单检查生成的故事单词数是否达到100,否则再生成部分文本补充
46
- if len(story.split()) < 100:
47
- additional = story_generator(prompt, max_length=350, num_return_sequences=1)[0]['generated_text']
48
- story += " " + additional
49
- return story
50
 
51
- # ----------------------------
52
- # 文字转语音 (TTS) 函数
53
- # ----------------------------
54
- def text_to_speech(text, output_file="output.mp3"):
55
- """
56
- 将文本转换为语音并保存为 mp3 文件
57
- 参数:
58
- text: 要转换的文本
59
- output_file: 保存的音频文件名
60
- 返回:
61
- output_file: 转换后的音频文件路径
62
- """
63
- from gtts import gTTS
64
- # 这里语言参数设为英语 "en",
65
- # 如需中文可修改 lang="zh-cn",但对应文本生成模型也需生成中文
66
- tts = gTTS(text=text, lang="en")
67
- tts.save(output_file)
68
- return output_file
69
 
70
- # ----------------------------
71
- # 主函数:构建 Streamlit 界面
72
- # ----------------------------
73
- def main():
74
- st.title("儿童故事生成应用")
75
- st.write("上传一张图片,我们将根据图片生成有趣的故事,并转换成语音播放!")
76
-
77
- uploaded_file = st.file_uploader("选择一张图片", type=["png", "jpg", "jpeg"])
78
-
79
- if uploaded_file is not None:
80
- # 显示上传的图片
81
- image = Image.open(uploaded_file)
82
- st.image(image, caption="上传的图片", use_column_width=True)
83
-
84
- # 生成图片描述
85
- with st.spinner("正在生成图片描述..."):
86
- caption = generate_caption(uploaded_file)
87
- st.write("图片描述:", caption)
88
-
89
- # 根据图片描述生成完整故事
90
- with st.spinner("正在生成故事..."):
91
- story = generate_story(caption)
92
- st.write("生成的故事:")
93
- st.write(story)
94
 
95
- # 文本转语音
96
- with st.spinner("正在转换成语音..."):
97
- audio_file = text_to_speech(story)
98
- st.audio(audio_file, format="audio/mp3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- if __name__ == "__main__":
101
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import streamlit as st
 
3
  from transformers import pipeline
4
+ from PIL import Image
5
+ from gtts import gTTS
6
+ from io import BytesIO
7
 
8
+ # Set up the image captioning pipeline
9
+ @st.cache_resource
10
+ def get_image_captioner():
11
+ return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Set up the story generation pipeline
14
+ @st.cache_resource
15
+ def get_story_generator():
16
+ return pipeline("text-generation", model="Qwen/Qwen2.5-1.5B-Instruct", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Convert text to speech using gTTS
19
+ def text_to_speech(text):
20
+ tts = gTTS(text=text, lang='en')
21
+ audio_bytes = BytesIO()
22
+ tts.write_to_fp(audio_bytes)
23
+ audio_bytes.seek(0)
24
+ return audio_bytes
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Streamlit UI
27
+ st.title("📖 Kids' Storytelling App")
28
+ st.write("Upload an image and let the magic create a story!")
29
+
30
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
31
+
32
+ if uploaded_image is not None:
33
+ try:
34
+ # Open and preprocess the image
35
+ image = Image.open(uploaded_image).convert("RGB") # Ensure image is in RGB format
36
+ st.image(image, caption="Your Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ if st.button("Generate Story!"):
39
+ with st.spinner("Creating your story..."):
40
+ # Generate image caption
41
+ captioner = get_image_captioner()
42
+ caption_result = captioner(image)
43
+ caption = caption_result[0]['generated_text']
44
+ st.subheader("Image Caption")
45
+ st.write(caption)
46
+
47
+ # Generate story from caption
48
+ story_gen = get_story_generator()
49
+ prompt = f"Create a fun, children's story based on this: {caption}. The story must be at least 100 words, imaginative, and suitable for kids aged 3-10. Story:"
50
+ story_result = story_gen(
51
+ prompt,
52
+ max_length=100,
53
+ num_return_sequences=1,
54
+ temperature=0.9,
55
+ repetition_penalty=1.2
56
+ )
57
+
58
 
59
+ story = story_result[0]['generated_text']
60
+
61
+ # Ensure story meets word count
62
+ if len(story.split()) < 100:
63
+ story_result = story_gen(
64
+ prompt,
65
+ max_length=100,
66
+ num_return_sequences=1,
67
+ temperature=0.9
68
+ )
69
+ story = story_result[0]['generated_text']
70
+
71
+ st.subheader("Your Story")
72
+ st.write(story)
73
+
74
+ # Convert story to audio
75
+ audio_bytes = text_to_speech(story)
76
+ st.subheader("Listen to the Story")
77
+ st.audio(audio_bytes, format="audio/mp3")
78
+ except Exception as e:
79
+ st.error(f"An error occurred: {e}")