TLH01 commited on
Commit
fb5ea01
·
verified ·
1 Parent(s): f096369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -87
app.py CHANGED
@@ -1,105 +1,51 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
- import tempfile
5
- import torch
6
- from TTS.api import TTS # Coqui TTS
7
  import os
 
 
8
 
9
- # ======================
10
- # Stage 1: Image Captioning
11
- # ======================
12
- @st.cache_resource
13
- def load_image_captioner():
14
- return pipeline(
15
- "image-to-text",
16
- model="Salesforce/blip-image-captioning-base",
17
- device="cuda" if torch.cuda.is_available() else "cpu"
18
- )
19
-
20
- def generate_caption(_pipeline, image):
21
- try:
22
- result = _pipeline(image, max_new_tokens=50)
23
- return result[0]['generated_text']
24
- except Exception as e:
25
- st.error(f"Caption generation failed: {str(e)}")
26
- return None
27
-
28
- # ======================
29
- # Stage 2: Story Generation
30
- # ======================
31
  @st.cache_resource
32
- def load_story_generator():
33
- return pipeline(
34
- "text-generation",
35
- model="pranavpsv/gpt2-genre-story-generator", # 可以替换为更强模型
36
- device="cuda" if torch.cuda.is_available() else "cpu"
37
- )
38
-
39
- def generate_story(_pipeline, caption):
40
- prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max).
41
- The story should:
42
- - Use simple and friendly language
43
- - Be related to the content of the image
44
- - Include a magical or fun twist
45
- - End happily
46
 
47
- Story:"""
48
-
49
- try:
50
- story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text']
51
- return story.replace(prompt, "").strip()
52
- except Exception as e:
53
- st.error(f"Story generation failed: {str(e)}")
54
- return None
55
-
56
- # ======================
57
- # Stage 3: Text-to-Speech using Coqui TTS
58
- # ======================
59
- @st.cache_resource
60
- def load_tts():
61
- return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
62
 
63
- def text_to_speech(tts_model, story_text):
64
- try:
65
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
66
- tts_model.tts_to_file(text=story_text, file_path=f.name)
67
- return f.name
68
- except Exception as e:
69
- st.error(f"Audio generation failed: {str(e)}")
70
- return None
71
 
72
- # ======================
73
- # Main Streamlit App
74
- # ======================
75
- def main():
76
- st.set_page_config(page_title="Magic Story Generator", layout="wide")
77
- st.title("🧚 Magic Story Generator")
78
 
79
- uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"])
80
- if not uploaded_image:
81
- return
82
 
83
- image = Image.open(uploaded_image)
84
- st.image(image, use_container_width=True)
 
85
 
86
- with st.spinner("Processing your magical story..."):
87
- caption_pipe = load_image_captioner()
88
- story_pipe = load_story_generator()
89
- tts_model = load_tts()
90
 
91
- caption = generate_caption(caption_pipe, image)
92
- if caption:
93
- st.success(f"Image description: {caption}")
94
- story = generate_story(story_pipe, caption)
95
 
96
- if story:
97
- st.subheader("Your Magical Story")
98
- st.markdown(story)
 
99
 
100
- audio_path = text_to_speech(tts_model, story)
101
- if audio_path:
102
- st.audio(audio_path, format="audio/wav")
103
 
 
104
  if __name__ == "__main__":
105
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import requests
 
 
5
  import os
6
+ from io import BytesIO
7
+ import tempfile
8
 
9
+ # Load Hugging Face pipelines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @st.cache_resource
11
+ def load_pipelines():
12
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
13
+ story_generator = pipeline("text-generation", model="Tevatron/tiny-stories-generator", max_length=200)
14
+ text_to_speech = pipeline("text-to-speech", model="espnet/kan-bayashi_ljspeech_vits", framework="pt")
15
+ return image_captioner, story_generator, text_to_speech
 
 
 
 
 
 
 
 
 
16
 
17
+ # Define main interface
18
+ def main():
19
+ st.set_page_config(page_title="Kids Story Maker 🧒📖", page_icon="📸", layout="centered")
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Child-friendly header
22
+ st.markdown("<h1 style='color:#FF69B4; font-family:Comic Sans MS;'>🎨 Welcome to Kids Story Maker!</h1>", unsafe_allow_html=True)
23
+ st.markdown("<h3 style='color:#4CAF50;'>Upload a picture, and we'll turn it into a magical story and voice! 🐻✨</h3>", unsafe_allow_html=True)
 
 
 
 
 
24
 
25
+ image_captioner, story_generator, text_to_speech = load_pipelines()
 
 
 
 
 
26
 
27
+ uploaded_file = st.file_uploader("🖼️ Upload an image:", type=["png", "jpg", "jpeg"])
 
 
28
 
29
+ if uploaded_file:
30
+ image = Image.open(uploaded_file)
31
+ st.image(image, caption="Your Image", use_column_width=True)
32
 
33
+ with st.spinner("🔍 Generating a description..."):
34
+ caption = image_captioner(image)[0]['generated_text']
35
+ st.success(f"📝 Description: {caption}")
 
36
 
37
+ # Prompt template for storytelling
38
+ story_prompt = f"Write a short story for children aged 3 to 10 based on this description: {caption}. The story should be creative, friendly, and use simple words."
 
 
39
 
40
+ with st.spinner("✍️ Creating your story..."):
41
+ story = story_generator(story_prompt)[0]['generated_text']
42
+ st.success("📖 Here's your story:")
43
+ st.write(story)
44
 
45
+ with st.spinner("🔊 Turning your story into voice..."):
46
+ speech = text_to_speech(story)[0]['audio']
47
+ st.audio(speech, format="audio/wav")
48
 
49
+ # Run the app
50
  if __name__ == "__main__":
51
  main()