# modules.py """ modules.py Helper functions for the Image-to-Story Streamlit application. Provides: - Cached loaders for Hugging Face pipelines (captioning, story generation) - Inference functions: generate_caption, generate_story_simple, generate_audio """ import streamlit as st import re from transformers import pipeline from gtts import gTTS import io @st.cache_resource def load_captioner(): """Load and cache BLIP image captioning pipeline.""" return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") @st.cache_resource def load_story_gen(): """Load and cache the genre-story-generator-v2 text-generation pipeline.""" return pipeline("text-generation", model="pranavpsv/genre-story-generator-v2") def trim_to_sentence(text: str, max_words: int) -> str: """ Trim the story to the last complete sentence under max_words. If no sentence fits, fallback to the first max_words words. """ sentences = re.split(r'(?<=[.!?])\s+', text) trimmed = [] count = 0 for s in sentences: wc = len(s.split()) if count + wc <= max_words: trimmed.append(s) count += wc else: break if trimmed: return " ".join(trimmed) # fallback to naive word trim return " ".join(text.split()[:max_words]) def generate_caption(captioner, image) -> str: """Run the captioner pipeline on the PIL image.""" raw = captioner(image) first = raw[0] return first.get("generated_text", "") if isinstance(first, dict) else str(first) def generate_story_simple(storyteller, prompt_text: str, min_words: int = 50, max_words: int = 100) -> str: """ Generate a 50–100 word story: 1. Sample ~120 tokens with nucleus sampling. 2. If under min_words, re-sample ~200 tokens with higher top_p. 3. Trim to last sentence under max_words. """ out = storyteller(prompt_text, max_new_tokens=120, do_sample=True, top_p=0.9, num_return_sequences=1) story = out[0]["generated_text"] if len(story.split()) < min_words: out = storyteller(prompt_text, max_new_tokens=200, do_sample=True, top_p=0.95, num_return_sequences=1) story = out[0]["generated_text"] return trim_to_sentence(story, max_words) def generate_audio(text: str) -> (bytes, str): """ Convert text to MP3 bytes using gTTS. Returns (audio_bytes, mime_type) for use in st.audio(...). """ tts = gTTS(text=text, lang="en") buf = io.BytesIO() tts.write_to_fp(buf) return buf.getvalue(), "audio/mp3"