TLH01 commited on
Commit
1a64058
·
verified ·
1 Parent(s): 3cc1c44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -4,42 +4,51 @@ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoToken
4
  import torch
5
  from TTS.api import TTS
6
 
7
- # Set page config
8
- st.set_page_config(page_title="Image Storytelling for Kids", layout="wide")
9
-
10
- st.title("🧒📖 AI Image Storytelling")
11
- st.write("Upload an image, and let AI generate a story with voice for children aged 3–10.")
12
 
13
  # Load models
14
  @st.cache_resource
15
  def load_models():
 
16
  vision_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
  processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
18
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
19
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
20
  return vision_model, processor, tokenizer, tts
21
 
22
- vision_model, processor, tokenizer, tts_model = load_models()
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Upload image
25
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
26
 
27
- if uploaded_file:
28
- image = Image.open(uploaded_file).convert("RGB")
29
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
30
 
31
- if st.button("Generate Story"):
32
- with st.spinner("Generating description..."):
33
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
34
- output_ids = vision_model.generate(pixel_values, max_length=50, num_beams=4)
35
- caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
36
- st.success("Image Description: " + caption)
37
 
38
- with st.spinner("Creating story..."):
39
- story_prompt = f"Tell a short, friendly children's story based on: {caption}"
40
- story = caption + " Once upon a time, " + caption.lower() + " went on an adventure and made new friends in a magical forest."
41
- st.success("Story: " + story)
42
 
43
- with st.spinner("Generating voice..."):
44
- tts_model.tts_to_file(text=story, file_path="story.wav")
45
- st.audio("story.wav", format="audio/wav")
 
4
  import torch
5
  from TTS.api import TTS
6
 
7
+ # Set page configuration
8
+ st.set_page_config(page_title="Children's Image Storytelling", layout="wide")
 
 
 
9
 
10
  # Load models
11
  @st.cache_resource
12
  def load_models():
13
+ # Load image captioning model
14
  vision_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
  processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
+ # Load text-to-speech model
18
  tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
19
  return vision_model, processor, tokenizer, tts
20
 
21
+ # Main function
22
+ def main():
23
+ # Display title
24
+ st.title("🧒📖 AI Image Storytelling")
25
+ st.write("Upload an image, and let AI generate a story for children aged 3–10 with voice narration.")
26
+
27
+ # Upload image
28
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
29
+
30
+ if uploaded_file:
31
+ image = Image.open(uploaded_file).convert("RGB")
32
+ st.image(image, caption="Uploaded Image", use_column_width=True)
33
 
34
+ if st.button("Generate Story"):
35
+ vision_model, processor, tokenizer, tts_model = load_models()
36
 
37
+ with st.spinner("Generating description..."):
38
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
39
+ output_ids = vision_model.generate(pixel_values, max_length=50, num_beams=4)
40
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
41
+ st.success("Image Description: " + caption)
42
 
43
+ with st.spinner("Generating story..."):
44
+ story_prompt = f"Based on the following description, tell me a short children's story: {caption}"
45
+ story = caption + " Once upon a time, " + caption.lower() + " entered a magical forest and met many new friends."
46
+ st.success("Story: " + story)
 
 
47
 
48
+ with st.spinner("Generating voice..."):
49
+ tts_model.tts_to_file(text=story, file_path="story.wav")
50
+ st.audio("story.wav", format="audio/wav")
 
51
 
52
+ # Run the main program
53
+ if __name__ == "__main__":
54
+ main()