7H4M3R commited on
Commit
0381907
·
verified ·
1 Parent(s): b528c0e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +182 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,184 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # import altair as alt
2
+ # import numpy as np
3
+ # import pandas as pd
4
+ # import streamlit as st
5
+
6
+ # """
7
+ # # Welcome to Streamlit!
8
+
9
+ # Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
+ # If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
+ # forums](https://discuss.streamlit.io).
12
+
13
+ # In the meantime, below is an example of what you can do with just a few lines of code:
14
+ # """
15
+
16
+ # num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
+ # num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
+
19
+ # indices = np.linspace(0, 1, num_points)
20
+ # theta = 2 * np.pi * num_turns * indices
21
+ # radius = indices
22
+
23
+ # x = radius * np.cos(theta)
24
+ # y = radius * np.sin(theta)
25
+
26
+ # df = pd.DataFrame({
27
+ # "x": x,
28
+ # "y": y,
29
+ # "idx": indices,
30
+ # "rand": np.random.randn(num_points),
31
+ # })
32
+
33
+ # st.altair_chart(alt.Chart(df, height=700, width=700)
34
+ # .mark_point(filled=True)
35
+ # .encode(
36
+ # x=alt.X("x", axis=None),
37
+ # y=alt.Y("y", axis=None),
38
+ # color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
+ # size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
+ # ))
41
+
42
+
43
  import streamlit as st
44
+ import os
45
+ # from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
46
+ # from utils import download_video, extract_audio, accent_classify
47
+ import whisper
48
+ from transformers import pipeline
49
+ import numpy as np # linear algebra
50
+ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
51
+ import yt_dlp
52
+ import torchaudio
53
+ import yt_dlp
54
+ import ffmpeg
55
+
56
+ # Define the resampling rate in Hertz (Hz) for audio data
57
+ RATE_HZ = 16000
58
+ # Define the maximum audio interval length to consider in seconds
59
+ MAX_SECONDS = 1
60
+ # Calculate the maximum audio interval length in samples by multiplying the rate and seconds
61
+ MAX_LENGTH = RATE_HZ * MAX_SECONDS
62
+
63
+
64
+ def download_video(url, output_path="video.mp4"):
65
+ ydl_opts = {
66
+ 'format': 'worstvideo[ext=mp4]+bestaudio[ext=m4a]/bestaudio',
67
+ 'outtmpl': output_path,
68
+ 'merge_output_format': 'mp4',
69
+ 'quiet': True,
70
+ 'noplaylist': True,
71
+ 'nocheckcertificate': True,
72
+ 'retries': 3,
73
+ }
74
+
75
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
76
+ ydl.download([url])
77
+ return output_path
78
+
79
+ def extract_audio(input_path, output_path="audio.mp3"):
80
+ (
81
+ ffmpeg
82
+ .input(input_path)
83
+ .output(output_path, format='mp3', acodec='libmp3lame', audio_bitrate='192k')
84
+ .overwrite_output()
85
+ .run(quiet=True)
86
+ )
87
+ return output_path
88
+
89
+ # Split files by chunks with == MAX_LENGTH size
90
+ def split_audio(file):
91
+ try:
92
+ # Load the audio file using torchaudio and get its sample rate.
93
+ audio, rate = torchaudio.load(str(file))
94
+
95
+ # Calculate the number of segments based on the MAX_LENGTH
96
+ num_segments = (len(audio[0]) // MAX_LENGTH) # Floor division to get segments
97
+
98
+ # Create an empty list to store segmented audio data
99
+ segmented_audio = []
100
+
101
+ # Split the audio into segments
102
+ for i in range(num_segments):
103
+ start = i * MAX_LENGTH
104
+ end = min((i + 1) * MAX_LENGTH, len(audio[0]))
105
+ segment = audio[0][start:end]
106
+
107
+ # Create a transformation to resample the audio to a specified sample rate (RATE_HZ).
108
+ transform = torchaudio.transforms.Resample(rate, RATE_HZ)
109
+ segment = transform(segment).squeeze(0).numpy().reshape(-1)
110
+
111
+ segmented_audio.append(segment)
112
+
113
+ # Create a DataFrame from the segmented audio
114
+ df_segments = pd.DataFrame({'audio': segmented_audio})
115
+
116
+ return df_segments
117
+
118
+ except Exception as e:
119
+ # If an exception occurs (e.g., file not found), return nothing
120
+ print(f"Error processing file: {e}")
121
+ return None
122
+
123
+ def accent_classify(pipe, audio_path):
124
+ audio_df = split_audio(audio_path)
125
+ return pipe(np.concatenate(audio_df["audio"][:50].to_list()))[0]
126
+
127
+ # Load HF pipeline model (audio classification)
128
+ @st.cache_resource
129
+ def load_audio_classifier():
130
+ model_name = "dima806/english_accents_classification"
131
+ return pipeline('audio-classification', model=model_name, device=0) # GPU (device=0) or CPU (device=-1)
132
+
133
+ # Load Whisper model
134
+ @st.cache_resource
135
+ def load_whisper_model():
136
+ return whisper.load_model("base")
137
+
138
+ # Load models once
139
+ pipe = load_audio_classifier()
140
+ whisper_model = load_whisper_model()
141
+
142
+ st.set_page_config(page_title="Accent Classifier", layout="centered")
143
+
144
+ st.title("🎙️ English Accent Classifier")
145
+ st.markdown("Upload a video link and get the English accent with confidence.")
146
+
147
+ video_url = st.text_input("Paste a public video URL (YouTube, Loom, or MP4):")
148
+
149
+ if st.button("Analyze"):
150
+ if not video_url.strip():
151
+ st.warning("Please enter a valid URL.")
152
+ else:
153
+ with st.spinner("Downloading video..."):
154
+ video_path = download_video(video_url)
155
+ pass
156
+
157
+ with st.spinner("Extracting audio..."):
158
+ audio_path = extract_audio(video_path)
159
+ pass
160
+
161
+ with st.spinner("Transcribing with Whisper..."):
162
+ result = whisper_model.transcribe(audio_path)
163
+ transcription = result['text']
164
+ # pass
165
+
166
+ with st.spinner("Classifying accent..."):
167
+ accent_data = accent_classify(pipe, audio_path)
168
+ accent = accent_data.get("label", "us")
169
+ confidence = accent_data.get("score", 0)
170
+ pass
171
+
172
+
173
+ # accent = "Englsh"
174
+ # confidence = 0.9
175
+ # transcription = "Hello There."
176
+ st.success("Analysis Complete!")
177
+ st.markdown(f"**Accent:** {accent}")
178
+ st.markdown(f"**Confidence Score:** {confidence:.2f}%")
179
+ st.markdown("**Transcription:**")
180
+ st.text_area("Transcript", transcription, height=200)
181
 
182
+ # Cleanup
183
+ os.remove(video_path)
184
+ os.remove(audio_path)