Spaces:
Running
Running
| import os | |
| import re | |
| import json | |
| import datetime | |
| import subprocess | |
| import tempfile | |
| import shutil | |
| import numpy as np | |
| import torch | |
| import srt | |
| import gradio as gr | |
| from pathlib import Path | |
| from pydub import AudioSegment | |
| from pydub.effects import speedup | |
| from functools import reduce | |
| import whisper | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| VitsModel, AutoTokenizer as TTSTokenizer | |
| ) | |
| # ============================================================ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SUPPORTED_LANGUAGES = { | |
| "French": ("facebook/mms-tts-fra", "fr"), | |
| "Arabic": ("facebook/mms-tts-ara", "ar"), | |
| "Spanish": ("facebook/mms-tts-spa", "es"), | |
| "German": ("facebook/mms-tts-deu", "de"), | |
| "English": ("facebook/mms-tts-eng", "en"), | |
| } | |
| # Cache des modèles (pour éviter de re-télécharger à chaque requête) | |
| _model_cache = {} | |
| def get_whisper(): | |
| if "whisper" not in _model_cache: | |
| _model_cache["whisper"] = whisper.load_model("base", device=DEVICE) | |
| return _model_cache["whisper"] | |
| def get_llm(): | |
| mid = "Qwen/Qwen2.5-1.5B-Instruct" | |
| if "llm" not in _model_cache: | |
| tok = AutoTokenizer.from_pretrained(mid) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| mid, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" | |
| ) | |
| _model_cache["llm"] = (tok, mdl) | |
| return _model_cache["llm"] | |
| def get_tts(lang: str): | |
| model_id = SUPPORTED_LANGUAGES[lang][0] | |
| key = f"tts_{lang}" | |
| if key not in _model_cache: | |
| tok = TTSTokenizer.from_pretrained(model_id) | |
| mdl = VitsModel.from_pretrained(model_id).to(DEVICE) | |
| mdl.eval() | |
| _model_cache[key] = (tok, mdl) | |
| return _model_cache[key] | |
| # ---- Pipeline functions ---- | |
| def download_video(url: str, work_dir: Path) -> dict: | |
| video = work_dir / "video.mp4" | |
| audio = work_dir / "audio.wav" | |
| cookies_path = "/app/cookies.txt" | |
| yt_cmd = [ | |
| "yt-dlp", | |
| "--cookies", cookies_path, | |
| "-f", "bestvideo[height<=480][ext=mp4]+bestaudio[ext=m4a]/best[height<=480][ext=mp4]", | |
| "--merge-output-format", "mp4", | |
| "-o", str(video), "--no-playlist", url | |
| ] | |
| subprocess.run(yt_cmd, check=True, capture_output=True) | |
| subprocess.run([ | |
| "ffmpeg", "-y", "-i", str(video), | |
| "-ac", "1", "-ar", "16000", "-vn", str(audio) | |
| ], check=True, capture_output=True) | |
| probe = subprocess.run([ | |
| "ffprobe", "-v", "error", "-show_entries", "format=duration", | |
| "-of", "json", str(audio) | |
| ], capture_output=True, text=True) | |
| duration = float(json.loads(probe.stdout)["format"]["duration"]) | |
| return {"video": video, "audio": audio, "duration": duration} | |
| def transcribe(audio_path: Path) -> list: | |
| model = get_whisper() | |
| result = model.transcribe(str(audio_path), word_timestamps=True, verbose=False) | |
| segments = [ | |
| {"text": s["text"].strip(), "start": round(s["start"], 3), | |
| "end": round(s["end"], 3), "duration": round(s["end"] - s["start"], 3)} | |
| for s in result["segments"] if s["text"].strip() | |
| ] | |
| lang = result.get("language", "english").capitalize() | |
| return segments, lang | |
| def translate_segment(text: str, src: str, tgt: str) -> str: | |
| tok, mdl = get_llm() | |
| sys_p = (f"You are a professional subtitle translator. Translate from {src} to {tgt}. " | |
| f"Output ONLY the translation, nothing else.") | |
| msgs = [{"role": "system", "content": sys_p}, {"role": "user", "content": text}] | |
| input_text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| inputs = tok(input_text, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = mdl.generate(**inputs, max_new_tokens=150, temperature=0.3, | |
| do_sample=True, repetition_penalty=1.1, | |
| pad_token_id=tok.eos_token_id) | |
| gen = out[0][inputs.input_ids.shape[1]:] | |
| tr = tok.decode(gen, skip_special_tokens=True).strip() | |
| lines = [l.strip() for l in tr.splitlines() if l.strip()] | |
| return lines[0] if lines else tr | |
| def build_dubbed_audio(segments: list, tgt_lang: str, total_s: float, out: Path) -> Path: | |
| tts_tok, tts_mdl = get_tts(tgt_lang) | |
| sr = tts_mdl.config.sampling_rate | |
| track = AudioSegment.silent(duration=int(total_s * 1000)) | |
| for seg in segments: | |
| text = seg["translated_text"].strip() | |
| if not text: | |
| continue | |
| inputs = tts_tok(text, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| wav = tts_mdl(**inputs).waveform[0].cpu().numpy() | |
| wav_i16 = (wav * 32767).astype(np.int16) | |
| seg_aud = AudioSegment(wav_i16.tobytes(), frame_rate=sr, sample_width=2, channels=1) | |
| target_s = seg["duration"] | |
| actual_s = len(seg_aud) / 1000 | |
| if actual_s > target_s and target_s > 0.1: | |
| spd = min(actual_s / target_s, 2.0) | |
| seg_aud = speedup(seg_aud, spd, chunk_size=50, crossfade=25) | |
| track = track.overlay(seg_aud, position=int(seg["start"] * 1000)) | |
| track.export(str(out), format="wav") | |
| return out | |
| def mix_audio(orig: Path, dubbed: Path, segments: list, out: Path) -> Path: | |
| original = AudioSegment.from_wav(str(orig)).set_frame_rate(44100).set_channels(2) | |
| dub = AudioSegment.from_wav(str(dubbed)).set_frame_rate(44100).set_channels(2) | |
| total = len(original) | |
| if len(dub) < total: | |
| dub = dub + AudioSegment.silent(duration=total - len(dub), frame_rate=44100) | |
| dub = dub[:total] | |
| parts, prev = [], 0 | |
| for seg in segments: | |
| s, e = int(seg["start"]*1000), int(seg["end"]*1000) | |
| if s > prev: | |
| parts.append(original[prev:s]) | |
| chunk = original[s:e] + (-20) | |
| parts.append(chunk) | |
| prev = e | |
| if prev < total: | |
| parts.append(original[prev:]) | |
| ducked = reduce(lambda a, b: a + b, parts) if parts else original | |
| final = ducked.overlay(dub + (-0.9)) | |
| final.export(str(out), format="wav") | |
| return out | |
| def burn_subtitles(video: Path, audio: Path, srt_path: Path, out: Path) -> Path: | |
| srt_esc = str(srt_path).replace("\\", "/") | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-i", str(video), | |
| "-i", str(audio), | |
| "-vf", f"subtitles={srt_esc}:force_style='FontSize=18,PrimaryColour=&H00FFFFFF,OutlineColour=&H00000000,Outline=2,Bold=1'", | |
| "-c:v", "libx264", "-preset", "ultrafast", "-crf", "28", | |
| "-c:a", "aac", "-b:a", "128k", | |
| "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(out) | |
| ] | |
| r = subprocess.run(cmd, capture_output=True, text=True) | |
| if r.returncode != 0: | |
| cmd2 = ["ffmpeg", "-y", "-i", str(video), "-i", str(audio), | |
| "-c:v", "copy", "-c:a", "aac", "-b:a", "128k", | |
| "-map", "0:v:0", "-map", "1:a:0", "-shortest", str(out)] | |
| subprocess.run(cmd2, check=True, capture_output=True) | |
| return out | |
| # ---- PIPELINE COMPLET ---- | |
| def run_pipeline(youtube_url: str, target_language: str, progress=gr.Progress()) -> str: | |
| if not youtube_url.strip(): | |
| raise gr.Error("Veuillez entrer un URL YouTube valide") | |
| work_dir = Path(tempfile.mkdtemp(prefix="dubbing_")) | |
| try: | |
| progress(0.05, desc="Téléchargement de la vidéo...") | |
| files = download_video(youtube_url, work_dir) | |
| progress(0.20, desc="Transcription Whisper...") | |
| segments, src_lang = transcribe(files["audio"]) | |
| subs_orig = [srt.Subtitle(i+1, | |
| datetime.timedelta(seconds=s["start"]), | |
| datetime.timedelta(seconds=s["end"]), | |
| s["text"]) for i, s in enumerate(segments)] | |
| progress(0.40, desc="Traduction en cours...") | |
| translated = [] | |
| for seg in segments: | |
| tr = translate_segment(seg["text"], src_lang, target_language) | |
| translated.append({**seg, "translated_text": tr}) | |
| srt_file = work_dir / "translated.srt" | |
| subs_tr = [srt.Subtitle(i+1, | |
| datetime.timedelta(seconds=s["start"]), | |
| datetime.timedelta(seconds=s["end"]), | |
| s["translated_text"]) for i, s in enumerate(translated)] | |
| srt_file.write_text(srt.compose(subs_tr), encoding="utf-8") | |
| progress(0.60, desc="Génération audio (TTS)...") | |
| dubbed_wav = work_dir / "dubbed.wav" | |
| build_dubbed_audio(translated, target_language, files["duration"], dubbed_wav) | |
| progress(0.80, desc="Mixage audio...") | |
| mixed_wav = work_dir / "mixed.wav" | |
| mix_audio(files["audio"], dubbed_wav, translated, mixed_wav) | |
| progress(0.90, desc="Création vidéo finale...") | |
| final_video = work_dir / "final.mp4" | |
| burn_subtitles(files["video"], mixed_wav, srt_file, final_video) | |
| output_path = Path("/tmp/output_dubbed.mp4") | |
| shutil.copy(final_video, output_path) | |
| progress(1.0, desc="Terminé !") | |
| return str(output_path) | |
| except Exception as e: | |
| raise gr.Error(f"Erreur pipeline : {str(e)[:300]}") | |
| finally: | |
| shutil.rmtree(work_dir, ignore_errors=True) | |
| # ---- INTERFACE GRADIO ---- | |
| with gr.Blocks(title="Video Dubbing Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎬 Automated Video Dubbing Pipeline | |
| Entrez un lien YouTube (30s-1min) et choisissez la langue cible. | |
| Le pipeline transcrit, traduit, génère une voix et produit une vidéo doublée. | |
| > **Note :** Le traitement prend environ 5-15 minutes selon la durée de la vidéo. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| url_input = gr.Textbox( | |
| label="YouTube URL", | |
| placeholder="https://www.youtube.com/watch?v=...", | |
| lines=1 | |
| ) | |
| lang_choice = gr.Dropdown( | |
| choices=list(SUPPORTED_LANGUAGES.keys()), | |
| value="French", | |
| label="Langue cible" | |
| ) | |
| run_btn = gr.Button("🚀 Lancer le doublage", variant="primary") | |
| with gr.Column(scale=3): | |
| video_output = gr.Video(label="Vidéo doublée") | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[url_input, lang_choice], | |
| outputs=video_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |