TLH01 commited on
Commit
5e5ea3c
·
verified ·
1 Parent(s): 0311f5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -74
app.py CHANGED
@@ -1,39 +1,38 @@
1
- """
2
- 儿童故事生成器 (Children's Story Generator)
3
- 功能:上传图片 → 生成描述 → 创作故事 → 语音朗读
4
- """
5
-
6
- # ============ 导入模块 ============
7
  import streamlit as st
 
8
  from PIL import Image
9
  import tempfile
10
- from transformers import pipeline
11
  import torch
12
- import os
13
 
14
- # ============ 第一阶段:图片描述生成 ============
15
- @st.cache_resource # 缓存模型避免重复加载
 
 
16
  def load_image_captioner():
17
- """加载图片描述模型(BLIP模型)"""
18
  return pipeline(
19
  "image-to-text",
20
  model="Salesforce/blip-image-captioning-base",
21
- device="cuda" if torch.cuda.is_available() else "cpu" # 自动检测GPU
22
  )
23
 
24
  def generate_caption(_pipeline, image):
25
- """生成图片英文描述"""
26
  try:
27
- result = _pipeline(image, max_new_tokens=50) # 限制生成长度
28
  return result[0]['generated_text']
29
  except Exception as e:
30
- st.error(f"生成描述失败: {str(e)}")
31
  return None
32
 
33
- # ============ 第二阶段:故事创作 ============
 
 
34
  @st.cache_resource
35
  def load_story_generator():
36
- """加载儿童故事生成模型"""
37
  return pipeline(
38
  "text-generation",
39
  model="pranavpsv/gpt2-genre-story-generator",
@@ -41,10 +40,10 @@ def load_story_generator():
41
  )
42
 
43
  def generate_story(_pipeline, keywords):
44
- """根据关键词生成儿童故事"""
45
- prompt = f"""Generate a children's story (60-80 words) in English about: {keywords}
46
  Requirements:
47
- - Use simple words
48
  - Include magical elements
49
  - Happy ending
50
  Story:"""
@@ -53,17 +52,19 @@ def generate_story(_pipeline, keywords):
53
  story = _pipeline(
54
  prompt,
55
  max_length=200,
56
- temperature=0.7 # 控制创意程度
57
  )[0]['generated_text']
58
  return story.replace(prompt, "").strip()
59
  except Exception as e:
60
- st.error(f"生成故事失败: {str(e)}")
61
  return None
62
 
63
- # ============ 第三阶段:语音合成 ============
 
 
64
  @st.cache_resource
65
  def load_tts():
66
- """加载文本转语音模型"""
67
  return pipeline(
68
  "text-to-speech",
69
  model="facebook/mms-tts-eng",
@@ -71,74 +72,49 @@ def load_tts():
71
  )
72
 
73
  def text_to_speech(_pipeline, text):
74
- """将文本转为语音"""
75
  try:
76
  audio = _pipeline(text)
77
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
78
- import soundfile as sf
79
- sf.write(f.name, audio["audio"].squeeze().numpy(), audio["sampling_rate"])
80
  return f.name
81
  except Exception as e:
82
- st.error(f"语音生成失败: {str(e)}")
83
  return None
84
 
85
- # ============ 主界面 ============
86
  def main():
87
- # 界面设置
88
- st.set_page_config(
89
- page_title="魔法故事生成器",
90
- page_icon="🧚",
91
- layout="wide"
92
- )
93
-
94
- # 儿童风格CSS
95
- st.markdown("""
96
- <style>
97
- .main { background-color: #FFF5E6 }
98
- h1 { color: #FF6B6B; font-family: 'Comic Sans MS' }
99
- .stButton>button { background-color: #4CAF50; border-radius: 20px }
100
- </style>
101
- """, unsafe_allow_html=True)
102
-
103
- st.title("🧚 魔法故事生成器")
104
- st.write("上传小朋友的照片,AI会生成专属故事并朗读!")
105
-
106
- # 图片上传
107
- uploaded_file = st.file_uploader("选择照片", type=["jpg", "png"])
108
 
109
- if not uploaded_file:
110
- st.info("请先上传照片")
111
  return
112
-
113
- image = Image.open(uploaded_file)
114
- st.image(image, use_column_width=True)
115
 
116
- # 加载模型
117
- with st.spinner("正在准备魔法..."):
 
 
 
118
  caption_pipe = load_image_captioner()
119
  story_pipe = load_story_generator()
120
  tts_pipe = load_tts()
121
-
122
- # 第一阶段
123
- with st.spinner("正在分析图片..."):
124
  caption = generate_caption(caption_pipe, image)
125
  if caption:
126
- st.success(f"图片描述: {caption}")
127
-
128
- # 第二阶段
129
- if caption:
130
- with st.spinner("正在创作故事..."):
131
  story = generate_story(story_pipe, caption)
132
  if story:
133
- st.subheader("你的故事")
134
- st.markdown(f'<div style="background-color:#FFF0F5; padding:20px; border-radius:15px">{story}</div>', unsafe_allow_html=True)
135
-
136
- # 第三阶段
137
- with st.spinner("正在生成语音..."):
138
- audio_path = text_to_speech(tts_pipe, story)
139
- if audio_path:
140
- st.audio(audio_path, format="audio/wav")
141
 
142
  if __name__ == "__main__":
143
- os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" # 设置缓存路径
144
  main()
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from PIL import Image
4
  import tempfile
5
+ import numpy as np
6
  import torch
7
+ import soundfile as sf
8
 
9
+ # ======================
10
+ # Stage 1: Image Captioning
11
+ # ======================
12
+ @st.cache_resource
13
  def load_image_captioner():
14
+ """Load BLIP model for image caption generation"""
15
  return pipeline(
16
  "image-to-text",
17
  model="Salesforce/blip-image-captioning-base",
18
+ device="cuda" if torch.cuda.is_available() else "cpu"
19
  )
20
 
21
  def generate_caption(_pipeline, image):
22
+ """Generate English description from image"""
23
  try:
24
+ result = _pipeline(image, max_new_tokens=50)
25
  return result[0]['generated_text']
26
  except Exception as e:
27
+ st.error(f"Caption generation failed: {str(e)}")
28
  return None
29
 
30
+ # ======================
31
+ # Stage 2: Story Generation
32
+ # ======================
33
  @st.cache_resource
34
  def load_story_generator():
35
+ """Load fine-tuned story generator"""
36
  return pipeline(
37
  "text-generation",
38
  model="pranavpsv/gpt2-genre-story-generator",
 
40
  )
41
 
42
  def generate_story(_pipeline, keywords):
43
+ """Generate children's story based on keywords"""
44
+ prompt = f"""Generate a children's story (60-80 words) about: {keywords}
45
  Requirements:
46
+ - Use simple English
47
  - Include magical elements
48
  - Happy ending
49
  Story:"""
 
52
  story = _pipeline(
53
  prompt,
54
  max_length=200,
55
+ temperature=0.7
56
  )[0]['generated_text']
57
  return story.replace(prompt, "").strip()
58
  except Exception as e:
59
+ st.error(f"Story generation failed: {str(e)}")
60
  return None
61
 
62
+ # ======================
63
+ # Stage 3: Text-to-Speech
64
+ # ======================
65
  @st.cache_resource
66
  def load_tts():
67
+ """Load TTS model for audio generation"""
68
  return pipeline(
69
  "text-to-speech",
70
  model="facebook/mms-tts-eng",
 
72
  )
73
 
74
  def text_to_speech(_pipeline, text):
75
+ """Convert text to speech audio"""
76
  try:
77
  audio = _pipeline(text)
78
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
79
+ sf.write(f.name, audio["audio"], audio["sampling_rate"])
 
80
  return f.name
81
  except Exception as e:
82
+ st.error(f"Audio generation failed: {str(e)}")
83
  return None
84
 
85
+ # Main App
86
  def main():
87
+ st.set_page_config(page_title="Magic Story Generator", layout="wide")
88
+ st.title("🧚 Magic Story Generator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "png"])
91
+ if not uploaded_image:
92
  return
 
 
 
93
 
94
+ image = Image.open(uploaded_image)
95
+ st.image(image, use_container_width=True) # Fixed deprecated parameter
96
+
97
+ # Process stages
98
+ with st.spinner("Processing..."):
99
  caption_pipe = load_image_captioner()
100
  story_pipe = load_story_generator()
101
  tts_pipe = load_tts()
102
+
103
+ # Stage 1
 
104
  caption = generate_caption(caption_pipe, image)
105
  if caption:
106
+ st.success(f"Image description: {caption}")
107
+
108
+ # Stage 2
 
 
109
  story = generate_story(story_pipe, caption)
110
  if story:
111
+ st.subheader("Your Story")
112
+ st.markdown(f'<div class="story-box">{story}</div>', unsafe_allow_html=True)
113
+
114
+ # Stage 3
115
+ audio_path = text_to_speech(tts_pipe, story)
116
+ if audio_path:
117
+ st.audio(audio_path, format="audio/wav")
 
118
 
119
  if __name__ == "__main__":
 
120
  main()