TLH01 commited on
Commit
8327766
·
verified ·
1 Parent(s): eaea916

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # Stage 1: Image to Text (Captioning)
7
  @st.cache_resource
8
  def load_image_caption_model():
9
  return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
@@ -13,21 +13,21 @@ def generate_caption(image):
13
  result = caption_model(image)
14
  return result[0]['generated_text']
15
 
16
- # Stage 2: Text to Story (Children-friendly)
17
  @st.cache_resource
18
  def load_story_generator():
19
- return pipeline("text2text-generation", model="google/flan-t5-small", max_length=100)
20
 
21
  def text2story(description):
22
  story_gen = load_story_generator()
23
- prompt = f"Generate a short children's story about {description}."
24
  story = story_gen(prompt)[0]['generated_text']
25
  return story
26
 
27
- # Stage 3: Story to Audio (Lightweight & Compatible)
28
  @st.cache_resource
29
  def load_tts():
30
- return pipeline("text-to-speech", model="coqui/tts-en-simply-tts")
31
 
32
  def story_to_audio(story_text):
33
  tts = load_tts()
@@ -40,7 +40,6 @@ def story_to_audio(story_text):
40
  def main():
41
  st.set_page_config(page_title="Kids Story Creator", layout="centered")
42
  st.title("🧒 Kids Story Creator 📖")
43
- st.write("Upload a picture and let us create a short story with voice for children aged 3–10!")
44
 
45
  uploaded_image = st.file_uploader("Upload an image (jpg/jpeg/png):", type=["jpg", "jpeg", "png"])
46
 
@@ -48,19 +47,18 @@ def main():
48
  image = Image.open(uploaded_image).convert("RGB")
49
  st.image(image, caption="Uploaded Image", use_column_width=True)
50
 
51
- if st.button("Generate Story"):
52
- with st.spinner("Step 1: Generating image description..."):
53
- caption = generate_caption(image)
54
- st.write(f"**Caption:** {caption}")
55
 
56
- with st.spinner("Step 2: Creating children's story..."):
57
- story = text2story(caption)
58
- st.write("**Generated Story:**")
59
- st.write(story)
60
 
61
- with st.spinner("Step 3: Generating audio..."):
62
- audio, sample_rate = story_to_audio(story)
63
- st.audio(audio, format="audio/wav", sample_rate=sample_rate)
64
 
65
  if __name__ == "__main__":
66
  main()
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # Stage 1: Image to Caption
7
  @st.cache_resource
8
  def load_image_caption_model():
9
  return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
13
  result = caption_model(image)
14
  return result[0]['generated_text']
15
 
16
+ # Stage 2: Caption to Story
17
  @st.cache_resource
18
  def load_story_generator():
19
+ return pipeline("text-generation", model="pranavpsv/gpt2-genre-story-generator", max_length=120)
20
 
21
  def text2story(description):
22
  story_gen = load_story_generator()
23
+ prompt = f"A short and fun children's story about {description}."
24
  story = story_gen(prompt)[0]['generated_text']
25
  return story
26
 
27
+ # Stage 3: Story to Speech
28
  @st.cache_resource
29
  def load_tts():
30
+ return pipeline("text-to-speech", model="espnet/kan-bayashi_ljspeech_vits", framework="espnet")
31
 
32
  def story_to_audio(story_text):
33
  tts = load_tts()
 
40
  def main():
41
  st.set_page_config(page_title="Kids Story Creator", layout="centered")
42
  st.title("🧒 Kids Story Creator 📖")
 
43
 
44
  uploaded_image = st.file_uploader("Upload an image (jpg/jpeg/png):", type=["jpg", "jpeg", "png"])
45
 
 
47
  image = Image.open(uploaded_image).convert("RGB")
48
  st.image(image, caption="Uploaded Image", use_column_width=True)
49
 
50
+ with st.spinner("Step 1: Generating description..."):
51
+ caption = generate_caption(image)
52
+ st.success(f"Caption: {caption}")
 
53
 
54
+ with st.spinner("Step 2: Generating a short story..."):
55
+ story = text2story(caption)
56
+ st.success("Here's your story:")
57
+ st.write(story)
58
 
59
+ with st.spinner("Step 3: Converting story to audio..."):
60
+ audio, sample_rate = story_to_audio(story)
61
+ st.audio(audio, format="audio/wav", sample_rate=sample_rate)
62
 
63
  if __name__ == "__main__":
64
  main()