shingguy1 commited on
Commit
19b4c5e
Β·
verified Β·
1 Parent(s): 58abad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -59
app.py CHANGED
@@ -1,63 +1,128 @@
 
 
 
1
  import streamlit as st
 
2
  from PIL import Image
3
- from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
4
- from gtts import gTTS
5
- import os
6
- import tempfile
7
 
8
- # Load models
9
  @st.cache_resource
10
- def load_models():
11
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
- gpt2_pipeline = pipeline("text-generation", model="gpt2")
14
- return processor, blip_model, gpt2_pipeline
15
-
16
- processor, blip_model, gpt2 = load_models()
17
-
18
- # UI
19
- st.title("πŸ–ΌοΈπŸ“– Storyteller for Kids")
20
- st.write("Upload an image and let the app create and read a magical story just for kids!")
21
-
22
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
23
-
24
- if uploaded_file:
25
- image = Image.open(uploaded_file).convert("RGB")
26
- st.image(image, caption="Uploaded Image", use_column_width=True)
27
-
28
- with st.spinner("Generating image caption..."):
29
- inputs = processor(images=image, return_tensors="pt")
30
- out = blip_model.generate(**inputs)
31
- caption = processor.decode(out[0], skip_special_tokens=True)
32
- st.success("Caption generated!")
33
- st.write(f"**Caption:** {caption}")
34
-
35
- with st.spinner("Writing a children's story..."):
36
- prompt = f"Write a short, imaginative story for children aged 3-10 about this: {caption}"
37
- story_output = gpt2(
38
- prompt,
39
- max_length=100,
40
- num_return_sequences=1,
41
- do_sample=True,
42
- temperature=0.9,
43
- top_p=0.95,
44
- top_k=50,
45
- repetition_penalty=1.2,
46
- pad_token_id=50256,
47
- eos_token_id=50256,
48
- )[0]["generated_text"]
49
- story = story_output.strip().replace('\n', ' ')
50
- # Truncate to ~100 words for safety
51
- story = " ".join(story.split()[:100])
52
- st.success("Story created!")
53
- st.write(f"**Story:**\n\n{story}")
54
-
55
- with st.spinner("Converting story to audio..."):
56
- try:
57
- tts = gTTS(text=story, lang='en')
58
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
59
- tts.save(fp.name)
60
- st.audio(fp.name, format="audio/mp3")
61
- st.success("Audio playback ready!")
62
- except Exception as e:
63
- st.error(f"Text-to-speech failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import wave
4
  import streamlit as st
5
+ from transformers import pipeline
6
  from PIL import Image
7
+ import numpy as np
 
 
 
8
 
9
+ # β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
10
  @st.cache_resource
11
+ def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
12
+ return pipeline("image-to-text", model=model_name, device="cpu")
13
+
14
+ @st.cache_resource
15
+ def get_story_pipe(model_name="google/flan-t5-base"):
16
+ return pipeline("text2text-generation", model=model_name, device="cpu")
17
+
18
+ @st.cache_resource
19
+ def get_tts_pipe(model_name="facebook/mms-tts-eng"):
20
+ return pipeline("text-to-speech", model=model_name, device="cpu")
21
+
22
+ # β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
23
+ def part1_image_to_text(pil_img, captioner):
24
+ results = captioner(pil_img)
25
+ return results[0].get("generated_text", "") if results else ""
26
+
27
+ def part2_text_to_story(
28
+ caption: str,
29
+ story_pipe,
30
+ target_words: int = 100,
31
+ max_length: int = 100,
32
+ min_length: int = 80,
33
+ do_sample: bool = True,
34
+ top_k: int = 100,
35
+ top_p: float= 0.9,
36
+ temperature: float= 0.7,
37
+ repetition_penalty: float = 1.1,
38
+ no_repeat_ngram_size: int = 4
39
+ ) -> str:
40
+ prompt = (
41
+ f"Write a vivid, imaginative short story of about {target_words} words "
42
+ f"describing this scene: {caption}"
43
+ )
44
+ out = story_pipe(
45
+ prompt,
46
+ max_length=max_length,
47
+ min_length=min_length,
48
+ do_sample=do_sample,
49
+ top_k=top_k,
50
+ top_p=top_p,
51
+ temperature=temperature,
52
+ repetition_penalty=repetition_penalty,
53
+ no_repeat_ngram_size=no_repeat_ngram_size,
54
+ early_stopping=False
55
+ )
56
+ raw = out[0].get("generated_text", "").strip()
57
+ if not raw:
58
+ return ""
59
+ # strip echo of prompt
60
+ if raw.lower().startswith(prompt.lower()):
61
+ story = raw[len(prompt):].strip()
62
+ else:
63
+ story = raw
64
+ # cut at last full stop
65
+ idx = story.rfind(".")
66
+ if idx != -1:
67
+ story = story[:idx+1]
68
+ return story
69
+
70
+ def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
71
+ out = tts_pipe(text)
72
+ if isinstance(out, list):
73
+ out = out[0]
74
+ audio_array = out["audio"] # np.ndarray (channels, samples)
75
+ rate = out["sampling_rate"] # int
76
+ data = audio_array.T if audio_array.ndim == 2 else audio_array
77
+ pcm = (data * 32767).astype(np.int16)
78
+
79
+ buffer = io.BytesIO()
80
+ wf = wave.open(buffer, "wb")
81
+ channels = 1 if data.ndim == 1 else data.shape[1]
82
+ wf.setnchannels(channels)
83
+ wf.setsampwidth(2)
84
+ wf.setframerate(rate)
85
+ wf.writeframes(pcm.tobytes())
86
+ wf.close()
87
+ buffer.seek(0)
88
+ return buffer.read()
89
+
90
+ # β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
91
+ st.set_page_config(
92
+ page_title="Image→Story→Speech",
93
+ page_icon="πŸ–ΌοΈπŸŽ€",
94
+ layout="centered"
95
+ )
96
+ st.title("πŸ–ΌοΈ ➑️ πŸ“– ➑️ πŸŽ™οΈ Image β†’ Story β†’ Speech")
97
+
98
+ uploaded = st.file_uploader("1️⃣ Upload an image", type=["jpg","jpeg","png"])
99
+ if not uploaded:
100
+ st.info("Please upload an image to begin.")
101
+ st.stop()
102
+
103
+ # Show image
104
+ with st.spinner("Rendering image..."):
105
+ pil_img = Image.open(uploaded)
106
+ st.image(pil_img, use_container_width=True)
107
+
108
+ # Generate caption
109
+ captioner = get_image_captioner()
110
+ with st.spinner("Generating caption..."):
111
+ caption = part1_image_to_text(pil_img, captioner)
112
+ st.markdown(f"**Caption:** {caption}")
113
+
114
+ # Generate story & play audio
115
+ if st.button("πŸ“ Generate Story & Play Audio"):
116
+ # Story
117
+ story_pipe = get_story_pipe()
118
+ with st.spinner("Generating story..."):
119
+ story = part2_text_to_story(caption, story_pipe)
120
+ st.markdown("**Story:**")
121
+ st.write(story)
122
+
123
+ # TTS
124
+ tts_pipe = get_tts_pipe()
125
+ with st.spinner("Synthesizing speech..."):
126
+ audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
127
+ st.audio(audio_bytes, format="audio/wav")
128
+ st.success("All done!")