aeresd commited on
Commit
1634b47
·
verified ·
1 Parent(s): 80ebd89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -54
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import streamlit as st
2
  from transformers import pipeline, AutoTokenizer
3
  import torch
4
  import re
@@ -8,36 +8,42 @@ from PIL import Image
8
  from datasets import load_dataset
9
  import logging
10
 
11
- # 配置日志系统
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- # ==================== 模型缓存加载 ====================
16
  @st.cache_resource(show_spinner=False)
17
  def load_models():
18
- """预加载所有模型并缓存"""
19
- logger.info("Loading caption model...")
20
- caption_model = pipeline("image-to-text",
21
- model="Salesforce/blip-image-captioning-base",
22
- device=0 if torch.cuda.is_available() else -1)
 
 
23
 
24
- logger.info("Loading story model...")
25
  story_model = pipeline(
26
- "text-generation",
27
  model="Tincando/fiction_story_generator",
28
  device=0 if torch.cuda.is_available() else -1,
29
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
30
  )
31
 
32
- logger.info("Loading TTS model...")
33
- tts_model = pipeline("text-to-audio",
34
- model="Chan-Y/speecht5_finetuned_tr_commonvoice",
35
- device=0 if torch.cuda.is_available() else -1)
36
- tts_tokenizer = AutoTokenizer.from_pretrained("Chan-Y/speecht5_finetuned_tr_commonvoice")
 
 
 
 
37
 
38
  return caption_model, story_model, tts_model, tts_tokenizer
39
 
40
- # ==================== Streamlit 界面配置 ====================
41
  st.set_page_config(
42
  page_title="🧸 AI Story Generator Pro",
43
  page_icon="📖",
@@ -45,65 +51,65 @@ st.set_page_config(
45
  initial_sidebar_state="expanded"
46
  )
47
 
48
- # ==================== 侧边栏参数设置 ====================
49
  with st.sidebar:
50
- st.title("⚙️ 生成参数")
51
- temperature = st.slider("创意度", 0.5, 1.5, 0.85, step=0.05)
52
- max_length = st.slider("故事长度", 100, 500, 200)
53
- story_style = st.selectbox("故事风格", ["童话", "科幻", "冒险"])
54
- voice_speed = st.slider("语音速度", 0.5, 2.0, 1.0)
55
 
56
- # ==================== 主界面 ====================
57
- st.title("🖼️ AI 智能故事生成器")
58
- st.write("上传图片即可获得定制化故事与语音朗读")
59
 
60
- # ==================== 文件上传 ====================
61
- uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "jpeg", "png"])
62
 
63
  if uploaded_file:
64
- # ==================== 图像处理 ====================
65
  col1, col2 = st.columns([1, 2])
66
  with col1:
67
  image = Image.open(uploaded_file)
68
- st.image(image, caption="上传图片", use_column_width=True)
69
 
70
- # ==================== 生成流程 ====================
71
- if st.button("开始生成", type="primary"):
72
  try:
73
  progress_bar = st.progress(0)
74
  status_text = st.empty()
75
 
76
- # 加载模型
77
- with st.spinner("🔄 正在加载模型..."):
78
  caption_model, story_model, tts_model, tts_tokenizer = load_models()
79
  speaker_emb = torch.tensor(
80
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
81
  ).unsqueeze(0)
82
  progress_bar.progress(20)
83
 
84
- # 图像描述生成
85
- with st.spinner("📷 正在分析图片内容..."):
86
  caption_result = caption_model(image)
87
  caption = caption_result[0]['generated_text']
88
- progress_bar.progress(40)
89
 
90
- # 故事生成
91
- with st.spinner("✍️ 正在创作精彩故事..."):
92
- prompt = f"{story_style}风格创作儿童故事,主题:{caption}"
93
  story = story_model(
94
  prompt,
95
  temperature=temperature,
96
  max_length=max_length,
97
  do_sample=True
98
  )[0]['generated_text']
99
- story = re.sub(r'[^.!?]+$', '', story) # 确保完整结尾
100
- progress_bar.progress(70)
 
101
 
102
- # 语音合成
103
- with st.spinner("🔊 正在生成语音..."):
104
  chunks = re.split(r'(?<=[.!?]) +', story)
105
  audio_arrays = []
106
-
107
  for chunk in chunks:
108
  inputs = tts_tokenizer(chunk, return_tensors="pt")
109
  speech = tts_model.generate(
@@ -114,33 +120,32 @@ if uploaded_file:
114
  }
115
  )
116
  audio_arrays.append(speech.numpy())
117
-
118
  combined = np.concatenate(audio_arrays)
119
  sf.write("output.wav", combined, samplerate=16000)
120
- progress_bar.progress(100)
121
 
122
- # ==================== 结果展示 ====================
123
  with col2:
124
- st.subheader("📖 生成故事")
125
  st.success(story)
126
 
127
- st.subheader("🔊 语音朗读")
128
  st.audio("output.wav", format="audio/wav")
129
 
130
- # 下载功能
131
  st.download_button(
132
- label="下载故事文本",
133
  data=story,
134
  file_name="generated_story.txt",
135
  mime="text/plain"
136
  )
137
  st.download_button(
138
- label="下载语音文件",
139
  data=open("output.wav", "rb"),
140
  file_name="story_audio.wav",
141
  mime="audio/wav"
142
  )
143
 
144
  except Exception as e:
145
- st.error(f"生成失败:{str(e)}")
146
- st.button("重试", on_click=st.cache_resource.clear)
 
1
+ import streamlit as st
2
  from transformers import pipeline, AutoTokenizer
3
  import torch
4
  import re
 
8
  from datasets import load_dataset
9
  import logging
10
 
11
+ # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
+ # ==================== Model loading with caching ====================
16
  @st.cache_resource(show_spinner=False)
17
  def load_models():
18
+ """Pre-load and cache all models"""
19
+ logger.info("Loading image captioning model...")
20
+ caption_model = pipeline(
21
+ task="image-to-text",
22
+ model="Salesforce/blip-image-captioning-base",
23
+ device=0 if torch.cuda.is_available() else -1
24
+ )
25
 
26
+ logger.info("Loading story generation model...")
27
  story_model = pipeline(
28
+ task="text-generation",
29
  model="Tincando/fiction_story_generator",
30
  device=0 if torch.cuda.is_available() else -1,
31
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
32
  )
33
 
34
+ logger.info("Loading text-to-speech model...")
35
+ tts_model = pipeline(
36
+ task="text-to-audio",
37
+ model="Chan-Y/speecht5_finetuned_tr_commonvoice",
38
+ device=0 if torch.cuda.is_available() else -1
39
+ )
40
+ tts_tokenizer = AutoTokenizer.from_pretrained(
41
+ "Chan-Y/speecht5_finetuned_tr_commonvoice"
42
+ )
43
 
44
  return caption_model, story_model, tts_model, tts_tokenizer
45
 
46
+ # ==================== Streamlit page configuration ====================
47
  st.set_page_config(
48
  page_title="🧸 AI Story Generator Pro",
49
  page_icon="📖",
 
51
  initial_sidebar_state="expanded"
52
  )
53
 
54
+ # ==================== Sidebar settings ====================
55
  with st.sidebar:
56
+ st.title("⚙️ Generation Settings")
57
+ temperature = st.slider("Creativity", 0.5, 1.5, 0.85, step=0.05)
58
+ max_length = st.slider("Story Length", 100, 500, 200)
59
+ story_style = st.selectbox("Story Style", ["Fairy Tale", "Sci-Fi", "Adventure"])
60
+ voice_speed = st.slider("Voice Speed", 0.5, 2.0, 1.0)
61
 
62
+ # ==================== Main interface ====================
63
+ st.title("🖼️ AI Story Generator")
64
+ st.write("Upload an image to get a customized story with audio narration.")
65
 
66
+ # ==================== File upload ====================
67
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
68
 
69
  if uploaded_file:
70
+ # ==================== Image display ====================
71
  col1, col2 = st.columns([1, 2])
72
  with col1:
73
  image = Image.open(uploaded_file)
74
+ st.image(image, caption="Uploaded Image", use_column_width=True)
75
 
76
+ # ==================== Generation process ====================
77
+ if st.button("Generate Story", type="primary"):
78
  try:
79
  progress_bar = st.progress(0)
80
  status_text = st.empty()
81
 
82
+ # Load models
83
+ with st.spinner("🔄 Loading models..."):
84
  caption_model, story_model, tts_model, tts_tokenizer = load_models()
85
  speaker_emb = torch.tensor(
86
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
87
  ).unsqueeze(0)
88
  progress_bar.progress(20)
89
 
90
+ # Generate image caption
91
+ with st.spinner("📷 Analyzing image content..."):
92
  caption_result = caption_model(image)
93
  caption = caption_result[0]['generated_text']
94
+ progress_bar.progress(40)
95
 
96
+ # Generate story
97
+ with st.spinner("✍️ Writing the story..."):
98
+ prompt = f"Write a children's story in {story_style} style about: {caption}"
99
  story = story_model(
100
  prompt,
101
  temperature=temperature,
102
  max_length=max_length,
103
  do_sample=True
104
  )[0]['generated_text']
105
+ # Ensure story ends with punctuation
106
+ story = re.sub(r'[^.!?]+$', '', story)
107
+ progress_bar.progress(70)
108
 
109
+ # Text-to-speech synthesis
110
+ with st.spinner("🔊 Generating audio..."):
111
  chunks = re.split(r'(?<=[.!?]) +', story)
112
  audio_arrays = []
 
113
  for chunk in chunks:
114
  inputs = tts_tokenizer(chunk, return_tensors="pt")
115
  speech = tts_model.generate(
 
120
  }
121
  )
122
  audio_arrays.append(speech.numpy())
 
123
  combined = np.concatenate(audio_arrays)
124
  sf.write("output.wav", combined, samplerate=16000)
125
+ progress_bar.progress(100)
126
 
127
+ # ==================== Display results ====================
128
  with col2:
129
+ st.subheader("📖 Generated Story")
130
  st.success(story)
131
 
132
+ st.subheader("🔊 Audio Narration")
133
  st.audio("output.wav", format="audio/wav")
134
 
135
+ # Download buttons
136
  st.download_button(
137
+ label="Download Story Text",
138
  data=story,
139
  file_name="generated_story.txt",
140
  mime="text/plain"
141
  )
142
  st.download_button(
143
+ label="Download Audio File",
144
  data=open("output.wav", "rb"),
145
  file_name="story_audio.wav",
146
  mime="audio/wav"
147
  )
148
 
149
  except Exception as e:
150
+ st.error(f"Generation failed: {str(e)}")
151
+ st.button("Retry", on_click=st.cache_resource.clear)