WRX020510 commited on
Commit
7730cd8
·
verified ·
1 Parent(s): 46e6eb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -11
app.py CHANGED
@@ -1,6 +1,8 @@
1
  #Import part
2
  from transformers import pipeline
3
  import streamlit as st
 
 
4
 
5
  # Use function for the implementation
6
 
@@ -14,13 +16,33 @@ def img2text(img):
14
 
15
  # text2story
16
  def text2story(text):
17
- story_text = "" # to be completed,见2025-02-22_class.ipynb
 
 
 
 
 
18
  return story_text
19
 
 
20
  # text2audio
21
  def text2audio(story_text):
22
- audio_data = "" # to be completed, 直接在task中指定
23
- return audio_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # program main part
26
 
@@ -45,19 +67,19 @@ if uploaded_file is not None:
45
 
46
  #Stage 2: Text to Story
47
  st.text('Generating a story...')
48
- #story = text2story(scenario)
49
- #st.write(story)
50
 
51
  #Stage 3: Story to Audio data
52
- #st.text('Generating audio data...')
53
- #audio_data =text2audio(story)
54
 
55
 
56
  # Play button
57
  if st.button("Play Audio"):
58
- #st.audio(audio_data['audio'],
59
- # format="audio/wav",
60
- # start_time=0,
61
- # sample_rate = audio_data['sampling_rate'])
62
  st.audio("kids_playing_audio.wav")
63
 
 
1
  #Import part
2
  from transformers import pipeline
3
  import streamlit as st
4
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
5
+ import torch
6
 
7
  # Use function for the implementation
8
 
 
16
 
17
  # text2story
18
  def text2story(text):
19
+ generator = pipeline("text-to-story",
20
+ model="distilbert/distilgpt2")
21
+ story_text = generator(text,
22
+ min_length=100,
23
+ max_length=150,
24
+ num_return_sequences=1)
25
  return story_text
26
 
27
+
28
  # text2audio
29
  def text2audio(story_text):
30
+
31
+ processor = SpeechT5Processor.from_pretrained("facebook/fastspeech2-en-ljspeech")
32
+ model = SpeechT5ForTextToSpeech.from_pretrained("facebook/fastspeech2-en-ljspeech")
33
+
34
+ inputs = processor(story_text, return_tensors="pt")
35
+ with torch.no_grad():
36
+ speech = model.generate_speech(inputs["input_ids"], model.config.vocoder)
37
+
38
+ audio_buffer = io.BytesIO()
39
+ sf.write(audio_buffer, speech.numpy(), samplerate=22050, format='WAV')
40
+ audio_buffer.seek(0)
41
+
42
+ return {
43
+ 'audio': audio_buffer.getvalue(),
44
+ 'sampling_rate': 22050
45
+ }
46
 
47
  # program main part
48
 
 
67
 
68
  #Stage 2: Text to Story
69
  st.text('Generating a story...')
70
+ story = text2story(scenario)
71
+ st.write(story)
72
 
73
  #Stage 3: Story to Audio data
74
+ st.text('Generating audio data...')
75
+ audio_data =text2audio(story)
76
 
77
 
78
  # Play button
79
  if st.button("Play Audio"):
80
+ st.audio(audio_data['audio'],
81
+ format="audio/wav",
82
+ start_time=0,
83
+ sample_rate = audio_data['sampling_rate'])
84
  st.audio("kids_playing_audio.wav")
85