Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import tempfile | |
| import torch | |
| from TTS.api import TTS # Coqui TTS | |
| import os | |
| # ====================== | |
| # Stage 1: Image Captioning | |
| # ====================== | |
| def load_image_captioner(): | |
| return pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| def generate_caption(_pipeline, image): | |
| try: | |
| result = _pipeline(image, max_new_tokens=50) | |
| return result[0]['generated_text'] | |
| except Exception as e: | |
| st.error(f"Caption generation failed: {str(e)}") | |
| return None | |
| # ====================== | |
| # Stage 2: Story Generation | |
| # ====================== | |
| def load_story_generator(): | |
| return pipeline( | |
| "text-generation", | |
| model="pranavpsv/gpt2-genre-story-generator", # 可以替换为更强模型 | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| def generate_story(_pipeline, caption): | |
| prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max). | |
| The story should: | |
| - Use simple and friendly language | |
| - Be related to the content of the image | |
| - Include a magical or fun twist | |
| - End happily | |
| Story:""" | |
| try: | |
| story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text'] | |
| return story.replace(prompt, "").strip() | |
| except Exception as e: | |
| st.error(f"Story generation failed: {str(e)}") | |
| return None | |
| # ====================== | |
| # Stage 3: Text-to-Speech using Coqui TTS | |
| # ====================== | |
| def load_tts(): | |
| return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available()) | |
| def text_to_speech(tts_model, story_text): | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| tts_model.tts_to_file(text=story_text, file_path=f.name) | |
| return f.name | |
| except Exception as e: | |
| st.error(f"Audio generation failed: {str(e)}") | |
| return None | |
| # ====================== | |
| # Main Streamlit App | |
| # ====================== | |
| def main(): | |
| st.set_page_config(page_title="Magic Story Generator", layout="wide") | |
| st.title("🧚 Magic Story Generator") | |
| uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"]) | |
| if not uploaded_image: | |
| return | |
| image = Image.open(uploaded_image) | |
| st.image(image, use_container_width=True) | |
| with st.spinner("Processing your magical story..."): | |
| caption_pipe = load_image_captioner() | |
| story_pipe = load_story_generator() | |
| tts_model = load_tts() | |
| caption = generate_caption(caption_pipe, image) | |
| if caption: | |
| st.success(f"Image description: {caption}") | |
| story = generate_story(story_pipe, caption) | |
| if story: | |
| st.subheader("Your Magical Story") | |
| st.markdown(story) | |
| audio_path = text_to_speech(tts_model, story) | |
| if audio_path: | |
| st.audio(audio_path, format="audio/wav") | |
| if __name__ == "__main__": | |
| main() | |