TLH01 commited on
Commit
fb3ff7f
Β·
verified Β·
1 Parent(s): 4b5a116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -69
app.py CHANGED
@@ -1,86 +1,86 @@
1
- # app.py
2
-
3
  import streamlit as st
4
  from PIL import Image
5
- from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
- import pyttsx3
 
8
  import io
9
 
10
- # ----------- Stage 1: Image to Description -----------
 
 
 
11
 
 
12
  @st.cache_resource
13
  def load_caption_model():
14
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
  return processor, model
17
 
18
- def generate_caption(image):
19
- processor, model = load_caption_model()
20
- inputs = processor(images=image, return_tensors="pt")
21
- out = model.generate(**inputs)
22
- return processor.decode(out[0], skip_special_tokens=True)
23
-
24
- # ----------- Stage 2: Description to Story -----------
25
-
26
  @st.cache_resource
27
  def load_story_model():
28
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
29
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
30
- return tokenizer, model
31
-
32
- def generate_story(description):
33
- tokenizer, model = load_story_model()
34
- prompt = (
35
- f"Write a short and fun story (50-100 words) for children based on the following: {description}\n\n"
36
- "Story:"
37
- )
38
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
39
- output = model.generate(**inputs, max_new_tokens=120, do_sample=True, top_k=50, top_p=0.95)
40
- story = tokenizer.decode(output[0], skip_special_tokens=True)
41
- return story.split("Story:")[-1].strip()
42
-
43
- # ----------- Stage 3: Story to Speech -----------
44
-
45
- def generate_speech(story):
46
- engine = pyttsx3.init()
47
- engine.setProperty('rate', 150)
48
- engine.setProperty('volume', 0.9)
49
-
50
- with io.BytesIO() as audio:
51
- engine.save_to_file(story, 'temp.mp3')
52
- engine.runAndWait()
53
- with open('temp.mp3', 'rb') as f:
54
- audio_bytes = f.read()
55
- return audio_bytes
56
-
57
- # ----------- Streamlit Interface -----------
58
-
59
- st.set_page_config(page_title="Children's Story Generator", layout="centered")
60
-
61
- st.title("πŸ“– Children's Storytelling from Images")
62
- st.markdown("Upload an illustration and we'll turn it into a fun story with voice narration!")
63
-
64
- uploaded_image = st.file_uploader("Upload a drawing or illustration", type=["jpg", "jpeg", "png"])
65
-
66
- if uploaded_image:
67
- image = Image.open(uploaded_image)
 
 
 
68
  st.image(image, caption="Uploaded Image", use_column_width=True)
69
 
70
- # Stage 1
71
  with st.spinner("Generating description..."):
72
- description = generate_caption(image)
73
- st.success("βœ… Description Generated!")
74
- st.markdown(f"**Image Caption:** _{description}_")
75
-
76
- # Stage 2
77
- with st.spinner("Generating children's story..."):
78
- story = generate_story(description)
79
- st.success("βœ… Story Generated!")
80
- st.markdown("**Generated Story:**")
81
- st.write(story)
82
-
83
- # Stage 3
84
- with st.spinner("Generating voice..."):
85
- audio_data = generate_speech(story)
86
- st.audio(audio_data, format='audio/mp3')
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
 
3
  import torch
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from transformers import pipeline
6
  import io
7
 
8
+ st.set_page_config(page_title="Image Storytelling App", layout="centered")
9
+
10
+ # Title
11
+ st.title("πŸ–ΌοΈβ†’πŸ“–β†’πŸ—£οΈ Image Storytelling for Children")
12
 
13
+ # Load models (with caching)
14
  @st.cache_resource
15
  def load_caption_model():
16
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
18
  return processor, model
19
 
 
 
 
 
 
 
 
 
20
  @st.cache_resource
21
  def load_story_model():
22
+ return pipeline("text-generation", model="cahya/gpt2-small-indonesian-522M", device=0 if torch.cuda.is_available() else -1)
23
+
24
+ @st.cache_resource
25
+ def load_tts_model():
26
+ from TTS.api import TTS
27
+ return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
28
+
29
+ # Step 1: Generate caption
30
+ def generate_caption(image):
31
+ processor, model = load_caption_model()
32
+ try:
33
+ inputs = processor(images=[image], return_tensors="pt") # πŸ”§ fix: wrap in list
34
+ out = model.generate(**inputs)
35
+ return processor.decode(out[0], skip_special_tokens=True)
36
+ except Exception as e:
37
+ st.error(f"Image captioning failed: {e}")
38
+ return None
39
+
40
+ # Step 2: Generate story from caption
41
+ def generate_story(caption):
42
+ story_model = load_story_model()
43
+ prompt = f"Write a short story of 50 to 100 words for children about: {caption}"
44
+ outputs = story_model(prompt, max_new_tokens=120, do_sample=True, temperature=0.85)
45
+ return outputs[0]["generated_text"].strip()
46
+
47
+ # Step 3: Convert story to speech
48
+ def generate_audio(story):
49
+ tts = load_tts_model()
50
+ try:
51
+ audio_array = tts.tts(story)
52
+ byte_io = io.BytesIO()
53
+ tts.save_wav(audio_array, byte_io)
54
+ byte_io.seek(0)
55
+ return byte_io.read()
56
+ except Exception as e:
57
+ st.error(f"Audio generation failed: {e}")
58
+ return None
59
+
60
+ # App UI
61
+ uploaded_file = st.file_uploader("Upload an image (illustration or drawing)", type=["jpg", "jpeg", "png"])
62
+
63
+ if uploaded_file:
64
+ image = Image.open(uploaded_file).convert("RGB")
65
  st.image(image, caption="Uploaded Image", use_column_width=True)
66
 
 
67
  with st.spinner("Generating description..."):
68
+ caption = generate_caption(image)
69
+
70
+ if caption:
71
+ st.subheader("πŸ“ Description")
72
+ st.info(caption)
73
+
74
+ with st.spinner("Creating story..."):
75
+ story = generate_story(caption)
76
+
77
+ if story:
78
+ st.subheader("πŸ“– Story")
79
+ st.write(story)
80
+
81
+ with st.spinner("Generating voice..."):
82
+ audio = generate_audio(story)
83
+
84
+ if audio:
85
+ st.subheader("πŸ”Š Listen to the Story")
86
+ st.audio(audio, format="audio/wav")