TLH01 commited on
Commit
8f279a7
·
verified ·
1 Parent(s): f703e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -32
app.py CHANGED
@@ -34,36 +34,45 @@ def stage1_process(uploaded_file):
34
  # ======================
35
  @st.cache_resource
36
  def load_story_model():
37
- """Load story generation model"""
38
  return (
39
- AutoTokenizer.from_pretrained("prpappas/fairytale-gpt2"),
40
- AutoModelForCausalLM.from_pretrained("prpappas/fairytale-gpt2")
41
  )
42
 
43
  def stage2_process(keyword):
44
  """Generate children's story"""
45
  tokenizer, model = load_story_model()
46
- prompt = f"Write a children's story about {keyword} in 100 words:\n"
47
- inputs = tokenizer(prompt, return_tensors="pt", max_length=50, truncation=True)
 
 
48
  outputs = model.generate(
49
  inputs.input_ids,
50
- max_length=200,
51
  temperature=0.85,
52
  top_k=50,
53
  repetition_penalty=1.2
54
  )
55
- return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
 
56
 
57
  # ======================
58
  # Stage 3: Text-to-Speech
59
  # ======================
60
  def stage3_process(text):
61
  """Convert text to audio"""
62
- tts = gTTS(text=text[:200], lang='en')
63
- audio = io.BytesIO()
64
- tts.write_to_fp(audio)
65
- audio.seek(0)
66
- return audio
 
 
 
 
 
 
67
 
68
  # ======================
69
  # Main Application
@@ -73,44 +82,46 @@ def main():
73
 
74
  # Initialize session state
75
  if 'stage1_done' not in st.session_state:
76
- st.session_state.stage1_done = False
77
- if 'stage2_done' not in st.session_state:
78
- st.session_state.stage2_done = False
 
 
 
79
 
80
- # File upload section
81
  uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
82
 
83
  if uploaded_file:
84
- # Always show image and Stage 1 result
85
  st.image(uploaded_file, width=300)
86
 
87
- # Stage 1 Processing
88
  if not st.session_state.stage1_done:
89
  with st.spinner("Analyzing image..."):
90
- caption = stage1_process(uploaded_file)
91
- st.session_state.caption = caption
92
  st.session_state.stage1_done = True
93
  st.success(f"Detected Theme: {st.session_state.caption}")
94
 
95
- # Stage 2 Processing
96
  if not st.session_state.stage2_done:
97
- with st.spinner("Creating story..."):
98
- story = stage2_process(st.session_state.caption)
99
- st.session_state.story = story
100
  st.session_state.stage2_done = True
101
-
102
- if st.session_state.stage2_done:
 
103
  st.subheader("Generated Story")
104
  st.write(st.session_state.story)
105
 
106
- # Stage 3 Processing
107
  with st.spinner("Generating audio..."):
108
  audio = stage3_process(st.session_state.story)
109
- st.audio(audio, format="audio/mp3")
110
- st.download_button("Download Audio",
111
- data=audio.getvalue(),
112
- file_name="story.mp3",
113
- mime="audio/mp3")
114
 
115
  if __name__ == "__main__":
116
  main()
 
34
  # ======================
35
  @st.cache_resource
36
  def load_story_model():
37
+ """Load reliable story model"""
38
  return (
39
+ AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium"),
40
+ AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
41
  )
42
 
43
  def stage2_process(keyword):
44
  """Generate children's story"""
45
  tokenizer, model = load_story_model()
46
+ prompt = f"""Write a children's story about {keyword} with animals in 100 words.
47
+ Story: Once upon a time, there was a little rabbit named Fluffy who found"""
48
+
49
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
50
  outputs = model.generate(
51
  inputs.input_ids,
52
+ max_length=300,
53
  temperature=0.85,
54
  top_k=50,
55
  repetition_penalty=1.2
56
  )
57
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ return full_text.replace(prompt, "").strip()
59
 
60
  # ======================
61
  # Stage 3: Text-to-Speech
62
  # ======================
63
  def stage3_process(text):
64
  """Convert text to audio"""
65
+ try:
66
+ clean_text = text.strip().replace('\n', ' ')[:200]
67
+ if len(clean_text) < 10:
68
+ return None
69
+ tts = gTTS(text=clean_text, lang='en')
70
+ audio = io.BytesIO()
71
+ tts.write_to_fp(audio)
72
+ audio.seek(0)
73
+ return audio
74
+ except:
75
+ return None
76
 
77
  # ======================
78
  # Main Application
 
82
 
83
  # Initialize session state
84
  if 'stage1_done' not in st.session_state:
85
+ st.session_state.update({
86
+ 'stage1_done': False,
87
+ 'stage2_done': False,
88
+ 'caption': "",
89
+ 'story': ""
90
+ })
91
 
92
+ # File upload
93
  uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
94
 
95
  if uploaded_file:
96
+ # Permanent display
97
  st.image(uploaded_file, width=300)
98
 
99
+ # Stage 1
100
  if not st.session_state.stage1_done:
101
  with st.spinner("Analyzing image..."):
102
+ st.session_state.caption = stage1_process(uploaded_file)
 
103
  st.session_state.stage1_done = True
104
  st.success(f"Detected Theme: {st.session_state.caption}")
105
 
106
+ # Stage 2
107
  if not st.session_state.stage2_done:
108
+ with st.spinner("Writing story..."):
109
+ st.session_state.story = stage2_process(st.session_state.caption)
 
110
  st.session_state.stage2_done = True
111
+
112
+ # Display results
113
+ if st.session_state.story:
114
  st.subheader("Generated Story")
115
  st.write(st.session_state.story)
116
 
117
+ # Stage 3
118
  with st.spinner("Generating audio..."):
119
  audio = stage3_process(st.session_state.story)
120
+ if audio:
121
+ st.audio(audio, format="audio/mp3")
122
+ st.download_button("Download Audio", audio.getvalue(), "story.mp3")
123
+ else:
124
+ st.warning("Audio generation skipped due to short text")
125
 
126
  if __name__ == "__main__":
127
  main()