| import io |
| import re |
| import torch |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration |
| import requests |
| from bs4 import BeautifulSoup |
| import tempfile |
| import os |
| import soundfile as sf |
| from spellchecker import SpellChecker |
| from pydub import AudioSegment |
| import librosa |
| import numpy as np |
| from pyannote.audio import Pipeline |
| import dash |
| from dash import dcc, html, Input, Output, State |
| import dash_bootstrap_components as dbc |
| from dash.exceptions import PreventUpdate |
| import base64 |
| import threading |
|
|
| |
| try: |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization") |
| print("Speaker diarization pipeline initialized successfully") |
| except Exception as e: |
| print(f"Error initializing speaker diarization pipeline: {str(e)}") |
| pipeline = None |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| model_name = "openai/whisper-small" |
| processor = WhisperProcessor.from_pretrained(model_name) |
| model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) |
|
|
| spell = SpellChecker() |
|
|
| def download_audio_from_url(url): |
| try: |
| if "share" in url: |
| print("Processing shareable link...") |
| response = requests.get(url) |
| soup = BeautifulSoup(response.content, 'html.parser') |
| video_tag = soup.find('video') |
| if video_tag and 'src' in video_tag.attrs: |
| video_url = video_tag['src'] |
| print(f"Extracted video URL: {video_url}") |
| else: |
| raise ValueError("Direct video URL not found in the shareable link.") |
| else: |
| video_url = url |
| |
| print(f"Downloading video from URL: {video_url}") |
| response = requests.get(video_url) |
| audio_bytes = response.content |
| print(f"Successfully downloaded {len(audio_bytes)} bytes of data") |
| return audio_bytes |
| except Exception as e: |
| print(f"Error in download_audio_from_url: {str(e)}") |
| raise |
|
|
| def correct_spelling(text): |
| words = text.split() |
| corrected_words = [spell.correction(word) or word for word in words] |
| return ' '.join(corrected_words) |
|
|
| def format_transcript_with_speakers(transcript, diarization): |
| formatted_transcript = [] |
| current_speaker = None |
| for segment, _, speaker in diarization.itertracks(yield_label=True): |
| start = segment.start |
| end = segment.end |
| if speaker != current_speaker: |
| if current_speaker is not None: |
| formatted_transcript.append("\n") |
| formatted_transcript.append(f"Speaker {speaker}:\n") |
| current_speaker = speaker |
| segment_text = transcript[start:end].strip() |
| if segment_text: |
| formatted_transcript.append(f"{segment_text}\n") |
| return "".join(formatted_transcript) |
|
|
| def transcribe_audio(audio_file): |
| try: |
| print("Loading audio file...") |
| audio_input, sr = librosa.load(audio_file, sr=16000) |
| audio_input = audio_input.astype(np.float32) |
| print(f"Audio duration: {len(audio_input) / sr:.2f} seconds") |
|
|
| |
| if pipeline: |
| print("Applying speaker diarization...") |
| diarization = pipeline(audio_file) |
| print("Speaker diarization complete.") |
| else: |
| diarization = None |
|
|
| chunk_length = 30 * sr |
| overlap = 5 * sr |
| transcriptions = [] |
| |
| print("Starting transcription...") |
| for i in range(0, len(audio_input), chunk_length - overlap): |
| chunk = audio_input[i:i+chunk_length] |
| input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device) |
| predicted_ids = model.generate(input_features) |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
| transcriptions.extend(transcription) |
| print(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds") |
|
|
| full_transcription = " ".join(transcriptions) |
| print(f"Transcription complete. Full transcription length: {len(full_transcription)} characters") |
|
|
| if diarization: |
| print("Applying formatting with speaker diarization...") |
| formatted_transcription = format_transcript_with_speakers(full_transcription, diarization) |
| else: |
| print("Applying formatting without speaker diarization...") |
| formatted_transcription = format_transcript_with_breaks(full_transcription) |
|
|
| return formatted_transcription |
| except Exception as e: |
| print(f"Error in transcribe_audio: {str(e)}") |
| raise |
|
|
| def format_transcript_with_breaks(transcript): |
| sentences = re.split('(?<=[.!?]) +', transcript) |
| paragraphs = [] |
| current_paragraph = [] |
|
|
| for sentence in sentences: |
| current_paragraph.append(sentence) |
| if len(current_paragraph) >= 3: |
| paragraphs.append(' '.join(current_paragraph)) |
| current_paragraph = [] |
|
|
| if current_paragraph: |
| paragraphs.append(' '.join(current_paragraph)) |
|
|
| return '\n\n'.join(paragraphs) |
|
|
| def transcribe_video(url): |
| try: |
| print(f"Attempting to download audio from URL: {url}") |
| audio_bytes = download_audio_from_url(url) |
| print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data") |
| |
| |
| audio = AudioSegment.from_file(io.BytesIO(audio_bytes)) |
| |
| print(f"Audio duration: {len(audio) / 1000} seconds") |
| |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
| audio.export(temp_audio.name, format="wav") |
| temp_audio_path = temp_audio.name |
|
|
| print("Starting audio transcription...") |
| transcript = transcribe_audio(temp_audio_path) |
| print(f"Transcription completed. Transcript length: {len(transcript)} characters") |
| |
| |
| os.unlink(temp_audio_path) |
|
|
| |
| transcript = correct_spelling(transcript) |
|
|
| return transcript |
| except Exception as e: |
| error_message = f"An error occurred: {str(e)}" |
| print(error_message) |
| return error_message |
|
|
| app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
| app.layout = dbc.Container([ |
| dbc.Row([ |
| dbc.Col([ |
| html.H1("Video Transcription", className="text-center mb-4"), |
| dbc.Card([ |
| dbc.CardBody([ |
| dbc.Input(id="video-url", type="text", placeholder="Enter video URL"), |
| dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"), |
| dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), |
| dcc.Download(id="download-transcript") |
| ]) |
| ]) |
| ], width=12) |
| ]) |
| ], fluid=True) |
|
|
| @app.callback( |
| Output("transcription-output", "children"), |
| Output("download-transcript", "data"), |
| Input("transcribe-button", "n_clicks"), |
| State("video-url", "value"), |
| prevent_initial_call=True |
| ) |
| def update_transcription(n_clicks, url): |
| if not url: |
| raise PreventUpdate |
|
|
| def transcribe(): |
| transcript = transcribe_video(url) |
| return transcript |
|
|
| |
| thread = threading.Thread(target=transcribe) |
| thread.start() |
| thread.join() |
|
|
| transcript = transcribe() |
|
|
| if transcript: |
| download_data = dict(content=transcript, filename="transcript.txt") |
| return dbc.Card([ |
| dbc.CardBody([ |
| html.H5("Transcription Result"), |
| html.Pre(transcript), |
| dbc.Button("Download Transcript", id="btn-download", color="secondary", className="mt-3") |
| ]) |
| ]), download_data |
| else: |
| return "Failed to transcribe video.", None |
|
|
| if __name__ == '__main__': |
| app.run(debug=True, host='0.0.0.0', port=7860) |