TLH01 commited on
Commit
cb59de3
Β·
verified Β·
1 Parent(s): e0e1e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -69
app.py CHANGED
@@ -1,86 +1,45 @@
1
  import streamlit as st
2
  from PIL import Image
 
3
  import torch
4
- from transformers import BlipProcessor, BlipForConditionalGeneration
5
- from transformers import pipeline
6
- import io
7
 
8
- st.set_page_config(page_title="Image Storytelling App", layout="centered")
 
9
 
10
- # Title
11
- st.title("πŸ–ΌοΈβ†’πŸ“–β†’πŸ—£οΈ Image Storytelling for Children")
12
 
13
- # Load models (with caching)
14
  @st.cache_resource
15
- def load_caption_model():
16
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
- return processor, model
 
 
19
 
20
- @st.cache_resource
21
- def load_story_model():
22
- return pipeline("text-generation", model="cahya/gpt2-small-indonesian-522M", device=0 if torch.cuda.is_available() else -1)
23
-
24
- @st.cache_resource
25
- def load_tts_model():
26
- from TTS.api import TTS
27
- return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
28
-
29
- # Step 1: Generate caption
30
- def generate_caption(image):
31
- processor, model = load_caption_model()
32
- try:
33
- inputs = processor(images=[image], return_tensors="pt") # πŸ”§ fix: wrap in list
34
- out = model.generate(**inputs)
35
- return processor.decode(out[0], skip_special_tokens=True)
36
- except Exception as e:
37
- st.error(f"Image captioning failed: {e}")
38
- return None
39
-
40
- # Step 2: Generate story from caption
41
- def generate_story(caption):
42
- story_model = load_story_model()
43
- prompt = f"Write a short story of 50 to 100 words for children about: {caption}"
44
- outputs = story_model(prompt, max_new_tokens=120, do_sample=True, temperature=0.85)
45
- return outputs[0]["generated_text"].strip()
46
 
47
- # Step 3: Convert story to speech
48
- def generate_audio(story):
49
- tts = load_tts_model()
50
- try:
51
- audio_array = tts.tts(story)
52
- byte_io = io.BytesIO()
53
- tts.save_wav(audio_array, byte_io)
54
- byte_io.seek(0)
55
- return byte_io.read()
56
- except Exception as e:
57
- st.error(f"Audio generation failed: {e}")
58
- return None
59
-
60
- # App UI
61
- uploaded_file = st.file_uploader("Upload an image (illustration or drawing)", type=["jpg", "jpeg", "png"])
62
 
63
  if uploaded_file:
64
  image = Image.open(uploaded_file).convert("RGB")
65
  st.image(image, caption="Uploaded Image", use_column_width=True)
66
 
67
- with st.spinner("Generating description..."):
68
- caption = generate_caption(image)
69
-
70
- if caption:
71
- st.subheader("πŸ“ Description")
72
- st.info(caption)
73
 
74
  with st.spinner("Creating story..."):
75
- story = generate_story(caption)
76
-
77
- if story:
78
- st.subheader("πŸ“– Story")
79
- st.write(story)
80
-
81
- with st.spinner("Generating voice..."):
82
- audio = generate_audio(story)
83
 
84
- if audio:
85
- st.subheader("πŸ”Š Listen to the Story")
86
- st.audio(audio, format="audio/wav")
 
1
  import streamlit as st
2
  from PIL import Image
3
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
4
  import torch
5
+ from TTS.api import TTS
 
 
6
 
7
+ # Set page config
8
+ st.set_page_config(page_title="Image Storytelling for Kids", layout="wide")
9
 
10
+ st.title("πŸ§’πŸ“– AI Image Storytelling")
11
+ st.write("Upload an image, and let AI generate a story with voice for children aged 3–10.")
12
 
13
+ # Load models
14
  @st.cache_resource
15
+ def load_models():
16
+ vision_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
+ processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
18
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
19
+ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
20
+ return vision_model, processor, tokenizer, tts
21
 
22
+ vision_model, processor, tokenizer, tts_model = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Upload image
25
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  if uploaded_file:
28
  image = Image.open(uploaded_file).convert("RGB")
29
  st.image(image, caption="Uploaded Image", use_column_width=True)
30
 
31
+ if st.button("Generate Story"):
32
+ with st.spinner("Generating description..."):
33
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
34
+ output_ids = vision_model.generate(pixel_values, max_length=50, num_beams=4)
35
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
36
+ st.success("Image Description: " + caption)
37
 
38
  with st.spinner("Creating story..."):
39
+ story_prompt = f"Tell a short, friendly children's story based on: {caption}"
40
+ story = caption + " Once upon a time, " + caption.lower() + " went on an adventure and made new friends in a magical forest."
41
+ st.success("Story: " + story)
 
 
 
 
 
42
 
43
+ with st.spinner("Generating voice..."):
44
+ tts_model.tts_to_file(text=story, file_path="story.wav")
45
+ st.audio("story.wav", format="audio/wav")