Spaces:
Build error
Build error
| """ | |
| 儿童故事生成器 (Children's Story Generator) | |
| 功能:上传图片 → 生成描述 → 创作故事 → 语音朗读 | |
| """ | |
| # ============ 导入模块 ============ | |
| import streamlit as st | |
| from PIL import Image | |
| import tempfile | |
| from transformers import pipeline | |
| import torch | |
| import os | |
| # ============ 第一阶段:图片描述生成 ============ | |
| # 缓存模型避免重复加载 | |
| def load_image_captioner(): | |
| """加载图片描述模型(BLIP模型)""" | |
| return pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| device="cuda" if torch.cuda.is_available() else "cpu" # 自动检测GPU | |
| ) | |
| def generate_caption(_pipeline, image): | |
| """生成图片英文描述""" | |
| try: | |
| result = _pipeline(image, max_new_tokens=50) # 限制生成长度 | |
| return result[0]['generated_text'] | |
| except Exception as e: | |
| st.error(f"生成描述失败: {str(e)}") | |
| return None | |
| # ============ 第二阶段:故事创作 ============ | |
| def load_story_generator(): | |
| """加载儿童故事生成模型""" | |
| return pipeline( | |
| "text-generation", | |
| model="pranavpsv/gpt2-genre-story-generator", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| def generate_story(_pipeline, keywords): | |
| """根据关键词生成儿童故事""" | |
| prompt = f"""Generate a children's story (60-80 words) in English about: {keywords} | |
| Requirements: | |
| - Use simple words | |
| - Include magical elements | |
| - Happy ending | |
| Story:""" | |
| try: | |
| story = _pipeline( | |
| prompt, | |
| max_length=200, | |
| temperature=0.7 # 控制创意程度 | |
| )[0]['generated_text'] | |
| return story.replace(prompt, "").strip() | |
| except Exception as e: | |
| st.error(f"生成故事失败: {str(e)}") | |
| return None | |
| # ============ 第三阶段:语音合成 ============ | |
| def load_tts(): | |
| """加载文本转语音模型""" | |
| return pipeline( | |
| "text-to-speech", | |
| model="facebook/mms-tts-eng", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| def text_to_speech(_pipeline, text): | |
| """将文本转为语音""" | |
| try: | |
| audio = _pipeline(text) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| import soundfile as sf | |
| sf.write(f.name, audio["audio"].squeeze().numpy(), audio["sampling_rate"]) | |
| return f.name | |
| except Exception as e: | |
| st.error(f"语音生成失败: {str(e)}") | |
| return None | |
| # ============ 主界面 ============ | |
| def main(): | |
| # 界面设置 | |
| st.set_page_config( | |
| page_title="魔法故事生成器", | |
| page_icon="🧚", | |
| layout="wide" | |
| ) | |
| # 儿童风格CSS | |
| st.markdown(""" | |
| <style> | |
| .main { background-color: #FFF5E6 } | |
| h1 { color: #FF6B6B; font-family: 'Comic Sans MS' } | |
| .stButton>button { background-color: #4CAF50; border-radius: 20px } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("🧚 魔法故事生成器") | |
| st.write("上传小朋友的照片,AI会生成专属故事并朗读!") | |
| # 图片上传 | |
| uploaded_file = st.file_uploader("选择照片", type=["jpg", "png"]) | |
| if not uploaded_file: | |
| st.info("请先上传照片") | |
| return | |
| image = Image.open(uploaded_file) | |
| st.image(image, use_column_width=True) | |
| # 加载模型 | |
| with st.spinner("正在准备魔法..."): | |
| caption_pipe = load_image_captioner() | |
| story_pipe = load_story_generator() | |
| tts_pipe = load_tts() | |
| # 第一阶段 | |
| with st.spinner("正在分析图片..."): | |
| caption = generate_caption(caption_pipe, image) | |
| if caption: | |
| st.success(f"图片描述: {caption}") | |
| # 第二阶段 | |
| if caption: | |
| with st.spinner("正在创作故事..."): | |
| story = generate_story(story_pipe, caption) | |
| if story: | |
| st.subheader("你的故事") | |
| st.markdown(f'<div style="background-color:#FFF0F5; padding:20px; border-radius:15px">{story}</div>', unsafe_allow_html=True) | |
| # 第三阶段 | |
| with st.spinner("正在生成语音..."): | |
| audio_path = text_to_speech(tts_pipe, story) | |
| if audio_path: | |
| st.audio(audio_path, format="audio/wav") | |
| if __name__ == "__main__": | |
| os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" # 设置缓存路径 | |
| main() |