TLH01 commited on
Commit
258921e
·
verified ·
1 Parent(s): 23ad0fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -3,12 +3,12 @@ from PIL import Image
3
  from transformers import (
4
  BlipProcessor,
5
  BlipForConditionalGeneration,
6
- AutoTokenizer,
7
- AutoModelForCausalLM
8
  )
9
  from gtts import gTTS
10
  import io
11
  import logging
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
19
  # ======================
20
  @st.cache_resource
21
  def load_image_model():
22
- """Load official image captioning model"""
23
  try:
24
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
@@ -34,7 +34,7 @@ def stage1_generate_caption(uploaded_file):
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
- img.thumbnail((512, 512)) # Optimize image size
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
@@ -47,57 +47,56 @@ def stage1_generate_caption(uploaded_file):
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
- """Load reliable story generation model"""
51
  try:
52
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
53
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
 
54
  logger.info("Stage 2 model loaded")
55
- return tokenizer, model
56
  except Exception as e:
57
  st.error("❌ Failed to load story model")
58
  raise
59
 
60
  def stage2_generate_story(keyword):
61
- """Generate structured story"""
62
- tokenizer, model = load_story_model()
63
-
64
- # Optimized prompt template
65
- prompt = f"""Write a children's story with:
66
- - Theme: {keyword}
67
- - Characters: Animals
68
- - Length: 100 words
69
-
70
- Story: Once upon a time, a little bear named Honey found"""
71
-
72
  try:
73
- inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
74
- outputs = model.generate(
75
- inputs.input_ids,
76
- max_length=300,
77
- temperature=0.85,
78
- top_k=50,
79
- repetition_penalty=1.2,
80
- pad_token_id=tokenizer.eos_token_id
 
81
  )
82
- full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
- return full_text.replace(prompt, "").strip()
84
  except Exception as e:
85
  st.error(f"Story generation failed: {str(e)}")
86
- return "The animals had a wonderful day playing together!"
87
 
88
  # ======================
89
  # Stage 3: Text-to-Speech
90
  # ======================
91
  def stage3_generate_audio(text):
92
- """Convert text to audio"""
93
  try:
94
- tts = gTTS(text=text[:300], lang='en')
 
 
 
 
95
  audio_buffer = io.BytesIO()
96
  tts.write_to_fp(audio_buffer)
97
  audio_buffer.seek(0)
98
  return audio_buffer
99
  except Exception as e:
100
- st.error(f"Audio generation failed: {str(e)}")
101
  return None
102
 
103
  # ======================
@@ -121,13 +120,14 @@ def main():
121
  st.subheader("Generated Story")
122
  st.write(story)
123
 
124
- # Stage 3
125
- if len(story) > 20:
126
- with st.spinner("Creating audio..."):
127
- audio = stage3_generate_audio(story)
128
- if audio:
129
- st.audio(audio, format="audio/mp3")
130
- st.download_button("Download Audio", audio.getvalue(), "story.mp3")
 
131
 
132
  if __name__ == "__main__":
133
  main()
 
3
  from transformers import (
4
  BlipProcessor,
5
  BlipForConditionalGeneration,
6
+ pipeline
 
7
  )
8
  from gtts import gTTS
9
  import io
10
  import logging
11
+ import torch
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
 
19
  # ======================
20
  @st.cache_resource
21
  def load_image_model():
22
+ """Load image captioning model"""
23
  try:
24
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
25
  model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
+ img.thumbnail((512, 512))
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
 
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
+ """Load Mistral story model"""
51
  try:
52
+ pipe = pipeline(
53
+ "text-generation",
54
+ model="ajibawa-2023/Young-Children-Storyteller-Mistral-7B",
55
+ device_map="auto",
56
+ torch_dtype=torch.float16
57
+ )
58
  logger.info("Stage 2 model loaded")
59
+ return pipe
60
  except Exception as e:
61
  st.error("❌ Failed to load story model")
62
  raise
63
 
64
  def stage2_generate_story(keyword):
65
+ """Generate story with chat format"""
66
+ pipe = load_story_model()
 
 
 
 
 
 
 
 
 
67
  try:
68
+ messages = [{
69
+ "role": "user",
70
+ "content": f"Write a children's story about {keyword} with animals under 100 words"
71
+ }]
72
+ outputs = pipe(
73
+ messages,
74
+ max_new_tokens=200,
75
+ temperature=0.7,
76
+ do_sample=True
77
  )
78
+ return outputs[0]['generated_text'][-1]['content']
 
79
  except Exception as e:
80
  st.error(f"Story generation failed: {str(e)}")
81
+ return "The animals had a great time playing together!"
82
 
83
  # ======================
84
  # Stage 3: Text-to-Speech
85
  # ======================
86
  def stage3_generate_audio(text):
87
+ """Generate audio with validation"""
88
  try:
89
+ clean_text = text.strip().replace('\n', ' ')[:300]
90
+ if len(clean_text) < 10:
91
+ raise ValueError("Text too short")
92
+
93
+ tts = gTTS(text=clean_text, lang='en')
94
  audio_buffer = io.BytesIO()
95
  tts.write_to_fp(audio_buffer)
96
  audio_buffer.seek(0)
97
  return audio_buffer
98
  except Exception as e:
99
+ st.error(f"Audio Error: {str(e)}")
100
  return None
101
 
102
  # ======================
 
120
  st.subheader("Generated Story")
121
  st.write(story)
122
 
123
+ # Stage 3 (强制显示)
124
+ with st.spinner("Creating audio..."):
125
+ audio = stage3_generate_audio(story)
126
+ if audio:
127
+ st.audio(audio, format="audio/mp3")
128
+ st.download_button("Download Audio", audio.getvalue(), "story.mp3")
129
+ else:
130
+ st.warning("Audio generation skipped due to invalid input")
131
 
132
  if __name__ == "__main__":
133
  main()