Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from faster_whisper import WhisperModel | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from pydub import AudioSegment | |
| import yt_dlp as youtube_dl | |
| import tempfile | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| from gradio.components import Audio, Dropdown, Radio, Textbox | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Paramètres | |
| FILE_LIMIT_MB = 1000 | |
| YT_LENGTH_LIMIT_S = 3600 # Limite de 1 heure pour les vidéos YouTube | |
| # Charger les codes de langue | |
| from flores200_codes import flores_codes | |
| # Fonction pour déterminer le device | |
| def set_device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = set_device() | |
| # Charger les modèles une seule fois | |
| model_dict = {} | |
| def load_models(): | |
| global model_dict | |
| if not model_dict: | |
| model_name_dict = { | |
| #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B', | |
| 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', | |
| #'nllb-1.3B': 'facebook/nllb-200-1.3B', | |
| #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B', | |
| #'nllb-3.3B': 'facebook/nllb-200-3.3B', | |
| # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', | |
| } | |
| for call_name, real_name in model_name_dict.items(): | |
| model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
| tokenizer = AutoTokenizer.from_pretrained(real_name) | |
| model_dict[call_name+'_model'] = model | |
| model_dict[call_name+'_tokenizer'] = tokenizer | |
| load_models() | |
| model_size = "large-v2" | |
| model = WhisperModel(model_size) | |
| # Fonction pour la transcription | |
| def transcribe_audio(audio_file): | |
| # model_size = "large-v2" | |
| # model = WhisperModel(model_size) | |
| # model = WhisperModel(model_size, device=device, compute_type="int8") | |
| global model | |
| segments, _ = model.transcribe(audio_file, beam_size=1) | |
| transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments] | |
| return transcriptions | |
| # Fonction pour la traduction | |
| def traduction(text, source_lang, target_lang): | |
| # Vérifier si les codes de langue sont dans flores_codes | |
| if source_lang not in flores_codes or target_lang not in flores_codes: | |
| print(f"Code de langue non trouvé : {source_lang} ou {target_lang}") | |
| return "" | |
| src_code = flores_codes[source_lang] | |
| tgt_code = flores_codes[target_lang] | |
| model_name = "nllb-distilled-600M" | |
| model = model_dict[model_name + "_model"] | |
| tokenizer = model_dict[model_name + "_tokenizer"] | |
| translator = pipeline("translation", model=model, tokenizer=tokenizer) | |
| return translator(text, src_lang=src_code, tgt_lang=tgt_code)[0]["translation_text"] | |
| # Fonction principale | |
| def full_transcription_and_translation(audio_input, source_lang, target_lang): | |
| # Si audio_input est une URL | |
| if isinstance(audio_input, str) and audio_input.startswith("http"): | |
| audio_file = download_yt_audio(audio_input) | |
| # Si audio_input est un dictionnaire contenant des données audio | |
| elif isinstance(audio_input, dict) and "array" in audio_input and "sampling_rate" in audio_input: | |
| audio_array = audio_input["array"] | |
| sampling_rate = audio_input["sampling_rate"] | |
| # Écrire le tableau NumPy dans un fichier temporaire WAV | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as f: | |
| sf.write(f, audio_array, sampling_rate) | |
| audio_file = f.name | |
| else: | |
| # Supposons que c'est un chemin de fichier | |
| audio_file = audio_input | |
| transcriptions = transcribe_audio(audio_file) | |
| translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions] | |
| # Supprimez le fichier temporaire s'il a été créé | |
| if isinstance(audio_input, dict): | |
| os.remove(audio_file) | |
| return transcriptions, translations | |
| # Téléchargement audio YouTube | |
| """def download_yt_audio(yt_url): | |
| with tempfile.NamedTemporaryFile(suffix='.mp3') as f: | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': f.name, | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'mp3', | |
| 'preferredquality': '192', | |
| }], | |
| } | |
| with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([yt_url]) | |
| return f.name""" | |
| lang_codes = list(flores_codes.keys()) | |
| # Interface Gradio | |
| def gradio_interface(audio_file, source_lang, target_lang): | |
| if audio_file.startswith("http"): | |
| audio_file = download_yt_audio(audio_file) | |
| transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang) | |
| transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions]) | |
| translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations]) | |
| return transcribed_text, translated_text | |
| def _return_yt_html_embed(yt_url): | |
| video_id = yt_url.split("?v=")[-1] | |
| HTML_str = ( | |
| f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' | |
| " </center>" | |
| ) | |
| return HTML_str | |
| def download_yt_audio(yt_url, filename): | |
| info_loader = youtube_dl.YoutubeDL() | |
| try: | |
| info = info_loader.extract_info(yt_url, download=False) | |
| except youtube_dl.utils.DownloadError as err: | |
| raise gr.Error(str(err)) | |
| file_length = info["duration_string"] | |
| file_h_m_s = file_length.split(":") | |
| file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] | |
| if len(file_h_m_s) == 1: | |
| file_h_m_s.insert(0, 0) | |
| if len(file_h_m_s) == 2: | |
| file_h_m_s.insert(0, 0) | |
| file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] | |
| if file_length_s > YT_LENGTH_LIMIT_S: | |
| yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S)) | |
| file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s)) | |
| raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.") | |
| ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"} | |
| with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
| try: | |
| ydl.download([yt_url]) | |
| except youtube_dl.utils.ExtractorError as err: | |
| raise gr.Error(str(err)) | |
| """def yt_transcribe(yt_url,source_lang, target_lang, task, max_filesize=75.0): | |
| html_embed_str = _return_yt_html_embed(yt_url) | |
| global model # S'assurer que le modèle est accessible | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| filepath = os.path.join(tmpdirname, "video.mp4") | |
| download_yt_audio(yt_url, filepath) | |
| with open(filepath, "rb") as f: | |
| inputs = f.read() | |
| inputs = ffmpeg_read(inputs, model.feature_extractor.sampling_rate) | |
| inputs = {"array": inputs, "sampling_rate": model.feature_extractor.sampling_rate} | |
| transcriptions, translations = full_transcription_and_translation(inputs, source_lang, target_lang) | |
| transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions]) | |
| translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations]) | |
| return html_embed_str, transcribed_text, translated_text""" | |
| # Interfaces | |
| demo = gr.Blocks() | |
| with demo: | |
| with gr.Tab("Microphone"): | |
| gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Audio(sources=["microphone"], type="filepath"), | |
| gr.Dropdown(lang_codes, value='French', label='Source Language'), | |
| gr.Dropdown(lang_codes, value='English', label='Target Language')], | |
| outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] | |
| ) | |
| with gr.Tab("Audio file"): | |
| gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Audio file"), | |
| gr.Dropdown(lang_codes, value='French', label='Source Language'), | |
| gr.Dropdown(lang_codes, value='English', label='Target Language')], | |
| outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] | |
| ) | |
| """with gr.Tab("YouTube"): | |
| gr.Interface( | |
| fn=yt_transcribe, | |
| inputs=[ | |
| gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), | |
| gr.Dropdown(lang_codes, value='French', label='Source Language'), | |
| gr.Dropdown(lang_codes, value='English', label='Target Language') | |
| ], | |
| outputs=["html", gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] | |
| )""" | |
| #with demo: | |
| #gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"]) | |
| demo.launch() |