File size: 4,206 Bytes
5bbcf1b c460031 ef7e1aa c460031 bed9467 146cc47 258921e c460031 504a753 c460031 504a753 870428e c460031 258921e ef7e1aa 504a753 ef7e1aa c460031 ef7e1aa 504a753 c75f8e2 504a753 c460031 c75f8e2 ef7e1aa c75f8e2 ef7e1aa 504a753 ef7e1aa c75f8e2 ef7e1aa 8f279a7 c75f8e2 ef7e1aa c75f8e2 ef7e1aa c75f8e2 ef7e1aa 8f279a7 c75f8e2 fb3ff7f 504a753 c460031 504a753 ef7e1aa 8f279a7 c75f8e2 8f279a7 504a753 146cc47 ef7e1aa c75f8e2 8f279a7 c75f8e2 8f279a7 870428e 8f279a7 ef7e1aa 146cc47 bed9467 8f279a7 ef7e1aa bed9467 8f279a7 c75f8e2 ef7e1aa 8f279a7 ef7e1aa 1394a8a 8f279a7 c75f8e2 8f279a7 c75f8e2 8f279a7 ef7e1aa 8f279a7 c75f8e2 fb3ff7f 1a64058 146cc47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
import io
import torch
# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_model():
"""Load image captioning model"""
return (
BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
)
def stage1_process(uploaded_file):
"""Generate image caption"""
processor, model = load_image_model()
img = Image.open(uploaded_file).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
outputs = model.generate(**inputs)
return processor.decode(outputs[0], skip_special_tokens=True)
# ======================
# Stage 2: Story Generation (Optimized)
# ======================
@st.cache_resource
def load_story_model():
"""Load optimized story model"""
return (
AutoTokenizer.from_pretrained("gpt2-medium"),
AutoModelForCausalLM.from_pretrained("gpt2-medium")
)
def stage2_process(keyword):
"""Generate structured story"""
tokenizer, model = load_story_model()
# Enhanced prompt template
prompt = f"""Write a children's story in 100-150 words with these elements:
- Theme: {keyword}
- Characters: Friendly animals
- Moral: Sharing is caring
Story begins: One sunny morning, a little rabbit named Cotton discovered"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True)
outputs = model.generate(
inputs.input_ids,
max_new_tokens=300,
temperature=0.9,
top_k=50,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
do_sample=True
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return full_text.split("Story begins:")[-1].strip()
# ======================
# Stage 3: Text-to-Speech
# ======================
def stage3_process(text):
"""Convert text to audio"""
try:
clean_text = text.strip().replace('\n', ' ')[:300]
if len(clean_text) < 20:
return None
tts = gTTS(text=clean_text, lang='en')
audio = io.BytesIO()
tts.write_to_fp(audio)
audio.seek(0)
return audio
except:
return None
# ======================
# Main Application
# ======================
def main():
st.title("📖 Children's Story Generator")
# Initialize session state
if 'processing' not in st.session_state:
st.session_state.update({
'caption': None,
'story': None,
'audio': None
})
# File upload
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
if uploaded_file:
# Permanent display
st.image(uploaded_file, width=300)
# Stage 1
if not st.session_state.caption:
with st.spinner("Analyzing image..."):
st.session_state.caption = stage1_process(uploaded_file)
st.success(f"Detected Theme: {st.session_state.caption}")
# Stage 2
if not st.session_state.story:
with st.spinner("Writing magical story..."):
st.session_state.story = stage2_process(st.session_state.caption)
# Display story
if st.session_state.story:
st.subheader("Generated Story")
st.write(st.session_state.story)
# Stage 3
if not st.session_state.audio:
with st.spinner("Generating audio..."):
st.session_state.audio = stage3_process(st.session_state.story)
if st.session_state.audio:
st.audio(st.session_state.audio, format="audio/mp3")
st.download_button("Download Audio",
st.session_state.audio.getvalue(),
"story.mp3",
mime="audio/mp3")
if __name__ == "__main__":
main() |