import streamlit as st from transformers import pipeline from PIL import Image import tempfile import torch from TTS.api import TTS # Coqui TTS import os # ====================== # Stage 1: Image Captioning # ====================== @st.cache_resource def load_image_captioner(): return pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", device="cuda" if torch.cuda.is_available() else "cpu" ) def generate_caption(_pipeline, image): try: result = _pipeline(image, max_new_tokens=50) return result[0]['generated_text'] except Exception as e: st.error(f"Caption generation failed: {str(e)}") return None # ====================== # Stage 2: Story Generation # ====================== @st.cache_resource def load_story_generator(): return pipeline( "text-generation", model="pranavpsv/gpt2-genre-story-generator", # 可以替换为更强模型 device="cuda" if torch.cuda.is_available() else "cpu" ) def generate_story(_pipeline, caption): prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max). The story should: - Use simple and friendly language - Be related to the content of the image - Include a magical or fun twist - End happily Story:""" try: story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text'] return story.replace(prompt, "").strip() except Exception as e: st.error(f"Story generation failed: {str(e)}") return None # ====================== # Stage 3: Text-to-Speech using Coqui TTS # ====================== @st.cache_resource def load_tts(): return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available()) def text_to_speech(tts_model, story_text): try: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_model.tts_to_file(text=story_text, file_path=f.name) return f.name except Exception as e: st.error(f"Audio generation failed: {str(e)}") return None # ====================== # Main Streamlit App # ====================== def main(): st.set_page_config(page_title="Magic Story Generator", layout="wide") st.title("🧚 Magic Story Generator") uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"]) if not uploaded_image: return image = Image.open(uploaded_image) st.image(image, use_container_width=True) with st.spinner("Processing your magical story..."): caption_pipe = load_image_captioner() story_pipe = load_story_generator() tts_model = load_tts() caption = generate_caption(caption_pipe, image) if caption: st.success(f"Image description: {caption}") story = generate_story(story_pipe, caption) if story: st.subheader("Your Magical Story") st.markdown(story) audio_path = text_to_speech(tts_model, story) if audio_path: st.audio(audio_path, format="audio/wav") if __name__ == "__main__": main()