TLH01 commited on
Commit
a7bd32c
·
verified ·
1 Parent(s): 5e5ea3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -46
app.py CHANGED
@@ -2,16 +2,15 @@ import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
  import tempfile
5
- import numpy as np
6
  import torch
7
- import soundfile as sf
 
8
 
9
  # ======================
10
  # Stage 1: Image Captioning
11
  # ======================
12
  @st.cache_resource
13
  def load_image_captioner():
14
- """Load BLIP model for image caption generation"""
15
  return pipeline(
16
  "image-to-text",
17
  model="Salesforce/blip-image-captioning-base",
@@ -19,7 +18,6 @@ def load_image_captioner():
19
  )
20
 
21
  def generate_caption(_pipeline, image):
22
- """Generate English description from image"""
23
  try:
24
  result = _pipeline(image, max_new_tokens=50)
25
  return result[0]['generated_text']
@@ -32,89 +30,76 @@ def generate_caption(_pipeline, image):
32
  # ======================
33
  @st.cache_resource
34
  def load_story_generator():
35
- """Load fine-tuned story generator"""
36
  return pipeline(
37
  "text-generation",
38
- model="pranavpsv/gpt2-genre-story-generator",
39
  device="cuda" if torch.cuda.is_available() else "cpu"
40
  )
41
 
42
- def generate_story(_pipeline, keywords):
43
- """Generate children's story based on keywords"""
44
- prompt = f"""Generate a children's story (60-80 words) about: {keywords}
45
- Requirements:
46
- - Use simple English
47
- - Include magical elements
48
- - Happy ending
49
- Story:"""
50
-
 
51
  try:
52
- story = _pipeline(
53
- prompt,
54
- max_length=200,
55
- temperature=0.7
56
- )[0]['generated_text']
57
  return story.replace(prompt, "").strip()
58
  except Exception as e:
59
  st.error(f"Story generation failed: {str(e)}")
60
  return None
61
 
62
  # ======================
63
- # Stage 3: Text-to-Speech
64
  # ======================
65
  @st.cache_resource
66
  def load_tts():
67
- """Load TTS model for audio generation"""
68
- return pipeline(
69
- "text-to-speech",
70
- model="facebook/mms-tts-eng",
71
- device="cuda" if torch.cuda.is_available() else "cpu"
72
- )
73
 
74
- def text_to_speech(_pipeline, text):
75
- """Convert text to speech audio"""
76
  try:
77
- audio = _pipeline(text)
78
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
79
- sf.write(f.name, audio["audio"], audio["sampling_rate"])
80
  return f.name
81
  except Exception as e:
82
  st.error(f"Audio generation failed: {str(e)}")
83
  return None
84
 
85
- # Main App
 
 
86
  def main():
87
  st.set_page_config(page_title="Magic Story Generator", layout="wide")
88
  st.title("🧚 Magic Story Generator")
89
-
90
- uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "png"])
91
  if not uploaded_image:
92
  return
93
 
94
  image = Image.open(uploaded_image)
95
- st.image(image, use_container_width=True) # Fixed deprecated parameter
96
 
97
- # Process stages
98
- with st.spinner("Processing..."):
99
  caption_pipe = load_image_captioner()
100
  story_pipe = load_story_generator()
101
- tts_pipe = load_tts()
102
 
103
- # Stage 1
104
  caption = generate_caption(caption_pipe, image)
105
  if caption:
106
  st.success(f"Image description: {caption}")
107
-
108
- # Stage 2
109
  story = generate_story(story_pipe, caption)
 
110
  if story:
111
- st.subheader("Your Story")
112
- st.markdown(f'<div class="story-box">{story}</div>', unsafe_allow_html=True)
113
 
114
- # Stage 3
115
- audio_path = text_to_speech(tts_pipe, story)
116
  if audio_path:
117
  st.audio(audio_path, format="audio/wav")
118
 
119
  if __name__ == "__main__":
120
- main()
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import tempfile
 
5
  import torch
6
+ from TTS.api import TTS # Coqui TTS
7
+ import os
8
 
9
  # ======================
10
  # Stage 1: Image Captioning
11
  # ======================
12
  @st.cache_resource
13
  def load_image_captioner():
 
14
  return pipeline(
15
  "image-to-text",
16
  model="Salesforce/blip-image-captioning-base",
 
18
  )
19
 
20
  def generate_caption(_pipeline, image):
 
21
  try:
22
  result = _pipeline(image, max_new_tokens=50)
23
  return result[0]['generated_text']
 
30
  # ======================
31
  @st.cache_resource
32
  def load_story_generator():
 
33
  return pipeline(
34
  "text-generation",
35
+ model="pranavpsv/gpt2-genre-story-generator", # 可以替换为更强模型
36
  device="cuda" if torch.cuda.is_available() else "cpu"
37
  )
38
 
39
+ def generate_story(_pipeline, caption):
40
+ prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max).
41
+ The story should:
42
+ - Use simple and friendly language
43
+ - Be related to the content of the image
44
+ - Include a magical or fun twist
45
+ - End happily
46
+
47
+ Story:"""
48
+
49
  try:
50
+ story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text']
 
 
 
 
51
  return story.replace(prompt, "").strip()
52
  except Exception as e:
53
  st.error(f"Story generation failed: {str(e)}")
54
  return None
55
 
56
  # ======================
57
+ # Stage 3: Text-to-Speech using Coqui TTS
58
  # ======================
59
  @st.cache_resource
60
  def load_tts():
61
+ return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
 
 
 
 
 
62
 
63
+ def text_to_speech(tts_model, story_text):
 
64
  try:
 
65
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
66
+ tts_model.tts_to_file(text=story_text, file_path=f.name)
67
  return f.name
68
  except Exception as e:
69
  st.error(f"Audio generation failed: {str(e)}")
70
  return None
71
 
72
+ # ======================
73
+ # Main Streamlit App
74
+ # ======================
75
  def main():
76
  st.set_page_config(page_title="Magic Story Generator", layout="wide")
77
  st.title("🧚 Magic Story Generator")
78
+
79
+ uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"])
80
  if not uploaded_image:
81
  return
82
 
83
  image = Image.open(uploaded_image)
84
+ st.image(image, use_container_width=True)
85
 
86
+ with st.spinner("Processing your magical story..."):
 
87
  caption_pipe = load_image_captioner()
88
  story_pipe = load_story_generator()
89
+ tts_model = load_tts()
90
 
 
91
  caption = generate_caption(caption_pipe, image)
92
  if caption:
93
  st.success(f"Image description: {caption}")
 
 
94
  story = generate_story(story_pipe, caption)
95
+
96
  if story:
97
+ st.subheader("Your Magical Story")
98
+ st.markdown(story)
99
 
100
+ audio_path = text_to_speech(tts_model, story)
 
101
  if audio_path:
102
  st.audio(audio_path, format="audio/wav")
103
 
104
  if __name__ == "__main__":
105
+ main()