TLH01 commited on
Commit
870428e
·
verified ·
1 Parent(s): bed9467

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -73
app.py CHANGED
@@ -7,93 +7,86 @@ import torch
7
  from gtts import gTTS
8
  import io
9
 
10
- # ======================
11
- # Stage 1: Image Captioning
12
- # ======================
13
- def image_to_caption(uploaded_image):
14
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
 
17
- try:
18
- img = Image.open(uploaded_image).convert("RGB")
19
- inputs = processor(
20
- images=img,
21
- return_tensors="pt",
22
- padding=True,
23
- truncation=True,
24
- max_length=30
25
- )
26
- outputs = model.generate(**inputs)
27
- return processor.decode(outputs[0], skip_special_tokens=True)
28
- except:
29
- return "a happy scene with children" # Fallback caption
30
-
31
- # ======================
32
- # Stage 2: Story Generation
33
- # ======================
34
- def generate_story(caption):
35
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
36
- model = GPT2LMHeadModel.from_pretrained("gpt2")
37
 
38
- prompt = f"""Create a children's story (3-6 years old) about {caption} with:
39
- 1. Friendly animals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  2. Happy ending
41
  3. 50-100 words
42
  Story:"""
43
 
44
- try:
45
- inputs = tokenizer(prompt, return_tensors="pt")
46
- outputs = model.generate(
47
- inputs.input_ids,
48
- max_length=300,
49
- num_return_sequences=1,
50
- no_repeat_ngram_size=2,
51
- early_stopping=True
52
- )
53
- story = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
- return story.replace(prompt, "").strip()[:500] # Length control
55
- except:
56
- return """Once upon a time, there was a friendly bear who loved playing with children.
57
- They had wonderful adventures every day, always ending with big hugs and happy smiles!"""
58
 
59
- # ======================
60
- # Stage 3: Text-to-Speech
61
- # ======================
62
- def create_audio(story_text):
63
- try:
64
- tts = gTTS(text=story_text[:500], lang='en', slow=False)
65
- audio_buffer = io.BytesIO()
66
- tts.write_to_fp(audio_buffer)
67
- audio_buffer.seek(0)
68
- return audio_buffer
69
- except:
70
- return None # Silent fallback
71
 
72
- # ======================
73
- # Main Application
74
- # ======================
75
  def main():
76
- st.title("🎈 Children's Story Maker")
77
 
78
- uploaded_file = st.file_uploader("Upload a child's photo", type=["jpg", "png"])
 
 
 
79
 
80
  if uploaded_file:
81
- img = Image.open(uploaded_file)
82
- st.image(img, use_column_width=True)
83
 
84
  # Processing pipeline
85
- caption = image_to_caption(uploaded_file)
86
- story = generate_story(caption)
87
-
88
- st.subheader("Generated Story")
89
- st.write(story)
90
-
91
- if audio_data := create_audio(story):
92
- st.audio(audio_data, format="audio/mp3")
93
- st.download_button("Download Audio",
94
- data=audio_data,
95
- file_name="story.mp3",
96
- mime="audio/mp3")
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
  main()
 
7
  from gtts import gTTS
8
  import io
9
 
10
+ # Pre-load models during app initialization
11
+ @st.cache_resource
12
+ def load_models():
13
+ # Image captioning model
14
+ img_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
+ img_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
 
17
+ # Story generation model
18
+ text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
19
+ text_model = GPT2LMHeadModel.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ return img_processor, img_model, text_tokenizer, text_model
22
+
23
+ def generate_caption(uploaded_image, processor, model):
24
+ img = Image.open(uploaded_image).convert("RGB")
25
+ inputs = processor(
26
+ images=img,
27
+ return_tensors="pt",
28
+ padding=True,
29
+ truncation=True
30
+ )
31
+ outputs = model.generate(**inputs)
32
+ return processor.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ def create_story(caption, tokenizer, model):
35
+ prompt = f"""Create a children's story about {caption} with:
36
+ 1. Friendly characters
37
  2. Happy ending
38
  3. 50-100 words
39
  Story:"""
40
 
41
+ inputs = tokenizer(prompt, return_tensors="pt")
42
+ outputs = model.generate(
43
+ inputs.input_ids,
44
+ max_length=300,
45
+ num_return_sequences=1,
46
+ no_repeat_ngram_size=2,
47
+ early_stopping=True
48
+ )
49
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
 
 
 
 
 
50
 
51
+ def text_to_audio(text):
52
+ audio_buffer = io.BytesIO()
53
+ tts = gTTS(text=text, lang='en')
54
+ tts.write_to_fp(audio_buffer)
55
+ audio_buffer.seek(0)
56
+ return audio_buffer
 
 
 
 
 
 
57
 
 
 
 
58
  def main():
59
+ st.title("Children's Story Generator")
60
 
61
+ # Load models once at startup
62
+ img_processor, img_model, text_tokenizer, text_model = load_models()
63
+
64
+ uploaded_file = st.file_uploader("Upload a photo", type=["jpg", "png", "jpeg"])
65
 
66
  if uploaded_file:
67
+ # Display image with corrected parameter
68
+ st.image(uploaded_file, use_container_width=True)
69
 
70
  # Processing pipeline
71
+ with st.spinner("Analyzing image..."):
72
+ caption = generate_caption(uploaded_file, img_processor, img_model)
73
+ st.subheader("Image Analysis")
74
+ st.write(f"Detected scene: {caption}")
75
+
76
+ with st.spinner("Writing story..."):
77
+ story = create_story(caption, text_tokenizer, text_model)
78
+ st.subheader("Generated Story")
79
+ st.write(story)
80
+
81
+ with st.spinner("Creating audio..."):
82
+ audio = text_to_audio(story)
83
+ st.audio(audio, format="audio/mp3")
84
+ st.download_button(
85
+ "Download Audio",
86
+ data=audio,
87
+ file_name="story.mp3",
88
+ mime="audio/mp3"
89
+ )
90
 
91
  if __name__ == "__main__":
92
  main()