streamlit / app.py
tester1hf's picture
Update app.py
2f74b6f verified
import streamlit as st
from g4f.client import Client
import json
import torch
import soundfile as sf
import os
import argparse
from tqdm import tqdm
from openvoice_cli.downloader import download_checkpoint
from openvoice_cli.api import ToneColorConverter
import openvoice_cli.se_extractor as se_extractor
import glob
import uuid
import logging
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from moviepy.editor import AudioFileClip, VideoFileClip, concatenate_videoclips
# Streamlit UI
st.set_page_config(
page_title="Прямая линия с Путиным",
page_icon="🇷🇺",
layout="centered",
initial_sidebar_state="expanded"
)
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Initialize G4F client
client = Client()
# Silero TTS configuration
language = 'ru'
model_id = 'ru_v3'
sample_rate = 48000
device = torch.device('cpu')
# Load Silero TTS model with thread safety
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_tts',
language=language,
speaker=model_id)
model.to(device)
tts_lock = threading.Lock()
# GIF mappings
GIF_MAPPING = {
"Киселёв": "kisel.gif",
"Путин": "putin.gif"
}
@st.cache_resource
def setup_openvoice():
"""Initialize OpenVoice components once with caching"""
logger.info("Initializing OpenVoice components...")
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, 'checkpoints')
ckpt_converter = os.path.join(checkpoints_dir, 'converter')
# Download checkpoints if needed
if not os.path.exists(ckpt_converter):
logger.info("Downloading OpenVoice checkpoints...")
os.makedirs(ckpt_converter, exist_ok=True)
download_checkpoint(ckpt_converter)
# Initialize converter
device = 'cpu'
tone_color_converter = ToneColorConverter(
os.path.join(ckpt_converter, 'config.json'),
device=device
)
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth'))
# Precompute reference speaker embeddings
ref_se = {}
speaker_refs = {
"Киселёв": "kisel.mp3",
"Путин": "putin.mp3"
}
for speaker, ref_file in speaker_refs.items():
if not os.path.exists(ref_file):
raise FileNotFoundError(f"Reference file {ref_file} not found")
logger.info(f"Extracting speaker embedding for {speaker}...")
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True)
ref_se[speaker] = target_se
return tone_color_converter, ref_se
# Initialize OpenVoice components
try:
tone_color_converter, ref_se = setup_openvoice()
except Exception as e:
logger.error(f"Failed to initialize OpenVoice: {str(e)}")
st.error("Ошибка инициализации голосового движка. Пожалуйста, проверьте файлы конфигурации.")
st.stop()
def get_task(prompt):
return f'''Ты генерируешь скрипт продолжения прямой линии с Владимиром Владимировичем Путиным. Вопросы телезрителей передаёт Киселёв - он максимально подобострастен, поддерживает Путина и Россию. Пиши в духе пропаганды пророссийской. Каждая реплика - не менее 150-200 символов. Английские названия при необходимости всегда транскрибируй кириллицей. Числительные также записывай кириллицей. Пиши в стиле кремлёвской пропаганды. Пиши в стереотипном стиле. Киселёв задаёт вопросы, обрашается к Путину как "Владимир Владимирович" или "Господин Президент". Путин отвечает. Всё максимально аутентично. Фразы Киселёва: {{"Киселёв":"фраза"}} Фразы Путина: {{"Путин":"фраза"}} Ответ дай в формате JSON без дополнительных символов: [{{"Киселёв":"фраза"}}, {{"Путин":"фраза"}} . . . ].
Вопрос от пользователя поступил: "{prompt}"'''
def validate_response(response):
try:
data = json.loads(response)
if not isinstance(data, list):
logger.warning("Invalid response: Root element is not a list")
return False
for idx, item in enumerate(data):
if not isinstance(item, dict):
logger.warning(f"Invalid item #{idx+1}: Not a dictionary")
return False
if len(item) != 1:
logger.warning(f"Invalid item #{idx+1}: Contains {len(item)} keys instead of 1")
return False
key = next(iter(item.keys()))
if key not in ["Киселёв", "Путин"]:
logger.warning(f"Invalid item #{idx+1}: Unexpected speaker '{key}'")
return False
return True
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error: {str(e)}")
return False
def generate_text(prompt):
logger.info(f"Generating text for prompt: '{prompt}'")
max_retries = 40
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model="llama-3.3-70b",
messages=[{"role": "user", "content": get_task(prompt)}],
web_search=False
)
response_text = response.choices[0].message.content
if validate_response(response_text):
return response_text
except Exception as e:
logger.error(f"API call failed: {str(e)}")
return '[{"Киселёв":"К сожалению, не удалось расслышать вопрос. Пожалуйста, попробуйте еще раз."}, {"Путин":"Мы работаем над улучшением системы. Спасибо за понимание."}]'
def split_text(text, max_length=800):
chunks = []
while len(text) > max_length:
split_at = text.rfind(' ', 0, max_length)
if split_at == -1:
split_at = max_length
chunks.append(text[:split_at])
text = text[split_at:].lstrip()
chunks.append(text)
return chunks
def generate_audio(text, speaker_name):
logger.info(f"Generating audio for {speaker_name} ({len(text)} characters)")
silero_speaker = 'aidar' if speaker_name == 'Киселёв' else 'baya'
chunks = split_text(text)
audio_arrays = []
for chunk in chunks:
with tts_lock:
audio = model.apply_tts(
ssml_text=f"<speak>{chunk}</speak>",
speaker=silero_speaker,
sample_rate=sample_rate,
put_accent=True,
put_yo=True
)
audio_arrays.append(audio)
full_audio = np.concatenate(audio_arrays)
temp_filename = f"temp_{uuid.uuid4().hex}.wav"
sf.write(temp_filename, full_audio, sample_rate)
return temp_filename
def process_single_chunk(chunk_file, speaker):
output_filename = f"temp_output_{uuid.uuid4().hex}.wav"
try:
source_se, _ = se_extractor.get_se(chunk_file, tone_color_converter, vad=True)
tone_color_converter.convert(
audio_src_path=chunk_file,
src_se=source_se,
tgt_se=ref_se[speaker],
output_path=output_filename,
)
return output_filename
except Exception as e:
logger.error(f"Error processing chunk: {str(e)}")
return None
def merge_audio_files(files, sample_rate):
merged = np.array([])
for f in files:
audio, _ = sf.read(f)
merged = np.concatenate([merged, audio])
return merged
def process_line(args):
idx, speaker, text = args
final_filename = f"t{idx+1}-{speaker}.wav"
base_audio = None
try:
logger.info(f"Processing line {idx+1} for {speaker}")
base_audio = generate_audio(text, speaker)
if not os.path.exists(base_audio):
return None
audio_array, sr = sf.read(base_audio)
duration = len(audio_array) / sr
chunks = []
if duration > 15:
chunk_samples = 15 * sr
num_full_chunks = len(audio_array) // chunk_samples
remainder_samples = len(audio_array) % chunk_samples
remainder_duration = remainder_samples / sr
chunks = []
for i in range(num_full_chunks):
start = i * chunk_samples
end = start + chunk_samples
chunks.append(audio_array[start:end])
# Handle remainder
if remainder_samples > 0:
if remainder_duration < 10:
if chunks:
last_chunk = chunks.pop()
merged = np.concatenate([last_chunk, audio_array[num_full_chunks*chunk_samples:]])
chunks.append(merged)
else:
chunks.append(audio_array)
else:
chunks.append(audio_array[num_full_chunks*chunk_samples:])
# Validate chunks durations
valid_chunks = []
for chunk in chunks:
chunk_duration = len(chunk)/sr
if chunk_duration >= 10:
valid_chunks.append(chunk)
else:
if valid_chunks:
prev = valid_chunks.pop()
merged = np.concatenate([prev, chunk])
valid_chunks.append(merged)
else:
valid_chunks.append(chunk)
chunks = valid_chunks
else:
chunks = [audio_array]
# Process each chunk
converted_files = []
for i, chunk in enumerate(chunks):
chunk_file = f"temp_chunk_{uuid.uuid4().hex}.wav"
sf.write(chunk_file, chunk, sr)
chunk_output = process_single_chunk(chunk_file, speaker)
if chunk_output:
converted_files.append(chunk_output)
os.remove(chunk_file)
if not converted_files:
return None
merged_audio = merge_audio_files(converted_files, sr)
sf.write(final_filename, merged_audio, sr)
# Cleanup converted files
for f in converted_files:
os.remove(f)
return final_filename
except Exception as e:
logger.error(f"Error processing line {idx+1}: {str(e)}")
return None
finally:
if base_audio and os.path.exists(base_audio):
os.remove(base_audio)
def create_video(audio_files):
try:
audio_files.sort(key=lambda x: int(x.split('t')[1].split('-')[0]))
clips = []
for audio_file in audio_files:
speaker = audio_file.split('-')[1].split('.')[0]
gif_file = GIF_MAPPING.get(speaker)
if not gif_file or not os.path.exists(gif_file):
continue
audio_clip = AudioFileClip(audio_file)
gif_clip = VideoFileClip(gif_file).loop(duration=audio_clip.duration)
gif_clip = gif_clip.set_audio(audio_clip)
clips.append(gif_clip)
final_video = concatenate_videoclips(clips)
video_filename = f"output_{uuid.uuid4().hex[:8]}.mp4"
final_video.write_videofile(video_filename, codec='libx264', audio_codec='aac')
return video_filename
except Exception as e:
logger.error(f"Video creation failed: {str(e)}")
raise
def process_prompt(prompt):
try:
script = generate_text(prompt)
script_data = json.loads(script)
tasks = [(idx, speaker, text)
for idx, item in enumerate(script_data)
for speaker, text in item.items()]
audio_files = []
with ThreadPoolExecutor(max_workers=1) as executor:
futures = [executor.submit(process_line, task) for task in tasks]
for future in as_completed(futures):
result = future.result()
if result:
audio_files.append(result)
return create_video(audio_files) if audio_files else None
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
return None
st.markdown("""
<style>
.stApp { background-color: #E6E6FA; color: #000080; }
h1 { color: #FF0000; text-align: center; font-family: 'Times New Roman', serif; }
.stTextArea textarea { background-color: #FFFFFF; color: #000000; }
.stButton button { background-color: #FF0000; color: #FFFFFF; font-weight: bold; border-radius: 5px; padding: 10px 20px; }
.stMarkdown h3 { color: #000080; text-align: center; }
</style>
""", unsafe_allow_html=True)
st.markdown("# Прямая линия с Владимиром Путиным ")
st.markdown("### Великая Россия! Великий Путин! Великие победы!")
prompt = st.text_area("Введите ваш вопрос:", placeholder="Напишите ваш вопрос здесь...", height=100)
if st.button("Создать видео") and prompt:
with st.spinner("Генерация видео..."):
video_filename = process_prompt(prompt)
if video_filename:
with open(video_filename, "rb") as f:
st.video(f.read())
os.remove(video_filename)
else:
st.error("Не удалось создать видео.")