peteparker456 commited on
Commit
8d816bb
·
verified ·
1 Parent(s): f644ce0

Upload tti.py

Browse files
Files changed (1) hide show
  1. tti.py +153 -0
tti.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from langchain_groq import ChatGroq
4
+ import nltk
5
+ import requests
6
+ from PIL import Image
7
+ import io
8
+ import time
9
+ from gtts import gTTS
10
+ from dotenv import load_dotenv
11
+ import os
12
+
13
+ # **NLTK Setup**
14
+ nltk.download('punkt')
15
+ nltk.download('punkt_tab')
16
+ # Load .env file
17
+ load_dotenv()
18
+
19
+ # Retrieve API keys from environment variables
20
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY") # e.g. export GROQ_API_KEY="your-groq-key"
21
+ HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
22
+ # **Initialize ChatGroq LLM**
23
+ llm = ChatGroq(
24
+ model="llama-3.3-70b-versatile",
25
+ temperature=0.5,
26
+ max_tokens=None,
27
+ timeout=None,
28
+ max_retries=2,
29
+ api_key=GROQ_API_KEY # Replace with your Groq API key
30
+ )
31
+
32
+ # **Stable Diffusion XL API Setup**
33
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
34
+ headers_sdxl = {"Authorization": f"Bearer {HUGGINGFACE_API_TOKEN}"} # Replace with your Hugging Face API token
35
+ API_COOLDOWN = 3 # Seconds between API calls
36
+
37
+ def query_sdxl(prompt):
38
+ """Query the Stable Diffusion XL API to generate an image from a prompt."""
39
+ payload = {"inputs": prompt, "options": {"wait_for_model": True}}
40
+ try:
41
+ r = requests.post(API_URL, headers=headers_sdxl, json=payload)
42
+ if r.status_code == 200:
43
+ return r.content
44
+ elif r.status_code == 429:
45
+ wait = int(r.headers.get("Retry-After", API_COOLDOWN))
46
+ print(f"Rate limited—waiting {wait}s…")
47
+ time.sleep(wait)
48
+ return query_sdxl(prompt)
49
+ else:
50
+ print(f"SDXL API error {r.status_code}: {r.text}")
51
+ except Exception as e:
52
+ print("Request failed:", e)
53
+ return None
54
+
55
+ def generate_storyboard(story):
56
+ """Generate a storyboard with images and voice-overs from a user's story."""
57
+ # Split story into sentences (up to 10 frames)
58
+ sentences = nltk.sent_tokenize(story)[:10]
59
+ num_frames = len(sentences)
60
+
61
+ # Build few-shot messages for ChatGroq
62
+ messages = [
63
+ ("system", f"You are a helpful assistant that rewrites a story into frame-by-frame prompts for image generation. "
64
+ f"Generate EXACTLY {num_frames} frames - no more, no less. "
65
+ "Ensure that each frame is a continuation of the previous one, maintaining consistency in characters, objects, colors, clothing, environment, lighting, and other visual elements. "
66
+ "Avoid vague references—refer back to previously introduced characters and objects explicitly. "
67
+ "Make sure the setting and visual context flow smoothly from one frame to the next."),
68
+ # Example 1
69
+ ("human", "A girl in a red dress walks into a forest.\n"
70
+ "She sees a white rabbit hopping past.\n"
71
+ "Curious, she follows the rabbit deeper into the woods.\n"
72
+ "She stumbles upon a glowing cave entrance.\n"
73
+ "The girl steps into the cave, her red dress glowing under the blue light."),
74
+ ("assistant", "Frame 1: A young girl wearing a red dress walks into a dense forest surrounded by tall trees.\n"
75
+ "Frame 2: The young girl wearing a red dress walks, a small white rabbit hops past her feet on the dense forest path.\n"
76
+ "Frame 3: The young girl in the red dress follows the white rabbit, moving deeper into the darker parts of the dense forest.\n"
77
+ "Frame 4: The young girl in the red dress reaches a glowing blue cave entrance hidden between mossy rocks and trees in the dense forest.\n"
78
+ "Frame 5: Inside the glowing blue cave in the dense forest, the young girl's red dress softly illuminates the surroundings."),
79
+ # Example 2
80
+ ("human", "A boy flies a kite in a windy field.\n"
81
+ "The kite gets tangled in a tree.\n"
82
+ "He climbs the tree to untangle it.\n"
83
+ "Suddenly, it starts to rain.\n"
84
+ "The boy holds his kite and runs home drenched."),
85
+ ("assistant", "Frame 1: A boy joyfully flies a colorful kite in a wide, windy green field.\n"
86
+ "Frame 2: A boy watches the kite gets tangled in the branches of a tall tree nearby in the windy green field.\n"
87
+ "Frame 3: Climbing up the tree carefully, the boy reaches out to untangle the kite in the windy green field.\n"
88
+ "Frame 4: Dark clouds roll in as rain begins to fall, soaking the boy on the tree in the windy green field.\n"
89
+ "Frame 5: Holding the damp kite, the boy runs across the field, drenched and smiling."),
90
+ ("human", "\n".join(sentences))
91
+ ]
92
+
93
+ # Generate frame prompts with ChatGroq
94
+ ai_msg = llm.invoke(messages)
95
+ frames_txt = ai_msg.content.strip()
96
+
97
+ # Extract prompts
98
+ prompts = [
99
+ line.partition(":")[2].strip()
100
+ for line in frames_txt.splitlines()
101
+ if line.lower().startswith("frame")
102
+ ]
103
+ prompts = prompts[:num_frames]
104
+
105
+ # Generate images and voice-overs
106
+ image_paths = []
107
+ audio_paths = []
108
+ for idx, prompt in enumerate(prompts, start=1):
109
+ # Generate voice-over
110
+ tts = gTTS(text=prompt, lang='en')
111
+ audio_path = f"frame_{idx}.mp3"
112
+ tts.save(audio_path)
113
+ audio_paths.append(audio_path)
114
+
115
+ # Generate image
116
+ img_bytes = query_sdxl(prompt)
117
+ if img_bytes:
118
+ img = Image.open(io.BytesIO(img_bytes))
119
+ img_path = f"frame_{idx}.png"
120
+ img.save(img_path)
121
+ image_paths.append(img_path)
122
+ else:
123
+ image_paths.append(None) # Placeholder for failed images
124
+
125
+ # Respect API cooldown
126
+ if idx < len(prompts):
127
+ time.sleep(API_COOLDOWN)
128
+
129
+ # Prepare Gradio updates
130
+ image_updates = gr.update(value=image_paths)
131
+ audio_updates = [
132
+ gr.update(value=audio_paths[i], visible=True) if i < num_frames else gr.update(value=None, visible=False)
133
+ for i in range(10)
134
+ ]
135
+ return [image_updates] + audio_updates
136
+
137
+ # **Gradio Interface**
138
+ with gr.Blocks(title="Storyboard Generator") as demo:
139
+ gr.Markdown("# Storyboard Generator\nEnter a story (up to 10 sentences) to generate a storyboard with images and voice-overs.")
140
+ story_input = gr.Textbox(label="Enter your story", lines=5, placeholder="Type your story here...")
141
+ generate_btn = gr.Button("Generate Storyboard")
142
+ image_gallery = gr.Gallery(label="Storyboard Images", show_label=True)
143
+ with gr.Column():
144
+ audios = [gr.Audio(label=f"Frame {i+1} Voice-Over", visible=False) for i in range(10)]
145
+
146
+ generate_btn.click(
147
+ fn=generate_storyboard,
148
+ inputs=story_input,
149
+ outputs=[image_gallery] + audios
150
+ )
151
+
152
+ # **Launch the App**
153
+ demo.launch()