Spaces:
Sleeping
Sleeping
| #Import part | |
| from transformers import pipeline | |
| import streamlit as st | |
| import torch | |
| # Use function for the implementation | |
| # function part | |
| # img2text | |
| def img2text(img): | |
| image_to_text_model = pipeline("image-to-text", | |
| model="Salesforce/blip-image-captioning-base") | |
| text = image_to_text_model(img)[0]["generated_text"] | |
| return text | |
| # text2story | |
| def text2story(text): | |
| generator = pipeline("text-generation", model="distilbert/distilgpt2") | |
| story_text = generator( | |
| text, | |
| min_length=150, # min_length, # of tokens at least larger than100 | |
| max_length=300, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, # prevent repetition | |
| early_stopping=False # prohibit early stopping | |
| )[0]["generated_text"] | |
| return story_text | |
| # text2audio | |
| def text2audio(story_text): | |
| tts_pipeline = pipeline("text-to-speech", model="facebook/mms-tts-eng") | |
| audio_data = tts_pipeline(story_text) # 直接返回字典 | |
| return audio_data | |
| # tts_pipeline = pipeline("text-to-speech", model="suno/bark-small") | |
| # audio_data = tts_pipeline(story_text) | |
| # audio_buffer = io.BytesIO() | |
| # wavfile.write(audio_buffer, rate=audio_data["sampling_rate"], data=audio_data["audio"]) | |
| # audio_buffer.seek(0) | |
| # return { | |
| # 'audio': audio_buffer.getvalue(), | |
| # 'sampling_rate': audio_data["sampling_rate"] | |
| # } | |
| # processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| # model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | |
| # inputs = processor(text=story_text, return_tensors="pt") | |
| # with torch.no_grad(): | |
| # speech = model.generate(**inputs) | |
| # audio_data = speech.cpu().numpy().squeeze() | |
| # audio_buffer = io.BytesIO() | |
| # wavfile.write(audio_buffer, rate=16000, data=audio_data) # 16kHz 采样率 | |
| # audio_buffer.seek(0) | |
| # return {'audio': audio_buffer.getvalue(), 'sampling_rate': 16000} | |
| # program main part | |
| st.set_page_config(page_title="Your Image to Audio Story", | |
| page_icon="🦜") | |
| st.header("Turn Your Image to Audio Story") | |
| uploaded_file = st.file_uploader("Select an Image...") | |
| if uploaded_file is not None: | |
| print(uploaded_file) | |
| bytes_data = uploaded_file.getvalue() | |
| with open(uploaded_file.name, "wb") as file: | |
| file.write(bytes_data) | |
| st.image(uploaded_file, caption="Uploaded Image", | |
| use_column_width=True) | |
| #Stage 1: Image to Text | |
| st.text('Processing img2text...') | |
| scenario = img2text(uploaded_file.name) | |
| st.write(scenario) | |
| #Stage 2: Text to Story | |
| st.text('Generating a story...') | |
| story = text2story(scenario) | |
| st.write(story) | |
| #Stage 3: Story to Audio data | |
| st.text('Generating audio data...') | |
| audio_data =text2audio(story) | |
| # Play button | |
| if st.button("Play Audio"): | |
| st.audio(audio_data['audio'], | |
| format="audio/wav", | |
| start_time=0, | |
| sample_rate = audio_data['sampling_rate']) | |
| #st.audio("kids_playing_audio.wav") | |