Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| from gtts import gTTS | |
| import torch | |
| import tempfile | |
| # 页面设置 | |
| st.set_page_config(page_title="🧸 Story Generator (CPU Phi-2)", page_icon="📚") | |
| st.title("🖼️ Kid-Friendly Story Generator (CPU Version)") | |
| st.write("Upload a picture and receive a simple, magical story for young children with narration!") | |
| # 上传图像 | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| if st.button("Generate Story"): | |
| with st.spinner("📷 Generating image caption..."): | |
| # 图像转文本(使用 ViT-GPT2) | |
| captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
| caption = captioner(image)[0]['generated_text'].strip() | |
| with st.spinner("✍️ Generating story with Phi-2..."): | |
| # 加载 Phi-2 模型(适合 CPU) | |
| model_name = "microsoft/phi-2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # 精简、友好的 prompt(控制语言风格和长度) | |
| prompt = ( | |
| f"Write a short and imaginative story (under 100 words) for a 4 to 8 year-old child.\n" | |
| f"The story should be based on this scene: {caption}.\n" | |
| f"Use simple language, cheerful tone, and mention children playing, toys, or nature.\n" | |
| ) | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=150, | |
| temperature=0.8, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| story = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| story = story[len(prompt):].strip() | |
| # 截断词数(最大 100) | |
| words = story.split() | |
| if len(words) > 100: | |
| story = " ".join(words[:100]) + "..." | |
| with st.spinner("🔊 Converting story to speech..."): | |
| tts = gTTS(text=story, lang='en') | |
| temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_audio.name) | |
| # 显示最终结果 | |
| st.subheader("📖 Generated Story") | |
| st.write(story) | |
| st.subheader("🔊 Listen to the Story") | |
| st.audio(temp_audio.name, format="audio/mp3") | |