imagetotextapp / app.py
JenniferHJF's picture
Update app.py
8a76aba verified
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from PIL import Image
from gtts import gTTS
import torch
import tempfile
# Page configuration
st.set_page_config(page_title="🧸 Story Generator for Kids", page_icon="📚")
st.title("🖼️ Image to Story Generator (Zephyr + BLIP)")
st.write("Upload an image and enjoy a magical story with voice, designed for kids aged 3–10!")
# upload image
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
if st.button("Generate Story"):
with st.spinner("📷 Generating caption..."):
# image description model(BLIP Large)
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=0)
caption = captioner(image)[0]['generated_text'].strip()
with st.spinner("✍️ Generating story with Zephyr..."):
# Load the Zephyr 7B model
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta",
torch_dtype=torch.float16,
device_map="auto"
)
# Instruction prompt word format (Zephyr format)
prompt = (
"<|system|>\nYou are a friendly AI assistant who writes short stories for children.\n"
"<|user|>\nWrite a short, vivid, and imaginative story (under 100 words) suitable for children aged 3 to 10, "
f"based on this image description: {caption}\n<|assistant|>\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=180,
do_sample=True,
temperature=0.8,
top_p=0.95
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
story = decoded.split("<|assistant|>")[-1].strip()
# The number of restrictive words shall not exceed 100
words = story.split()
if len(words) > 100:
story = " ".join(words[:100]) + "..."
with st.spinner("🔊 Converting story to speech..."):
# Text-to-speech
tts = gTTS(text=story, lang='en')
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
# result of presentation
st.subheader("📖 Generated Story")
st.write(story)
st.subheader("🔊 Listen to the Story")
st.audio(temp_audio.name, format="audio/mp3")