aeresd commited on
Commit
dae1dcd
Β·
verified Β·
1 Parent(s): 6f0ab42

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer
3
+ import torch
4
+ import re
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+ import logging
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ==================== Model loading with caching ====================
16
+ @st.cache_resource(show_spinner=False)
17
+ def load_models():
18
+ """Pre-load and cache all models"""
19
+ logger.info("Loading image captioning model...")
20
+ caption_model = pipeline(
21
+ task="image-to-text",
22
+ model="Salesforce/blip-image-captioning-base",
23
+ device=0 if torch.cuda.is_available() else -1
24
+ )
25
+
26
+ logger.info("Loading story generation model...")
27
+ story_model = pipeline(
28
+ task="text-generation",
29
+ model="Tincando/fiction_story_generator",
30
+ device=0 if torch.cuda.is_available() else -1,
31
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
32
+ )
33
+
34
+ logger.info("Loading text-to-speech model...")
35
+ tts_model = pipeline(
36
+ task="text-to-audio",
37
+ model="Chan-Y/speecht5_finetuned_tr_commonvoice",
38
+ device=0 if torch.cuda.is_available() else -1
39
+ )
40
+ tts_tokenizer = AutoTokenizer.from_pretrained(
41
+ "Chan-Y/speecht5_finetuned_tr_commonvoice"
42
+ )
43
+
44
+ return caption_model, story_model, tts_model, tts_tokenizer
45
+
46
+ # ==================== Streamlit page configuration ====================
47
+ st.set_page_config(
48
+ page_title="🧸 AI Story Generator Pro",
49
+ page_icon="πŸ“–",
50
+ layout="wide",
51
+ initial_sidebar_state="expanded"
52
+ )
53
+
54
+ # ==================== Sidebar settings ====================
55
+ with st.sidebar:
56
+ st.title("βš™οΈ Generation Settings")
57
+ temperature = st.slider("Creativity", 0.5, 1.5, 0.85, step=0.05)
58
+ max_length = st.slider("Story Length", 100, 500, 200)
59
+ story_style = st.selectbox("Story Style", ["Fairy Tale", "Sci-Fi", "Adventure"])
60
+ voice_speed = st.slider("Voice Speed", 0.5, 2.0, 1.0)
61
+
62
+ # ==================== Main interface ====================
63
+ st.title("πŸ–ΌοΈ AI Story Generator")
64
+ st.write("Upload an image to get a customized story with audio narration.")
65
+
66
+ # ==================== File upload ====================
67
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
68
+
69
+ if uploaded_file:
70
+ # ==================== Image display ====================
71
+ col1, col2 = st.columns([1, 2])
72
+ with col1:
73
+ image = Image.open(uploaded_file)
74
+ st.image(image, caption="Uploaded Image", use_column_width=True)
75
+
76
+ # ==================== Generation process ====================
77
+ if st.button("Generate Story", type="primary"):
78
+ try:
79
+ progress_bar = st.progress(0)
80
+ status_text = st.empty()
81
+
82
+ # Load models
83
+ with st.spinner("πŸ”„ Loading models..."):
84
+ caption_model, story_model, tts_model, tts_tokenizer = load_models()
85
+ speaker_emb = torch.tensor(
86
+ load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
87
+ ).unsqueeze(0)
88
+ progress_bar.progress(20)
89
+
90
+ # Generate image caption
91
+ with st.spinner("πŸ“· Analyzing image content..."):
92
+ caption_result = caption_model(image)
93
+ caption = caption_result[0]['generated_text']
94
+ progress_bar.progress(40)
95
+
96
+ # Generate story
97
+ with st.spinner("✍️ Writing the story..."):
98
+ prompt = f"Write a children's story in {story_style} style about: {caption}"
99
+ story = story_model(
100
+ prompt,
101
+ temperature=temperature,
102
+ max_length=max_length,
103
+ do_sample=True
104
+ )[0]['generated_text']
105
+ # Ensure story ends with punctuation
106
+ story = re.sub(r'[^.!?]+$', '', story)
107
+ progress_bar.progress(70)
108
+
109
+ # Text-to-speech synthesis
110
+ with st.spinner("πŸ”Š Generating audio..."):
111
+ chunks = re.split(r'(?<=[.!?]) +', story)
112
+ audio_arrays = []
113
+ for chunk in chunks:
114
+ inputs = tts_tokenizer(chunk, return_tensors="pt")
115
+ speech = tts_model.generate(
116
+ inputs["input_ids"],
117
+ forward_params={
118
+ "speaker_embeddings": speaker_emb,
119
+ "speed": voice_speed
120
+ }
121
+ )
122
+ audio_arrays.append(speech.numpy())
123
+ combined = np.concatenate(audio_arrays)
124
+ sf.write("output.wav", combined, samplerate=16000)
125
+ progress_bar.progress(100)
126
+
127
+ # ==================== Display results ====================
128
+ with col2:
129
+ st.subheader("πŸ“– Generated Story")
130
+ st.success(story)
131
+
132
+ st.subheader("πŸ”Š Audio Narration")
133
+ st.audio("output.wav", format="audio/wav")
134
+
135
+ # Download buttons
136
+ st.download_button(
137
+ label="Download Story Text",
138
+ data=story,
139
+ file_name="generated_story.txt",
140
+ mime="text/plain"
141
+ )
142
+ st.download_button(
143
+ label="Download Audio File",
144
+ data=open("output.wav", "rb"),
145
+ file_name="story_audio.wav",
146
+ mime="audio/wav"
147
+ )
148
+
149
+ except Exception as e:
150
+ st.error(f"Generation failed: {str(e)}")
151
+ st.button("Retry", on_click=st.cache_resource.clear)