import streamlit as st from PIL import Image from transformers import ( BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM ) from gtts import gTTS import io import torch # ====================== # Stage 1: Image Captioning # ====================== @st.cache_resource def load_image_model(): """Load image captioning model""" return ( BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"), BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") ) def stage1_process(uploaded_file): """Generate image caption""" processor, model = load_image_model() img = Image.open(uploaded_file).convert("RGB") inputs = processor(images=img, return_tensors="pt") outputs = model.generate(**inputs) return processor.decode(outputs[0], skip_special_tokens=True) # ====================== # Stage 2: Story Generation (Optimized) # ====================== @st.cache_resource def load_story_model(): """Load optimized story model""" return ( AutoTokenizer.from_pretrained("gpt2-medium"), AutoModelForCausalLM.from_pretrained("gpt2-medium") ) def stage2_process(keyword): """Generate structured story""" tokenizer, model = load_story_model() # Enhanced prompt template prompt = f"""Write a children's story in 100-150 words with these elements: - Theme: {keyword} - Characters: Friendly animals - Moral: Sharing is caring Story begins: One sunny morning, a little rabbit named Cotton discovered""" inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True) outputs = model.generate( inputs.input_ids, max_new_tokens=300, temperature=0.9, top_k=50, no_repeat_ngram_size=3, repetition_penalty=1.2, do_sample=True ) full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return full_text.split("Story begins:")[-1].strip() # ====================== # Stage 3: Text-to-Speech # ====================== def stage3_process(text): """Convert text to audio""" try: clean_text = text.strip().replace('\n', ' ')[:300] if len(clean_text) < 20: return None tts = gTTS(text=clean_text, lang='en') audio = io.BytesIO() tts.write_to_fp(audio) audio.seek(0) return audio except: return None # ====================== # Main Application # ====================== def main(): st.title("📖 Children's Story Generator") # Initialize session state if 'processing' not in st.session_state: st.session_state.update({ 'caption': None, 'story': None, 'audio': None }) # File upload uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"]) if uploaded_file: # Permanent display st.image(uploaded_file, width=300) # Stage 1 if not st.session_state.caption: with st.spinner("Analyzing image..."): st.session_state.caption = stage1_process(uploaded_file) st.success(f"Detected Theme: {st.session_state.caption}") # Stage 2 if not st.session_state.story: with st.spinner("Writing magical story..."): st.session_state.story = stage2_process(st.session_state.caption) # Display story if st.session_state.story: st.subheader("Generated Story") st.write(st.session_state.story) # Stage 3 if not st.session_state.audio: with st.spinner("Generating audio..."): st.session_state.audio = stage3_process(st.session_state.story) if st.session_state.audio: st.audio(st.session_state.audio, format="audio/mp3") st.download_button("Download Audio", st.session_state.audio.getvalue(), "story.mp3", mime="audio/mp3") if __name__ == "__main__": main()