Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from PIL import Image | |
| from gtts import gTTS | |
| import torch | |
| import tempfile | |
| # Page configuration | |
| st.set_page_config(page_title="🧸 Story Generator for Kids", page_icon="📚") | |
| st.title("🖼️ Image to Story Generator (Zephyr + BLIP)") | |
| st.write("Upload an image and enjoy a magical story with voice, designed for kids aged 3–10!") | |
| # upload image | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| if st.button("Generate Story"): | |
| with st.spinner("📷 Generating caption..."): | |
| # image description model(BLIP Large) | |
| captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=0) | |
| caption = captioner(image)[0]['generated_text'].strip() | |
| with st.spinner("✍️ Generating story with Zephyr..."): | |
| # Load the Zephyr 7B model | |
| tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Instruction prompt word format (Zephyr format) | |
| prompt = ( | |
| "<|system|>\nYou are a friendly AI assistant who writes short stories for children.\n" | |
| "<|user|>\nWrite a short, vivid, and imaginative story (under 100 words) suitable for children aged 3 to 10, " | |
| f"based on this image description: {caption}\n<|assistant|>\n" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=180, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.95 | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| story = decoded.split("<|assistant|>")[-1].strip() | |
| # The number of restrictive words shall not exceed 100 | |
| words = story.split() | |
| if len(words) > 100: | |
| story = " ".join(words[:100]) + "..." | |
| with st.spinner("🔊 Converting story to speech..."): | |
| # Text-to-speech | |
| tts = gTTS(text=story, lang='en') | |
| temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_audio.name) | |
| # result of presentation | |
| st.subheader("📖 Generated Story") | |
| st.write(story) | |
| st.subheader("🔊 Listen to the Story") | |
| st.audio(temp_audio.name, format="audio/mp3") | |