TLH01 commited on
Commit
146cc47
·
verified ·
1 Parent(s): 23456e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -45
app.py CHANGED
@@ -1,54 +1,145 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
 
 
4
  import torch
5
- from TTS.api import TTS
 
 
 
6
 
7
- # Set page configuration
8
- st.set_page_config(page_title="Children's Image Storytelling", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load models
11
- @st.cache_resource
12
- def load_models():
13
- # Load image captioning model
14
- vision_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
- processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
- # Load text-to-speech model
18
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
19
- return vision_model, processor, tokenizer, tts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Main function
22
- def main():
23
- # Display title
24
- st.title("🧒📖 AI Image Storytelling")
25
- st.write("Upload an image, and let AI generate a story for children aged 3–10 with voice narration.")
26
-
27
- # Upload image
28
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
29
-
30
- if uploaded_file:
31
- image = Image.open(uploaded_file).convert("RGB")
32
- st.image(image, caption="Uploaded Image", use_column_width=True)
33
-
34
- if st.button("Generate Story"):
35
- vision_model, processor, tokenizer, tts_model = load_models()
 
 
 
 
 
36
 
37
- with st.spinner("Generating description..."):
38
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
39
- output_ids = vision_model.generate(pixel_values, max_length=50, num_beams=4)
40
- caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
41
- st.success("Image Description: " + caption)
42
-
43
- with st.spinner("Generating story..."):
44
- story_prompt = f"Based on the following description, tell me a short children's story: {caption}"
45
- story = caption + " Once upon a time, " + caption.lower() + " entered a magical forest and met many new friends."
46
- st.success("Story: " + story)
47
-
48
- with st.spinner("Generating voice..."):
49
- tts_model.tts_to_file(text=story, file_path="story.wav")
50
- st.audio("story.wav", format="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Run the main program
53
  if __name__ == "__main__":
54
- main()
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import requests
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
  import torch
7
+ import io
8
+ import soundfile as sf
9
+ from speechbrain.pretrained import Tacotron2
10
+ from speechbrain.pretrained import HIFIGAN
11
 
12
+ # Stage 1: Image to Keyword/Caption
13
+ def image_to_keyword(uploaded_image):
14
+ try:
15
+ # Load model
16
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
+
19
+ # Process image
20
+ raw_image = Image.open(uploaded_image).convert('RGB')
21
+ inputs = processor(raw_image, return_tensors="pt")
22
+
23
+ # Generate caption
24
+ out = model.generate(**inputs)
25
+ caption = processor.decode(out[0], skip_special_tokens=True)
26
+
27
+ return caption
28
+ except Exception as e:
29
+ st.error(f"Error in image captioning: {str(e)}")
30
+ return None
31
 
32
+ # Stage 2: Keyword to Story
33
+ def keyword_to_story(keyword):
34
+ try:
35
+ # Load model
36
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
37
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
38
+
39
+ # Create prompt
40
+ prompt = f"Write a short story between 50-100 words based on: {keyword}\n\nStory:"
41
+
42
+ # Generate story
43
+ inputs = tokenizer(prompt, return_tensors="pt")
44
+ outputs = model.generate(
45
+ inputs.input_ids,
46
+ max_length=200,
47
+ num_return_sequences=1,
48
+ no_repeat_ngram_size=2,
49
+ early_stopping=True
50
+ )
51
+
52
+ story = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+
54
+ # Clean up the story (remove prompt if it appears)
55
+ story = story.replace(prompt, "").strip()
56
+
57
+ # Ensure story length is between 50-100 words
58
+ words = story.split()
59
+ if len(words) > 100:
60
+ story = " ".join(words[:100])
61
+ elif len(words) < 50:
62
+ # If too short, try again with higher temperature
63
+ outputs = model.generate(
64
+ inputs.input_ids,
65
+ max_length=200,
66
+ num_return_sequences=1,
67
+ no_repeat_ngram_size=2,
68
+ do_sample=True,
69
+ temperature=0.9,
70
+ early_stopping=True
71
+ )
72
+ story = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+ story = story.replace(prompt, "").strip()
74
+
75
+ return story
76
+ except Exception as e:
77
+ st.error(f"Error in story generation: {str(e)}")
78
+ return None
79
 
80
+ # Stage 3: Story to Audio
81
+ def story_to_audio(story_text):
82
+ try:
83
+ # Initialize TTS
84
+ tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmp_tts")
85
+ hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="tmp_vocoder")
86
+
87
+ # Generate mel spectrogram and waveform
88
+ mel_output, mel_length, alignment = tacotron2.encode_text(story_text)
89
+ waveforms = hifi_gan.decode_batch(mel_output)
90
+
91
+ # Convert to bytes
92
+ audio_bytes = io.BytesIO()
93
+ sf.write(audio_bytes, waveforms.squeeze(1).cpu().numpy(), 22050, format='WAV')
94
+ audio_bytes.seek(0)
95
+
96
+ return audio_bytes
97
+ except Exception as e:
98
+ st.error(f"Error in audio generation: {str(e)}")
99
+ return None
100
 
101
+ # Main App Function
102
+ def main():
103
+ st.title("Image to Story Generator")
104
+ st.write("Upload an image to generate a story and audio narration")
105
+
106
+ # File uploader
107
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
108
+
109
+ if uploaded_file is not None:
110
+ # Display image
111
+ image = Image.open(uploaded_file)
112
+ st.image(image, caption='Uploaded Image', use_column_width=True)
113
+
114
+ # Stage 1: Image to Keyword
115
+ st.write("Generating caption from image...")
116
+ caption = image_to_keyword(uploaded_file)
117
+
118
+ if caption:
119
+ st.success(f"Generated Caption: {caption}")
120
+
121
+ # Stage 2: Keyword to Story
122
+ st.write("Generating story from caption...")
123
+ story = keyword_to_story(caption)
124
+
125
+ if story:
126
+ st.subheader("Generated Story")
127
+ st.write(story)
128
+
129
+ # Stage 3: Story to Audio
130
+ st.write("Converting story to audio...")
131
+ audio_bytes = story_to_audio(story)
132
+
133
+ if audio_bytes:
134
+ st.audio(audio_bytes, format='audio/wav')
135
+
136
+ # Download button for audio
137
+ st.download_button(
138
+ label="Download Audio",
139
+ data=audio_bytes,
140
+ file_name="generated_story.wav",
141
+ mime="audio/wav"
142
+ )
143
 
 
144
  if __name__ == "__main__":
145
+ main()