TLH01 commited on
Commit
ef7e1aa
·
verified ·
1 Parent(s): cae99ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -89
app.py CHANGED
@@ -3,131 +3,114 @@ from PIL import Image
3
  from transformers import (
4
  BlipProcessor,
5
  BlipForConditionalGeneration,
6
- pipeline
 
7
  )
8
  from gtts import gTTS
9
  import io
10
- import logging
11
  import torch
12
 
13
- # Configure logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
  # ======================
18
  # Stage 1: Image Captioning
19
  # ======================
20
  @st.cache_resource
21
  def load_image_model():
22
  """Load image captioning model"""
23
- try:
24
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
26
- logger.info("Stage 1 model loaded")
27
- return processor, model
28
- except Exception as e:
29
- st.error("❌ Failed to load image model")
30
- raise
31
 
32
- def stage1_generate_caption(uploaded_file):
33
  """Generate image caption"""
34
  processor, model = load_image_model()
35
- try:
36
- img = Image.open(uploaded_file).convert("RGB")
37
- img.thumbnail((512, 512))
38
- inputs = processor(images=img, return_tensors="pt", padding=True)
39
- outputs = model.generate(**inputs, max_length=30)
40
- return processor.decode(outputs[0], skip_special_tokens=True)
41
- except Exception as e:
42
- st.error(f"Image processing failed: {str(e)}")
43
- return "children playing"
44
 
45
  # ======================
46
  # Stage 2: Story Generation
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
- """Load Mistral story model"""
51
- try:
52
- pipe = pipeline(
53
- "text-generation",
54
- model="ajibawa-2023/Young-Children-Storyteller-Mistral-7B",
55
- device_map="auto",
56
- torch_dtype=torch.float16
57
- )
58
- logger.info("Stage 2 model loaded")
59
- return pipe
60
- except Exception as e:
61
- st.error("❌ Failed to load story model")
62
- raise
63
 
64
- def stage2_generate_story(keyword):
65
- """Generate story with chat format"""
66
- pipe = load_story_model()
67
- try:
68
- messages = [{
69
- "role": "user",
70
- "content": f"Write a children's story about {keyword} with animals under 100 words"
71
- }]
72
- outputs = pipe(
73
- messages,
74
- max_new_tokens=200,
75
- temperature=0.7,
76
- do_sample=True
77
- )
78
- return outputs[0]['generated_text'][-1]['content']
79
- except Exception as e:
80
- st.error(f"Story generation failed: {str(e)}")
81
- return "The animals had a great time playing together!"
82
 
83
  # ======================
84
  # Stage 3: Text-to-Speech
85
  # ======================
86
- def stage3_generate_audio(text):
87
- """Generate audio with validation"""
88
- try:
89
- clean_text = text.strip().replace('\n', ' ')[:300]
90
- if len(clean_text) < 10:
91
- raise ValueError("Text too short")
92
-
93
- tts = gTTS(text=clean_text, lang='en')
94
- audio_buffer = io.BytesIO()
95
- tts.write_to_fp(audio_buffer)
96
- audio_buffer.seek(0)
97
- return audio_buffer
98
- except Exception as e:
99
- st.error(f"Audio Error: {str(e)}")
100
- return None
101
 
102
  # ======================
103
  # Main Application
104
  # ======================
105
  def main():
106
- st.title("📚 Smart Story Generator")
 
 
 
 
 
 
107
 
108
- uploaded_file = st.file_uploader("Upload Photo (JPG/PNG)", type=["jpg", "png", "jpeg"])
 
109
 
110
  if uploaded_file:
111
- # Stage 1
112
- st.image(uploaded_file, use_container_width=True)
113
- with st.spinner("Analyzing image..."):
114
- caption = stage1_generate_caption(uploaded_file)
115
- st.write(f"✨ Detected Theme: **{caption}**")
116
 
117
- # Stage 2
118
- with st.spinner("Generating story..."):
119
- story = stage2_generate_story(caption)
120
- st.subheader("Generated Story")
121
- st.write(story)
 
 
122
 
123
- # Stage 3 (强制显示)
124
- with st.spinner("Creating audio..."):
125
- audio = stage3_generate_audio(story)
126
- if audio:
 
 
 
 
 
 
 
 
 
 
127
  st.audio(audio, format="audio/mp3")
128
- st.download_button("Download Audio", audio.getvalue(), "story.mp3")
129
- else:
130
- st.warning("Audio generation skipped due to invalid input")
 
131
 
132
  if __name__ == "__main__":
133
  main()
 
3
  from transformers import (
4
  BlipProcessor,
5
  BlipForConditionalGeneration,
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM
8
  )
9
  from gtts import gTTS
10
  import io
 
11
  import torch
12
 
 
 
 
 
13
  # ======================
14
  # Stage 1: Image Captioning
15
  # ======================
16
  @st.cache_resource
17
  def load_image_model():
18
  """Load image captioning model"""
19
+ return (
20
+ BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base"),
21
+ BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
22
+ )
 
 
 
 
23
 
24
+ def stage1_process(uploaded_file):
25
  """Generate image caption"""
26
  processor, model = load_image_model()
27
+ img = Image.open(uploaded_file).convert("RGB")
28
+ inputs = processor(images=img, return_tensors="pt")
29
+ outputs = model.generate(**inputs)
30
+ return processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
31
 
32
  # ======================
33
  # Stage 2: Story Generation
34
  # ======================
35
  @st.cache_resource
36
  def load_story_model():
37
+ """Load story generation model"""
38
+ return (
39
+ AutoTokenizer.from_pretrained("prpappas/fairytale-gpt2"),
40
+ AutoModelForCausalLM.from_pretrained("prpappas/fairytale-gpt2")
41
+ )
 
 
 
 
 
 
 
 
42
 
43
+ def stage2_process(keyword):
44
+ """Generate children's story"""
45
+ tokenizer, model = load_story_model()
46
+ prompt = f"Write a children's story about {keyword} in 100 words:\n"
47
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=50, truncation=True)
48
+ outputs = model.generate(
49
+ inputs.input_ids,
50
+ max_length=200,
51
+ temperature=0.85,
52
+ top_k=50,
53
+ repetition_penalty=1.2
54
+ )
55
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
 
 
 
 
 
56
 
57
  # ======================
58
  # Stage 3: Text-to-Speech
59
  # ======================
60
+ def stage3_process(text):
61
+ """Convert text to audio"""
62
+ tts = gTTS(text=text[:200], lang='en')
63
+ audio = io.BytesIO()
64
+ tts.write_to_fp(audio)
65
+ audio.seek(0)
66
+ return audio
 
 
 
 
 
 
 
 
67
 
68
  # ======================
69
  # Main Application
70
  # ======================
71
  def main():
72
+ st.title("📖 Children's Story Generator")
73
+
74
+ # Initialize session state
75
+ if 'stage1_done' not in st.session_state:
76
+ st.session_state.stage1_done = False
77
+ if 'stage2_done' not in st.session_state:
78
+ st.session_state.stage2_done = False
79
 
80
+ # File upload section
81
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png"])
82
 
83
  if uploaded_file:
84
+ # Always show image and Stage 1 result
85
+ st.image(uploaded_file, width=300)
 
 
 
86
 
87
+ # Stage 1 Processing
88
+ if not st.session_state.stage1_done:
89
+ with st.spinner("Analyzing image..."):
90
+ caption = stage1_process(uploaded_file)
91
+ st.session_state.caption = caption
92
+ st.session_state.stage1_done = True
93
+ st.success(f"Detected Theme: {st.session_state.caption}")
94
 
95
+ # Stage 2 Processing
96
+ if not st.session_state.stage2_done:
97
+ with st.spinner("Creating story..."):
98
+ story = stage2_process(st.session_state.caption)
99
+ st.session_state.story = story
100
+ st.session_state.stage2_done = True
101
+
102
+ if st.session_state.stage2_done:
103
+ st.subheader("Generated Story")
104
+ st.write(st.session_state.story)
105
+
106
+ # Stage 3 Processing
107
+ with st.spinner("Generating audio..."):
108
+ audio = stage3_process(st.session_state.story)
109
  st.audio(audio, format="audio/mp3")
110
+ st.download_button("Download Audio",
111
+ data=audio.getvalue(),
112
+ file_name="story.mp3",
113
+ mime="audio/mp3")
114
 
115
  if __name__ == "__main__":
116
  main()