TLH01 commited on
Commit
c75f8e2
·
verified ·
1 Parent(s): 8f279a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -32
app.py CHANGED
@@ -30,32 +30,40 @@ def stage1_process(uploaded_file):
30
  return processor.decode(outputs[0], skip_special_tokens=True)
31
 
32
  # ======================
33
- # Stage 2: Story Generation
34
  # ======================
35
  @st.cache_resource
36
  def load_story_model():
37
- """Load reliable story model"""
38
  return (
39
- AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium"),
40
- AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
41
  )
42
 
43
  def stage2_process(keyword):
44
- """Generate children's story"""
45
  tokenizer, model = load_story_model()
46
- prompt = f"""Write a children's story about {keyword} with animals in 100 words.
47
- Story: Once upon a time, there was a little rabbit named Fluffy who found"""
48
 
49
- inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
 
 
 
 
 
 
 
 
50
  outputs = model.generate(
51
  inputs.input_ids,
52
- max_length=300,
53
- temperature=0.85,
54
  top_k=50,
55
- repetition_penalty=1.2
 
 
56
  )
57
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
- return full_text.replace(prompt, "").strip()
59
 
60
  # ======================
61
  # Stage 3: Text-to-Speech
@@ -63,8 +71,8 @@ Story: Once upon a time, there was a little rabbit named Fluffy who found"""
63
  def stage3_process(text):
64
  """Convert text to audio"""
65
  try:
66
- clean_text = text.strip().replace('\n', ' ')[:200]
67
- if len(clean_text) < 10:
68
  return None
69
  tts = gTTS(text=clean_text, lang='en')
70
  audio = io.BytesIO()
@@ -81,12 +89,11 @@ def main():
81
  st.title("📖 Children's Story Generator")
82
 
83
  # Initialize session state
84
- if 'stage1_done' not in st.session_state:
85
  st.session_state.update({
86
- 'stage1_done': False,
87
- 'stage2_done': False,
88
- 'caption': "",
89
- 'story': ""
90
  })
91
 
92
  # File upload
@@ -97,31 +104,31 @@ def main():
97
  st.image(uploaded_file, width=300)
98
 
99
  # Stage 1
100
- if not st.session_state.stage1_done:
101
  with st.spinner("Analyzing image..."):
102
  st.session_state.caption = stage1_process(uploaded_file)
103
- st.session_state.stage1_done = True
104
  st.success(f"Detected Theme: {st.session_state.caption}")
105
 
106
  # Stage 2
107
- if not st.session_state.stage2_done:
108
- with st.spinner("Writing story..."):
109
  st.session_state.story = stage2_process(st.session_state.caption)
110
- st.session_state.stage2_done = True
111
 
112
- # Display results
113
  if st.session_state.story:
114
  st.subheader("Generated Story")
115
  st.write(st.session_state.story)
116
 
117
  # Stage 3
118
- with st.spinner("Generating audio..."):
119
- audio = stage3_process(st.session_state.story)
120
- if audio:
121
- st.audio(audio, format="audio/mp3")
122
- st.download_button("Download Audio", audio.getvalue(), "story.mp3")
123
- else:
124
- st.warning("Audio generation skipped due to short text")
 
 
125
 
126
  if __name__ == "__main__":
127
  main()
 
30
  return processor.decode(outputs[0], skip_special_tokens=True)
31
 
32
  # ======================
33
+ # Stage 2: Story Generation (Optimized)
34
  # ======================
35
  @st.cache_resource
36
  def load_story_model():
37
+ """Load optimized story model"""
38
  return (
39
+ AutoTokenizer.from_pretrained("gpt2-medium"),
40
+ AutoModelForCausalLM.from_pretrained("gpt2-medium")
41
  )
42
 
43
  def stage2_process(keyword):
44
+ """Generate structured story"""
45
  tokenizer, model = load_story_model()
 
 
46
 
47
+ # Enhanced prompt template
48
+ prompt = f"""Write a children's story in 100-150 words with these elements:
49
+ - Theme: {keyword}
50
+ - Characters: Friendly animals
51
+ - Moral: Sharing is caring
52
+
53
+ Story begins: One sunny morning, a little rabbit named Cotton discovered"""
54
+
55
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=150, truncation=True)
56
  outputs = model.generate(
57
  inputs.input_ids,
58
+ max_new_tokens=300,
59
+ temperature=0.9,
60
  top_k=50,
61
+ no_repeat_ngram_size=3,
62
+ repetition_penalty=1.2,
63
+ do_sample=True
64
  )
65
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ return full_text.split("Story begins:")[-1].strip()
67
 
68
  # ======================
69
  # Stage 3: Text-to-Speech
 
71
  def stage3_process(text):
72
  """Convert text to audio"""
73
  try:
74
+ clean_text = text.strip().replace('\n', ' ')[:300]
75
+ if len(clean_text) < 20:
76
  return None
77
  tts = gTTS(text=clean_text, lang='en')
78
  audio = io.BytesIO()
 
89
  st.title("📖 Children's Story Generator")
90
 
91
  # Initialize session state
92
+ if 'processing' not in st.session_state:
93
  st.session_state.update({
94
+ 'caption': None,
95
+ 'story': None,
96
+ 'audio': None
 
97
  })
98
 
99
  # File upload
 
104
  st.image(uploaded_file, width=300)
105
 
106
  # Stage 1
107
+ if not st.session_state.caption:
108
  with st.spinner("Analyzing image..."):
109
  st.session_state.caption = stage1_process(uploaded_file)
 
110
  st.success(f"Detected Theme: {st.session_state.caption}")
111
 
112
  # Stage 2
113
+ if not st.session_state.story:
114
+ with st.spinner("Writing magical story..."):
115
  st.session_state.story = stage2_process(st.session_state.caption)
 
116
 
117
+ # Display story
118
  if st.session_state.story:
119
  st.subheader("Generated Story")
120
  st.write(st.session_state.story)
121
 
122
  # Stage 3
123
+ if not st.session_state.audio:
124
+ with st.spinner("Generating audio..."):
125
+ st.session_state.audio = stage3_process(st.session_state.story)
126
+ if st.session_state.audio:
127
+ st.audio(st.session_state.audio, format="audio/mp3")
128
+ st.download_button("Download Audio",
129
+ st.session_state.audio.getvalue(),
130
+ "story.mp3",
131
+ mime="audio/mp3")
132
 
133
  if __name__ == "__main__":
134
  main()