|
|
import streamlit as st |
|
|
from g4f.client import Client |
|
|
import json |
|
|
import torch |
|
|
import soundfile as sf |
|
|
import os |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from openvoice_cli.downloader import download_checkpoint |
|
|
from openvoice_cli.api import ToneColorConverter |
|
|
import openvoice_cli.se_extractor as se_extractor |
|
|
import glob |
|
|
import uuid |
|
|
import logging |
|
|
import numpy as np |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
import threading |
|
|
from moviepy.editor import AudioFileClip, VideoFileClip, concatenate_videoclips |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Прямая линия с Путиным", |
|
|
page_icon="🇷🇺", |
|
|
layout="centered", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
client = Client() |
|
|
|
|
|
|
|
|
language = 'ru' |
|
|
model_id = 'ru_v3' |
|
|
sample_rate = 48000 |
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', |
|
|
model='silero_tts', |
|
|
language=language, |
|
|
speaker=model_id) |
|
|
model.to(device) |
|
|
tts_lock = threading.Lock() |
|
|
|
|
|
|
|
|
GIF_MAPPING = { |
|
|
"Киселёв": "kisel.gif", |
|
|
"Путин": "putin.gif" |
|
|
} |
|
|
|
|
|
@st.cache_resource |
|
|
def setup_openvoice(): |
|
|
"""Initialize OpenVoice components once with caching""" |
|
|
logger.info("Initializing OpenVoice components...") |
|
|
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
checkpoints_dir = os.path.join(current_dir, 'checkpoints') |
|
|
ckpt_converter = os.path.join(checkpoints_dir, 'converter') |
|
|
|
|
|
|
|
|
if not os.path.exists(ckpt_converter): |
|
|
logger.info("Downloading OpenVoice checkpoints...") |
|
|
os.makedirs(ckpt_converter, exist_ok=True) |
|
|
download_checkpoint(ckpt_converter) |
|
|
|
|
|
|
|
|
device = 'cpu' |
|
|
tone_color_converter = ToneColorConverter( |
|
|
os.path.join(ckpt_converter, 'config.json'), |
|
|
device=device |
|
|
) |
|
|
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth')) |
|
|
|
|
|
|
|
|
ref_se = {} |
|
|
speaker_refs = { |
|
|
"Киселёв": "kisel.mp3", |
|
|
"Путин": "putin.mp3" |
|
|
} |
|
|
|
|
|
for speaker, ref_file in speaker_refs.items(): |
|
|
if not os.path.exists(ref_file): |
|
|
raise FileNotFoundError(f"Reference file {ref_file} not found") |
|
|
logger.info(f"Extracting speaker embedding for {speaker}...") |
|
|
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True) |
|
|
ref_se[speaker] = target_se |
|
|
|
|
|
return tone_color_converter, ref_se |
|
|
|
|
|
|
|
|
try: |
|
|
tone_color_converter, ref_se = setup_openvoice() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize OpenVoice: {str(e)}") |
|
|
st.error("Ошибка инициализации голосового движка. Пожалуйста, проверьте файлы конфигурации.") |
|
|
st.stop() |
|
|
|
|
|
def get_task(prompt): |
|
|
return f'''Ты генерируешь скрипт продолжения прямой линии с Владимиром Владимировичем Путиным. Вопросы телезрителей передаёт Киселёв - он максимально подобострастен, поддерживает Путина и Россию. Пиши в духе пропаганды пророссийской. Каждая реплика - не менее 150-200 символов. Английские названия при необходимости всегда транскрибируй кириллицей. Числительные также записывай кириллицей. Пиши в стиле кремлёвской пропаганды. Пиши в стереотипном стиле. Киселёв задаёт вопросы, обрашается к Путину как "Владимир Владимирович" или "Господин Президент". Путин отвечает. Всё максимально аутентично. Фразы Киселёва: {{"Киселёв":"фраза"}} Фразы Путина: {{"Путин":"фраза"}} Ответ дай в формате JSON без дополнительных символов: [{{"Киселёв":"фраза"}}, {{"Путин":"фраза"}} . . . ]. |
|
|
Вопрос от пользователя поступил: "{prompt}"''' |
|
|
|
|
|
def validate_response(response): |
|
|
try: |
|
|
data = json.loads(response) |
|
|
if not isinstance(data, list): |
|
|
logger.warning("Invalid response: Root element is not a list") |
|
|
return False |
|
|
for idx, item in enumerate(data): |
|
|
if not isinstance(item, dict): |
|
|
logger.warning(f"Invalid item #{idx+1}: Not a dictionary") |
|
|
return False |
|
|
if len(item) != 1: |
|
|
logger.warning(f"Invalid item #{idx+1}: Contains {len(item)} keys instead of 1") |
|
|
return False |
|
|
key = next(iter(item.keys())) |
|
|
if key not in ["Киселёв", "Путин"]: |
|
|
logger.warning(f"Invalid item #{idx+1}: Unexpected speaker '{key}'") |
|
|
return False |
|
|
return True |
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning(f"JSON decode error: {str(e)}") |
|
|
return False |
|
|
|
|
|
def generate_text(prompt): |
|
|
logger.info(f"Generating text for prompt: '{prompt}'") |
|
|
max_retries = 40 |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model="llama-3.3-70b", |
|
|
messages=[{"role": "user", "content": get_task(prompt)}], |
|
|
web_search=False |
|
|
) |
|
|
response_text = response.choices[0].message.content |
|
|
if validate_response(response_text): |
|
|
return response_text |
|
|
except Exception as e: |
|
|
logger.error(f"API call failed: {str(e)}") |
|
|
return '[{"Киселёв":"К сожалению, не удалось расслышать вопрос. Пожалуйста, попробуйте еще раз."}, {"Путин":"Мы работаем над улучшением системы. Спасибо за понимание."}]' |
|
|
|
|
|
def split_text(text, max_length=800): |
|
|
chunks = [] |
|
|
while len(text) > max_length: |
|
|
split_at = text.rfind(' ', 0, max_length) |
|
|
if split_at == -1: |
|
|
split_at = max_length |
|
|
chunks.append(text[:split_at]) |
|
|
text = text[split_at:].lstrip() |
|
|
chunks.append(text) |
|
|
return chunks |
|
|
|
|
|
def generate_audio(text, speaker_name): |
|
|
logger.info(f"Generating audio for {speaker_name} ({len(text)} characters)") |
|
|
|
|
|
silero_speaker = 'aidar' if speaker_name == 'Киселёв' else 'baya' |
|
|
chunks = split_text(text) |
|
|
audio_arrays = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
with tts_lock: |
|
|
audio = model.apply_tts( |
|
|
ssml_text=f"<speak>{chunk}</speak>", |
|
|
speaker=silero_speaker, |
|
|
sample_rate=sample_rate, |
|
|
put_accent=True, |
|
|
put_yo=True |
|
|
) |
|
|
audio_arrays.append(audio) |
|
|
|
|
|
full_audio = np.concatenate(audio_arrays) |
|
|
temp_filename = f"temp_{uuid.uuid4().hex}.wav" |
|
|
sf.write(temp_filename, full_audio, sample_rate) |
|
|
return temp_filename |
|
|
|
|
|
def process_single_chunk(chunk_file, speaker): |
|
|
output_filename = f"temp_output_{uuid.uuid4().hex}.wav" |
|
|
try: |
|
|
source_se, _ = se_extractor.get_se(chunk_file, tone_color_converter, vad=True) |
|
|
tone_color_converter.convert( |
|
|
audio_src_path=chunk_file, |
|
|
src_se=source_se, |
|
|
tgt_se=ref_se[speaker], |
|
|
output_path=output_filename, |
|
|
) |
|
|
return output_filename |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing chunk: {str(e)}") |
|
|
return None |
|
|
|
|
|
def merge_audio_files(files, sample_rate): |
|
|
merged = np.array([]) |
|
|
for f in files: |
|
|
audio, _ = sf.read(f) |
|
|
merged = np.concatenate([merged, audio]) |
|
|
return merged |
|
|
|
|
|
def process_line(args): |
|
|
idx, speaker, text = args |
|
|
final_filename = f"t{idx+1}-{speaker}.wav" |
|
|
base_audio = None |
|
|
|
|
|
try: |
|
|
logger.info(f"Processing line {idx+1} for {speaker}") |
|
|
base_audio = generate_audio(text, speaker) |
|
|
|
|
|
if not os.path.exists(base_audio): |
|
|
return None |
|
|
|
|
|
audio_array, sr = sf.read(base_audio) |
|
|
duration = len(audio_array) / sr |
|
|
|
|
|
chunks = [] |
|
|
if duration > 15: |
|
|
chunk_samples = 15 * sr |
|
|
num_full_chunks = len(audio_array) // chunk_samples |
|
|
remainder_samples = len(audio_array) % chunk_samples |
|
|
remainder_duration = remainder_samples / sr |
|
|
|
|
|
chunks = [] |
|
|
for i in range(num_full_chunks): |
|
|
start = i * chunk_samples |
|
|
end = start + chunk_samples |
|
|
chunks.append(audio_array[start:end]) |
|
|
|
|
|
|
|
|
if remainder_samples > 0: |
|
|
if remainder_duration < 10: |
|
|
if chunks: |
|
|
last_chunk = chunks.pop() |
|
|
merged = np.concatenate([last_chunk, audio_array[num_full_chunks*chunk_samples:]]) |
|
|
chunks.append(merged) |
|
|
else: |
|
|
chunks.append(audio_array) |
|
|
else: |
|
|
chunks.append(audio_array[num_full_chunks*chunk_samples:]) |
|
|
|
|
|
|
|
|
valid_chunks = [] |
|
|
for chunk in chunks: |
|
|
chunk_duration = len(chunk)/sr |
|
|
if chunk_duration >= 10: |
|
|
valid_chunks.append(chunk) |
|
|
else: |
|
|
if valid_chunks: |
|
|
prev = valid_chunks.pop() |
|
|
merged = np.concatenate([prev, chunk]) |
|
|
valid_chunks.append(merged) |
|
|
else: |
|
|
valid_chunks.append(chunk) |
|
|
chunks = valid_chunks |
|
|
else: |
|
|
chunks = [audio_array] |
|
|
|
|
|
|
|
|
converted_files = [] |
|
|
for i, chunk in enumerate(chunks): |
|
|
chunk_file = f"temp_chunk_{uuid.uuid4().hex}.wav" |
|
|
sf.write(chunk_file, chunk, sr) |
|
|
chunk_output = process_single_chunk(chunk_file, speaker) |
|
|
if chunk_output: |
|
|
converted_files.append(chunk_output) |
|
|
os.remove(chunk_file) |
|
|
|
|
|
if not converted_files: |
|
|
return None |
|
|
|
|
|
merged_audio = merge_audio_files(converted_files, sr) |
|
|
sf.write(final_filename, merged_audio, sr) |
|
|
|
|
|
|
|
|
for f in converted_files: |
|
|
os.remove(f) |
|
|
|
|
|
return final_filename |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing line {idx+1}: {str(e)}") |
|
|
return None |
|
|
finally: |
|
|
if base_audio and os.path.exists(base_audio): |
|
|
os.remove(base_audio) |
|
|
|
|
|
def create_video(audio_files): |
|
|
try: |
|
|
audio_files.sort(key=lambda x: int(x.split('t')[1].split('-')[0])) |
|
|
clips = [] |
|
|
|
|
|
for audio_file in audio_files: |
|
|
speaker = audio_file.split('-')[1].split('.')[0] |
|
|
gif_file = GIF_MAPPING.get(speaker) |
|
|
|
|
|
if not gif_file or not os.path.exists(gif_file): |
|
|
continue |
|
|
|
|
|
audio_clip = AudioFileClip(audio_file) |
|
|
gif_clip = VideoFileClip(gif_file).loop(duration=audio_clip.duration) |
|
|
gif_clip = gif_clip.set_audio(audio_clip) |
|
|
clips.append(gif_clip) |
|
|
|
|
|
final_video = concatenate_videoclips(clips) |
|
|
video_filename = f"output_{uuid.uuid4().hex[:8]}.mp4" |
|
|
final_video.write_videofile(video_filename, codec='libx264', audio_codec='aac') |
|
|
return video_filename |
|
|
except Exception as e: |
|
|
logger.error(f"Video creation failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
def process_prompt(prompt): |
|
|
try: |
|
|
script = generate_text(prompt) |
|
|
script_data = json.loads(script) |
|
|
|
|
|
tasks = [(idx, speaker, text) |
|
|
for idx, item in enumerate(script_data) |
|
|
for speaker, text in item.items()] |
|
|
|
|
|
audio_files = [] |
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
|
futures = [executor.submit(process_line, task) for task in tasks] |
|
|
for future in as_completed(futures): |
|
|
result = future.result() |
|
|
if result: |
|
|
audio_files.append(result) |
|
|
|
|
|
return create_video(audio_files) if audio_files else None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Processing failed: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.stApp { background-color: #E6E6FA; color: #000080; } |
|
|
h1 { color: #FF0000; text-align: center; font-family: 'Times New Roman', serif; } |
|
|
.stTextArea textarea { background-color: #FFFFFF; color: #000000; } |
|
|
.stButton button { background-color: #FF0000; color: #FFFFFF; font-weight: bold; border-radius: 5px; padding: 10px 20px; } |
|
|
.stMarkdown h3 { color: #000080; text-align: center; } |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("# Прямая линия с Владимиром Путиным ") |
|
|
st.markdown("### Великая Россия! Великий Путин! Великие победы!") |
|
|
|
|
|
prompt = st.text_area("Введите ваш вопрос:", placeholder="Напишите ваш вопрос здесь...", height=100) |
|
|
|
|
|
if st.button("Создать видео") and prompt: |
|
|
with st.spinner("Генерация видео..."): |
|
|
video_filename = process_prompt(prompt) |
|
|
if video_filename: |
|
|
with open(video_filename, "rb") as f: |
|
|
st.video(f.read()) |
|
|
os.remove(video_filename) |
|
|
else: |
|
|
st.error("Не удалось создать видео.") |