TLH01 commited on
Commit
1394a8a
·
verified ·
1 Parent(s): 796dba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -51
app.py CHANGED
@@ -1,92 +1,73 @@
1
- # app.py
2
  import streamlit as st
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration
5
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
- import torch
7
  from gtts import gTTS
8
  import io
9
 
10
- # Pre-load models during app initialization
11
  @st.cache_resource
12
  def load_models():
13
- # Image captioning model
14
  img_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
  img_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
-
17
- # Story generation model
18
  text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
19
  text_model = GPT2LMHeadModel.from_pretrained("gpt2")
20
-
21
  return img_processor, img_model, text_tokenizer, text_model
22
 
23
- def generate_caption(uploaded_image, processor, model):
24
- img = Image.open(uploaded_image).convert("RGB")
25
- inputs = processor(
26
- images=img,
27
- return_tensors="pt",
28
- padding=True,
29
- truncation=True
30
- )
31
  outputs = model.generate(**inputs)
32
  return processor.decode(outputs[0], skip_special_tokens=True)
33
 
34
- def create_story(caption, tokenizer, model):
35
- prompt = f"""Create a children's story about {caption} with:
36
- 1. Friendly characters
37
- 2. Happy ending
38
- 3. 50-100 words
39
- Story:"""
40
-
41
- inputs = tokenizer(prompt, return_tensors="pt")
42
  outputs = model.generate(
43
  inputs.input_ids,
44
  max_length=300,
45
  num_return_sequences=1,
46
- no_repeat_ngram_size=2,
47
- early_stopping=True
48
  )
49
- return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
50
 
51
- def text_to_audio(text):
52
  audio_buffer = io.BytesIO()
53
- tts = gTTS(text=text, lang='en')
54
  tts.write_to_fp(audio_buffer)
55
  audio_buffer.seek(0)
56
  return audio_buffer
57
 
58
  def main():
59
- st.title("Children's Story Generator")
60
-
61
- # Load models once at startup
62
  img_processor, img_model, text_tokenizer, text_model = load_models()
63
 
64
- uploaded_file = st.file_uploader("Upload a photo", type=["jpg", "png", "jpeg"])
65
 
66
  if uploaded_file:
67
- # Display image with corrected parameter
68
  st.image(uploaded_file, use_container_width=True)
69
 
70
- # Processing pipeline
71
- with st.spinner("Analyzing image..."):
72
- caption = generate_caption(uploaded_file, img_processor, img_model)
73
- st.subheader("Image Analysis")
74
- st.write(f"Detected scene: {caption}")
75
 
76
- with st.spinner("Writing story..."):
77
- story = create_story(caption, text_tokenizer, text_model)
78
- st.subheader("Generated Story")
79
- st.write(story)
80
 
81
- with st.spinner("Creating audio..."):
82
- audio = text_to_audio(story)
83
- st.audio(audio, format="audio/mp3")
84
- st.download_button(
85
- "Download Audio",
86
- data=audio,
87
- file_name="story.mp3",
88
- mime="audio/mp3"
89
- )
 
 
 
90
 
91
  if __name__ == "__main__":
92
  main()
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
5
  from gtts import gTTS
6
  import io
7
 
8
+ # Model loading with cache
9
  @st.cache_resource
10
  def load_models():
 
11
  img_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  img_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
 
13
  text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
  text_model = GPT2LMHeadModel.from_pretrained("gpt2")
 
15
  return img_processor, img_model, text_tokenizer, text_model
16
 
17
+ def process_image(uploaded_file, processor, model):
18
+ img = Image.open(uploaded_file).convert('RGB')
19
+ inputs = processor(images=img, return_tensors="pt", padding=True)
 
 
 
 
 
20
  outputs = model.generate(**inputs)
21
  return processor.decode(outputs[0], skip_special_tokens=True)
22
 
23
+ def generate_story(caption, tokenizer, model):
24
+ prompt = f"Create a children's story about {caption} with animals:"
25
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
 
 
 
 
 
26
  outputs = model.generate(
27
  inputs.input_ids,
28
  max_length=300,
29
  num_return_sequences=1,
30
+ temperature=0.7
 
31
  )
32
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
33
 
34
+ def text_to_speech(text):
35
  audio_buffer = io.BytesIO()
36
+ tts = gTTS(text=text[:300], lang='en')
37
  tts.write_to_fp(audio_buffer)
38
  audio_buffer.seek(0)
39
  return audio_buffer
40
 
41
  def main():
42
+ st.title("Children's Story Maker")
 
 
43
  img_processor, img_model, text_tokenizer, text_model = load_models()
44
 
45
+ uploaded_file = st.file_uploader("Upload photo (JPG/PNG)", type=["jpg", "png", "jpeg"])
46
 
47
  if uploaded_file:
 
48
  st.image(uploaded_file, use_container_width=True)
49
 
50
+ with st.status("Processing Pipeline", expanded=True):
51
+ # Stage 1: Image Analysis
52
+ st.write("🖼️ Analyzing image...")
53
+ caption = process_image(uploaded_file, img_processor, img_model)
 
54
 
55
+ # Stage 2: Story Generation
56
+ st.write("📖 Creating story...")
57
+ story = generate_story(caption, text_tokenizer, text_model)
 
58
 
59
+ # Stage 3: Audio Conversion
60
+ st.write("🔊 Generating audio...")
61
+ audio = text_to_speech(story)
62
+
63
+ st.subheader("Results")
64
+ st.write(f"**Caption:** {caption}")
65
+ st.write(f"**Story:** {story}")
66
+ st.audio(audio, format="audio/mp3")
67
+
68
+ # Download buttons
69
+ st.download_button("Download Story", story, "story.txt")
70
+ st.download_button("Download Audio", audio.getvalue(), "story.mp3")
71
 
72
  if __name__ == "__main__":
73
  main()