TLH01 commited on
Commit
6a5a7a4
·
verified ·
1 Parent(s): 1394a8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -46
app.py CHANGED
@@ -5,69 +5,48 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  from gtts import gTTS
6
  import io
7
 
8
- # Model loading with cache
9
  @st.cache_resource
10
  def load_models():
11
- img_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
- img_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
- text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
- text_model = GPT2LMHeadModel.from_pretrained("gpt2")
15
- return img_processor, img_model, text_tokenizer, text_model
16
-
17
- def process_image(uploaded_file, processor, model):
18
- img = Image.open(uploaded_file).convert('RGB')
19
- inputs = processor(images=img, return_tensors="pt", padding=True)
20
- outputs = model.generate(**inputs)
21
- return processor.decode(outputs[0], skip_special_tokens=True)
22
-
23
- def generate_story(caption, tokenizer, model):
24
- prompt = f"Create a children's story about {caption} with animals:"
25
- inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
26
- outputs = model.generate(
27
- inputs.input_ids,
28
- max_length=300,
29
- num_return_sequences=1,
30
- temperature=0.7
31
  )
32
- return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
33
-
34
- def text_to_speech(text):
35
- audio_buffer = io.BytesIO()
36
- tts = gTTS(text=text[:300], lang='en')
37
- tts.write_to_fp(audio_buffer)
38
- audio_buffer.seek(0)
39
- return audio_buffer
40
 
41
  def main():
42
- st.title("Children's Story Maker")
 
43
  img_processor, img_model, text_tokenizer, text_model = load_models()
44
 
45
- uploaded_file = st.file_uploader("Upload photo (JPG/PNG)", type=["jpg", "png", "jpeg"])
46
 
47
  if uploaded_file:
48
  st.image(uploaded_file, use_container_width=True)
49
 
50
- with st.status("Processing Pipeline", expanded=True):
51
- # Stage 1: Image Analysis
52
- st.write("🖼️ Analyzing image...")
53
- caption = process_image(uploaded_file, img_processor, img_model)
 
54
 
55
- # Stage 2: Story Generation
56
- st.write("📖 Creating story...")
57
- story = generate_story(caption, text_tokenizer, text_model)
 
 
 
 
58
 
59
- # Stage 3: Audio Conversion
60
- st.write("🔊 Generating audio...")
61
- audio = text_to_speech(story)
 
 
62
 
63
- st.subheader("Results")
64
  st.write(f"**Caption:** {caption}")
65
  st.write(f"**Story:** {story}")
66
  st.audio(audio, format="audio/mp3")
67
-
68
- # Download buttons
69
- st.download_button("Download Story", story, "story.txt")
70
- st.download_button("Download Audio", audio.getvalue(), "story.mp3")
71
 
72
  if __name__ == "__main__":
73
  main()
 
5
  from gtts import gTTS
6
  import io
7
 
 
8
  @st.cache_resource
9
  def load_models():
10
+ return (
11
+ BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
12
+ BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base"),
13
+ GPT2Tokenizer.from_pretrained("gpt2"),
14
+ GPT2LMHeadModel.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  )
 
 
 
 
 
 
 
 
16
 
17
  def main():
18
+ st.title("Stable Story Maker")
19
+
20
  img_processor, img_model, text_tokenizer, text_model = load_models()
21
 
22
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
23
 
24
  if uploaded_file:
25
  st.image(uploaded_file, use_container_width=True)
26
 
27
+ with st.status("Processing"):
28
+ # Stage 1
29
+ img = Image.open(uploaded_file).convert("RGB")
30
+ inputs = img_processor(images=img, return_tensors="pt")
31
+ caption = img_processor.decode(img_model.generate(**inputs)[0], skip_special_tokens=True)
32
 
33
+ # Stage 2
34
+ prompt = f"Children's story about {caption}:"
35
+ inputs = text_tokenizer(prompt, return_tensors="pt")
36
+ story = text_tokenizer.decode(
37
+ text_model.generate(inputs.input_ids, max_length=200)[0],
38
+ skip_special_tokens=True
39
+ ).replace(prompt, "")
40
 
41
+ # Stage 3
42
+ tts = gTTS(text=story[:250], lang='en')
43
+ audio = io.BytesIO()
44
+ tts.write_to_fp(audio)
45
+ audio.seek(0)
46
 
 
47
  st.write(f"**Caption:** {caption}")
48
  st.write(f"**Story:** {story}")
49
  st.audio(audio, format="audio/mp3")
 
 
 
 
50
 
51
  if __name__ == "__main__":
52
  main()