testtest / app.py
TLH01's picture
Update app.py
c75f8e2 verified
import streamlit as st
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
import io
import torch
# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_model():
"""Load image captioning model"""
return (
BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
)
def stage1_process(uploaded_file):
"""Generate image caption"""
processor, model = load_image_model()
img = Image.open(uploaded_file).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
outputs = model.generate(**inputs)
return processor.decode(outputs[0], skip_special_tokens=True)
# ======================
# Stage 2: Story Generation (Optimized)
# ======================
@st.cache_resource
def load_story_model():
"""Load optimized story model"""
return (
AutoTokenizer.from_pretrained("gpt2-medium"),
AutoModelForCausalLM.from_pretrained("gpt2-medium")
)
def stage2_process(keyword):
"""Generate structured story"""
tokenizer, model = load_story_model()
# Enhanced prompt template
prompt = f"""Write a children's story in 100-150 words with these elements:
- Theme: {keyword}
- Characters: Friendly animals
- Moral: Sharing is caring
Story begins: One sunny morning, a little rabbit named Cotton discovered"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True)
outputs = model.generate(
inputs.input_ids,
max_new_tokens=300,
temperature=0.9,
top_k=50,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
do_sample=True
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return full_text.split("Story begins:")[-1].strip()
# ======================
# Stage 3: Text-to-Speech
# ======================
def stage3_process(text):
"""Convert text to audio"""
try:
clean_text = text.strip().replace('\n', ' ')[:300]
if len(clean_text) < 20:
return None
tts = gTTS(text=clean_text, lang='en')
audio = io.BytesIO()
tts.write_to_fp(audio)
audio.seek(0)
return audio
except:
return None
# ======================
# Main Application
# ======================
def main():
st.title("📖 Children's Story Generator")
# Initialize session state
if 'processing' not in st.session_state:
st.session_state.update({
'caption': None,
'story': None,
'audio': None
})
# File upload
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
if uploaded_file:
# Permanent display
st.image(uploaded_file, width=300)
# Stage 1
if not st.session_state.caption:
with st.spinner("Analyzing image..."):
st.session_state.caption = stage1_process(uploaded_file)
st.success(f"Detected Theme: {st.session_state.caption}")
# Stage 2
if not st.session_state.story:
with st.spinner("Writing magical story..."):
st.session_state.story = stage2_process(st.session_state.caption)
# Display story
if st.session_state.story:
st.subheader("Generated Story")
st.write(st.session_state.story)
# Stage 3
if not st.session_state.audio:
with st.spinner("Generating audio..."):
st.session_state.audio = stage3_process(st.session_state.story)
if st.session_state.audio:
st.audio(st.session_state.audio, format="audio/mp3")
st.download_button("Download Audio",
st.session_state.audio.getvalue(),
"story.mp3",
mime="audio/mp3")
if __name__ == "__main__":
main()