ai_voice / app.py
Pepguy's picture
Update app.py
d868c8e verified
# pip install flask google-genai
import time
import os
from flask import Flask, request, render_template_string, Response, jsonify
from google import genai
from google.genai import types
import struct
app = Flask(__name__)
HTML = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>Gemini TTS Test</title>
</head>
<body style="font-family:sans-serif;padding:2rem;">
<h1>Gemini-2.5-Flash-Preview-TTS</h1>
<form id="genai-form">
<textarea id="prompt" rows="6" cols="60" placeholder="Enter text to synthesize"></textarea><br/><br/>
<label>Voice: <input id="voice" value="Sadachbia" /></label><br/>
<label>Accent: <input id="accent" value="British" /></label><br/>
<label>Tone: <input id="tone" value="casual and friendly" /></label><br/><br/>
<button type="submit">Generate</button>
</form>
<div id="output" style="margin-top:1rem;"></div>
<script>
const form = document.getElementById('genai-form');
form.addEventListener('submit', async e => {
e.preventDefault();
const text = document.getElementById('prompt').value.trim();
const voice = document.getElementById('voice').value.trim();
const accent = document.getElementById('accent').value.trim();
const tone = document.getElementById('tone').value.trim();
const out = document.getElementById('output');
out.textContent = 'Generating…';
try {
const resp = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text, voice, accent, tone }),
});
if (!resp.ok) {
const errText = await resp.text();
throw new Error(`Server returned ${resp.status}: ${errText}`);
}
const blob = await resp.blob();
const url = URL.createObjectURL(blob);
out.innerHTML = '<audio controls src="' + url + '"></audio>';
} catch (err) {
console.error(err);
out.textContent = 'Fetch error: ' + err.message;
}
});
</script>
</body>
</html>
"""
client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY", "AIzaSyDYF7OP-0P3rwLuBOVZULY1hn5HgJCcx6s"))
def wrap_pcm_to_wav(pcm_data: bytes, sample_rate: int = 24000, num_channels: int = 1, bits_per_sample: int = 16) -> bytes:
"""Wrap raw PCM bytes into WAV container."""
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
data_size = len(pcm_data)
fmt_chunk_size = 16
audio_format = 1 # PCM
header = b"RIFF" + struct.pack("<I", 36 + data_size) + b"WAVE"
header += b"fmt " + struct.pack("<IHHIIHH",
fmt_chunk_size, audio_format, num_channels,
sample_rate, byte_rate, block_align, bits_per_sample
)
header += b"data" + struct.pack("<I", data_size)
return header + pcm_data
def generate_audio_from_gemini(prompt: str, accent: str, tone: str, voice: str) -> bytes:
style_prompt = f"Say the following text in a {accent} accent with a {tone} tone:\n\n{prompt}"
response = client.models.generate_content(
model="gemini-2.5-flash-preview-tts",
contents=[types.Content(role="user", parts=[types.Part(text=style_prompt)])],
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=voice
)
)
)
)
)
candidate = response.candidates[0]
part = candidate.content.parts[0]
pcm_bytes = part.inline_data.data # raw PCM from API
if not pcm_bytes:
raise RuntimeError("No audio returned from Gemini")
return wrap_pcm_to_wav(pcm_bytes)
@app.route('/')
def index():
return render_template_string(HTML)
@app.route('/generate', methods=['POST'])
def gen():
data = request.get_json(silent=True) or {}
prompt = data.get("text", "").strip()
voice = data.get("voice", "Sadachbia").strip()
accent = data.get("accent", "British").strip()
tone = data.get("tone", "casual and friendly").strip()
if not prompt:
return jsonify({"error": "No prompt provided"}), 400
try:
t0 = time.perf_counter()
wav_bytes = generate_audio_from_gemini(prompt, accent, tone, voice)
t1 = time.perf_counter()
app.logger.info(f"Gemini TTS API call took {t1 - t0:.2f}s")
return Response(wav_bytes, mimetype="audio/wav")
except Exception as e:
app.logger.exception("Generation failed")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)