testtest / app.py
TLH01's picture
Update app.py
ef7e1aa verified
raw
history blame
3.83 kB
import streamlit as st
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
import io
import torch
# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_model():
"""Load image captioning model"""
return (
BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
)
def stage1_process(uploaded_file):
"""Generate image caption"""
processor, model = load_image_model()
img = Image.open(uploaded_file).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
outputs = model.generate(**inputs)
return processor.decode(outputs[0], skip_special_tokens=True)
# ======================
# Stage 2: Story Generation
# ======================
@st.cache_resource
def load_story_model():
"""Load story generation model"""
return (
AutoTokenizer.from_pretrained("prpappas/fairytale-gpt2"),
AutoModelForCausalLM.from_pretrained("prpappas/fairytale-gpt2")
)
def stage2_process(keyword):
"""Generate children's story"""
tokenizer, model = load_story_model()
prompt = f"Write a children's story about {keyword} in 100 words:\n"
inputs = tokenizer(prompt, return_tensors="pt", max_length=50, truncation=True)
outputs = model.generate(
inputs.input_ids,
max_length=200,
temperature=0.85,
top_k=50,
repetition_penalty=1.2
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
# ======================
# Stage 3: Text-to-Speech
# ======================
def stage3_process(text):
"""Convert text to audio"""
tts = gTTS(text=text[:200], lang='en')
audio = io.BytesIO()
tts.write_to_fp(audio)
audio.seek(0)
return audio
# ======================
# Main Application
# ======================
def main():
st.title("📖 Children's Story Generator")
# Initialize session state
if 'stage1_done' not in st.session_state:
st.session_state.stage1_done = False
if 'stage2_done' not in st.session_state:
st.session_state.stage2_done = False
# File upload section
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
if uploaded_file:
# Always show image and Stage 1 result
st.image(uploaded_file, width=300)
# Stage 1 Processing
if not st.session_state.stage1_done:
with st.spinner("Analyzing image..."):
caption = stage1_process(uploaded_file)
st.session_state.caption = caption
st.session_state.stage1_done = True
st.success(f"Detected Theme: {st.session_state.caption}")
# Stage 2 Processing
if not st.session_state.stage2_done:
with st.spinner("Creating story..."):
story = stage2_process(st.session_state.caption)
st.session_state.story = story
st.session_state.stage2_done = True
if st.session_state.stage2_done:
st.subheader("Generated Story")
st.write(st.session_state.story)
# Stage 3 Processing
with st.spinner("Generating audio..."):
audio = stage3_process(st.session_state.story)
st.audio(audio, format="audio/mp3")
st.download_button("Download Audio",
data=audio.getvalue(),
file_name="story.mp3",
mime="audio/mp3")
if __name__ == "__main__":
main()