storytellingmodel / apptest.py
TLH01's picture
Create apptest.py
33de3ee verified
import streamlit as st
from PIL import Image
import tempfile
import numpy as np
from transformers import pipeline, set_seed
import soundfile as sf
# --- 模型初始化(缓存优化)---
@st.cache_resource
def load_models():
caption_pipeline = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device="cuda" if torch.cuda.is_available() else "cpu"
)
story_pipeline = pipeline(
"text-generation",
model="pranavpsv/gpt2-genre-story-generator",
device="cuda" if torch.cuda.is_available() else "cpu"
)
tts_pipeline = pipeline(
"text-to-speech",
model="speechbrain/tts-tacotron2-ljspeech",
device="cuda" if torch.cuda.is_available() else "cpu"
)
return caption_pipeline, story_pipeline, tts_pipeline
# --- Stage 1: Image → Caption ---
def generate_caption(image, pipeline):
caption = pipeline(image)[0]['generated_text']
return caption
# --- Stage 2: Caption(keyword) → Story (严格限制字数) ---
def generate_story(caption, pipeline):
prompt = f"Generate a children's story in 50-100 words about: {caption}"
story = pipeline(
prompt,
max_length=150, # Token数量(约对应100词)
min_length=80, # 约对应50词
do_sample=True,
temperature=0.7,
top_k=50,
num_return_sequences=1
)[0]['generated_text']
# 移除重复提示并截断
story = story.replace(prompt, "").strip().split(".")[:5] # 取前5个句子
return ".".join(story[:5]) + "." # 确保以句号结尾
# --- Stage 3: Story → Audio (兼容Spaces) ---
def generate_audio(story_text, pipeline):
speech = pipeline(story_text)
audio_array = speech["audio"].squeeze().numpy()
sample_rate = speech["sampling_rate"]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, audio_array, sample_rate)
return f.name
# --- Streamlit UI ---
def main():
st.title("📖 AI Storyteller for Kids")
caption_pipeline, story_pipeline, tts_pipeline = load_models()
uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, use_column_width=True)
with st.spinner("🔍 Analyzing the image..."):
caption = generate_caption(image, caption_pipeline)
st.success(f"📝 Caption: {caption}")
with st.spinner("✨ Creating a magical story..."):
story = generate_story(caption, story_pipeline)
st.subheader("📚 Your Story")
st.write(story)
st.info(f"Word count: {len(story.split())}") # 显示字数
with st.spinner("🔊 Generating audio..."):
audio_path = generate_audio(story, tts_pipeline)
st.audio(audio_path, format="audio/wav")
if __name__ == "__main__":
import torch # 延迟导入以避免Spaces预加载问题
main()