TLH01 commited on
Commit
c460031
·
verified ·
1 Parent(s): 002777f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -66
app.py CHANGED
@@ -1,102 +1,133 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from transformers import BlipProcessor, BlipForConditionalGeneration
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
5
  from gtts import gTTS
6
  import io
 
 
 
 
 
7
 
8
  # ======================
9
- # Stage1: Image Captioning
10
  # ======================
11
  @st.cache_resource
12
- def load_stage1_model():
13
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
14
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
15
- return processor, model
 
 
 
 
 
 
16
 
17
  def stage1_generate_caption(uploaded_file):
18
- processor, model = load_stage1_model()
19
- img = Image.open(uploaded_file).convert("RGB")
20
- inputs = processor(images=img, return_tensors="pt", padding=True)
21
- outputs = model.generate(**inputs)
22
- return processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
23
 
24
  # ======================
25
- # Stage2: Story Generation
26
  # ======================
27
  @st.cache_resource
28
- def load_stage2_model():
29
- tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt-genre-story-generator")
30
- model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt-genre-story-generator")
31
- return tokenizer, model
 
 
 
 
 
 
32
 
33
  def stage2_generate_story(keyword):
34
- tokenizer, model = load_stage2_model()
35
-
36
- # 专业prompt模板
37
- prompt_template = f"""Generate a children's story in English with these elements:
38
- - Main theme: {keyword}
39
- - Characters: Friendly animals
40
- - Plot: Daily adventure
41
- - Moral lesson: Sharing is caring
42
- - Word count: 50-100 words
43
 
44
- Story: Once upon a time, there was a little rabbit named Fluffy who loved"""
 
 
 
 
45
 
46
- inputs = tokenizer(prompt_template, return_tensors="pt")
47
- outputs = model.generate(
48
- inputs.input_ids,
49
- max_length=300,
50
- temperature=0.85,
51
- top_k=50,
52
- repetition_penalty=1.2,
53
- num_return_sequences=1
54
- )
55
- full_story = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
 
57
- # 提取生成部分并标准化格式
58
- generated_part = full_story.replace(prompt_template, "").strip()
59
- return _format_story(generated_part)
60
-
61
- def _format_story(raw_text):
62
- # 后处理:添加段落结构
63
- sentences = raw_text.split(". ")
64
- return "\n\n".join([". ".join(sentences[i:i+3]) + "." for i in range(0, len(sentences), 3)])
 
 
 
 
 
 
 
65
 
66
  # ======================
67
- # Stage3: Text-to-Speech
68
  # ======================
69
- def stage3_generate_audio(story_text):
70
- tts = gTTS(text=story_text, lang='en')
71
- audio_buffer = io.BytesIO()
72
- tts.write_to_fp(audio_buffer)
73
- audio_buffer.seek(0)
74
- return audio_buffer
 
 
 
 
 
75
 
76
  # ======================
77
  # Main Application
78
  # ======================
79
  def main():
80
- st.title("📚 Smart Story Generator")
81
 
82
- uploaded_file = st.file_uploader("Upload children's photo", type=["jpg", "png"])
83
 
84
  if uploaded_file:
85
- # Stage1
86
  st.image(uploaded_file, use_container_width=True)
87
- caption = stage1_generate_caption(uploaded_file)
88
- st.write(f"✨ Detected Theme: **{caption}**")
 
89
 
90
- # Stage2
91
- story = stage2_generate_story(caption)
92
- st.subheader("Magic Story")
93
- st.write(story)
 
94
 
95
- # Stage3
96
- audio = stage3_generate_audio(story[:500]) # Limit for TTS
97
- st.audio(audio, format="audio/mp3")
98
- st.download_button("Download Story", story, "story.txt")
99
- st.download_button("Download Audio", audio.getvalue(), "story.mp3")
 
 
100
 
101
  if __name__ == "__main__":
102
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
+ from transformers import (
4
+ BlipProcessor,
5
+ BlipForConditionalGeneration,
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM
8
+ )
9
  from gtts import gTTS
10
  import io
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
 
17
  # ======================
18
+ # Stage 1: Image Captioning
19
  # ======================
20
  @st.cache_resource
21
+ def load_image_model():
22
+ """Load official Hugging Face image captioning model"""
23
+ try:
24
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
26
+ logger.info("Stage 1 model loaded")
27
+ return processor, model
28
+ except Exception as e:
29
+ st.error("❌ 图像模型加载失败,请检查网络连接")
30
+ raise
31
 
32
  def stage1_generate_caption(uploaded_file):
33
+ """Generate image caption"""
34
+ processor, model = load_image_model()
35
+ try:
36
+ img = Image.open(uploaded_file).convert("RGB")
37
+ img.thumbnail((512, 512)) # Resize for speed
38
+ inputs = processor(images=img, return_tensors="pt", padding=True)
39
+ outputs = model.generate(**inputs, max_length=30)
40
+ return processor.decode(outputs[0], skip_special_tokens=True)
41
+ except Exception as e:
42
+ st.error(f"图像处理失败: {str(e)}")
43
+ return "children playing"
44
 
45
  # ======================
46
+ # Stage 2: Story Generation
47
  # ======================
48
  @st.cache_resource
49
+ def load_story_model():
50
+ """Load Microsoft DialoGPT model"""
51
+ try:
52
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
53
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
54
+ logger.info("Stage 2 model loaded")
55
+ return tokenizer, model
56
+ except Exception as e:
57
+ st.error("❌ 故事模型加载失败,请检查模型名称")
58
+ raise
59
 
60
  def stage2_generate_story(keyword):
61
+ """Generate children's story"""
62
+ tokenizer, model = load_story_model()
 
 
 
 
 
 
 
63
 
64
+ # Optimized prompt template
65
+ prompt = f"""写一个儿童故事,包含以下要素:
66
+ - 主题: {keyword}
67
+ - 角色: 小动物
68
+ - 字数: 100字左右
69
 
70
+ 故事开头: 有一天,小熊嘟嘟在公园里发现"""
 
 
 
 
 
 
 
 
 
71
 
72
+ try:
73
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
74
+ outputs = model.generate(
75
+ inputs.input_ids,
76
+ max_length=300,
77
+ temperature=0.9,
78
+ top_k=50,
79
+ repetition_penalty=1.2,
80
+ pad_token_id=tokenizer.eos_token_id
81
+ )
82
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+ return full_text.replace(prompt, "").strip()
84
+ except Exception as e:
85
+ st.error(f"故事生成失败: {str(e)}")
86
+ return "小熊和朋友们玩得很开心!"
87
 
88
  # ======================
89
+ # Stage 3: Text-to-Speech
90
  # ======================
91
+ def stage3_generate_audio(text):
92
+ """Convert text to audio"""
93
+ try:
94
+ tts = gTTS(text=text[:300], lang='zh-CN') # Chinese support
95
+ audio_buffer = io.BytesIO()
96
+ tts.write_to_fp(audio_buffer)
97
+ audio_buffer.seek(0)
98
+ return audio_buffer
99
+ except Exception as e:
100
+ st.error(f"语音生成失败: {str(e)}")
101
+ return None
102
 
103
  # ======================
104
  # Main Application
105
  # ======================
106
  def main():
107
+ st.title("📚 智能故事生成器")
108
 
109
+ uploaded_file = st.file_uploader("上传儿童照片", type=["jpg", "png", "jpeg"])
110
 
111
  if uploaded_file:
112
+ # Stage 1
113
  st.image(uploaded_file, use_container_width=True)
114
+ with st.spinner("正在分析图片..."):
115
+ caption = stage1_generate_caption(uploaded_file)
116
+ st.write(f"✨ 识别主题: **{caption}**")
117
 
118
+ # Stage 2
119
+ with st.spinner("正在生成故事..."):
120
+ story = stage2_generate_story(caption)
121
+ st.subheader("生成故事")
122
+ st.write(story)
123
 
124
+ # Stage 3
125
+ if len(story) > 10: # Minimum length check
126
+ with st.spinner("正在生成语音..."):
127
+ audio = stage3_generate_audio(story)
128
+ if audio:
129
+ st.audio(audio, format="audio/mp3")
130
+ st.download_button("下载语音", audio.getvalue(), "story.mp3")
131
 
132
  if __name__ == "__main__":
133
  main()