aeresd commited on
Commit
aa2ae39
·
verified ·
1 Parent(s): a274a68

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer
3
+ import torch
4
+ import re
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+ import logging
10
+
11
+ # 配置日志系统
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ==================== 模型缓存加载 ====================
16
+ @st.cache_resource(show_spinner=False)
17
+ def load_models():
18
+ """预加载所有模型并缓存"""
19
+ logger.info("Loading caption model...")
20
+ caption_model = pipeline("image-to-text",
21
+ model="Salesforce/blip-image-captioning-base",
22
+ device=0 if torch.cuda.is_available() else -1)
23
+
24
+ logger.info("Loading story model...")
25
+ story_model = pipeline(
26
+ "text-generation",
27
+ model="Tincando/fiction_story_generator",
28
+ device=0 if torch.cuda.is_available() else -1,
29
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
30
+ )
31
+
32
+ logger.info("Loading TTS model...")
33
+ tts_model = pipeline("text-to-audio",
34
+ model="Chan-Y/speecht5_finetuned_tr_commonvoice",
35
+ device=0 if torch.cuda.is_available() else -1)
36
+ tts_tokenizer = AutoTokenizer.from_pretrained("Chan-Y/speecht5_finetuned_tr_commonvoice")
37
+
38
+ return caption_model, story_model, tts_model, tts_tokenizer
39
+
40
+ # ==================== Streamlit 界面配置 ====================
41
+ st.set_page_config(
42
+ page_title="🧸 AI Story Generator Pro",
43
+ page_icon="📖",
44
+ layout="wide",
45
+ initial_sidebar_state="expanded"
46
+ )
47
+
48
+ # ==================== 侧边栏参数设置 ====================
49
+ with st.sidebar:
50
+ st.title("⚙️ 生成参数")
51
+ temperature = st.slider("创意度", 0.5, 1.5, 0.85, step=0.05)
52
+ max_length = st.slider("故事长度", 100, 500, 200)
53
+ story_style = st.selectbox("故事风格", ["童话", "科幻", "冒险"])
54
+ voice_speed = st.slider("语音速度", 0.5, 2.0, 1.0)
55
+
56
+ # ==================== 主界面 ====================
57
+ st.title("🖼️ AI 智能故事生成器")
58
+ st.write("上传图片即可获得定制化故事与语音朗读")
59
+
60
+ # ==================== 文件上传 ====================
61
+ uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "jpeg", "png"])
62
+
63
+ if uploaded_file:
64
+ # ==================== 图像处理 ====================
65
+ col1, col2 = st.columns([1, 2])
66
+ with col1:
67
+ image = Image.open(uploaded_file)
68
+ st.image(image, caption="上传图片", use_column_width=True)
69
+
70
+ # ==================== 生成流程 ====================
71
+ if st.button("开始生成", type="primary"):
72
+ try:
73
+ progress_bar = st.progress(0)
74
+ status_text = st.empty()
75
+
76
+ # 加载模型
77
+ with st.spinner("🔄 正在加载模型..."):
78
+ caption_model, story_model, tts_model, tts_tokenizer = load_models()
79
+ speaker_emb = torch.tensor(
80
+ load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
81
+ ).unsqueeze(0)
82
+ progress_bar.progress(20)
83
+
84
+ # 图像描述生成
85
+ with st.spinner("📷 正在分析图片内容..."):
86
+ caption_result = caption_model(image)
87
+ caption = caption_result[0]['generated_text']
88
+ progress_bar.progress(40)
89
+
90
+ # 故事生成
91
+ with st.spinner("✍️ 正在创作精彩故事..."):
92
+ prompt = f"以{story_style}风格创作儿童故事,主题:{caption}"
93
+ story = story_model(
94
+ prompt,
95
+ temperature=temperature,
96
+ max_length=max_length,
97
+ do_sample=True
98
+ )[0]['generated_text']
99
+ story = re.sub(r'[^.!?]+$', '', story) # 确保完整结尾
100
+ progress_bar.progress(70)
101
+
102
+ # 语音合成
103
+ with st.spinner("🔊 正在生成语音..."):
104
+ chunks = re.split(r'(?<=[.!?]) +', story)
105
+ audio_arrays = []
106
+
107
+ for chunk in chunks:
108
+ inputs = tts_tokenizer(chunk, return_tensors="pt")
109
+ speech = tts_model.generate(
110
+ inputs["input_ids"],
111
+ forward_params={
112
+ "speaker_embeddings": speaker_emb,
113
+ "speed": voice_speed
114
+ }
115
+ )
116
+ audio_arrays.append(speech.numpy())
117
+
118
+ combined = np.concatenate(audio_arrays)
119
+ sf.write("output.wav", combined, samplerate=16000)
120
+ progress_bar.progress(100)
121
+
122
+ # ==================== 结果展示 ====================
123
+ with col2:
124
+ st.subheader("📖 生成故事")
125
+ st.success(story)
126
+
127
+ st.subheader("🔊 语音朗读")
128
+ st.audio("output.wav", format="audio/wav")
129
+
130
+ # 下载功能
131
+ st.download_button(
132
+ label="下载故事文本",
133
+ data=story,
134
+ file_name="generated_story.txt",
135
+ mime="text/plain"
136
+ )
137
+ st.download_button(
138
+ label="下载语音文件",
139
+ data=open("output.wav", "rb"),
140
+ file_name="story_audio.wav",
141
+ mime="audio/wav"
142
+ )
143
+
144
+ except Exception as e:
145
+ st.error(f"生成失败:{str(e)}")
146
+ st.button("重试", on_click=st.cache_resource.clear)