Pepguy's picture
Update app.py
19cbf37 verified
# pip install flask google-genai
import os, time, base64, struct
from flask import Flask, request, render_template_string, jsonify, Response, stream_with_context
from google import genai
from google.genai import types
app = Flask(__name__)
HTML = """
<!DOCTYPE html>
<html>
<head><meta charset="UTF-8"><title>Gemini Multi (Text → Streaming TTS)</title></head>
<body style="font-family:sans-serif;padding:2rem;">
<h1>Gemini Multi (Text + Image → Streaming TTS)</h1>
<form id="genai-form" enctype="multipart/form-data">
<textarea id="prompt" name="text" rows="6" cols="60" placeholder="Enter prompt"></textarea><br/><br/>
<input type="file" id="image" name="image" accept="image/*" /><br/><br/>
<label>Voice: <input id="voice" name="voice" value="Sadachbia" /></label><br/>
<label>Accent: <input id="accent" name="accent" value="British" /></label><br/>
<label>Tone: <input id="tone" name="tone" value="casual and friendly" /></label><br/><br/>
<button type="submit">Generate</button>
</form>
<pre id="output" style="background:#f4f4f4;padding:1rem;margin-top:1rem;"></pre>
<div id="audio-out" style="margin-top:1rem;"></div>
<div id="status" style="margin-top:1rem;color:#666;"></div>
<script>
const form = document.getElementById('genai-form');
// Audio streaming setup
let audioContext = null;
let nextStartTime = 0;
let audioQueue = [];
let isPlaying = false;
function initAudioContext() {
if (!audioContext) {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
}
return audioContext;
}
function base64ToArrayBuffer(base64) {
const binaryString = atob(base64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes.buffer;
}
async function playAudioChunk(wavBase64) {
const ctx = initAudioContext();
const arrayBuffer = base64ToArrayBuffer(wavBase64);
try {
const audioBuffer = await ctx.decodeAudioData(arrayBuffer);
const source = ctx.createBufferSource();
source.buffer = audioBuffer;
source.connect(ctx.destination);
const currentTime = ctx.currentTime;
const startTime = Math.max(currentTime, nextStartTime);
source.start(startTime);
nextStartTime = startTime + audioBuffer.duration;
return audioBuffer.duration;
} catch (err) {
console.error('Error playing audio chunk:', err);
return 0;
}
}
form.addEventListener('submit', async e => {
e.preventDefault();
const out = document.getElementById('output');
const audioDiv = document.getElementById('audio-out');
const status = document.getElementById('status');
out.textContent = 'Generating text…';
audioDiv.innerHTML = '';
status.textContent = '';
// Reset audio state
if (audioContext) {
nextStartTime = audioContext.currentTime;
}
const formData = new FormData(form);
try {
const resp = await fetch('/generate_stream', { method: 'POST', body: formData });
if (!resp.ok) {
out.textContent = 'Server error: ' + resp.statusText;
return;
}
const reader = resp.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
let textReceived = false;
let audioChunks = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\\n');
buffer = lines.pop(); // Keep incomplete line in buffer
for (const line of lines) {
if (!line.trim() || !line.startsWith('data: ')) continue;
try {
const data = JSON.parse(line.slice(6));
if (data.error) {
out.textContent = 'Error: ' + data.error;
status.textContent = '';
return;
}
if (data.type === 'text') {
out.textContent = data.text;
textReceived = true;
status.textContent = 'Text received, generating audio...';
}
if (data.type === 'audio_chunk' && data.audio_base64) {
audioChunks++;
status.textContent = `Streaming audio... (chunk ${audioChunks})`;
await playAudioChunk(data.audio_base64);
}
if (data.type === 'complete') {
status.textContent = `Complete! Text: ${data.timings.text_seconds}s, TTS: ${data.timings.tts_seconds}s, Total: ${data.timings.total_seconds}s`;
}
} catch (err) {
console.error('Error parsing SSE:', err, line);
}
}
}
} catch (err) {
console.error(err);
out.textContent = 'Fetch error: ' + err.message;
status.textContent = '';
}
});
</script>
</body>
</html>
"""
client = genai.Client(api_key="AIzaSyDolbPUZBPUPvQUu-RGktJmvnUpkcEKIYo")
def wrap_pcm_to_wav(pcm_data: bytes, sample_rate=24000, num_channels=1, bits_per_sample=16) -> bytes:
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
data_size = len(pcm_data)
header = b"RIFF" + struct.pack("<I", 36 + data_size) + b"WAVE"
header += b"fmt " + struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
header += b"data" + struct.pack("<I", data_size)
return header + pcm_data
def extract_text(resp) -> str:
if getattr(resp, "text", None): return resp.text
parts_text = []
for cand in getattr(resp, "candidates", []) or []:
content = getattr(cand, "content", None)
parts = getattr(content, "parts", None) or []
for p in parts:
if getattr(p, "text", None):
parts_text.append(p.text)
return "\n".join(parts_text).strip()
@app.route('/')
def index():
return render_template_string(HTML)
@app.route('/generate_stream', methods=['POST'])
def generate_stream():
def generate():
t_start = time.perf_counter()
prompt = (request.form.get("text") or "").strip()
file = request.files.get("image")
voice = (request.form.get("voice") or "Sadachbia").strip()
accent = (request.form.get("accent") or "British").strip()
tone = (request.form.get("tone") or "casual and friendly").strip()
if not prompt and not file:
yield f"data: {jsonify({'error': 'No input provided'}).get_data(as_text=True)}\n\n"
return
# Build multimodal input
parts = []
if prompt:
parts.append(types.Part.from_text(text=prompt))
if file:
parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png"))
# 1) Generate text
t0 = time.perf_counter()
try:
gen_resp = client.models.generate_content(
model="gemini-2.5-flash-lite",
contents=[types.Content(role="user", parts=parts)],
config=types.GenerateContentConfig(response_mime_type="text/plain"),
)
except Exception as e:
yield f"data: {jsonify({'error': f'text generation failed: {str(e)}'}).get_data(as_text=True)}\n\n"
return
t1 = time.perf_counter()
final_text = extract_text(gen_resp)
if not final_text:
yield f"data: {jsonify({'error': 'Text generation returned empty'}).get_data(as_text=True)}\n\n"
return
# Send text immediately
yield f"data: {jsonify({'type': 'text', 'text': final_text}).get_data(as_text=True)}\n\n"
# 2) Stream TTS audio
style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{final_text}"
tts_start = time.perf_counter()
try:
# Use streaming for TTS
tts_stream = client.models.generate_content_stream(
model= "gemini-2.5-flash-preview-tts",
contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])],
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice)
)
)
)
)
for chunk in tts_stream:
for cand in getattr(chunk, "candidates", []) or []:
for p in getattr(cand.content, "parts", []):
if getattr(p, "inline_data", None) and p.inline_data.data:
pcm_bytes = p.inline_data.data
wav = wrap_pcm_to_wav(pcm_bytes)
audio_b64 = base64.b64encode(wav).decode("ascii")
yield f"data: {jsonify({'type': 'audio_chunk', 'audio_base64': audio_b64}).get_data(as_text=True)}\n\n"
except Exception as e:
yield f"data: {jsonify({'error': f'tts streaming failed: {str(e)}', 'text': final_text}).get_data(as_text=True)}\n\n"
return
tts_end = time.perf_counter()
t_total = time.perf_counter() - t_start
# Send completion signal
yield f"data: {jsonify({'type': 'complete', 'timings': {'text_seconds': round(t1 - t0, 3), 'tts_seconds': round(tts_end - tts_start, 3), 'total_seconds': round(t_total, 3)}}).get_data(as_text=True)}\n\n"
return Response(stream_with_context(generate()), mimetype='text/event-stream')
# Keep the original endpoint for compatibility
@app.route('/generate', methods=['POST'])
def generate():
t_start = time.perf_counter()
prompt = (request.form.get("text") or "").strip()
file = request.files.get("image")
voice = (request.form.get("voice") or "Sadachbia").strip()
accent = (request.form.get("accent") or "British").strip()
tone = (request.form.get("tone") or "casual and friendly").strip()
if not prompt and not file:
return jsonify({"error": "No input provided"}), 400
parts = []
if prompt:
parts.append(types.Part.from_text(text=prompt))
if file:
parts.append(types.Part.from_bytes(data=file.read(), mime_type=file.mimetype or "image/png"))
t0 = time.perf_counter()
try:
gen_resp = client.models.generate_content(
model="gemini-2.5-flash-lite",
contents=[types.Content(role="user", parts=parts)],
config=types.GenerateContentConfig(response_mime_type="text/plain"),
)
except Exception as e:
return jsonify({"error": f"text generation failed: {str(e)}"}), 500
t1 = time.perf_counter()
final_text = extract_text(gen_resp)
if not final_text:
return jsonify({"error": "Text generation returned empty"}), 500
style_prompt = f"Say the following in a {accent} accent with a {tone} tone:\n\n{final_text}"
tts_start = time.perf_counter()
try:
tts_resp = client.models.generate_content(
model="gemini-2.5-flash-preview-tts",
contents=[types.Content(role="user", parts=[types.Part.from_text(text=style_prompt)])],
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice)
)
)
)
)
except Exception as e:
return jsonify({"error": f"tts generation failed: {str(e)}", "text": final_text}), 500
tts_end = time.perf_counter()
pcm_bytes = None
for cand in getattr(tts_resp, "candidates", []) or []:
for p in getattr(cand.content, "parts", []):
if getattr(p, "inline_data", None) and p.inline_data.data:
pcm_bytes = p.inline_data.data
break
if pcm_bytes: break
if not pcm_bytes:
return jsonify({"error": "TTS returned no audio", "text": final_text}), 500
wav = wrap_pcm_to_wav(pcm_bytes)
audio_b64 = base64.b64encode(wav).decode("ascii")
t_total = time.perf_counter() - t_start
return jsonify({
"text": final_text,
"audio_base64": audio_b64,
"timings": {
"text_seconds": round(t1 - t0, 3),
"tts_seconds": round(tts_end - tts_start, 3),
"total_seconds": round(t_total, 3)
}
})
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)