| | |
| |
|
| | """ |
| | modules.py |
| | |
| | Helper functions for the Image-to-Story Streamlit application. |
| | Provides: |
| | - Cached loaders for Hugging Face pipelines (captioning, story generation) |
| | - Inference functions: generate_caption, generate_story_simple, generate_audio |
| | """ |
| | import streamlit as st |
| | import re |
| | from transformers import pipeline |
| | from gtts import gTTS |
| | import io |
| |
|
| | @st.cache_resource |
| | def load_captioner(): |
| | """Load and cache BLIP image captioning pipeline.""" |
| | return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
| |
|
| | @st.cache_resource |
| | def load_story_gen(): |
| | """Load and cache the genre-story-generator-v2 text-generation pipeline.""" |
| | return pipeline("text-generation", model="pranavpsv/genre-story-generator-v2") |
| |
|
| | def trim_to_sentence(text: str, max_words: int) -> str: |
| | """ |
| | Trim the story to the last complete sentence under max_words. |
| | If no sentence fits, fallback to the first max_words words. |
| | """ |
| | sentences = re.split(r'(?<=[.!?])\s+', text) |
| | trimmed = [] |
| | count = 0 |
| | for s in sentences: |
| | wc = len(s.split()) |
| | if count + wc <= max_words: |
| | trimmed.append(s) |
| | count += wc |
| | else: |
| | break |
| | if trimmed: |
| | return " ".join(trimmed) |
| | |
| | return " ".join(text.split()[:max_words]) |
| |
|
| | def generate_caption(captioner, image) -> str: |
| | """Run the captioner pipeline on the PIL image.""" |
| | raw = captioner(image) |
| | first = raw[0] |
| | return first.get("generated_text", "") if isinstance(first, dict) else str(first) |
| |
|
| | def generate_story_simple(storyteller, prompt_text: str, |
| | min_words: int = 50, max_words: int = 100) -> str: |
| | """ |
| | Generate a 50–100 word story: |
| | 1. Sample ~120 tokens with nucleus sampling. |
| | 2. If under min_words, re-sample ~200 tokens with higher top_p. |
| | 3. Trim to last sentence under max_words. |
| | """ |
| | out = storyteller(prompt_text, max_new_tokens=120, |
| | do_sample=True, top_p=0.9, num_return_sequences=1) |
| | story = out[0]["generated_text"] |
| | if len(story.split()) < min_words: |
| | out = storyteller(prompt_text, max_new_tokens=200, |
| | do_sample=True, top_p=0.95, num_return_sequences=1) |
| | story = out[0]["generated_text"] |
| | return trim_to_sentence(story, max_words) |
| |
|
| | def generate_audio(text: str) -> (bytes, str): |
| | """ |
| | Convert text to MP3 bytes using gTTS. |
| | Returns (audio_bytes, mime_type) for use in st.audio(...). |
| | """ |
| | tts = gTTS(text=text, lang="en") |
| | buf = io.BytesIO() |
| | tts.write_to_fp(buf) |
| | return buf.getvalue(), "audio/mp3" |
| |
|