File size: 2,646 Bytes
07e7687
edda7d8
07e7687
 
 
edda7d8
07e7687
edda7d8
 
07e7687
 
edda7d8
07e7687
edda7d8
 
07e7687
 
 
edda7d8
 
07e7687
 
 
edda7d8
 
07e7687
edda7d8
07e7687
edda7d8
 
07e7687
edda7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e7687
 
edda7d8
07e7687
edda7d8
 
fd5d864
edda7d8
 
 
 
fd5d864
edda7d8
 
fd5d864
edda7d8
 
 
 
 
fd5d864
edda7d8
07e7687
edda7d8
 
07e7687
edda7d8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# modules.py

"""
modules.py

Helper functions for the Image-to-Story Streamlit application.
Provides:
- Cached loaders for Hugging Face pipelines (captioning, story generation)
- Inference functions: generate_caption, generate_story_simple, generate_audio
"""
import streamlit as st
import re
from transformers import pipeline
from gtts import gTTS
import io

@st.cache_resource
def load_captioner():
    """Load and cache BLIP image captioning pipeline."""
    return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

@st.cache_resource
def load_story_gen():
    """Load and cache the genre-story-generator-v2 text-generation pipeline."""
    return pipeline("text-generation", model="pranavpsv/genre-story-generator-v2")

def trim_to_sentence(text: str, max_words: int) -> str:
    """
    Trim the story to the last complete sentence under max_words.
    If no sentence fits, fallback to the first max_words words.
    """
    sentences = re.split(r'(?<=[.!?])\s+', text)
    trimmed = []
    count = 0
    for s in sentences:
        wc = len(s.split())
        if count + wc <= max_words:
            trimmed.append(s)
            count += wc
        else:
            break
    if trimmed:
        return " ".join(trimmed)
    # fallback to naive word trim
    return " ".join(text.split()[:max_words])

def generate_caption(captioner, image) -> str:
    """Run the captioner pipeline on the PIL image."""
    raw = captioner(image)
    first = raw[0]
    return first.get("generated_text", "") if isinstance(first, dict) else str(first)

def generate_story_simple(storyteller, prompt_text: str,
                          min_words: int = 50, max_words: int = 100) -> str:
    """
    Generate a 50–100 word story:
      1. Sample ~120 tokens with nucleus sampling.
      2. If under min_words, re-sample ~200 tokens with higher top_p.
      3. Trim to last sentence under max_words.
    """
    out = storyteller(prompt_text, max_new_tokens=120,
                      do_sample=True, top_p=0.9, num_return_sequences=1)
    story = out[0]["generated_text"]
    if len(story.split()) < min_words:
        out = storyteller(prompt_text, max_new_tokens=200,
                          do_sample=True, top_p=0.95, num_return_sequences=1)
        story = out[0]["generated_text"]
    return trim_to_sentence(story, max_words)

def generate_audio(text: str) -> (bytes, str):
    """
    Convert text to MP3 bytes using gTTS.
    Returns (audio_bytes, mime_type) for use in st.audio(...).
    """
    tts = gTTS(text=text, lang="en")
    buf = io.BytesIO()
    tts.write_to_fp(buf)
    return buf.getvalue(), "audio/mp3"