Nick021402's picture
Update app.py
845138a verified
from whisper.tokenizer import LANGUAGES as LLANGUAGES
import whisper
import yt_dlp
import os
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import pipeline
import srt
from datetime import timedelta
import gradio as gr
import torchaudio
import whisper.tokenizer
import ffmpeg
import time
# -----------------------------
# Helper Functions
# -----------------------------
def download_youtube_audio(url):
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'audio.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=True)
return "audio.mp3"
except Exception as e:
raise RuntimeError(f"Error downloading audio: {str(e)}")
def extract_audio_from_video(video_path):
output_audio = "audio.mp3"
try:
ffmpeg.input(video_path).output(
output_audio,
vn=None, # Disable video
acodec="libmp3lame", # Use MP3 codec
ar="16000", # 16kHz sample rate
ac="1" # Mono channel
).run(quiet=True, overwrite_output=True)
print(f"Audio extracted to: {output_audio}")
return output_audio
except Exception as e:
raise RuntimeError(f"FFmpeg error: {str(e)}")
def generate_srt(segments):
subs = []
for i, seg in enumerate(segments):
start = timedelta(seconds=seg['start'])
end = timedelta(seconds=seg['end'])
text = seg['text'].strip()
if text:
subs.append(srt.Subtitle(index=i+1, start=start, end=end, content=text))
return srt.compose(subs)
# -----------------------------
# Model Loading Functions
# -----------------------------
def load_kotani_model():
status = "πŸ“₯ Loading Kotani Whisper Small model..."
print(status)
whisper.load_model("small", download_root=".")
print("Model loaded successfully.")
return status
def load_khaiii_model():
status = "πŸ“₯ Loading Khaiii Wav2Vec2 model..."
print(status)
Wav2Vec2Processor.from_pretrained("khaiii/wav2vec2-xls1r-aishell-korean")
Wav2Vec2ForCTC.from_pretrained("khaiii/wav2vec2-xls1r-aishell-korean")
print("Model loaded successfully.")
return status
# -----------------------------
# Transcription Functions
# -----------------------------
def transcribe_kotani(audio_path):
model = whisper.load_model("small", download_root=".")
result = model.transcribe(audio_path, language=None) # auto-detect
return result["segments"], result["language"]
def transcribe_khaiii(audio_path):
processor = Wav2Vec2Processor.from_pretrained("khaiii/wav2vec2-xls1r-aishell-korean")
model = Wav2Vec2ForCTC.from_pretrained("khaiii/wav2vec2-xls1r-aishell-korean")
speech, sr = torchaudio.load(audio_path)
input_values = processor(speech.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
duration = len(speech) / sr
return [{"start": 0, "end": duration, "text": transcription}], "ko"
# -----------------------------
# Translation Function
# -----------------------------
def translate_text(text, src_lang, tgt_lang="en"):
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
try:
translator = pipeline("translation", model=model_name)
translated = translator(text, max_length=400)
return translated[0]['translation_text']
except Exception as e:
return f"[Translation error: {str(e)}]"
# -----------------------------
# Main Processing Function
# -----------------------------
def process_video(youtube_url, video_file, selected_model, translate, target_lang):
status = "⏳ Starting..."
yield status, "", None
try:
# Step 1: Extract audio
if youtube_url:
status = "πŸ“₯ Downloading YouTube audio..."
yield status, "", None
audio_path = download_youtube_audio(youtube_url)
elif video_file:
status = "πŸ“Ό Waiting for upload to complete..."
yield status, "", None
# Wait until file exists
start_time = time.time()
while not os.path.exists(video_file.name):
if time.time() - start_time > 30:
raise RuntimeError("Timeout: File upload took too long.")
time.sleep(1)
status = "πŸ“Ό Extracting audio from video..."
yield status, "", None
audio_path = extract_audio_from_video(video_file.name)
else:
yield "❌ Please provide a video or YouTube URL", "", None
return
# Debug: Confirm audio path
print(f"Audio path: {audio_path}")
# Step 2: Transcribe
if selected_model == "kotani":
status = "πŸŽ™οΈ Transcribing using Kotani Whisper Small..."
yield status, "", None
segments, lang = transcribe_kotani(audio_path)
else:
status = "πŸŽ™οΈ Transcribing using Khaiii Wav2Vec2..."
yield status, "", None
segments, lang = transcribe_khaiii(audio_path)
lang_desc = LLANGUAGES.get(lang, lang.upper())
# Step 3: Translate if needed
if translate:
status = f"🌐 Translating {lang_desc} to {target_lang.upper()}..."
yield status, "", None
translated_segments = []
for seg in segments:
translated = translate_text(seg['text'], lang, target_lang)
translated_segments.append({**seg, "text": translated})
segments = translated_segments
# Step 4: Generate SRT
status = "πŸ“ Generating subtitle file..."
yield status, "", None
srt_content = generate_srt(segments)
with open("output.srt", "w") as f:
f.write(srt_content)
preview = srt_content[:1000] + ("\n..." if len(srt_content) > 1000 else "")
status = f"βœ… Done! ({lang_desc})"
yield status, preview, "output.srt"
except Exception as e:
yield f"❌ Error: {str(e)}", "", None
# -----------------------------
# UI Layout
# -----------------------------
model_desc_kotani = """
<div style="border:1px solid #ddd; padding: 10px; border-radius:8px;">
<strong>Kotani Whisper Small</strong><br>
β–ͺ Fast & multilingual<br>
β–ͺ Good for quick subtitles<br>
β–ͺ Moderate accuracy for Korean
</div>
"""
model_desc_khaiii = """
<div style="border:1px solid #ddd; padding: 10px; border-radius:8px;">
<strong>Khaiii Wav2Vec2</strong><br>
β–ͺ Best Korean speech recognition<br>
β–ͺ Slower but highly accurate<br>
β–ͺ Only supports Korean
</div>
"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
status_box = gr.Textbox(label="Status", interactive=False)
gr.Markdown("## 🌍 Multilingual Subtitle Generator")
gr.Markdown("Upload a video or paste a YouTube link. Automatically detect language and optionally translate subtitles.")
selected_model = gr.State(value="kotani") # default model
gr.Markdown("### πŸ” Choose ASR Model")
with gr.Row():
with gr.Column():
kotani_btn = gr.Button("βœ… Select Kotani Whisper Small")
gr.HTML(model_desc_kotani)
with gr.Column():
khaiii_btn = gr.Button("βœ… Select Khaiii Wav2Vec2")
gr.HTML(model_desc_khaiii)
def select_kotani():
msg = load_kotani_model()
return "kotani", msg
def select_khaiii():
msg = load_khaiii_model()
return "khaiii", msg
kotani_btn.click(fn=select_kotani, outputs=[selected_model, status_box])
khaiii_btn.click(fn=select_khaiii, outputs=[selected_model, status_box])
gr.Markdown("### πŸ“₯ Input Source")
with gr.Row():
youtube_url = gr.Textbox(label="YouTube URL", scale=2)
video_upload = gr.File(label="Upload Video", type="filepath", file_types=["video"], scale=1)
gr.Markdown("### 🌍 Translation Options")
with gr.Row():
translate_checkbox = gr.Checkbox(label="Translate to another language?")
target_lang = gr.Textbox(label="Target Language Code (e.g., 'en')", value="en", visible=False)
def toggle_translate(checked):
return gr.update(visible=checked)
translate_checkbox.change(fn=toggle_translate, inputs=translate_checkbox, outputs=target_lang)
subtitle_preview = gr.Textbox(label="Generated Subtitles", lines=10)
download_file = gr.File(label="Download .srt File")
submit_btn = gr.Button("🎬 Generate Subtitles")
submit_btn.click(
fn=process_video,
inputs=[youtube_url, video_upload, selected_model, translate_checkbox, target_lang],
outputs=[status_box, subtitle_preview, download_file]
)
if __name__ == "__main__":
demo.launch()