TLH01's picture
Update app.py
a7bd32c verified
raw
history blame
3.28 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import tempfile
import torch
from TTS.api import TTS # Coqui TTS
import os
# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_captioner():
return pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device="cuda" if torch.cuda.is_available() else "cpu"
)
def generate_caption(_pipeline, image):
try:
result = _pipeline(image, max_new_tokens=50)
return result[0]['generated_text']
except Exception as e:
st.error(f"Caption generation failed: {str(e)}")
return None
# ======================
# Stage 2: Story Generation
# ======================
@st.cache_resource
def load_story_generator():
return pipeline(
"text-generation",
model="pranavpsv/gpt2-genre-story-generator", # 可以替换为更强模型
device="cuda" if torch.cuda.is_available() else "cpu"
)
def generate_story(_pipeline, caption):
prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max).
The story should:
- Use simple and friendly language
- Be related to the content of the image
- Include a magical or fun twist
- End happily
Story:"""
try:
story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text']
return story.replace(prompt, "").strip()
except Exception as e:
st.error(f"Story generation failed: {str(e)}")
return None
# ======================
# Stage 3: Text-to-Speech using Coqui TTS
# ======================
@st.cache_resource
def load_tts():
return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())
def text_to_speech(tts_model, story_text):
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_model.tts_to_file(text=story_text, file_path=f.name)
return f.name
except Exception as e:
st.error(f"Audio generation failed: {str(e)}")
return None
# ======================
# Main Streamlit App
# ======================
def main():
st.set_page_config(page_title="Magic Story Generator", layout="wide")
st.title("🧚 Magic Story Generator")
uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"])
if not uploaded_image:
return
image = Image.open(uploaded_image)
st.image(image, use_container_width=True)
with st.spinner("Processing your magical story..."):
caption_pipe = load_image_captioner()
story_pipe = load_story_generator()
tts_model = load_tts()
caption = generate_caption(caption_pipe, image)
if caption:
st.success(f"Image description: {caption}")
story = generate_story(story_pipe, caption)
if story:
st.subheader("Your Magical Story")
st.markdown(story)
audio_path = text_to_speech(tts_model, story)
if audio_path:
st.audio(audio_path, format="audio/wav")
if __name__ == "__main__":
main()