scmlewis commited on
Commit
7af7299
·
verified ·
1 Parent(s): 3ddf191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -42
app.py CHANGED
@@ -1,59 +1,84 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
- import time
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def generate_caption(image):
7
- image_to_text = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
8
- caption = image_to_text(image)[0]["generated_text"]
9
- return caption
 
 
 
 
10
 
11
  def generate_story(caption):
12
- pipe = pipeline("text-generation", model="pranavpsv/genre-story-generator-v2")
13
- story = pipe(caption)[0]['generated_text']
14
- return story
 
 
 
 
 
15
 
16
  def generate_audio(story):
17
- pipe = pipeline("text-to-speech", model="facebook/mms-tts-eng")
18
- audio = pipe(story)
19
- return audio
 
 
 
 
20
 
21
  # Streamlit UI
22
-
23
- # Title of the Streamlit app
24
- st.title("Upload your image for an instant storytelling!")
25
-
26
- # Write a description
27
- st.write("This app is designed for 3-10 year-old kids by allowing them uploading image for fun storytelling entertainment.")
28
 
29
  # File uploader for image
30
- uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
31
 
32
  if uploaded_file is not None:
33
  # Display the uploaded image
34
  image = Image.open(uploaded_file)
35
- st.image(image, caption="Uploaded Image", use_container_width=True)
36
-
37
- # Generate Image Caption
38
- image_caption = generate_caption(image)
39
-
40
- # Display results
41
- st.subheader("Image Caption:")
42
- st.write(f"{image_caption}")
43
-
44
- # Generate Story
45
- story_telling = generate_story(image_caption)
46
-
47
- # Display results
48
- st.subheader("Story:")
49
- st.write(f"{story_telling}")
50
-
51
- # Generate Audio
52
- audio = generate_audio(story_telling)
53
-
54
- # Display an audio file with a spinner effect
55
- if st.button("Play Audio"):
56
- st.audio(audio['audio'],
57
- format="audio/wav",
58
- start_time=0,
59
- sample_rate = audio['sampling_rate'])
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
 
4
 
5
+ # Cache model loading for performance
6
+ @st.cache_resource
7
+ def load_image_to_text():
8
+ return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
9
+
10
+ @st.cache_resource
11
+ def load_story_generator():
12
+ return pipeline("text-generation", model="pranavpsv/genre-story-generator-v2")
13
+
14
+ @st.cache_resource
15
+ def load_tts():
16
+ return pipeline("text-to-speech", model="facebook/mms-tts-eng")
17
+
18
+ # Generation functions with error handling
19
  def generate_caption(image):
20
+ image_to_text = load_image_to_text()
21
+ try:
22
+ caption = image_to_text(image)[0]["generated_text"]
23
+ return caption
24
+ except Exception as e:
25
+ st.error(f"Oops! Something went wrong while generating the caption: {e}")
26
+ return None
27
 
28
  def generate_story(caption):
29
+ story_generator = load_story_generator()
30
+ try:
31
+ prompt = f"Once upon a time, there was {caption}. "
32
+ story = story_generator(prompt, max_length=200, do_sample=True, temperature=0.7)[0]['generated_text']
33
+ return story
34
+ except Exception as e:
35
+ st.error(f"Oops! Something went wrong while generating the story: {e}")
36
+ return None
37
 
38
  def generate_audio(story):
39
+ tts = load_tts()
40
+ try:
41
+ audio = tts(story)
42
+ return audio
43
+ except Exception as e:
44
+ st.error(f"Oops! Something went wrong while generating the audio: {e}")
45
+ return None
46
 
47
  # Streamlit UI
48
+ st.markdown("<h1 style='text-align: center; color: blue;'>📸✨ Storyteller for Kids! ✨📸</h1>", unsafe_allow_html=True)
49
+ st.markdown("<p style='text-align: center;'>Upload a fun picture and I’ll tell you a magical story about it!</p>", unsafe_allow_html=True)
 
 
 
 
50
 
51
  # File uploader for image
52
+ uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
53
 
54
  if uploaded_file is not None:
55
  # Display the uploaded image
56
  image = Image.open(uploaded_file)
57
+ st.image(image, caption="Your Uploaded Image", use_column_width=True)
58
+
59
+ # Generate and display caption with spinner
60
+ with st.spinner("Creating a caption..."):
61
+ generated_caption = generate_caption(image)
62
+
63
+ if generated_caption:
64
+ # Allow user to edit the caption
65
+ caption_input = st.text_area("Caption:", value=generated_caption, height=100)
66
+ st.write("Feel free to change the caption to make your own story!")
67
+
68
+ # Generate story when button is clicked
69
+ if st.button("Generate Story"):
70
+ with st.spinner("Writing a magical story..."):
71
+ story = generate_story(caption_input)
72
+
73
+ if story:
74
+ st.subheader("Your Story:")
75
+ st.write(story)
76
+ st.download_button("Download Story", story, file_name="my_story.txt")
77
+
78
+ # Generate audio on demand
79
+ if st.button("Generate Audio"):
80
+ with st.spinner("Turning your story into sound..."):
81
+ audio = generate_audio(story)
82
+ if audio:
83
+ st.audio(audio['audio'], format="audio/wav", start_time=0, sample_rate=audio['sampling_rate'])
84
+ st.download_button("Download Audio", audio['audio'], file_name="my_story.wav", mime="audio/wav")