Spaces:
Sleeping
Sleeping
| # storygen_tts_final.py | |
| import streamlit as st | |
| from transformers import ( | |
| BlipForConditionalGeneration, | |
| BlipProcessor, | |
| AutoProcessor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| pipeline | |
| ) | |
| from datasets import load_dataset | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| # 初始化模型(CPU优化版) | |
| def load_models(): | |
| """加载所有需要的AI模型""" | |
| try: | |
| # 图像描述模型 | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # 文本生成pipeline | |
| story_generator = pipeline( | |
| "text-generation", | |
| model="openai-community/gpt2", | |
| device_map="auto" | |
| ) | |
| # 语音合成模型 | |
| tts_processor = AutoProcessor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # 加载说话者嵌入数据集 | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| return blip_processor, blip_model, story_generator, tts_processor, tts_model, vocoder, embeddings_dataset | |
| except Exception as e: | |
| st.error(f"模型加载失败: {str(e)}") | |
| raise | |
| def generate_story(image, blip_processor, blip_model, story_generator): | |
| """生成高质量儿童故事""" | |
| inputs = blip_processor(image, return_tensors="pt") | |
| # 生成图像描述 | |
| caption_ids = blip_model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| num_beams=5, | |
| early_stopping=True, | |
| temperature=0.9 | |
| ) | |
| caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True) | |
| # 构建故事生成提示词 | |
| prompt = f"""Based on this image: {caption} | |
| Write a magical story for children with: | |
| 1. Talking animals | |
| 2. Happy ending | |
| 3. Sound effects (*whoosh*, *giggle*) | |
| 4. 50-100 words | |
| Story:""" | |
| # 使用GPT-2生成故事 | |
| generated = story_generator( | |
| prompt, | |
| max_length=100, | |
| min_length=50, | |
| num_return_sequences=1, | |
| temperature=0.85, | |
| repetition_penalty=2.0 | |
| ) | |
| # 提取生成文本并清理 | |
| full_text = generated[0]['generated_text'] | |
| story = full_text.split("Story:")[-1].strip() | |
| return story[:600].replace(caption, "").strip() | |
| def text_to_speech(text, processor, model, vocoder, embeddings_dataset): | |
| """文本转语音""" | |
| try: | |
| inputs = processor( | |
| text=text, | |
| return_tensors="pt", | |
| voice_preset=None | |
| ) | |
| input_ids = inputs["input_ids"].to(torch.int64) | |
| # 随机选择一个说话者嵌入 | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
| with torch.no_grad(): | |
| speech = model.generate_speech( | |
| input_ids=input_ids, | |
| speaker_embeddings=speaker_embeddings, | |
| vocoder=vocoder | |
| ) | |
| audio_array = speech.numpy() | |
| audio_array = audio_array / np.max(np.abs(audio_array)) | |
| return audio_array, 16000 | |
| except Exception as e: | |
| st.error(f"语音生成失败: {str(e)}") | |
| raise | |
| def main(): | |
| # 界面配置 | |
| st.set_page_config( | |
| page_title="Magic Story Box", | |
| page_icon="🧙", | |
| layout="centered" | |
| ) | |
| st.title("🧚♀️ Magic Story Box") | |
| st.markdown("---") | |
| st.write("Upload an image to get your magical story!") | |
| # 初始化会话状态 | |
| if 'generated' not in st.session_state: | |
| st.session_state.generated = False | |
| # 加载模型 | |
| try: | |
| (blip_proc, blip_model, story_gen, | |
| tts_proc, tts_model, vocoder, embeddings) = load_models() | |
| except: | |
| return | |
| # 文件上传组件 | |
| uploaded_file = st.file_uploader( | |
| "Choose your magic image", | |
| type=["jpg", "png", "jpeg"], | |
| help="Upload photos of pets, toys or adventures!", | |
| key="uploader" | |
| ) | |
| # 处理上传文件 | |
| if uploaded_file and not st.session_state.generated: | |
| try: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Your Magic Picture ✨", use_container_width=True) | |
| with st.status("Creating Magic...", expanded=True) as status: | |
| # 生成故事 | |
| st.write("🔍 Reading the image...") | |
| story = generate_story(image, blip_proc, blip_model, story_gen) | |
| # 生成语音 | |
| st.write("🔊 Adding sounds...") | |
| audio_array, sr = text_to_speech(story, tts_proc, tts_model, vocoder, embeddings) | |
| # 保存结果 | |
| st.session_state.story = story | |
| st.session_state.audio = (audio_array, sr) | |
| status.update(label="Ready!", state="complete", expanded=False) | |
| st.session_state.generated = True | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Magic failed: {str(e)}") | |
| # 显示结果 | |
| if st.session_state.generated: | |
| st.markdown("---") | |
| st.subheader("Your Story 📖") | |
| st.markdown(f'<div style="background:#fff3e6; padding:20px; border-radius:10px;">{st.session_state.story}</div>', | |
| unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.subheader("Listen 🎧") | |
| audio_data, sr = st.session_state.audio | |
| st.audio(audio_data, sample_rate=sr) | |
| st.markdown("---") | |
| if st.button("Create New Story", use_container_width=True): | |
| st.session_state.generated = False | |
| st.session_state.uploader = None | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |