|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
from transformers import ( |
|
|
BlipProcessor, |
|
|
BlipForConditionalGeneration, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM |
|
|
) |
|
|
from gtts import gTTS |
|
|
import io |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_image_model(): |
|
|
"""Load image captioning model""" |
|
|
return ( |
|
|
BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"), |
|
|
BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
) |
|
|
|
|
|
def stage1_process(uploaded_file): |
|
|
"""Generate image caption""" |
|
|
processor, model = load_image_model() |
|
|
img = Image.open(uploaded_file).convert("RGB") |
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
outputs = model.generate(**inputs) |
|
|
return processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_story_model(): |
|
|
"""Load optimized story model""" |
|
|
return ( |
|
|
AutoTokenizer.from_pretrained("gpt2-medium"), |
|
|
AutoModelForCausalLM.from_pretrained("gpt2-medium") |
|
|
) |
|
|
|
|
|
def stage2_process(keyword): |
|
|
"""Generate structured story""" |
|
|
tokenizer, model = load_story_model() |
|
|
|
|
|
|
|
|
prompt = f"""Write a children's story in 100-150 words with these elements: |
|
|
- Theme: {keyword} |
|
|
- Characters: Friendly animals |
|
|
- Moral: Sharing is caring |
|
|
|
|
|
Story begins: One sunny morning, a little rabbit named Cotton discovered""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True) |
|
|
outputs = model.generate( |
|
|
inputs.input_ids, |
|
|
max_new_tokens=300, |
|
|
temperature=0.9, |
|
|
top_k=50, |
|
|
no_repeat_ngram_size=3, |
|
|
repetition_penalty=1.2, |
|
|
do_sample=True |
|
|
) |
|
|
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return full_text.split("Story begins:")[-1].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stage3_process(text): |
|
|
"""Convert text to audio""" |
|
|
try: |
|
|
clean_text = text.strip().replace('\n', ' ')[:300] |
|
|
if len(clean_text) < 20: |
|
|
return None |
|
|
tts = gTTS(text=clean_text, lang='en') |
|
|
audio = io.BytesIO() |
|
|
tts.write_to_fp(audio) |
|
|
audio.seek(0) |
|
|
return audio |
|
|
except: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("📖 Children's Story Generator") |
|
|
|
|
|
|
|
|
if 'processing' not in st.session_state: |
|
|
st.session_state.update({ |
|
|
'caption': None, |
|
|
'story': None, |
|
|
'audio': None |
|
|
}) |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"]) |
|
|
|
|
|
if uploaded_file: |
|
|
|
|
|
st.image(uploaded_file, width=300) |
|
|
|
|
|
|
|
|
if not st.session_state.caption: |
|
|
with st.spinner("Analyzing image..."): |
|
|
st.session_state.caption = stage1_process(uploaded_file) |
|
|
st.success(f"Detected Theme: {st.session_state.caption}") |
|
|
|
|
|
|
|
|
if not st.session_state.story: |
|
|
with st.spinner("Writing magical story..."): |
|
|
st.session_state.story = stage2_process(st.session_state.caption) |
|
|
|
|
|
|
|
|
if st.session_state.story: |
|
|
st.subheader("Generated Story") |
|
|
st.write(st.session_state.story) |
|
|
|
|
|
|
|
|
if not st.session_state.audio: |
|
|
with st.spinner("Generating audio..."): |
|
|
st.session_state.audio = stage3_process(st.session_state.story) |
|
|
if st.session_state.audio: |
|
|
st.audio(st.session_state.audio, format="audio/mp3") |
|
|
st.download_button("Download Audio", |
|
|
st.session_state.audio.getvalue(), |
|
|
"story.mp3", |
|
|
mime="audio/mp3") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |