TLH01 commited on
Commit
b895c35
·
verified ·
1 Parent(s): 78e742e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -29
app.py CHANGED
@@ -19,14 +19,14 @@ logger = logging.getLogger(__name__)
19
  # ======================
20
  @st.cache_resource
21
  def load_image_model():
22
- """Load official Hugging Face image captioning model"""
23
  try:
24
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
26
  logger.info("Stage 1 model loaded")
27
  return processor, model
28
  except Exception as e:
29
- st.error("❌ 图像模型加载失败,请检查网络连接")
30
  raise
31
 
32
  def stage1_generate_caption(uploaded_file):
@@ -34,12 +34,12 @@ def stage1_generate_caption(uploaded_file):
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
- img.thumbnail((512, 512)) # Resize for speed
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
41
  except Exception as e:
42
- st.error(f"图像处理失败: {str(e)}")
43
  return "children playing"
44
 
45
  # ======================
@@ -47,27 +47,27 @@ def stage1_generate_caption(uploaded_file):
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
- """Load Microsoft DialoGPT model"""
51
  try:
52
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
53
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
54
  logger.info("Stage 2 model loaded")
55
  return tokenizer, model
56
  except Exception as e:
57
- st.error("❌ 故事模型加载失败,请检查模型名称")
58
  raise
59
 
60
  def stage2_generate_story(keyword):
61
- """Generate children's story"""
62
  tokenizer, model = load_story_model()
63
 
64
  # Optimized prompt template
65
- prompt = f"""写一个儿童故事,包含以下要素:
66
- - 主题: {keyword}
67
- - 角色: 小动物
68
- - 字数: 100字左右
69
 
70
- 故事开头: 有一天,小熊嘟嘟在公园里发现"""
71
 
72
  try:
73
  inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
@@ -76,14 +76,13 @@ def stage2_generate_story(keyword):
76
  max_length=300,
77
  temperature=0.9,
78
  top_k=50,
79
- repetition_penalty=1.2,
80
- pad_token_id=tokenizer.eos_token_id
81
  )
82
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
  return full_text.replace(prompt, "").strip()
84
  except Exception as e:
85
- st.error(f"故事生成失败: {str(e)}")
86
- return "小熊和朋友们玩得很开心!"
87
 
88
  # ======================
89
  # Stage 3: Text-to-Speech
@@ -91,43 +90,43 @@ def stage2_generate_story(keyword):
91
  def stage3_generate_audio(text):
92
  """Convert text to audio"""
93
  try:
94
- tts = gTTS(text=text[:300], lang='zh-CN') # Chinese support
95
  audio_buffer = io.BytesIO()
96
  tts.write_to_fp(audio_buffer)
97
  audio_buffer.seek(0)
98
  return audio_buffer
99
  except Exception as e:
100
- st.error(f"语音生成失败: {str(e)}")
101
  return None
102
 
103
  # ======================
104
  # Main Application
105
  # ======================
106
  def main():
107
- st.title("📚 智能故事生成器")
108
 
109
- uploaded_file = st.file_uploader("上传儿童照片", type=["jpg", "png", "jpeg"])
110
 
111
  if uploaded_file:
112
  # Stage 1
113
  st.image(uploaded_file, use_container_width=True)
114
- with st.spinner("正在分析图片..."):
115
  caption = stage1_generate_caption(uploaded_file)
116
- st.write(f"✨ 识别主题: **{caption}**")
117
 
118
  # Stage 2
119
- with st.spinner("正在生成故事..."):
120
  story = stage2_generate_story(caption)
121
- st.subheader("生成故事")
122
  st.write(story)
123
 
124
  # Stage 3
125
- if len(story) > 10: # Minimum length check
126
- with st.spinner("正在生成语音..."):
127
  audio = stage3_generate_audio(story)
128
  if audio:
129
  st.audio(audio, format="audio/mp3")
130
- st.download_button("下载语音", audio.getvalue(), "story.mp3")
131
 
132
  if __name__ == "__main__":
133
  main()
 
19
  # ======================
20
  @st.cache_resource
21
  def load_image_model():
22
+ """Load official image captioning model"""
23
  try:
24
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
26
  logger.info("Stage 1 model loaded")
27
  return processor, model
28
  except Exception as e:
29
+ st.error("❌ Image model failed to load")
30
  raise
31
 
32
  def stage1_generate_caption(uploaded_file):
 
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
+ img.thumbnail((512, 512))
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
41
  except Exception as e:
42
+ st.error(f"Image processing failed: {str(e)}")
43
  return "children playing"
44
 
45
  # ======================
 
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
+ """Load story generation model"""
51
  try:
52
+ tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt-genre-story-generator")
53
+ model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt-genre-story-generator", use_auth_token=True)
54
  logger.info("Stage 2 model loaded")
55
  return tokenizer, model
56
  except Exception as e:
57
+ st.error("❌ Story model failed to load")
58
  raise
59
 
60
  def stage2_generate_story(keyword):
61
+ """Generate structured story"""
62
  tokenizer, model = load_story_model()
63
 
64
  # Optimized prompt template
65
+ prompt = f"""Generate a children's story with:
66
+ - Theme: {keyword}
67
+ - Characters: Animals
68
+ - Word count: 100 words
69
 
70
+ Story: Once upon a time, a little bear named Honey discovered"""
71
 
72
  try:
73
  inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
 
76
  max_length=300,
77
  temperature=0.9,
78
  top_k=50,
79
+ repetition_penalty=1.2
 
80
  )
81
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
  return full_text.replace(prompt, "").strip()
83
  except Exception as e:
84
+ st.error(f"Story generation failed: {str(e)}")
85
+ return "The animals had a wonderful adventure!"
86
 
87
  # ======================
88
  # Stage 3: Text-to-Speech
 
90
  def stage3_generate_audio(text):
91
  """Convert text to audio"""
92
  try:
93
+ tts = gTTS(text=text[:300], lang='en')
94
  audio_buffer = io.BytesIO()
95
  tts.write_to_fp(audio_buffer)
96
  audio_buffer.seek(0)
97
  return audio_buffer
98
  except Exception as e:
99
+ st.error(f"Audio generation failed: {str(e)}")
100
  return None
101
 
102
  # ======================
103
  # Main Application
104
  # ======================
105
  def main():
106
+ st.title("📚 Smart Story Generator")
107
 
108
+ uploaded_file = st.file_uploader("Upload Photo (JPG/PNG)", type=["jpg", "png", "jpeg"])
109
 
110
  if uploaded_file:
111
  # Stage 1
112
  st.image(uploaded_file, use_container_width=True)
113
+ with st.spinner("Analyzing image..."):
114
  caption = stage1_generate_caption(uploaded_file)
115
+ st.write(f"✨ Detected Theme: **{caption}**")
116
 
117
  # Stage 2
118
+ with st.spinner("Generating story..."):
119
  story = stage2_generate_story(caption)
120
+ st.subheader("Generated Story")
121
  st.write(story)
122
 
123
  # Stage 3
124
+ if len(story) > 20:
125
+ with st.spinner("Creating audio..."):
126
  audio = stage3_generate_audio(story)
127
  if audio:
128
  st.audio(audio, format="audio/mp3")
129
+ st.download_button("Download Audio", audio.getvalue(), "story.mp3")
130
 
131
  if __name__ == "__main__":
132
  main()