Spaces:
Sleeping
Sleeping
| # pip install flask google-genai | |
| import os, time, base64, struct | |
| from flask import Flask, request, render_template_string, jsonify, Response, stream_with_context | |
| from google import genai | |
| from google.genai import types | |
| app = Flask(__name__) | |
| HTML = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head><meta charset="UTF-8"><title>Gemini Multi (Text → Streaming TTS)</title></head> | |
| <body style="font-family:sans-serif;padding:2rem;"> | |
| <h1>Gemini Multi (Text + Image → Streaming TTS)</h1> | |
| <form id="genai-form" enctype="multipart/form-data"> | |
| <textarea id="prompt" name="text" rows="6" cols="60" placeholder="Enter prompt"></textarea><br/><br/> | |
| <input type="file" id="image" name="image" accept="image/*" /><br/><br/> | |
| <label>Voice: <input id="voice" name="voice" value="Sadachbia" /></label><br/> | |
| <label>Accent: <input id="accent" name="accent" value="British" /></label><br/> | |
| <label>Tone: <input id="tone" name="tone" value="casual and friendly" /></label><br/><br/> | |
| <button type="submit">Generate</button> | |
| </form> | |
| <pre id="output" style="background:#f4f4f4;padding:1rem;margin-top:1rem;"></pre> | |
| <div id="audio-out" style="margin-top:1rem;"></div> | |
| <div id="status" style="margin-top:1rem;color:#666;"></div> | |
| <script> | |
| const form = document.getElementById('genai-form'); | |
| // Audio streaming setup | |
| let audioContext = null; | |
| let nextStartTime = 0; | |
| let audioQueue = []; | |
| let isPlaying = false; | |
| function initAudioContext() { | |
| if (!audioContext) { | |
| audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
| } | |
| return audioContext; | |
| } | |
| function base64ToArrayBuffer(base64) { | |
| const binaryString = atob(base64); | |
| const bytes = new Uint8Array(binaryString.length); | |
| for (let i = 0; i < binaryString.length; i++) { | |
| bytes[i] = binaryString.charCodeAt(i); | |
| } | |
| return bytes.buffer; | |
| } | |
| async function playAudioChunk(wavBase64) { | |
| const ctx = initAudioContext(); | |
| const arrayBuffer = base64ToArrayBuffer(wavBase64); | |
| try { | |
| const audioBuffer = await ctx.decodeAudioData(arrayBuffer); | |
| const source = ctx.createBufferSource(); | |
| source.buffer = audioBuffer; | |
| source.connect(ctx.destination); | |
| const currentTime = ctx.currentTime; | |
| const startTime = Math.max(currentTime, nextStartTime); | |
| source.start(startTime); | |
| nextStartTime = startTime + audioBuffer.duration; | |
| return audioBuffer.duration; | |
| } catch (err) { | |
| console.error('Error playing audio chunk:', err); | |
| return 0; | |
| } | |
| } | |
| form.addEventListener('submit', async e => { | |
| e.preventDefault(); | |
| const out = document.getElementById('output'); | |
| const audioDiv = document.getElementById('audio-out'); | |
| const status = document.getElementById('status'); | |
| out.textContent = 'Generating text…'; | |
| audioDiv.innerHTML = ''; | |
| status.textContent = ''; | |
| // Reset audio state | |
| if (audioContext) { | |
| nextStartTime = audioContext.currentTime; | |
| } | |
| const formData = new FormData(form); | |
| try { | |
| const resp = await fetch('/generate_stream', { method: 'POST', body: formData }); | |
| if (!resp.ok) { | |
| out.textContent = 'Server error: ' + resp.statusText; | |
| return; | |
| } | |
| const reader = resp.body.getReader(); | |
| const decoder = new TextDecoder(); | |
| let buffer = ''; | |
| let textReceived = false; | |
| let audioChunks = 0; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| buffer += decoder.decode(value, { stream: true }); | |
| const lines = buffer.split('\\n'); | |
| buffer = lines.pop(); // Keep incomplete line in buffer | |
| for (const line of lines) { | |
| if (!line.trim() || !line.startsWith('data: ')) continue; | |
| try { | |
| const data = JSON.parse(line.slice(6)); | |
| if (data.error) { | |
| out.textContent = 'Error: ' + data.error; | |
| status.textContent = ''; | |
| return; | |
| } | |
| if (data.type === 'text') { | |
| out.textContent = data.text; | |
| textReceived = true; | |
| status.textContent = 'Text received, generating audio...'; | |
| } | |
| if (data.type === 'audio_chunk' && data.audio_base64) { | |
| audioChunks++; | |
| status.textContent = `Streaming audio... (chunk ${audioChunks})`; | |
| await playAudioChunk(data.audio_base64); | |
| } | |
| if (data.type === 'complete') { | |
| status.textContent = `Complete! Text: ${data.timings.text_seconds}s, TTS: ${data.timings.tts_seconds}s, Total: ${data.timings.total_seconds}s`; | |
| } | |
| } catch (err) { | |
| console.error('Error parsing SSE:', err, line); | |
| } | |
| } | |
| } | |
| } catch (err) { | |
| console.error(err); | |
| out.textContent = 'Fetch error: ' + err.message; | |
| status.textContent = ''; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| client = genai.Client(api_key="AIzaSyDolbPUZBPUPvQUu-RGktJmvnUpkcEKIYo") | |
| def wrap_pcm_to_wav(pcm_data: bytes, sample_rate=24000, num_channels=1, bits_per_sample=16) -> bytes: | |
| byte_rate = sample_rate * num_channels * bits_per_sample // 8 | |
| block_align = num_channels * bits_per_sample // 8 | |
| data_size = len(pcm_data) | |
| header = b"RIFF" + struct.pack("<I", 36 + data_size) + b"WAVE" | |
| header += b"fmt " + struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
| header += b"data" + struct.pack("<I", data_size) | |
| return header + pcm_data | |
| def extract_text(resp) -> str: | |
| if getattr(resp, "text", None): return resp.text | |
| parts_text = [] | |
| for cand in getattr(resp, "candidates", []) or []: | |
| content = getattr(cand, "content", None) | |
| parts = getattr(content, "parts", None) or [] | |
| for p in parts: | |
| if getattr(p, "text", None): | |
| parts_text.append(p.text) | |
| return "\n".join(parts_text).strip() | |
| def index(): | |
| return render_template_string(HTML) | |
| def generate_stream(): | |
| def generate(): | |
| t_start = time.perf_counter() | |
| prompt = (request.form.get("text") or "").strip() | |
| file = request.files.get("image") | |
| voice = (request.form.get("voice") or "Sadachbia").strip() | |
| accent = (request.form.get("accent") or "British").strip() | |
| tone = (request.form.get("tone") or "casual and friendly").strip() | |
| if not prompt and not file: | |
| yield f"data: {jsonify({'error': 'No input provided'}).get_data(as_text=True)}\n\n" | |
| return | |
| # Build multimodal input | |
| parts = [] | |
| if prompt: | |
| parts.append(types.Part.from_text(text=prompt)) | |
| if file: | |
| parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png")) | |
| # 1) Generate text | |
| t0 = time.perf_counter() | |
| try: | |
| gen_resp = client.models.generate_content( | |
| model="gemini-2.5-flash-lite", | |
| contents=[types.Content(role="user", parts=parts)], | |
| config=types.GenerateContentConfig(response_mime_type="text/plain"), | |
| ) | |
| except Exception as e: | |
| yield f"data: {jsonify({'error': f'text generation failed: {str(e)}'}).get_data(as_text=True)}\n\n" | |
| return | |
| t1 = time.perf_counter() | |
| final_text = extract_text(gen_resp) | |
| if not final_text: | |
| yield f"data: {jsonify({'error': 'Text generation returned empty'}).get_data(as_text=True)}\n\n" | |
| return | |
| # Send text immediately | |
| yield f"data: {jsonify({'type': 'text', 'text': final_text}).get_data(as_text=True)}\n\n" | |
| # 2) Stream TTS audio | |
| style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{final_text}" | |
| tts_start = time.perf_counter() | |
| try: | |
| # Use streaming for TTS | |
| tts_stream = client.models.generate_content_stream( | |
| model= "gemini-2.5-flash-preview-tts", | |
| contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])], | |
| config=types.GenerateContentConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=types.SpeechConfig( | |
| voice_config=types.VoiceConfig( | |
| prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice) | |
| ) | |
| ) | |
| ) | |
| ) | |
| for chunk in tts_stream: | |
| for cand in getattr(chunk, "candidates", []) or []: | |
| for p in getattr(cand.content, "parts", []): | |
| if getattr(p, "inline_data", None) and p.inline_data.data: | |
| pcm_bytes = p.inline_data.data | |
| wav = wrap_pcm_to_wav(pcm_bytes) | |
| audio_b64 = base64.b64encode(wav).decode("ascii") | |
| yield f"data: {jsonify({'type': 'audio_chunk', 'audio_base64': audio_b64}).get_data(as_text=True)}\n\n" | |
| except Exception as e: | |
| yield f"data: {jsonify({'error': f'tts streaming failed: {str(e)}', 'text': final_text}).get_data(as_text=True)}\n\n" | |
| return | |
| tts_end = time.perf_counter() | |
| t_total = time.perf_counter() - t_start | |
| # Send completion signal | |
| yield f"data: {jsonify({'type': 'complete', 'timings': {'text_seconds': round(t1 - t0, 3), 'tts_seconds': round(tts_end - tts_start, 3), 'total_seconds': round(t_total, 3)}}).get_data(as_text=True)}\n\n" | |
| return Response(stream_with_context(generate()), mimetype='text/event-stream') | |
| # Keep the original endpoint for compatibility | |
| def generate(): | |
| t_start = time.perf_counter() | |
| prompt = (request.form.get("text") or "").strip() | |
| file = request.files.get("image") | |
| voice = (request.form.get("voice") or "Sadachbia").strip() | |
| accent = (request.form.get("accent") or "British").strip() | |
| tone = (request.form.get("tone") or "casual and friendly").strip() | |
| if not prompt and not file: | |
| return jsonify({"error": "No input provided"}), 400 | |
| parts = [] | |
| if prompt: | |
| parts.append(types.Part.from_text(text=prompt)) | |
| if file: | |
| parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png")) | |
| t0 = time.perf_counter() | |
| try: | |
| gen_resp = client.models.generate_content( | |
| model="gemini-2.5-flash-lite", | |
| contents=[types.Content(role="user", parts=parts)], | |
| config=types.GenerateContentConfig(response_mime_type="text/plain"), | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": f"text generation failed: {str(e)}"}), 500 | |
| t1 = time.perf_counter() | |
| final_text = extract_text(gen_resp) | |
| if not final_text: | |
| return jsonify({"error": "Text generation returned empty"}), 500 | |
| style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{final_text}" | |
| tts_start = time.perf_counter() | |
| try: | |
| tts_resp = client.models.generate_content( | |
| model="gemini-2.5-flash-preview-tts", | |
| contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])], | |
| config=types.GenerateContentConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=types.SpeechConfig( | |
| voice_config=types.VoiceConfig( | |
| prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice) | |
| ) | |
| ) | |
| ) | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": f"tts generation failed: {str(e)}", "text": final_text}), 500 | |
| tts_end = time.perf_counter() | |
| pcm_bytes = None | |
| for cand in getattr(tts_resp, "candidates", []) or []: | |
| for p in getattr(cand.content, "parts", []): | |
| if getattr(p, "inline_data", None) and p.inline_data.data: | |
| pcm_bytes = p.inline_data.data | |
| break | |
| if pcm_bytes: break | |
| if not pcm_bytes: | |
| return jsonify({"error": "TTS returned no audio", "text": final_text}), 500 | |
| wav = wrap_pcm_to_wav(pcm_bytes) | |
| audio_b64 = base64.b64encode(wav).decode("ascii") | |
| t_total = time.perf_counter() - t_start | |
| return jsonify({ | |
| "text": final_text, | |
| "audio_base64": audio_b64, | |
| "timings": { | |
| "text_seconds": round(t1 - t0, 3), | |
| "tts_seconds": round(tts_end - tts_start, 3), | |
| "total_seconds": round(t_total, 3) | |
| } | |
| }) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |