Spaces:
No application file
No application file
File size: 6,799 Bytes
3a8e7a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
from langchain_groq import ChatGroq
import nltk
import requests
from PIL import Image
import io
import time
from gtts import gTTS
from dotenv import load_dotenv
import os
# **NLTK Setup**
nltk.download('punkt')
nltk.download('punkt_tab')
# Load .env file
load_dotenv()
# Retrieve API keys from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # e.g. export GROQ_API_KEY="your-groq-key"
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
# **Initialize ChatGroq LLM**
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.5,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=GROQ_API_KEY # Replace with your Groq API key
)
# **Stable Diffusion XL API Setup**
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
headers_sdxl = {"Authorization": f"Bearer {HUGGINGFACE_API_TOKEN}"} # Replace with your Hugging Face API token
API_COOLDOWN = 3 # Seconds between API calls
def query_sdxl(prompt):
"""Query the Stable Diffusion XL API to generate an image from a prompt."""
payload = {"inputs": prompt, "options": {"wait_for_model": True}}
try:
r = requests.post(API_URL, headers=headers_sdxl, json=payload)
if r.status_code == 200:
return r.content
elif r.status_code == 429:
wait = int(r.headers.get("Retry-After", API_COOLDOWN))
print(f"Rate limited—waiting {wait}s…")
time.sleep(wait)
return query_sdxl(prompt)
else:
print(f"SDXL API error {r.status_code}: {r.text}")
except Exception as e:
print("Request failed:", e)
return None
def generate_storyboard(story):
"""Generate a storyboard with images and voice-overs from a user's story."""
# Split story into sentences (up to 10 frames)
sentences = nltk.sent_tokenize(story)[:10]
num_frames = len(sentences)
# Build few-shot messages for ChatGroq
messages = [
("system", f"You are a helpful assistant that rewrites a story into frame-by-frame prompts for image generation. "
f"Generate EXACTLY {num_frames} frames - no more, no less. "
"Ensure that each frame is a continuation of the previous one, maintaining consistency in characters, objects, colors, clothing, environment, lighting, and other visual elements. "
"Avoid vague references—refer back to previously introduced characters and objects explicitly. "
"Make sure the setting and visual context flow smoothly from one frame to the next."),
# Example 1
("human", "A girl in a red dress walks into a forest.\n"
"She sees a white rabbit hopping past.\n"
"Curious, she follows the rabbit deeper into the woods.\n"
"She stumbles upon a glowing cave entrance.\n"
"The girl steps into the cave, her red dress glowing under the blue light."),
("assistant", "Frame 1: A young girl wearing a red dress walks into a dense forest surrounded by tall trees.\n"
"Frame 2: The young girl wearing a red dress walks, a small white rabbit hops past her feet on the dense forest path.\n"
"Frame 3: The young girl in the red dress follows the white rabbit, moving deeper into the darker parts of the dense forest.\n"
"Frame 4: The young girl in the red dress reaches a glowing blue cave entrance hidden between mossy rocks and trees in the dense forest.\n"
"Frame 5: Inside the glowing blue cave in the dense forest, the young girl's red dress softly illuminates the surroundings."),
# Example 2
("human", "A boy flies a kite in a windy field.\n"
"The kite gets tangled in a tree.\n"
"He climbs the tree to untangle it.\n"
"Suddenly, it starts to rain.\n"
"The boy holds his kite and runs home drenched."),
("assistant", "Frame 1: A boy joyfully flies a colorful kite in a wide, windy green field.\n"
"Frame 2: A boy watches the kite gets tangled in the branches of a tall tree nearby in the windy green field.\n"
"Frame 3: Climbing up the tree carefully, the boy reaches out to untangle the kite in the windy green field.\n"
"Frame 4: Dark clouds roll in as rain begins to fall, soaking the boy on the tree in the windy green field.\n"
"Frame 5: Holding the damp kite, the boy runs across the field, drenched and smiling."),
("human", "\n".join(sentences))
]
# Generate frame prompts with ChatGroq
ai_msg = llm.invoke(messages)
frames_txt = ai_msg.content.strip()
# Extract prompts
prompts = [
line.partition(":")[2].strip()
for line in frames_txt.splitlines()
if line.lower().startswith("frame")
]
prompts = prompts[:num_frames]
# Generate images and voice-overs
image_paths = []
audio_paths = []
for idx, prompt in enumerate(prompts, start=1):
# Generate voice-over
tts = gTTS(text=prompt, lang='en')
audio_path = f"frame_{idx}.mp3"
tts.save(audio_path)
audio_paths.append(audio_path)
# Generate image
img_bytes = query_sdxl(prompt)
if img_bytes:
img = Image.open(io.BytesIO(img_bytes))
img_path = f"frame_{idx}.png"
img.save(img_path)
image_paths.append(img_path)
else:
image_paths.append(None) # Placeholder for failed images
# Respect API cooldown
if idx < len(prompts):
time.sleep(API_COOLDOWN)
# Prepare Gradio updates
image_updates = gr.update(value=image_paths)
audio_updates = [
gr.update(value=audio_paths[i], visible=True) if i < num_frames else gr.update(value=None, visible=False)
for i in range(10)
]
return [image_updates] + audio_updates
# **Gradio Interface**
with gr.Blocks(title="Storyboard Generator") as demo:
gr.Markdown("# Storyboard Generator\nEnter a story (up to 10 sentences) to generate a storyboard with images and voice-overs.")
story_input = gr.Textbox(label="Enter your story", lines=5, placeholder="Type your story here...")
generate_btn = gr.Button("Generate Storyboard")
image_gallery = gr.Gallery(label="Storyboard Images", show_label=True)
with gr.Column():
audios = [gr.Audio(label=f"Frame {i+1} Voice-Over", visible=False) for i in range(10)]
generate_btn.click(
fn=generate_storyboard,
inputs=story_input,
outputs=[image_gallery] + audios
)
# **Launch the App**
demo.launch() |