File size: 4,206 Bytes
5bbcf1b
 
c460031
 
 
ef7e1aa
 
c460031
bed9467
146cc47
258921e
c460031
504a753
c460031
504a753
870428e
c460031
258921e
ef7e1aa
 
 
 
504a753
ef7e1aa
c460031
 
ef7e1aa
 
 
 
504a753
 
c75f8e2
504a753
 
c460031
c75f8e2
ef7e1aa
c75f8e2
 
ef7e1aa
504a753
ef7e1aa
c75f8e2
ef7e1aa
8f279a7
c75f8e2
 
 
 
 
 
 
 
 
ef7e1aa
 
c75f8e2
 
ef7e1aa
c75f8e2
 
 
ef7e1aa
8f279a7
c75f8e2
fb3ff7f
504a753
c460031
504a753
ef7e1aa
 
8f279a7
c75f8e2
 
8f279a7
 
 
 
 
 
 
 
504a753
 
 
 
146cc47
ef7e1aa
 
 
c75f8e2
8f279a7
c75f8e2
 
 
8f279a7
870428e
8f279a7
ef7e1aa
146cc47
bed9467
8f279a7
ef7e1aa
bed9467
8f279a7
c75f8e2
ef7e1aa
8f279a7
ef7e1aa
1394a8a
8f279a7
c75f8e2
 
8f279a7
 
c75f8e2
8f279a7
ef7e1aa
 
 
8f279a7
c75f8e2
 
 
 
 
 
 
 
 
fb3ff7f
1a64058
146cc47
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
from PIL import Image
from transformers import (
    BlipProcessor, 
    BlipForConditionalGeneration,
    AutoTokenizer,
    AutoModelForCausalLM
)
from gtts import gTTS
import io
import torch

# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_model():
    """Load image captioning model"""
    return (
        BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
        BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    )

def stage1_process(uploaded_file):
    """Generate image caption"""
    processor, model = load_image_model()
    img = Image.open(uploaded_file).convert("RGB")
    inputs = processor(images=img, return_tensors="pt")
    outputs = model.generate(**inputs)
    return processor.decode(outputs[0], skip_special_tokens=True)

# ======================
# Stage 2: Story Generation (Optimized)
# ======================
@st.cache_resource
def load_story_model():
    """Load optimized story model"""
    return (
        AutoTokenizer.from_pretrained("gpt2-medium"),
        AutoModelForCausalLM.from_pretrained("gpt2-medium")
    )

def stage2_process(keyword):
    """Generate structured story"""
    tokenizer, model = load_story_model()
    
    # Enhanced prompt template
    prompt = f"""Write a children's story in 100-150 words with these elements:
    - Theme: {keyword}
    - Characters: Friendly animals
    - Moral: Sharing is caring

    Story begins: One sunny morning, a little rabbit named Cotton discovered"""
    
    inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True)
    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=300,
        temperature=0.9,
        top_k=50,
        no_repeat_ngram_size=3,
        repetition_penalty=1.2,
        do_sample=True
    )
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return full_text.split("Story begins:")[-1].strip()

# ======================
# Stage 3: Text-to-Speech
# ======================
def stage3_process(text):
    """Convert text to audio"""
    try:
        clean_text = text.strip().replace('\n', ' ')[:300]
        if len(clean_text) < 20:
            return None
        tts = gTTS(text=clean_text, lang='en')
        audio = io.BytesIO()
        tts.write_to_fp(audio)
        audio.seek(0)
        return audio
    except:
        return None

# ======================
# Main Application
# ======================
def main():
    st.title("📖 Children's Story Generator")
    
    # Initialize session state
    if 'processing' not in st.session_state:
        st.session_state.update({
            'caption': None,
            'story': None,
            'audio': None
        })
    
    # File upload
    uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
    
    if uploaded_file:
        # Permanent display
        st.image(uploaded_file, width=300)
        
        # Stage 1
        if not st.session_state.caption:
            with st.spinner("Analyzing image..."):
                st.session_state.caption = stage1_process(uploaded_file)
        st.success(f"Detected Theme: {st.session_state.caption}")
        
        # Stage 2
        if not st.session_state.story:
            with st.spinner("Writing magical story..."):
                st.session_state.story = stage2_process(st.session_state.caption)
        
        # Display story
        if st.session_state.story:
            st.subheader("Generated Story")
            st.write(st.session_state.story)
            
            # Stage 3
            if not st.session_state.audio:
                with st.spinner("Generating audio..."):
                    st.session_state.audio = stage3_process(st.session_state.story)
            if st.session_state.audio:
                st.audio(st.session_state.audio, format="audio/mp3")
                st.download_button("Download Audio", 
                                 st.session_state.audio.getvalue(),
                                 "story.mp3",
                                 mime="audio/mp3")

if __name__ == "__main__":
    main()