kiseltv_server2 / app.py
tester1hf's picture
Update app.py
5a92983 verified
import gradio as gr
from g4f.client import Client
import json
import torch
import soundfile as sf
from openvoice_cli.__main__ import tune_one
import os
import uuid
import logging
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from moviepy.editor import AudioFileClip, VideoFileClip, concatenate_videoclips
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Initialize G4F client
client = Client()
# Silero TTS configuration
language = 'ru'
model_id = 'ru_v3'
sample_rate = 48000 # Sample rate for Russian model
device = torch.device('cpu')
# Load Silero TTS model with thread safety
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_tts',
language=language,
speaker=model_id)
model.to(device)
tts_lock = threading.Lock() # Lock for TTS model thread safety
# GIF mappings
GIF_MAPPING = {
"Киселёв": "kisel.gif",
"Путин": "putin.gif"
}
def get_task(prompt):
return f'''Ты генерируешь скрипт продолжения прямой линии с Владимиром Владимировичем Путиным. Вопросы телезрителей передаёт Киселёв - он максимально подобострастен, поддерживает Путина и Россию. Пиши в духе пропаганды пророссийской. Каждая реплика - не менее 150-200 символов. Английские названия при необходимости всегда транскрибируй кириллицей. Числительные также записывай кириллицей. Пиши в стиле кремлёвской пропаганды. Пиши в стереотипном стиле. Киселёв задаёт вопросы, обрашается к Путину как "Владимир Владимирович" или "Господин Президент". Путин отвечает. Всё максимально аутентично. Фразы Киселёва: {{"Киселёв":"фраза"}} Фразы Путина: {{"Путин":"фраза"}} Ответ дай в формате JSON без дополнительных символов: [{{"Киселёв":"фраза"}}, {{"Путин":"фраза"}} . . . ].
Вопрос от пользователя поступил: "{prompt}"'''
def validate_response(response):
try:
data = json.loads(response)
if not isinstance(data, list):
logger.warning("Invalid response: Root element is not a list")
return False
for idx, item in enumerate(data):
if not isinstance(item, dict):
logger.warning(f"Invalid item #{idx+1}: Not a dictionary")
return False
if len(item) != 1:
logger.warning(f"Invalid item #{idx+1}: Contains {len(item)} keys instead of 1")
return False
key = next(iter(item.keys()))
if key not in ["Киселёв", "Путин"]:
logger.warning(f"Invalid item #{idx+1}: Unexpected speaker '{key}'")
return False
return True
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error: {str(e)}")
return False
def generate_text(prompt):
logger.info(f"Generating text for prompt: '{prompt}'")
max_retries = 4
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model="llama-3.3-70b",
messages=[{"role": "user", "content": get_task(prompt)}],
web_search=False
)
response_text = response.choices[0].message.content
logger.debug(f"Raw API response: {response_text}")
if validate_response(response_text):
logger.info(f"Successfully validated response (attempt {attempt+1})")
return response_text
logger.warning(f"Validation failed (attempt {attempt+1})")
except Exception as e:
logger.error(f"API call failed: {str(e)}")
logger.error("Failed to generate valid response after 4 attempts")
return '[{"Киселёв":"К сожалению, не удалось расслышать вопрос. Пожалуйста, попробуйте еще раз."}, {"Путин":"Мы работаем над улучшением системы. Спасибо за понимание."}]'
def split_text(text, max_length=800):
"""Split text into chunks of maximum length, trying to preserve word boundaries"""
chunks = []
while len(text) > max_length:
split_at = text.rfind(' ', 0, max_length)
if split_at == -1:
split_at = max_length
chunks.append(text[:split_at])
text = text[split_at:].lstrip()
chunks.append(text)
logger.debug(f"Split text into {len(chunks)} chunks")
return chunks
def generate_audio(text, speaker_name):
"""Generate audio with thread-safe splitting and synthesis"""
logger.info(f"Generating audio for {speaker_name} ({len(text)} characters)")
# Switch between speakers
silero_speaker = 'aidar' if speaker_name == 'Киселёв' else 'baya'
logger.debug(f"Using Silero speaker: {silero_speaker}")
chunks = split_text(text)
audio_arrays = []
for idx, chunk in enumerate(chunks, 1):
logger.debug(f"Processing chunk {idx}/{len(chunks)}")
with tts_lock: # Ensure thread-safe TTS operations
audio = model.apply_tts(
ssml_text=f"<speak>{chunk}</speak>",
speaker=silero_speaker,
sample_rate=sample_rate,
put_accent=True,
put_yo=True
)
audio_arrays.append(audio)
full_audio = np.concatenate(audio_arrays)
temp_filename = f"temp_{uuid.uuid4().hex}.wav"
sf.write(temp_filename, full_audio, sample_rate)
logger.debug(f"Temporary audio saved: {temp_filename}")
return temp_filename
def process_line(args):
"""Process single dialogue line with enhanced error handling"""
idx, speaker, text = args
final_filename = f"t{idx+1}-{speaker}.wav"
base_audio = None
output_filename = None
try:
logger.info(f"Processing line {idx+1} for {speaker}")
# Generate base audio
base_audio = generate_audio(text, speaker)
if not os.path.exists(base_audio):
logger.error(f"Base audio not generated for line {idx+1}")
return None
# Generate voice cover
ref_audio = "kisel.mp3" if speaker == "Киселёв" else "putin.mp3"
output_filename = f"output_{uuid.uuid4().hex[:6]}.wav"
logger.debug(f"Tuning audio with reference: {ref_audio}")
tune_one(
input_file=base_audio,
ref_file=ref_audio,
output_file=output_filename,
device='cpu'
)
# Verify output file creation
if not os.path.exists(output_filename):
logger.error(f"Voice tuning failed for line {idx+1}")
return None
# Rename final file
os.rename(output_filename, final_filename)
logger.info(f"Created final file: {final_filename}")
return final_filename
except Exception as e:
logger.error(f"Error processing line {idx+1}: {str(e)}", exc_info=True)
return None
finally:
# Cleanup temporary files
for f in [base_audio, output_filename]:
if f and os.path.exists(f):
os.remove(f)
def create_video(audio_files):
"""Create final video from processed audio files"""
logger.info(f"⏳ Starting video creation with {len(audio_files)} audio files")
try:
# Sort audio files by their numerical index
audio_files.sort(key=lambda x: int(x.split('t')[1].split('-')[0]))
clips = []
logger.info("Processing audio-GIF pairs:")
for audio_file in audio_files:
speaker = audio_file.split('-')[1].split('.')[0]
gif_file = GIF_MAPPING.get(speaker)
if not gif_file or not os.path.exists(gif_file):
logger.error(f"Missing GIF file for {speaker}")
continue
audio_clip = AudioFileClip(audio_file)
logger.info(f"🔊 {os.path.basename(audio_file)} ({audio_clip.duration:.1f}s)")
gif_clip = VideoFileClip(gif_file).loop(duration=audio_clip.duration)
gif_clip = gif_clip.set_audio(audio_clip)
clips.append(gif_clip)
logger.debug(f"Processed {speaker} segment")
if not clips:
raise ValueError("No valid video clips created")
final_video = concatenate_videoclips(clips)
video_filename = f"output_{uuid.uuid4().hex[:8]}.mp4"
logger.info(f"🎬 Concatenating {len(clips)} clips (total: {final_video.duration:.1f}s)")
final_video.write_videofile(video_filename, codec='libx264', audio_codec='aac')
logger.info(f"✅ Successfully created video: {video_filename}")
return video_filename
except Exception as e:
logger.error(f"Video creation failed: {str(e)}", exc_info=True)
raise
def process_prompt(prompt):
"""Main processing pipeline with parallel execution"""
logger.info(f"🚀 Starting processing for prompt: '{prompt}'")
try:
# Generate script
script = generate_text(prompt)
logger.debug(f"Raw script data: {script}")
script_data = json.loads(script)
logger.info(f"📝 Generated script with {len(script_data)} lines")
# Prepare tasks for parallel processing
tasks = [(idx, speaker, text)
for idx, item in enumerate(script_data)
for speaker, text in item.items()]
# Process lines in parallel
audio_files = []
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(process_line, task) for task in tasks]
total_tasks = len(futures)
logger.info(f"📦 Processing {total_tasks} audio segments in parallel")
for i, future in enumerate(as_completed(futures), 1):
result = future.result()
if result:
audio_files.append(result)
remaining = total_tasks - i
logger.info(f"🔧 Processed {os.path.basename(result)} ({i}/{total_tasks}, {remaining} remaining)")
else:
logger.warning(f"⚠️ Failed to process task {i}/{total_tasks}")
# Create final video
if not audio_files:
raise ValueError("No audio files generated")
return create_video(audio_files)
except Exception as e:
logger.error(f"❌ Processing failed: {str(e)}", exc_info=True)
return None
finally:
# Cleanup audio files after video creation
for file in audio_files:
if os.path.exists(file):
os.remove(file)
# Gradio interface
examples = [
"Почему такие высокие налоги?",
"Какие цели СВО?",
"Когда развалится Америка?"
]
with gr.Blocks() as demo:
gr.Markdown("# Kisel TV")
with gr.Row():
prompt_input = gr.Textbox(
label="Input Prompt",
placeholder="Enter your text here...",
lines=3
)
generate_btn = gr.Button("Generate", variant="primary")
output = gr.Video(label="Generated Video", format="mp4")
gr.Examples(
examples=examples,
inputs=prompt_input,
outputs=output,
fn=process_prompt,
cache_examples=False
)
generate_btn.click(
fn=process_prompt,
inputs=prompt_input,
outputs=output
)
if __name__ == "__main__":
demo.launch()