Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline, AutoTokenizer | |
| import torch | |
| import re | |
| import numpy as np | |
| import soundfile as sf | |
| from PIL import Image | |
| from datasets import load_dataset | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ==================== Model loading with caching ==================== | |
| def load_models(): | |
| """Pre-load and cache all models""" | |
| logger.info("Loading image captioning model...") | |
| caption_model = pipeline( | |
| task="image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| logger.info("Loading story generation model...") | |
| story_model = pipeline( | |
| task="text-generation", | |
| model="Tincando/fiction_story_generator", | |
| device=0 if torch.cuda.is_available() else -1, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| logger.info("Loading text-to-speech model...") | |
| tts_model = pipeline( | |
| task="text-to-audio", | |
| model="Chan-Y/speecht5_finetuned_tr_commonvoice", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| tts_tokenizer = AutoTokenizer.from_pretrained( | |
| "Chan-Y/speecht5_finetuned_tr_commonvoice" | |
| ) | |
| return caption_model, story_model, tts_model, tts_tokenizer | |
| # ==================== Streamlit page configuration ==================== | |
| st.set_page_config( | |
| page_title="π§Έ AI Story Generator Pro", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # ==================== Sidebar settings ==================== | |
| with st.sidebar: | |
| st.title("βοΈ Generation Settings") | |
| temperature = st.slider("Creativity", 0.5, 1.5, 0.85, step=0.05) | |
| max_length = st.slider("Story Length", 100, 500, 200) | |
| story_style = st.selectbox("Story Style", ["Fairy Tale", "Sci-Fi", "Adventure"]) | |
| voice_speed = st.slider("Voice Speed", 0.5, 2.0, 1.0) | |
| # ==================== Main interface ==================== | |
| st.title("πΌοΈ AI Story Generator") | |
| st.write("Upload an image to get a customized story with audio narration.") | |
| # ==================== File upload ==================== | |
| uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| # ==================== Image display ==================== | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| # ==================== Generation process ==================== | |
| if st.button("Generate Story", type="primary"): | |
| try: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Load models | |
| with st.spinner("π Loading models..."): | |
| caption_model, story_model, tts_model, tts_tokenizer = load_models() | |
| speaker_emb = torch.tensor( | |
| load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
| ).unsqueeze(0) | |
| progress_bar.progress(20) | |
| # Generate image caption | |
| with st.spinner("π· Analyzing image content..."): | |
| caption_result = caption_model(image) | |
| caption = caption_result[0]['generated_text'] | |
| progress_bar.progress(40) | |
| # Generate story | |
| with st.spinner("βοΈ Writing the story..."): | |
| prompt = f"Write a children's story in {story_style} style about: {caption}" | |
| story = story_model( | |
| prompt, | |
| temperature=temperature, | |
| max_length=max_length, | |
| do_sample=True | |
| )[0]['generated_text'] | |
| # Ensure story ends with punctuation | |
| story = re.sub(r'[^.!?]+$', '', story) | |
| progress_bar.progress(70) | |
| # Text-to-speech synthesis | |
| with st.spinner("π Generating audio..."): | |
| chunks = re.split(r'(?<=[.!?]) +', story) | |
| audio_arrays = [] | |
| for chunk in chunks: | |
| inputs = tts_tokenizer(chunk, return_tensors="pt") | |
| speech = tts_model.generate( | |
| inputs["input_ids"], | |
| forward_params={ | |
| "speaker_embeddings": speaker_emb, | |
| "speed": voice_speed | |
| } | |
| ) | |
| audio_arrays.append(speech.numpy()) | |
| combined = np.concatenate(audio_arrays) | |
| sf.write("output.wav", combined, samplerate=16000) | |
| progress_bar.progress(100) | |
| # ==================== Display results ==================== | |
| with col2: | |
| st.subheader("π Generated Story") | |
| st.success(story) | |
| st.subheader("π Audio Narration") | |
| st.audio("output.wav", format="audio/wav") | |
| # Download buttons | |
| st.download_button( | |
| label="Download Story Text", | |
| data=story, | |
| file_name="generated_story.txt", | |
| mime="text/plain" | |
| ) | |
| st.download_button( | |
| label="Download Audio File", | |
| data=open("output.wav", "rb"), | |
| file_name="story_audio.wav", | |
| mime="audio/wav" | |
| ) | |
| except Exception as e: | |
| st.error(f"Generation failed: {str(e)}") | |
| st.button("Retry", on_click=st.cache_resource.clear) | |