| import streamlit as st |
| from PIL import Image |
| from transformers import ( |
| BlipProcessor, |
| BlipForConditionalGeneration, |
| AutoTokenizer, |
| AutoModelForCausalLM |
| ) |
| from gtts import gTTS |
| import io |
| import torch |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_image_model(): |
| """Load image captioning model""" |
| return ( |
| BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"), |
| BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| ) |
|
|
| def stage1_process(uploaded_file): |
| """Generate image caption""" |
| processor, model = load_image_model() |
| img = Image.open(uploaded_file).convert("RGB") |
| inputs = processor(images=img, return_tensors="pt") |
| outputs = model.generate(**inputs) |
| return processor.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_story_model(): |
| """Load story generation model""" |
| return ( |
| AutoTokenizer.from_pretrained("prpappas/fairytale-gpt2"), |
| AutoModelForCausalLM.from_pretrained("prpappas/fairytale-gpt2") |
| ) |
|
|
| def stage2_process(keyword): |
| """Generate children's story""" |
| tokenizer, model = load_story_model() |
| prompt = f"Write a children's story about {keyword} in 100 words:\n" |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=50, truncation=True) |
| outputs = model.generate( |
| inputs.input_ids, |
| max_length=200, |
| temperature=0.85, |
| top_k=50, |
| repetition_penalty=1.2 |
| ) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "") |
|
|
| |
| |
| |
| def stage3_process(text): |
| """Convert text to audio""" |
| tts = gTTS(text=text[:200], lang='en') |
| audio = io.BytesIO() |
| tts.write_to_fp(audio) |
| audio.seek(0) |
| return audio |
|
|
| |
| |
| |
| def main(): |
| st.title("📖 Children's Story Generator") |
| |
| |
| if 'stage1_done' not in st.session_state: |
| st.session_state.stage1_done = False |
| if 'stage2_done' not in st.session_state: |
| st.session_state.stage2_done = False |
| |
| |
| uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"]) |
| |
| if uploaded_file: |
| |
| st.image(uploaded_file, width=300) |
| |
| |
| if not st.session_state.stage1_done: |
| with st.spinner("Analyzing image..."): |
| caption = stage1_process(uploaded_file) |
| st.session_state.caption = caption |
| st.session_state.stage1_done = True |
| st.success(f"Detected Theme: {st.session_state.caption}") |
| |
| |
| if not st.session_state.stage2_done: |
| with st.spinner("Creating story..."): |
| story = stage2_process(st.session_state.caption) |
| st.session_state.story = story |
| st.session_state.stage2_done = True |
| |
| if st.session_state.stage2_done: |
| st.subheader("Generated Story") |
| st.write(st.session_state.story) |
| |
| |
| with st.spinner("Generating audio..."): |
| audio = stage3_process(st.session_state.story) |
| st.audio(audio, format="audio/mp3") |
| st.download_button("Download Audio", |
| data=audio.getvalue(), |
| file_name="story.mp3", |
| mime="audio/mp3") |
|
|
| if __name__ == "__main__": |
| main() |