hskwon7's picture
Update modules.py
edda7d8 verified
# modules.py
"""
modules.py
Helper functions for the Image-to-Story Streamlit application.
Provides:
- Cached loaders for Hugging Face pipelines (captioning, story generation)
- Inference functions: generate_caption, generate_story_simple, generate_audio
"""
import streamlit as st
import re
from transformers import pipeline
from gtts import gTTS
import io
@st.cache_resource
def load_captioner():
"""Load and cache BLIP image captioning pipeline."""
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
@st.cache_resource
def load_story_gen():
"""Load and cache the genre-story-generator-v2 text-generation pipeline."""
return pipeline("text-generation", model="pranavpsv/genre-story-generator-v2")
def trim_to_sentence(text: str, max_words: int) -> str:
"""
Trim the story to the last complete sentence under max_words.
If no sentence fits, fallback to the first max_words words.
"""
sentences = re.split(r'(?<=[.!?])\s+', text)
trimmed = []
count = 0
for s in sentences:
wc = len(s.split())
if count + wc <= max_words:
trimmed.append(s)
count += wc
else:
break
if trimmed:
return " ".join(trimmed)
# fallback to naive word trim
return " ".join(text.split()[:max_words])
def generate_caption(captioner, image) -> str:
"""Run the captioner pipeline on the PIL image."""
raw = captioner(image)
first = raw[0]
return first.get("generated_text", "") if isinstance(first, dict) else str(first)
def generate_story_simple(storyteller, prompt_text: str,
min_words: int = 50, max_words: int = 100) -> str:
"""
Generate a 50–100 word story:
1. Sample ~120 tokens with nucleus sampling.
2. If under min_words, re-sample ~200 tokens with higher top_p.
3. Trim to last sentence under max_words.
"""
out = storyteller(prompt_text, max_new_tokens=120,
do_sample=True, top_p=0.9, num_return_sequences=1)
story = out[0]["generated_text"]
if len(story.split()) < min_words:
out = storyteller(prompt_text, max_new_tokens=200,
do_sample=True, top_p=0.95, num_return_sequences=1)
story = out[0]["generated_text"]
return trim_to_sentence(story, max_words)
def generate_audio(text: str) -> (bytes, str):
"""
Convert text to MP3 bytes using gTTS.
Returns (audio_bytes, mime_type) for use in st.audio(...).
"""
tts = gTTS(text=text, lang="en")
buf = io.BytesIO()
tts.write_to_fp(buf)
return buf.getvalue(), "audio/mp3"