TLH01 commited on
Commit
7f66618
·
verified ·
1 Parent(s): 6f03cd8

Create apptest.py

Browse files
Files changed (1) hide show
  1. apptest.py +84 -0
apptest.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import tempfile
4
+ import numpy as np
5
+ from transformers import pipeline, set_seed
6
+ import soundfile as sf
7
+
8
+ # --- 模型初始化(缓存优化)---
9
+ @st.cache_resource
10
+ def load_models():
11
+ caption_pipeline = pipeline(
12
+ "image-to-text",
13
+ model="Salesforce/blip-image-captioning-base",
14
+ device="cuda" if torch.cuda.is_available() else "cpu"
15
+ )
16
+ story_pipeline = pipeline(
17
+ "text-generation",
18
+ model="pranavpsv/gpt2-genre-story-generator",
19
+ device="cuda" if torch.cuda.is_available() else "cpu"
20
+ )
21
+ tts_pipeline = pipeline(
22
+ "text-to-speech",
23
+ model="speechbrain/tts-tacotron2-ljspeech",
24
+ device="cuda" if torch.cuda.is_available() else "cpu"
25
+ )
26
+ return caption_pipeline, story_pipeline, tts_pipeline
27
+
28
+ # --- Stage 1: Image → Caption ---
29
+ def generate_caption(image, pipeline):
30
+ caption = pipeline(image)[0]['generated_text']
31
+ return caption
32
+
33
+ # --- Stage 2: Caption → Story (严格限制字数) ---
34
+ def generate_story(caption, pipeline):
35
+ prompt = f"Generate a children's story in 50-100 words about: {caption}"
36
+ story = pipeline(
37
+ prompt,
38
+ max_length=150, # Token数量(约对应100词)
39
+ min_length=80, # 约对应50词
40
+ do_sample=True,
41
+ temperature=0.7,
42
+ top_k=50,
43
+ num_return_sequences=1
44
+ )[0]['generated_text']
45
+ # 移除重复提示并截断
46
+ story = story.replace(prompt, "").strip().split(".")[:5] # 取前5个句子
47
+ return ".".join(story[:5]) + "." # 确保以句号结尾
48
+
49
+ # --- Stage 3: Story → Audio (兼容Spaces) ---
50
+ def generate_audio(story_text, pipeline):
51
+ speech = pipeline(story_text)
52
+ audio_array = speech["audio"].squeeze().numpy()
53
+ sample_rate = speech["sampling_rate"]
54
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
55
+ sf.write(f.name, audio_array, sample_rate)
56
+ return f.name
57
+
58
+ # --- Streamlit UI ---
59
+ def main():
60
+ st.title("📖 AI Storyteller for Kids")
61
+ caption_pipeline, story_pipeline, tts_pipeline = load_models()
62
+
63
+ uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"])
64
+ if uploaded_image:
65
+ image = Image.open(uploaded_image)
66
+ st.image(image, use_column_width=True)
67
+
68
+ with st.spinner("🔍 Analyzing the image..."):
69
+ caption = generate_caption(image, caption_pipeline)
70
+ st.success(f"📝 Caption: {caption}")
71
+
72
+ with st.spinner("✨ Creating a magical story..."):
73
+ story = generate_story(caption, story_pipeline)
74
+ st.subheader("📚 Your Story")
75
+ st.write(story)
76
+ st.info(f"Word count: {len(story.split())}") # 显示字数
77
+
78
+ with st.spinner("🔊 Generating audio..."):
79
+ audio_path = generate_audio(story, tts_pipeline)
80
+ st.audio(audio_path, format="audio/wav")
81
+
82
+ if __name__ == "__main__":
83
+ import torch # 延迟导入以避免Spaces预加载问题
84
+ main()