File size: 8,420 Bytes
5c6bd93 dd99c93 618d3a8 29232db 400f26d 5c6bd93 29232db dd99c93 5c6bd93 29232db 400f26d 29232db 400f26d 29232db e5894bd 8105b9f e5894bd 2c0b7b8 618d3a8 dd99c93 618d3a8 400f26d 618d3a8 400f26d 618d3a8 400f26d 618d3a8 5c6bd93 400f26d 29232db 400f26d 29232db 400f26d 29232db 400f26d 29232db 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 400f26d 92de730 29232db 400f26d 29232db 400f26d 29232db 400f26d 29232db 400f26d 29232db 400f26d 29232db 5c6bd93 8af62fa 5c6bd93 dd99c93 8af62fa dd99c93 29232db dd99c93 29232db dd99c93 5c6bd93 dd99c93 29232db dd99c93 5c6bd93 dd99c93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import gradio as gr
from g4f.client import Client
import json
import torch
import soundfile as sf
from openvoice_cli.__main__ import tune_one
import os
import uuid
import zipfile
import logging
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import threading
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize G4F client
client = Client()
# Silero TTS configuration
language = 'ru'
model_id = 'ru_v3'
speaker = 'baya' # Russian voice
sample_rate = 48000 # Sample rate for Russian model
device = torch.device('cpu')
# Load Silero TTS model with thread safety
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_tts',
language=language,
speaker=model_id)
model.to(device)
tts_lock = threading.Lock() # Lock for TTS model thread safety
def get_task(prompt):
return f'''Ты генерируешь скрипт продолжения прямой линии с Владимиром Владимировичем Путиным. Вопросы телезрителей передаёт Киселёв - он максимально подобострастен, поддерживает Путина и Россию. Пиши в духе пропаганды пророссийской. Каждая реплика - не менее 150-200 символов. Английские названия при необходимости всегда транскрибируй кириллицей. Числительные также записывай кириллицей. Пиши в стиле кремлёвской пропаганды. Пиши в стереотипном стиле. Киселёв задаёт вопросы, обрашается к Путину как "Владимир Владимирович" или "Господин Президент". Путин отвечает. Всё максимально аутентично. Фразы Киселёва: {{"Киселёв":"фраза"}} Фразы Путина: {{"Путин":"фраза"}} Ответ дай в формате JSON без дополнительных символов: [{{"Киселёв":"фраза"}}, {{"Путин":"фраза"}} . . . ].
Вопрос от пользователя поступил: "{prompt}"'''
def validate_response(response):
try:
data = json.loads(response)
if isinstance(data, list) and all(isinstance(item, dict) and len(item) == 1 for item in data):
return True
except json.JSONDecodeError:
pass
return False
def generate_text(prompt):
max_retries = 4
for attempt in range(max_retries):
logger.info(f"Generating response for prompt: {prompt} (attempt {attempt+1})")
response = client.chat.completions.create(
model="llama-3.3-70b",
messages=[{"role": "user", "content": get_task(prompt)}],
web_search=False
)
response_text = response.choices[0].message.content
logger.info(f"Generated response: {response_text}")
if validate_response(response_text):
return response_text
logger.warning("Invalid response format, retrying...")
logger.error("Failed to generate valid response after 4 attempts")
return '[{"Киселёв":"К сожалению, не удалось расслышать вопрос. Пожалуйста, попробуйте еще раз."}, {"Путин":"Мы работаем над улучшением системы. Спасибо за понимание."}]'
def split_text(text, max_length=800):
"""Split text into chunks of maximum length, trying to preserve word boundaries"""
chunks = []
while len(text) > max_length:
split_at = text.rfind(' ', 0, max_length)
if split_at == -1:
split_at = max_length
chunks.append(text[:split_at])
text = text[split_at:].lstrip()
chunks.append(text)
return chunks
def generate_audio(text, speaker_name):
"""Generate audio with thread-safe splitting and synthesis"""
logger.info(f"Generating audio for {speaker_name} ({len(text)} chars)")
chunks = split_text(text)
audio_arrays = []
for chunk in chunks:
with tts_lock: # Ensure thread-safe TTS operations
audio = model.apply_tts(
ssml_text=f"<speak>{chunk}</speak>",
speaker=speaker,
sample_rate=sample_rate,
put_accent=True,
put_yo=True
)
audio_arrays.append(audio)
full_audio = np.concatenate(audio_arrays)
temp_filename = f"temp_{uuid.uuid4().hex}.wav"
sf.write(temp_filename, full_audio, sample_rate)
return temp_filename
def process_line(args):
"""Process single dialogue line with enhanced error handling"""
idx, speaker, text = args
final_filename = f"t{idx+1}-{speaker}.wav"
base_audio = None
output_filename = None
try:
logger.info(f"Processing line {idx+1} for {speaker}")
# Generate base audio
base_audio = generate_audio(text, speaker)
if not os.path.exists(base_audio):
logger.error(f"Base audio not generated for line {idx+1}")
return None
# Generate voice cover
ref_audio = "kisel.mp3" if speaker == "Киселёв" else "putin.mp3"
output_filename = f"output_{uuid.uuid4().hex[:6]}.wav"
logger.debug(f"Tuning audio with reference: {ref_audio}")
tune_one(
input_file=base_audio,
ref_file=ref_audio,
output_file=output_filename,
device='cpu'
)
# Verify output file creation
if not os.path.exists(output_filename):
logger.error(f"Voice tuning failed for line {idx+1}")
return None
# Rename final file
os.rename(output_filename, final_filename)
logger.info(f"Created final file: {final_filename}")
return final_filename
except Exception as e:
logger.error(f"Error processing line {idx+1}: {str(e)}", exc_info=True)
return None
finally:
# Cleanup temporary files
for f in [base_audio, output_filename]:
if f and os.path.exists(f):
os.remove(f)
def process_prompt(prompt):
"""Main processing pipeline with parallel execution"""
logger.info(f"Starting processing for prompt: {prompt}")
# Generate script
script = generate_text(prompt)
script_data = json.loads(script)
# Prepare tasks for parallel processing
tasks = [(idx, speaker, text)
for idx, item in enumerate(script_data)
for speaker, text in item.items()]
# Process lines in parallel
audio_files = []
with ThreadPoolExecutor(max_workers=4) as executor: # Optimal for CPU-bound tasks
futures = [executor.submit(process_line, task) for task in tasks]
for future in futures:
result = future.result()
if result:
audio_files.append(result)
# Package results
zip_filename = "output_audio_files.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in audio_files:
zipf.write(file)
# Cleanup working files
for file in audio_files:
os.remove(file)
return zip_filename
# Gradio interface
examples = [
"Почему такие высокие налоги?",
"Какие цели СВО?",
"Когда развалится Америка?"
]
with gr.Blocks() as demo:
gr.Markdown("# Kisel TV")
with gr.Row():
prompt_input = gr.Textbox(
label="Input Prompt",
placeholder="Enter your text here...",
lines=3
)
generate_btn = gr.Button("Generate", variant="primary")
output = gr.File(label="Generated Audio Files")
gr.Examples(
examples=examples,
inputs=prompt_input,
outputs=output,
fn=process_prompt,
cache_examples=False
)
generate_btn.click(
fn=process_prompt,
inputs=prompt_input,
outputs=output
)
if __name__ == "__main__":
demo.launch() |