Pepguy's picture
Update app.py
a416fca verified
# pip install flask google-genai
import os, time, base64, struct, re
from flask import Flask, request, render_template_string, Response, stream_with_context
from google import genai
from google.genai import types
import json
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading
app = Flask(__name__)
HTML = """
<!DOCTYPE html>
<html>
<head><meta charset="UTF-8"><title>Gemini Multi (Text → Chunked Streaming TTS)</title></head>
<body style="font-family:sans-serif;padding:2rem;">
<h1>Gemini Multi (Text + Image → Chunked Streaming TTS)</h1>
<form id="genai-form" enctype="multipart/form-data">
<textarea id="prompt" name="text" rows="6" cols="60" placeholder="Enter prompt"></textarea><br/><br/>
<input type="file" id="image" name="image" accept="image/*" /><br/><br/>
<label>Voice: <input id="voice" name="voice" value="Puck" /></label><br/>
<label>Accent: <input id="accent" name="accent" value="British" /></label><br/>
<label>Tone: <input id="tone" name="tone" value="casual and friendly" /></label><br/><br/>
<button type="submit">Generate</button>
</form>
<pre id="output" style="background:#f4f4f4;padding:1rem;margin-top:1rem;"></pre>
<div id="audio-out" style="margin-top:1rem;"></div>
<div id="status" style="margin-top:1rem;color:#666;"></div>
<script>
const form = document.getElementById('genai-form');
// Audio streaming setup
let audioContext = null;
let nextStartTime = 0;
let chunksReceived = 0;
function initAudioContext() {
if (!audioContext) {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
}
return audioContext;
}
function base64ToArrayBuffer(base64) {
const binaryString = atob(base64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes.buffer;
}
async function playAudioChunk(audioBase64) {
const ctx = initAudioContext();
const arrayBuffer = base64ToArrayBuffer(audioBase64);
try {
const audioBuffer = await ctx.decodeAudioData(arrayBuffer);
const source = ctx.createBufferSource();
source.buffer = audioBuffer;
source.connect(ctx.destination);
const currentTime = ctx.currentTime;
const startTime = Math.max(currentTime, nextStartTime);
source.start(startTime);
nextStartTime = startTime + audioBuffer.duration;
return audioBuffer.duration;
} catch (err) {
console.error('Error playing audio chunk:', err);
return 0;
}
}
form.addEventListener('submit', async e => {
e.preventDefault();
const out = document.getElementById('output');
const audioDiv = document.getElementById('audio-out');
const status = document.getElementById('status');
out.textContent = 'Generating text…';
audioDiv.innerHTML = '';
status.textContent = '';
chunksReceived = 0;
// Reset audio state
if (audioContext) {
nextStartTime = audioContext.currentTime;
}
const formData = new FormData(form);
try {
const resp = await fetch('/generate_stream', { method: 'POST', body: formData });
if (!resp.ok) {
out.textContent = 'Server error: ' + resp.statusText;
return;
}
const reader = resp.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
let textReceived = false;
let firstAudioTime = null;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\\n');
buffer = lines.pop();
for (const line of lines) {
if (!line.trim() || !line.startsWith('data: ')) continue;
try {
const data = JSON.parse(line.slice(6));
if (data.error) {
out.textContent = 'Error: ' + data.error;
status.textContent = '';
return;
}
if (data.type === 'text') {
out.textContent = data.text;
textReceived = true;
status.textContent = `Text received (${data.chunk_count} chunks), generating audio...`;
}
if (data.type === 'audio_chunk' && data.audio_base64) {
chunksReceived++;
if (!firstAudioTime) {
firstAudioTime = Date.now();
status.textContent = `First audio chunk received! Playing...`;
} else {
status.textContent = `Streaming audio... (${chunksReceived} chunks received)`;
}
await playAudioChunk(data.audio_base64);
}
if (data.type === 'complete') {
status.textContent = `Complete! ${chunksReceived} audio chunks. Text: ${data.timings.text_seconds}s, TTS: ${data.timings.tts_seconds}s, Total: ${data.timings.total_seconds}s`;
}
} catch (err) {
console.error('Error parsing SSE:', err, line);
}
}
}
} catch (err) {
console.error(err);
out.textContent = 'Fetch error: ' + err.message;
status.textContent = '';
}
});
</script>
</body>
</html>
"""
client = genai.Client(api_key="AIzaSyDolbPUZBPUPvQUu-RGktJmvnUpkcEKIYo")
def wrap_pcm_to_wav(pcm_data: bytes, sample_rate=24000, num_channels=1, bits_per_sample=16) -> bytes:
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
data_size = len(pcm_data)
header = b"RIFF" + struct.pack("<I", 36 + data_size) + b"WAVE"
header += b"fmt " + struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
header += b"data" + struct.pack("<I", data_size)
return header + pcm_data
def extract_text(resp) -> str:
if getattr(resp, "text", None): return resp.text
parts_text = []
for cand in getattr(resp, "candidates", []) or []:
content = getattr(cand, "content", None)
parts = getattr(content, "parts", None) or []
for p in parts:
if getattr(p, "text", None):
parts_text.append(p.text)
return "\n".join(parts_text).strip()
def chunk_text(text, max_words=25):
"""Split text into sentence-based chunks for parallel TTS generation."""
# Split on sentence boundaries
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_words = 0
for sentence in sentences:
words = len(sentence.split())
if current_words + words > max_words and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [sentence]
current_words = words
else:
current_chunk.append(sentence)
current_words += words
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def generate_audio_for_chunk(chunk, voice, accent, tone, chunk_index):
"""Generate audio for a single text chunk."""
style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{chunk}"
try:
tts_resp = client.models.generate_content(
model="gemini-2.5-flash-preview-tts",
contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])],
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice)
)
)
)
)
for cand in getattr(tts_resp, "candidates", []) or []:
for p in getattr(cand.content, "parts", []):
if getattr(p, "inline_data", None) and p.inline_data.data:
pcm_bytes = p.inline_data.data
wav = wrap_pcm_to_wav(pcm_bytes)
return (chunk_index, wav)
return (chunk_index, None)
except Exception as e:
print(f"Error generating audio for chunk {chunk_index}: {e}")
return (chunk_index, None)
@app.route('/')
def index():
return render_template_string(HTML)
@app.route('/generate_stream', methods=['POST'])
def generate_stream():
def generate():
t_start = time.perf_counter()
prompt = (request.form.get("text") or "").strip()
file = request.files.get("image")
voice = (request.form.get("voice") or "Puck").strip()
accent = (request.form.get("accent") or "British").strip()
tone = (request.form.get("tone") or "casual and friendly").strip()
if not prompt and not file:
yield f"data: {json.dumps({'error': 'No input provided'})}\n\n"
return
# Build multimodal input
parts = []
if prompt:
parts.append(types.Part.from_text(text=prompt))
if file:
parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png"))
# 1) Generate text
t0 = time.perf_counter()
try:
gen_resp = client.models.generate_content(
model="gemini-2.5-flash-lite",
contents=[types.Content(role="user", parts=parts)],
config=types.GenerateContentConfig(response_mime_type="text/plain"),
)
except Exception as e:
yield f"data: {json.dumps({'error': f'text generation failed: {str(e)}'})}\n\n"
return
t1 = time.perf_counter()
final_text = extract_text(gen_resp)
if not final_text:
yield f"data: {json.dumps({'error': 'Text generation returned empty'})}\n\n"
return
# Split text into chunks for parallel processing
text_chunks = chunk_text(final_text)
# Send text immediately with chunk count
yield f"data: {json.dumps({'type': 'text', 'text': final_text, 'chunk_count': len(text_chunks)})}\n\n"
# 2) Generate audio chunks in parallel with streaming
tts_start = time.perf_counter()
audio_queue = Queue()
def process_chunks():
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for i, chunk in enumerate(text_chunks):
future = executor.submit(generate_audio_for_chunk, chunk, voice, accent, tone, i)
futures.append(future)
# Process results as they complete (not necessarily in order)
for future in futures:
try:
result = future.result(timeout=30)
audio_queue.put(result)
except Exception as e:
print(f"Error in chunk processing: {e}")
audio_queue.put((None, None))
audio_queue.put(('DONE', None))
# Start parallel processing in background thread
processing_thread = threading.Thread(target=process_chunks)
processing_thread.start()
# Stream audio chunks as they become available (in order)
completed_chunks = {}
next_chunk_to_send = 0
while True:
chunk_index, wav_data = audio_queue.get()
if chunk_index == 'DONE':
break
if wav_data is None:
continue
# Store completed chunk
completed_chunks[chunk_index] = wav_data
# Send chunks in order
while next_chunk_to_send in completed_chunks:
audio_b64 = base64.b64encode(completed_chunks[next_chunk_to_send]).decode("ascii")
yield f"data: {json.dumps({'type': 'audio_chunk', 'audio_base64': audio_b64, 'chunk_index': next_chunk_to_send})}\n\n"
del completed_chunks[next_chunk_to_send]
next_chunk_to_send += 1
processing_thread.join()
tts_end = time.perf_counter()
t_total = time.perf_counter() - t_start
# Send completion signal
yield f"data: {json.dumps({'type': 'complete', 'timings': {'text_seconds': round(t1 - t0, 3), 'tts_seconds': round(tts_end - tts_start, 3), 'total_seconds': round(t_total, 3)}})}\n\n"
return Response(stream_with_context(generate()), mimetype='text/event-stream')
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, threaded=True)