UI-VieNeu / worker.py
HuuDatLego's picture
Upload folder using huggingface_hub
911c66e verified
import os
import re
import unicodedata
import soundfile as sf
from dotenv import load_dotenv
load_dotenv(override=True)
from celery import Celery
import tempfile
import sys
import redis
from contextlib import contextmanager
from services.ai_pipeline import process_video_pipeline, generate_tts_only, process_studio_pipeline
from supabase import create_client, Client
redis_client = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
class ProgressCatcher:
def __init__(self, original_stream, job_id):
self.original_stream = original_stream
self.job_id = job_id
def write(self, text):
self.original_stream.write(text)
match = re.search(r'\[(\d+:\d+)<(\d+:\d+)', text)
if match:
elapsed = match.group(1)
remaining = match.group(2)
redis_client.set(f"progress_{self.job_id}", f"{elapsed}|{remaining}")
def flush(self):
self.original_stream.flush()
@contextmanager
def catch_progress(job_id):
original_stderr = sys.stderr
sys.stderr = ProgressCatcher(original_stderr, job_id)
try:
yield
finally:
sys.stderr = original_stderr
def slugify(text):
# Chuyển tiếng Việt có dấu thành không dấu
text = unicodedata.normalize('NFD', text).encode('ascii', 'ignore').decode("utf-8")
# Xóa ký tự đặc biệt, chuyển sang lowercase, thay khoảng trắng bằng gạch dưới
text = re.sub(r'[^\w\s-]', '', text).strip().lower()
text = re.sub(r'[-\s]+', '_', text)
return text[:30]
# Initialize Celery pointing to Redis
celery_app = Celery(
"video_tasks",
broker=os.getenv("REDIS_URL", "redis://localhost:6379/0"),
backend=os.getenv("REDIS_URL", "redis://localhost:6379/0")
)
@celery_app.task(bind=True, max_retries=3)
def render_video_task(self, job_id: str, script: str, ref_audio_path: str, aspect_ratio: str, sub_style: str, font_name: str, highlight_color: str):
# Setup Supabase client per worker (avoid circular dependency with main.py)
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
try:
# Update DB status
supabase.table("video_jobs").update({"status": "processing"}).eq("id", job_id).execute()
# Download files to local temp storage
with tempfile.TemporaryDirectory() as tmpdir:
local_ref = None
if ref_audio_path:
local_ref = os.path.join(tmpdir, "ref.wav")
with open(local_ref, "wb") as f:
f.write(supabase.storage.from_("content").download(ref_audio_path))
# RUN CORE ML & FFMPEG LOGIC
output_mp4 = process_video_pipeline(tmpdir, script, local_ref, aspect_ratio, sub_style, font_name, highlight_color)
# Upload Result
result_path = f"rendered/{job_id}_final.mp4"
with open(output_mp4, "rb") as f:
supabase.storage.from_("content").upload(path=result_path, file=f.read())
# Finish
supabase.table("video_jobs").update({
"status": "completed",
"result_url": supabase.storage.from_("content").get_public_url(result_path)
}).eq("id", job_id).execute()
except Exception as e:
supabase.table("video_jobs").update({"status": "failed", "error": str(e)}).eq("id", job_id).execute()
raise e
@celery_app.task
def generate_tts_task(job_id: str, script: str, voice: str, temperature: float, ref_audio_path: str = None, bgm_path: str = None, bgm_volume: float = 0.1):
# Setup Supabase client per worker
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
supabase.table("video_jobs").update({"status": "processing"}).eq("id", job_id).execute()
# --- AUTO-CLEAN SCRIPT ---
# Xóa các đoạn ghi chú trong ngoặc vuông [như thế này] để không bị đọc nhầm (Bỏ qua [p:ms], [v:voice], [s:speed])
script = re.sub(r'\[(?!(?:p|v|s):\d*\.?\d*\]).*?\]', '', script).strip()
script = re.sub(r'\.{2,}', ',', script)
script = re.sub(r'\s+', ' ', script).strip()
# -------------------------
try:
with tempfile.TemporaryDirectory() as tmpdir:
# 1. Download ref audio if it exists
local_ref_path = None
if ref_audio_path:
local_ref_path = os.path.join(tmpdir, "input_ref.wav")
with open(local_ref_path, 'wb') as f:
f.write(supabase.storage.from_("content").download(ref_audio_path))
elif voice and (voice.endswith(".mp3") or voice.endswith(".wav")) and os.path.exists(voice):
# Use local static voice file as reference (Copy to tmpdir to avoid overwriting original)
import shutil
local_ref_path = os.path.join(tmpdir, "static_ref.wav")
shutil.copy(voice, local_ref_path)
if local_ref_path:
# --- AUTO-TRIM LOGIC ---
# Đọc audio và cắt lấy 15 giây đầu để tránh tràn RAM (OOM) trên OnnxRuntime
try:
data, samplerate = sf.read(local_ref_path)
# Nếu là stereo (2 kênh), lấy trung bình hoặc chỉ lấy 1 kênh
if len(data.shape) > 1:
data = data[:, 0]
max_samples = 15 * samplerate
if len(data) > max_samples:
print(f"DEBUG: Audio mẫu quá dài ({len(data)/samplerate:.2f}s), tự động cắt còn 15s.")
data = data[:max_samples]
sf.write(local_ref_path, data, samplerate)
except Exception as trim_err:
print(f"Warning: Không thể cắt audio mẫu: {trim_err}")
# -----------------------
# 2. Download BGM if provided
local_bgm_path = None
if bgm_path:
if os.path.exists(bgm_path):
local_bgm_path = bgm_path
else:
local_bgm_path = os.path.join(tmpdir, "bgm.mp3")
with open(local_bgm_path, 'wb') as f:
f.write(supabase.storage.from_("content").download(bgm_path))
# 3. Run Pure TTS Engine
result_audio_local = generate_tts_only(tmpdir, script, local_ref_path, temperature, local_bgm_path, bgm_volume)
# 3. Upload Result Audio
# Tạo tên file thân thiện từ 30 ký tự đầu của script
friendly_name = slugify(script)
final_audio_path = f"results/{friendly_name}_{job_id[:8]}.wav"
with open(result_audio_local, 'rb') as f:
supabase.storage.from_("content").upload(
path=final_audio_path,
file=f,
file_options={"content-type": "audio/wav"}
)
public_url = supabase.storage.from_("content").get_public_url(final_audio_path)
# 4. Mark job as complete
supabase.table("video_jobs").update({
"status": "completed",
"result_url": public_url
}).eq("id", job_id).execute()
except Exception as e:
import traceback
traceback.print_exc()
supabase.table("video_jobs").update({
"status": "error",
"error": str(e)
}).eq("id", job_id).execute()
raise e
@celery_app.task
def render_studio_task(job_id: str, script: str, temperature: float = 0.5, voice_preset: str = "default", bgm_path: str = None, bgm_volume: float = 0.1):
"""
Background job for Studio to render full MP4 with expression tracking.
"""
# Setup Supabase client per worker
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
try:
supabase.table("video_jobs").update({"status": "processing"}).eq("id", job_id).execute()
with tempfile.TemporaryDirectory() as tmpdir:
with catch_progress(job_id):
# Download BGM if provided
local_bgm_path = None
if bgm_path:
if os.path.exists(bgm_path):
local_bgm_path = bgm_path
else:
local_bgm_path = os.path.join(tmpdir, "bgm.mp3")
with open(local_bgm_path, 'wb') as f:
f.write(supabase.storage.from_("content").download(bgm_path))
output_mp4 = process_studio_pipeline(tmpdir, script, temperature, voice_preset, local_bgm_path, bgm_volume)
# Upload Result Video
friendly_name = slugify(script)
final_video_path = f"results/studio_{friendly_name}_{job_id[:8]}.mp4"
with open(output_mp4, 'rb') as f:
supabase.storage.from_("content").upload(
path=final_video_path,
file=f,
file_options={"content-type": "video/mp4"}
)
public_url = supabase.storage.from_("content").get_public_url(final_video_path)
supabase.table("video_jobs").update({
"status": "completed",
"result_url": public_url
}).eq("id", job_id).execute()
except Exception as e:
import traceback
traceback.print_exc()
supabase.table("video_jobs").update({
"status": "error",
"error": str(e)
}).eq("id", job_id).execute()
raise e