binaryMao's picture
Update app.py
64e18a4 verified
raw
history blame
8.25 kB
# -*- coding: utf-8 -*-
"""
ROBOTSMALI V38 FINAL — SOUS-TITRAGE BAMBARA (STYLE NETFLIX)
Correction V38 : Durée exacte, QuartzNet fonctionnel, pipeline simplifiée
"""
import os, tempfile, traceback, random, textwrap
import numpy as np
import torch
import soundfile as sf
import librosa
from huggingface_hub import snapshot_download
from nemo.collections import asr as nemo_asr
import gradio as gr
from moviepy.editor import VideoFileClip
# ----------------------------
# CONFIG
# ----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
MODELS = {
"Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
"Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"),
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
"Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"),
"QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"),
"QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"),
}
_cache = {}
# ----------------------------
# MODEL LOADING
# ----------------------------
def load_model(name):
if name in _cache: return _cache[name]
repo, mode = MODELS[name]
folder = snapshot_download(repo, local_dir_use_symlinks=False)
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
if not nemo_file:
raise FileNotFoundError(f"Aucun .nemo trouvé pour {name}")
model = (
nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
if mode == "rnnt"
else nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
)
model.to(DEVICE).eval()
_cache[name] = model
return model
# ----------------------------
# AUDIO EXTRACTION & CLEANING
# ----------------------------
def extract_audio(video, wav):
os.system(f'ffmpeg -y -i "{video}" -ar 16000 -ac 1 -vn "{wav}"')
def clean_audio(wav, top_db=35):
audio, sr = sf.read(wav)
if audio.ndim == 2:
audio = audio.mean(1)
max_val = np.max(np.abs(audio)) if audio.size > 0 else 0
if max_val > 1e-6:
audio = audio / max_val * 0.9
clean = wav.replace(".wav", "_clean.wav")
sf.write(clean, audio, sr)
return clean, audio, sr
# ----------------------------
# TRANSCRIPTION
# ----------------------------
def transcribe(model, wav):
out = model.transcribe([wav])
if isinstance(out, list):
if out and hasattr(out[0], "text"):
return out[0].text.strip()
if out and isinstance(out[0], str):
return out[0].strip()
if hasattr(out, "text"):
return out.text.strip()
return str(out).strip()
# ----------------------------
# UTILITAIRES
# ----------------------------
def keep_bambara(words):
res=[]
for w in words:
wl=w.lower()
if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(c in "aeiou" for c in wl)>=2:
res.append(w)
return res
MAX_CHARS=45; MIN_DUR=0.3; MAX_DUR=3.2; MAX_WORDS=8
def wrap2(txt):
parts=textwrap.wrap(txt,MAX_CHARS)
if len(parts)<=1: return txt
mid=len(txt)//2
left=txt.rfind(" ",0,mid)
right=txt.find(" ",mid)
cut=left if (mid-left)<=(right-mid if right!=-1 else 1e9) else right
l1=txt[:cut].strip(); l2=txt[cut:].strip()
return l1+"\n"+l2 if l2 else l1
def pack(spans,total):
tmp=[]
for s,e,t in spans:
s=max(0,min(s,total)); e=max(0,min(e,total))
if e<=s or not t.strip(): continue
tmp.append((s,e,t.strip()))
merged=[]
for seg in tmp:
if not merged: merged.append(seg); continue
ps,pe,pt=merged[-1]; s,e,t=seg
if (e-s)<MIN_DUR or (s-pe)<0.1:
merged[-1]=(ps,max(pe,e),(pt+" "+t).strip())
else: merged.append(seg)
out=[]; last_end=0
for s,e,t in merged:
dur=e-s; words=t.split()
blocks=[" ".join(words[i:i+MAX_WORDS]) for i in range(0,len(words),MAX_WORDS)]
step=dur/max(1,len(blocks)); base=s
for b in blocks:
st=base; en=min(base+step,e); base=en
if en<=st: en=min(st+0.05,total)
txt=wrap2(b)
if st<last_end: st=last_end+1e-3; en=max(en,st+0.05)
out.append((st,en,txt)); last_end=en
return out
# ----------------------------
# ALIGNEMENT SIMPLE (VAD)
# ----------------------------
def align_vad(text,audio,sr,total_dur,top_db=28):
words=keep_bambara(text.split())
total=total_dur
iv=librosa.effects.split(audio,top_db=top_db)
if len(iv)==0 or not words:
return pack([(0,total," ".join(words[:MAX_WORDS]))],total)
spans=[]; L=sum(e-s for s,e in iv); idx=0
for s,e in iv:
seg=e-s; segt=seg/sr; k=max(1,int(round(len(words)*(seg/L))))
chunk=words[idx:idx+k]; idx+=k
if not chunk: continue
lines=[chunk[i:i+MAX_WORDS] for i in range(0,len(chunk),MAX_WORDS)]
step=max(MIN_DUR,min(MAX_DUR,segt/len(lines))); base=s/sr
for j,ln in enumerate(lines):
st=base+j*step; en=base+(j+1)*step
spans.append((st,en," ".join(ln)))
return pack(spans,total)
# ----------------------------
# SOUS-TITRES SRT + FFmpeg
# ----------------------------
def burn(video, subs):
tmp_srt = tempfile.mktemp(suffix=".srt")
out_file = "RobotsMali_Subtitled.mp4"
# Écriture SRT
def sec_to_srt(t):
h=int(t//3600); m=int((t%3600)//60); s=int(t%60); ms=int((t-int(t))*1000)
return f"{h:02}:{m:02}:{s:02},{ms:03}"
with open(tmp_srt,"w",encoding="utf-8") as f:
for i,(start,end,text) in enumerate(subs,1):
f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n")
# Fusion vidéo + sous-titres sans changer durée
os.system(f'ffmpeg -y -i "{video}" -vf "subtitles={tmp_srt}" -c:v copy -c:a aac -b:a 192k "{out_file}"')
if os.path.exists(tmp_srt): os.remove(tmp_srt)
return out_file
# ----------------------------
# PIPELINE PRINCIPAL
# ----------------------------
def pipeline(video, model_name):
try:
wav=tempfile.mktemp(suffix=".wav")
# Extraction audio
extract_audio(video,wav)
clean,audio,sr=clean_audio(wav)
model=load_model(model_name)
text=transcribe(model,clean)
mode=MODELS[model_name][1]
if mode=="rnnt":
from ctc_segmentation import ctc_segmentation,CtcSegmentationParameters,prepare_text
words=keep_bambara(text.split())
if not words: return "⚠️ Aucun sous-titre utilisable",None
x=torch.tensor(audio).float().unsqueeze(0).to(DEVICE)
ln=torch.tensor([x.shape[1]]).to(DEVICE)
with torch.no_grad(): logits=model(input_signal=x,input_signal_length=ln)[0]
tps=VideoFileClip(video).duration/logits.shape[1]
raw=model.tokenizer.vocab
vocab=list(raw.keys()) if isinstance(raw,dict) else list(raw)
cfg=CtcSegmentationParameters(); cfg.char_list=vocab
gt=prepare_text(cfg,words)[0]
timing,_,_=ctc_segmentation(cfg,logits.detach().cpu().numpy()[0],gt)
spans=[(timing[i]*tps,timing[i+1]*tps,words[i]) for i in range(len(words))]
subs=pack(spans,VideoFileClip(video).duration)
else:
subs=align_vad(text,audio,sr,VideoFileClip(video).duration)
if not subs: return "⚠️ Aucun sous-titre utilisable",None
out=burn(video,subs)
return "✅ Terminé avec succès",out
except Exception:
traceback.print_exc()
return "❌ Erreur — voir logs ci-dessus",None
# ----------------------------
# INTERFACE GRADIO
# ----------------------------
with gr.Blocks(title="RobotsMali V38 Final") as demo:
gr.Markdown("## ⚡ RobotsMali V38 — Sous-titrage Style Netflix (QuartzNet & RNNT stable)")
v = gr.Video(label="Vidéo à sous-titrer")
m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
b = gr.Button("▶️ Générer")
s = gr.Markdown()
o = gr.Video(label="Vidéo sous-titrée")
b.click(pipeline, [v, m], [s, o])
demo.launch(share=True, debug=False)