TLH01 commited on
Commit
504a753
·
verified ·
1 Parent(s): 48b9452

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -32
app.py CHANGED
@@ -1,52 +1,102 @@
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  from gtts import gTTS
6
  import io
7
 
 
 
 
8
  @st.cache_resource
9
- def load_models():
10
- return (
11
- BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
12
- BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base"),
13
- GPT2Tokenizer.from_pretrained("gpt2"),
14
- GPT2LMHeadModel.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def main():
18
- st.title("Stable Story Maker")
19
-
20
- img_processor, img_model, text_tokenizer, text_model = load_models()
21
 
22
- uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
23
 
24
  if uploaded_file:
 
25
  st.image(uploaded_file, use_container_width=True)
 
 
26
 
27
- with st.status("Processing"):
28
- # Stage 1
29
- img = Image.open(uploaded_file).convert("RGB")
30
- inputs = img_processor(images=img, return_tensors="pt")
31
- caption = img_processor.decode(img_model.generate(**inputs)[0], skip_special_tokens=True)
32
-
33
- # Stage 2
34
- prompt = f"Children's story about {caption}:"
35
- inputs = text_tokenizer(prompt, return_tensors="pt")
36
- story = text_tokenizer.decode(
37
- text_model.generate(inputs.input_ids, max_length=200)[0],
38
- skip_special_tokens=True
39
- ).replace(prompt, "")
40
-
41
- # Stage 3
42
- tts = gTTS(text=story[:250], lang='en')
43
- audio = io.BytesIO()
44
- tts.write_to_fp(audio)
45
- audio.seek(0)
46
 
47
- st.write(f"**Caption:** {caption}")
48
- st.write(f"**Story:** {story}")
49
  st.audio(audio, format="audio/mp3")
 
 
50
 
51
  if __name__ == "__main__":
52
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from gtts import gTTS
6
  import io
7
 
8
+ # ======================
9
+ # Stage1: Image Captioning
10
+ # ======================
11
  @st.cache_resource
12
+ def load_stage1_model():
13
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
14
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
15
+ return processor, model
16
+
17
+ def stage1_generate_caption(uploaded_file):
18
+ processor, model = load_stage1_model()
19
+ img = Image.open(uploaded_file).convert("RGB")
20
+ inputs = processor(images=img, return_tensors="pt", padding=True)
21
+ outputs = model.generate(**inputs)
22
+ return processor.decode(outputs[0], skip_special_tokens=True)
23
+
24
+ # ======================
25
+ # Stage2: Story Generation
26
+ # ======================
27
+ @st.cache_resource
28
+ def load_stage2_model():
29
+ tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt-genre-story-generator")
30
+ model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt-genre-story-generator")
31
+ return tokenizer, model
32
+
33
+ def stage2_generate_story(keyword):
34
+ tokenizer, model = load_stage2_model()
35
+
36
+ # 专业prompt模板
37
+ prompt_template = f"""Generate a children's story in English with these elements:
38
+ - Main theme: {keyword}
39
+ - Characters: Friendly animals
40
+ - Plot: Daily adventure
41
+ - Moral lesson: Sharing is caring
42
+ - Word count: 50-100 words
43
+
44
+ Story: Once upon a time, there was a little rabbit named Fluffy who loved"""
45
+
46
+ inputs = tokenizer(prompt_template, return_tensors="pt")
47
+ outputs = model.generate(
48
+ inputs.input_ids,
49
+ max_length=300,
50
+ temperature=0.85,
51
+ top_k=50,
52
+ repetition_penalty=1.2,
53
+ num_return_sequences=1
54
  )
55
+ full_story = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+
57
+ # 提取生成部分并标准化格式
58
+ generated_part = full_story.replace(prompt_template, "").strip()
59
+ return _format_story(generated_part)
60
+
61
+ def _format_story(raw_text):
62
+ # 后处理:添加段落结构
63
+ sentences = raw_text.split(". ")
64
+ return "\n\n".join([". ".join(sentences[i:i+3]) + "." for i in range(0, len(sentences), 3)])
65
 
66
+ # ======================
67
+ # Stage3: Text-to-Speech
68
+ # ======================
69
+ def stage3_generate_audio(story_text):
70
+ tts = gTTS(text=story_text, lang='en')
71
+ audio_buffer = io.BytesIO()
72
+ tts.write_to_fp(audio_buffer)
73
+ audio_buffer.seek(0)
74
+ return audio_buffer
75
+
76
+ # ======================
77
+ # Main Application
78
+ # ======================
79
  def main():
80
+ st.title("📚 Smart Story Generator")
 
 
81
 
82
+ uploaded_file = st.file_uploader("Upload children's photo", type=["jpg", "png"])
83
 
84
  if uploaded_file:
85
+ # Stage1
86
  st.image(uploaded_file, use_container_width=True)
87
+ caption = stage1_generate_caption(uploaded_file)
88
+ st.write(f"✨ Detected Theme: **{caption}**")
89
 
90
+ # Stage2
91
+ story = stage2_generate_story(caption)
92
+ st.subheader("Magic Story")
93
+ st.write(story)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Stage3
96
+ audio = stage3_generate_audio(story[:500]) # Limit for TTS
97
  st.audio(audio, format="audio/mp3")
98
+ st.download_button("Download Story", story, "story.txt")
99
+ st.download_button("Download Audio", audio.getvalue(), "story.mp3")
100
 
101
  if __name__ == "__main__":
102
  main()