Spaces:
Sleeping
Sleeping
| # pip install flask google-genai | |
| import os, time, base64, struct, re | |
| from flask import Flask, request, render_template_string, Response, stream_with_context | |
| from google import genai | |
| from google.genai import types | |
| import json | |
| from concurrent.futures import ThreadPoolExecutor | |
| from queue import Queue | |
| import threading | |
| app = Flask(__name__) | |
| HTML = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head><meta charset="UTF-8"><title>Gemini Multi (Text → Chunked Streaming TTS)</title></head> | |
| <body style="font-family:sans-serif;padding:2rem;"> | |
| <h1>Gemini Multi (Text + Image → Chunked Streaming TTS)</h1> | |
| <form id="genai-form" enctype="multipart/form-data"> | |
| <textarea id="prompt" name="text" rows="6" cols="60" placeholder="Enter prompt"></textarea><br/><br/> | |
| <input type="file" id="image" name="image" accept="image/*" /><br/><br/> | |
| <label>Voice: <input id="voice" name="voice" value="Puck" /></label><br/> | |
| <label>Accent: <input id="accent" name="accent" value="British" /></label><br/> | |
| <label>Tone: <input id="tone" name="tone" value="casual and friendly" /></label><br/><br/> | |
| <button type="submit">Generate</button> | |
| </form> | |
| <pre id="output" style="background:#f4f4f4;padding:1rem;margin-top:1rem;"></pre> | |
| <div id="audio-out" style="margin-top:1rem;"></div> | |
| <div id="status" style="margin-top:1rem;color:#666;"></div> | |
| <script> | |
| const form = document.getElementById('genai-form'); | |
| // Audio streaming setup | |
| let audioContext = null; | |
| let nextStartTime = 0; | |
| let chunksReceived = 0; | |
| function initAudioContext() { | |
| if (!audioContext) { | |
| audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
| } | |
| return audioContext; | |
| } | |
| function base64ToArrayBuffer(base64) { | |
| const binaryString = atob(base64); | |
| const bytes = new Uint8Array(binaryString.length); | |
| for (let i = 0; i < binaryString.length; i++) { | |
| bytes[i] = binaryString.charCodeAt(i); | |
| } | |
| return bytes.buffer; | |
| } | |
| async function playAudioChunk(audioBase64) { | |
| const ctx = initAudioContext(); | |
| const arrayBuffer = base64ToArrayBuffer(audioBase64); | |
| try { | |
| const audioBuffer = await ctx.decodeAudioData(arrayBuffer); | |
| const source = ctx.createBufferSource(); | |
| source.buffer = audioBuffer; | |
| source.connect(ctx.destination); | |
| const currentTime = ctx.currentTime; | |
| const startTime = Math.max(currentTime, nextStartTime); | |
| source.start(startTime); | |
| nextStartTime = startTime + audioBuffer.duration; | |
| return audioBuffer.duration; | |
| } catch (err) { | |
| console.error('Error playing audio chunk:', err); | |
| return 0; | |
| } | |
| } | |
| form.addEventListener('submit', async e => { | |
| e.preventDefault(); | |
| const out = document.getElementById('output'); | |
| const audioDiv = document.getElementById('audio-out'); | |
| const status = document.getElementById('status'); | |
| out.textContent = 'Generating text…'; | |
| audioDiv.innerHTML = ''; | |
| status.textContent = ''; | |
| chunksReceived = 0; | |
| // Reset audio state | |
| if (audioContext) { | |
| nextStartTime = audioContext.currentTime; | |
| } | |
| const formData = new FormData(form); | |
| try { | |
| const resp = await fetch('/generate_stream', { method: 'POST', body: formData }); | |
| if (!resp.ok) { | |
| out.textContent = 'Server error: ' + resp.statusText; | |
| return; | |
| } | |
| const reader = resp.body.getReader(); | |
| const decoder = new TextDecoder(); | |
| let buffer = ''; | |
| let textReceived = false; | |
| let firstAudioTime = null; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| buffer += decoder.decode(value, { stream: true }); | |
| const lines = buffer.split('\\n'); | |
| buffer = lines.pop(); | |
| for (const line of lines) { | |
| if (!line.trim() || !line.startsWith('data: ')) continue; | |
| try { | |
| const data = JSON.parse(line.slice(6)); | |
| if (data.error) { | |
| out.textContent = 'Error: ' + data.error; | |
| status.textContent = ''; | |
| return; | |
| } | |
| if (data.type === 'text') { | |
| out.textContent = data.text; | |
| textReceived = true; | |
| status.textContent = `Text received (${data.chunk_count} chunks), generating audio...`; | |
| } | |
| if (data.type === 'audio_chunk' && data.audio_base64) { | |
| chunksReceived++; | |
| if (!firstAudioTime) { | |
| firstAudioTime = Date.now(); | |
| status.textContent = `First audio chunk received! Playing...`; | |
| } else { | |
| status.textContent = `Streaming audio... (${chunksReceived} chunks received)`; | |
| } | |
| await playAudioChunk(data.audio_base64); | |
| } | |
| if (data.type === 'complete') { | |
| status.textContent = `Complete! ${chunksReceived} audio chunks. Text: ${data.timings.text_seconds}s, TTS: ${data.timings.tts_seconds}s, Total: ${data.timings.total_seconds}s`; | |
| } | |
| } catch (err) { | |
| console.error('Error parsing SSE:', err, line); | |
| } | |
| } | |
| } | |
| } catch (err) { | |
| console.error(err); | |
| out.textContent = 'Fetch error: ' + err.message; | |
| status.textContent = ''; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| client = genai.Client(api_key="AIzaSyDolbPUZBPUPvQUu-RGktJmvnUpkcEKIYo") | |
| def wrap_pcm_to_wav(pcm_data: bytes, sample_rate=24000, num_channels=1, bits_per_sample=16) -> bytes: | |
| byte_rate = sample_rate * num_channels * bits_per_sample // 8 | |
| block_align = num_channels * bits_per_sample // 8 | |
| data_size = len(pcm_data) | |
| header = b"RIFF" + struct.pack("<I", 36 + data_size) + b"WAVE" | |
| header += b"fmt " + struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
| header += b"data" + struct.pack("<I", data_size) | |
| return header + pcm_data | |
| def extract_text(resp) -> str: | |
| if getattr(resp, "text", None): return resp.text | |
| parts_text = [] | |
| for cand in getattr(resp, "candidates", []) or []: | |
| content = getattr(cand, "content", None) | |
| parts = getattr(content, "parts", None) or [] | |
| for p in parts: | |
| if getattr(p, "text", None): | |
| parts_text.append(p.text) | |
| return "\n".join(parts_text).strip() | |
| def chunk_text(text, max_words=25): | |
| """Split text into sentence-based chunks for parallel TTS generation.""" | |
| # Split on sentence boundaries | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks = [] | |
| current_chunk = [] | |
| current_words = 0 | |
| for sentence in sentences: | |
| words = len(sentence.split()) | |
| if current_words + words > max_words and current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [sentence] | |
| current_words = words | |
| else: | |
| current_chunk.append(sentence) | |
| current_words += words | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| def generate_audio_for_chunk(chunk, voice, accent, tone, chunk_index): | |
| """Generate audio for a single text chunk.""" | |
| style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{chunk}" | |
| try: | |
| tts_resp = client.models.generate_content( | |
| model="gemini-2.5-flash-preview-tts", | |
| contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])], | |
| config=types.GenerateContentConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=types.SpeechConfig( | |
| voice_config=types.VoiceConfig( | |
| prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice) | |
| ) | |
| ) | |
| ) | |
| ) | |
| for cand in getattr(tts_resp, "candidates", []) or []: | |
| for p in getattr(cand.content, "parts", []): | |
| if getattr(p, "inline_data", None) and p.inline_data.data: | |
| pcm_bytes = p.inline_data.data | |
| wav = wrap_pcm_to_wav(pcm_bytes) | |
| return (chunk_index, wav) | |
| return (chunk_index, None) | |
| except Exception as e: | |
| print(f"Error generating audio for chunk {chunk_index}: {e}") | |
| return (chunk_index, None) | |
| def index(): | |
| return render_template_string(HTML) | |
| def generate_stream(): | |
| def generate(): | |
| t_start = time.perf_counter() | |
| prompt = (request.form.get("text") or "").strip() | |
| file = request.files.get("image") | |
| voice = (request.form.get("voice") or "Puck").strip() | |
| accent = (request.form.get("accent") or "British").strip() | |
| tone = (request.form.get("tone") or "casual and friendly").strip() | |
| if not prompt and not file: | |
| yield f"data: {json.dumps({'error': 'No input provided'})}\n\n" | |
| return | |
| # Build multimodal input | |
| parts = [] | |
| if prompt: | |
| parts.append(types.Part.from_text(text=prompt)) | |
| if file: | |
| parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png")) | |
| # 1) Generate text | |
| t0 = time.perf_counter() | |
| try: | |
| gen_resp = client.models.generate_content( | |
| model="gemini-2.5-flash-lite", | |
| contents=[types.Content(role="user", parts=parts)], | |
| config=types.GenerateContentConfig(response_mime_type="text/plain"), | |
| ) | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': f'text generation failed: {str(e)}'})}\n\n" | |
| return | |
| t1 = time.perf_counter() | |
| final_text = extract_text(gen_resp) | |
| if not final_text: | |
| yield f"data: {json.dumps({'error': 'Text generation returned empty'})}\n\n" | |
| return | |
| # Split text into chunks for parallel processing | |
| text_chunks = chunk_text(final_text) | |
| # Send text immediately with chunk count | |
| yield f"data: {json.dumps({'type': 'text', 'text': final_text, 'chunk_count': len(text_chunks)})}\n\n" | |
| # 2) Generate audio chunks in parallel with streaming | |
| tts_start = time.perf_counter() | |
| audio_queue = Queue() | |
| def process_chunks(): | |
| with ThreadPoolExecutor(max_workers=4) as executor: | |
| futures = [] | |
| for i, chunk in enumerate(text_chunks): | |
| future = executor.submit(generate_audio_for_chunk, chunk, voice, accent, tone, i) | |
| futures.append(future) | |
| # Process results as they complete (not necessarily in order) | |
| for future in futures: | |
| try: | |
| result = future.result(timeout=30) | |
| audio_queue.put(result) | |
| except Exception as e: | |
| print(f"Error in chunk processing: {e}") | |
| audio_queue.put((None, None)) | |
| audio_queue.put(('DONE', None)) | |
| # Start parallel processing in background thread | |
| processing_thread = threading.Thread(target=process_chunks) | |
| processing_thread.start() | |
| # Stream audio chunks as they become available (in order) | |
| completed_chunks = {} | |
| next_chunk_to_send = 0 | |
| while True: | |
| chunk_index, wav_data = audio_queue.get() | |
| if chunk_index == 'DONE': | |
| break | |
| if wav_data is None: | |
| continue | |
| # Store completed chunk | |
| completed_chunks[chunk_index] = wav_data | |
| # Send chunks in order | |
| while next_chunk_to_send in completed_chunks: | |
| audio_b64 = base64.b64encode(completed_chunks[next_chunk_to_send]).decode("ascii") | |
| yield f"data: {json.dumps({'type': 'audio_chunk', 'audio_base64': audio_b64, 'chunk_index': next_chunk_to_send})}\n\n" | |
| del completed_chunks[next_chunk_to_send] | |
| next_chunk_to_send += 1 | |
| processing_thread.join() | |
| tts_end = time.perf_counter() | |
| t_total = time.perf_counter() - t_start | |
| # Send completion signal | |
| yield f"data: {json.dumps({'type': 'complete', 'timings': {'text_seconds': round(t1 - t0, 3), 'tts_seconds': round(tts_end - tts_start, 3), 'total_seconds': round(t_total, 3)}})}\n\n" | |
| return Response(stream_with_context(generate()), mimetype='text/event-stream') | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, threaded=True) |