Assignment1 / app.py
shingguy1's picture
Update app.py
226a292 verified
raw
history blame
2.39 kB
import streamlit as st
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
from gtts import gTTS
import os
import tempfile
# Load models
@st.cache_resource
def load_models():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
gpt2_pipeline = pipeline("text-generation", model="gpt2")
return processor, blip_model, gpt2_pipeline
processor, blip_model, gpt2 = load_models()
# UI
st.title("๐Ÿ–ผ๏ธ๐Ÿ“– Storyteller for Kids")
st.write("Upload an image and let the app create and read a magical story just for kids!")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("Generating image caption..."):
inputs = processor(images=image, return_tensors="pt")
out = blip_model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
st.success("Caption generated!")
st.write(f"**Caption:** {caption}")
with st.spinner("Writing a children's story..."):
prompt = f"Write a short, imaginative story for children aged 3-10 about this: {caption}"
story_output = gpt2(
prompt,
max_length=100,
num_return_sequences=1,
do_sample=True,
temperature=0.9,
top_p=0.95,
top_k=50,
repetition_penalty=1.2,
pad_token_id=50256,
eos_token_id=50256,
)[0]["generated_text"]
story = story_output.strip().replace('\n', ' ')
# Truncate to ~100 words for safety
story = " ".join(story.split()[:100])
st.success("Story created!")
st.write(f"**Story:**\n\n{story}")
with st.spinner("Converting story to audio..."):
try:
tts = gTTS(text=story, lang='en')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
st.success("Audio playback ready!")
except Exception as e:
st.error(f"Text-to-speech failed: {e}")