Upload 3 files
Browse files- app.py +159 -0
- audio_transcribe.png +0 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import whisper
|
| 3 |
+
from pyannote.audio import Pipeline
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
os.environ["TRANSFORMERS_CACHE"] = "C:\\Users\\Admin\\.cache\\Documents\\huggingface_cache"
|
| 8 |
+
# Load models with error handling
|
| 9 |
+
UPLOAD_FOLDER = 'upload'
|
| 10 |
+
|
| 11 |
+
# Set up logging to write to a file
|
| 12 |
+
logging.basicConfig(filename="audio_transcription.log",
|
| 13 |
+
level=logging.INFO, # Adjust the log level as needed (e.g., DEBUG, INFO, WARNING)
|
| 14 |
+
format="%(asctime)s - %(levelname)s - %(message)s")
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_models():
|
| 17 |
+
try:
|
| 18 |
+
# Load Whisper model
|
| 19 |
+
whisper_model = whisper.load_model("medium") # Use "medium" or "large" for better accuracy
|
| 20 |
+
if whisper_model is None:
|
| 21 |
+
raise ValueError("Whisper model failed to load.")
|
| 22 |
+
|
| 23 |
+
# Load PyAnnote diarization pipeline with token
|
| 24 |
+
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
|
| 25 |
+
use_auth_token="YOUR_TOKEN")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if diarization_pipeline is None:
|
| 29 |
+
raise ValueError("Diarization model failed to load.")
|
| 30 |
+
|
| 31 |
+
return whisper_model, diarization_pipeline
|
| 32 |
+
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error loading models: {e}")
|
| 35 |
+
st.error(f"Error loading models: {e}")
|
| 36 |
+
return None, None
|
| 37 |
+
|
| 38 |
+
# Initialize models
|
| 39 |
+
whisper_model, diarization_pipeline = load_models()
|
| 40 |
+
|
| 41 |
+
# Function to handle file upload and save to the 'upload' directory
|
| 42 |
+
def save_uploaded_file(uploaded_file):
|
| 43 |
+
# Create a timestamp for the filename
|
| 44 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
| 45 |
+
file_extension = uploaded_file.name.split('.')[-1]
|
| 46 |
+
file_name = f"{timestamp}_{uploaded_file.name}"
|
| 47 |
+
|
| 48 |
+
# Save the uploaded file to the 'upload' directory with the timestamped filename
|
| 49 |
+
file_path = os.path.join(UPLOAD_FOLDER, file_name)
|
| 50 |
+
with open(file_path, "wb") as f:
|
| 51 |
+
f.write(uploaded_file.getbuffer())
|
| 52 |
+
|
| 53 |
+
return file_path
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Function to process audio from a file path
|
| 57 |
+
def process_audio(audio_file_path):
|
| 58 |
+
if whisper_model is None or diarization_pipeline is None:
|
| 59 |
+
st.error("Models are not loaded properly. Please check the logs.")
|
| 60 |
+
return None, None
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# Log the start of processing
|
| 64 |
+
logging.info(f"Started processing audio file: {audio_file_path}")
|
| 65 |
+
# Load audio with Whisper's load_audio function
|
| 66 |
+
audio = whisper.load_audio(audio_file_path)
|
| 67 |
+
audio = whisper.pad_or_trim(audio) # Ensure the audio is the correct length
|
| 68 |
+
|
| 69 |
+
# Transcribe with Whisper and get word-level timestamps
|
| 70 |
+
transcription = whisper_model.transcribe(audio, word_timestamps=True)
|
| 71 |
+
text = transcription["text"]
|
| 72 |
+
word_timestamps = transcription["segments"] # This contains the word timings
|
| 73 |
+
detected_language = transcription["language"]
|
| 74 |
+
|
| 75 |
+
# Log the detected language
|
| 76 |
+
logging.info(f"Detected language: {detected_language}")
|
| 77 |
+
|
| 78 |
+
# Perform speaker diarization
|
| 79 |
+
diarization = diarization_pipeline({"uri": "audio", "audio": audio_file_path})
|
| 80 |
+
|
| 81 |
+
# To store speaker-labeled text
|
| 82 |
+
labeled_text = []
|
| 83 |
+
|
| 84 |
+
# Process diarization and align it with transcribed words
|
| 85 |
+
current_word_index = 0
|
| 86 |
+
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
| 87 |
+
start = segment.start
|
| 88 |
+
end = segment.end
|
| 89 |
+
labeled_segment = f"[{speaker}] "
|
| 90 |
+
|
| 91 |
+
# Add words to the labeled segment based on the diarization timestamps
|
| 92 |
+
while current_word_index < len(word_timestamps):
|
| 93 |
+
word_info = word_timestamps[current_word_index]
|
| 94 |
+
word_start = word_info["start"]
|
| 95 |
+
word_end = word_info["end"]
|
| 96 |
+
word_text = word_info["text"]
|
| 97 |
+
|
| 98 |
+
# Check if the word's timing falls within the diarization segment
|
| 99 |
+
if word_end <= end:
|
| 100 |
+
labeled_segment += word_text + " "
|
| 101 |
+
current_word_index += 1
|
| 102 |
+
else:
|
| 103 |
+
break # Exit when we've processed all words within the diarization segment
|
| 104 |
+
|
| 105 |
+
labeled_text.append(labeled_segment.strip())
|
| 106 |
+
# Log each speaker's contribution
|
| 107 |
+
logging.info(f"Speaker {speaker} spoke: {labeled_segment.strip()}")
|
| 108 |
+
# Log the completion of processing
|
| 109 |
+
logging.info(f"Processing completed for: {audio_file_path}")
|
| 110 |
+
|
| 111 |
+
return labeled_text, detected_language
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
st.error(f"Error processing audio: {e}")
|
| 115 |
+
# Log the error
|
| 116 |
+
logging.error(f"Error processing audio {audio_file_path}: {e}")
|
| 117 |
+
print(f"Error processing audio: {e}")
|
| 118 |
+
return None, None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Streamlit App UI
|
| 122 |
+
st.title("Multilingual Audio Transcription with Speaker Labels")
|
| 123 |
+
st.write("Select an audio file from the 'upload' folder to transcribe and detect speakers.")
|
| 124 |
+
|
| 125 |
+
# Upload audio file
|
| 126 |
+
uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "m4a"])
|
| 127 |
+
if uploaded_file is not None:
|
| 128 |
+
# Save the uploaded file
|
| 129 |
+
audio_file_path = save_uploaded_file(uploaded_file)
|
| 130 |
+
|
| 131 |
+
# Display file path for debugging purposes
|
| 132 |
+
st.write(f"File uploaded successfully: {audio_file_path}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
st.audio(audio_file_path, format="audio/wav")
|
| 136 |
+
|
| 137 |
+
with st.spinner("Processing audio..."):
|
| 138 |
+
try:
|
| 139 |
+
labeled_text, detected_language = process_audio(audio_file_path)
|
| 140 |
+
if labeled_text is not None:
|
| 141 |
+
st.success("Processing complete!")
|
| 142 |
+
|
| 143 |
+
# Display detected language
|
| 144 |
+
st.subheader("Detected Language")
|
| 145 |
+
st.write(f"**{detected_language}**")
|
| 146 |
+
|
| 147 |
+
# Display speaker-labeled transcription
|
| 148 |
+
st.subheader("Transcription with Speaker Labels")
|
| 149 |
+
for line in labeled_text:
|
| 150 |
+
st.write(line)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
st.error(f"An error occurred: {e}")
|
| 153 |
+
|
| 154 |
+
# Footer
|
| 155 |
+
st.markdown("---")
|
| 156 |
+
st.markdown(
|
| 157 |
+
"Developed with using [Whisper](https://github.com/openai/whisper) and "
|
| 158 |
+
"[PyAnnote](https://github.com/pyannote/pyannote-audio)."
|
| 159 |
+
)
|
audio_transcribe.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.20.0
|
| 2 |
+
whisper==1.0.0
|
| 3 |
+
pyannote.audio==2.1.1
|
| 4 |
+
torch==2.0.0
|
| 5 |
+
transformers==4.27.0
|