File size: 6,799 Bytes
3a8e7a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

import gradio as gr
from langchain_groq import ChatGroq
import nltk
import requests
from PIL import Image
import io
import time
from gtts import gTTS
from dotenv import load_dotenv
import os

# **NLTK Setup**
nltk.download('punkt')
nltk.download('punkt_tab')
# Load .env file
load_dotenv()

# Retrieve API keys from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")  # e.g. export GROQ_API_KEY="your-groq-key"
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN") 
# **Initialize ChatGroq LLM**
llm = ChatGroq(
    model="llama-3.3-70b-versatile",
    temperature=0.5,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    api_key=GROQ_API_KEY  # Replace with your Groq API key
)

# **Stable Diffusion XL API Setup**
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
headers_sdxl = {"Authorization": f"Bearer {HUGGINGFACE_API_TOKEN}"}  # Replace with your Hugging Face API token
API_COOLDOWN = 3  # Seconds between API calls

def query_sdxl(prompt):
    """Query the Stable Diffusion XL API to generate an image from a prompt."""
    payload = {"inputs": prompt, "options": {"wait_for_model": True}}
    try:
        r = requests.post(API_URL, headers=headers_sdxl, json=payload)
        if r.status_code == 200:
            return r.content
        elif r.status_code == 429:
            wait = int(r.headers.get("Retry-After", API_COOLDOWN))
            print(f"Rate limited—waiting {wait}s…")
            time.sleep(wait)
            return query_sdxl(prompt)
        else:
            print(f"SDXL API error {r.status_code}: {r.text}")
    except Exception as e:
        print("Request failed:", e)
    return None

def generate_storyboard(story):
    """Generate a storyboard with images and voice-overs from a user's story."""
    # Split story into sentences (up to 10 frames)
    sentences = nltk.sent_tokenize(story)[:10]
    num_frames = len(sentences)

    # Build few-shot messages for ChatGroq
    messages = [
        ("system", f"You are a helpful assistant that rewrites a story into frame-by-frame prompts for image generation. "
                   f"Generate EXACTLY {num_frames} frames - no more, no less. "
                   "Ensure that each frame is a continuation of the previous one, maintaining consistency in characters, objects, colors, clothing, environment, lighting, and other visual elements. "
                   "Avoid vague references—refer back to previously introduced characters and objects explicitly. "
                   "Make sure the setting and visual context flow smoothly from one frame to the next."),
        # Example 1
        ("human", "A girl in a red dress walks into a forest.\n"
                  "She sees a white rabbit hopping past.\n"
                  "Curious, she follows the rabbit deeper into the woods.\n"
                  "She stumbles upon a glowing cave entrance.\n"
                  "The girl steps into the cave, her red dress glowing under the blue light."),
        ("assistant", "Frame 1: A young girl wearing a red dress walks into a dense forest surrounded by tall trees.\n"
                      "Frame 2: The young girl wearing a red dress walks, a small white rabbit hops past her feet on the dense forest path.\n"
                      "Frame 3: The young girl in the red dress follows the white rabbit, moving deeper into the darker parts of the dense forest.\n"
                      "Frame 4: The young girl in the red dress reaches a glowing blue cave entrance hidden between mossy rocks and trees in the dense forest.\n"
                      "Frame 5: Inside the glowing blue cave in the dense forest, the young girl's red dress softly illuminates the surroundings."),
        # Example 2
        ("human", "A boy flies a kite in a windy field.\n"
                  "The kite gets tangled in a tree.\n"
                  "He climbs the tree to untangle it.\n"
                  "Suddenly, it starts to rain.\n"
                  "The boy holds his kite and runs home drenched."),
        ("assistant", "Frame 1: A boy joyfully flies a colorful kite in a wide, windy green field.\n"
                      "Frame 2: A boy watches the kite gets tangled in the branches of a tall tree nearby in the windy green field.\n"
                      "Frame 3: Climbing up the tree carefully, the boy reaches out to untangle the kite in the windy green field.\n"
                      "Frame 4: Dark clouds roll in as rain begins to fall, soaking the boy on the tree in the windy green field.\n"
                      "Frame 5: Holding the damp kite, the boy runs across the field, drenched and smiling."),
        ("human", "\n".join(sentences))
    ]

    # Generate frame prompts with ChatGroq
    ai_msg = llm.invoke(messages)
    frames_txt = ai_msg.content.strip()

    # Extract prompts
    prompts = [
        line.partition(":")[2].strip()
        for line in frames_txt.splitlines()
        if line.lower().startswith("frame")
    ]
    prompts = prompts[:num_frames]

    # Generate images and voice-overs
    image_paths = []
    audio_paths = []
    for idx, prompt in enumerate(prompts, start=1):
        # Generate voice-over
        tts = gTTS(text=prompt, lang='en')
        audio_path = f"frame_{idx}.mp3"
        tts.save(audio_path)
        audio_paths.append(audio_path)

        # Generate image
        img_bytes = query_sdxl(prompt)
        if img_bytes:
            img = Image.open(io.BytesIO(img_bytes))
            img_path = f"frame_{idx}.png"
            img.save(img_path)
            image_paths.append(img_path)
        else:
            image_paths.append(None)  # Placeholder for failed images

        # Respect API cooldown
        if idx < len(prompts):
            time.sleep(API_COOLDOWN)

    # Prepare Gradio updates
    image_updates = gr.update(value=image_paths)
    audio_updates = [
        gr.update(value=audio_paths[i], visible=True) if i < num_frames else gr.update(value=None, visible=False)
        for i in range(10)
    ]
    return [image_updates] + audio_updates

# **Gradio Interface**
with gr.Blocks(title="Storyboard Generator") as demo:
    gr.Markdown("# Storyboard Generator\nEnter a story (up to 10 sentences) to generate a storyboard with images and voice-overs.")
    story_input = gr.Textbox(label="Enter your story", lines=5, placeholder="Type your story here...")
    generate_btn = gr.Button("Generate Storyboard")
    image_gallery = gr.Gallery(label="Storyboard Images", show_label=True)
    with gr.Column():
        audios = [gr.Audio(label=f"Frame {i+1} Voice-Over", visible=False) for i in range(10)]

    generate_btn.click(
        fn=generate_storyboard,
        inputs=story_input,
        outputs=[image_gallery] + audios
    )

# **Launch the App**
demo.launch()