sreenathsree1578 commited on
Commit
fb977ca
·
verified ·
1 Parent(s): 6a16850

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +337 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,339 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import psutil
3
+ import signal
4
+ import time
5
+ import re
6
+ import requests
7
+ import subprocess
8
  import streamlit as st
9
+ from fastapi import FastAPI, File, UploadFile, Form
10
+ from fastapi.responses import JSONResponse
11
+ from faster_whisper import WhisperModel
12
+ from pydub import AudioSegment
13
+ from io import BytesIO
14
+ import torch
15
+ import logging
16
+ import uvicorn
17
+ from pyngrok import ngrok
18
+ import nest_asyncio
19
+ from tenacity import retry, stop_after_attempt, wait_fixed
20
+ from deep_translator import GoogleTranslator
21
 
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # GitHub Gist configuration
27
+ GITHUB_TOKEN = "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" # Replace with your GitHub PAT
28
+ GIST_ID = "e1b9e5a9d7167ab2458405c38a2803c2"
29
+ GIST_FILENAME = "ngrok_public_url.txt"
30
+ GIST_DESCRIPTION = "Ngrok Public URL for FastAPI Server"
31
+
32
+ # Function to free port 8000
33
+ def free_port(port=8000):
34
+ try:
35
+ for proc in psutil.process_iter(['pid', 'name']):
36
+ for conn in proc.net_connections(kind='inet'):
37
+ if conn.laddr.port == port:
38
+ logger.info(f"Terminating process {proc.pid} using port {port}")
39
+ os.kill(proc.pid, signal.SIGTERM)
40
+ proc.wait(timeout=5)
41
+ except Exception as e:
42
+ logger.warning(f"Error freeing port {port}: {str(e)}")
43
+
44
+ # Function to create or update a GitHub Gist
45
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
46
+ def update_github_gist(public_url: str) -> str:
47
+ global GIST_ID
48
+ headers = {
49
+ "Authorization": f"token {GITHUB_TOKEN}",
50
+ "Accept": "application/vnd.github.v3+json"
51
+ }
52
+ gist_data = {
53
+ "description": GIST_DESCRIPTION,
54
+ "public": False,
55
+ "files": {
56
+ GIST_FILENAME: {
57
+ "content": public_url
58
+ }
59
+ }
60
+ }
61
+
62
+ try:
63
+ if GIST_ID:
64
+ response = requests.patch(
65
+ f"https://api.github.com/gists/{GIST_ID}",
66
+ headers=headers,
67
+ json=gist_data
68
+ )
69
+ if response.status_code == 200:
70
+ logger.info(f"Successfully updated Gist {GIST_ID} with URL: {public_url}")
71
+ return GIST_ID
72
+ else:
73
+ logger.error(f"Failed to update Gist {GIST_ID}: {response.text}")
74
+ raise Exception(f"Failed to update Gist: {response.text}")
75
+ else:
76
+ response = requests.post(
77
+ "https://api.github.com/gists",
78
+ headers=headers,
79
+ json=gist_data
80
+ )
81
+ if response.status_code == 201:
82
+ GIST_ID = response.json()['id']
83
+ logger.info(f"Created new Gist {GIST_ID} with URL: {public_url}")
84
+ return GIST_ID
85
+ else:
86
+ logger.error(f"Failed to create Gist: {response.text}")
87
+ raise Exception(f"Failed to create Gist: {response.text}")
88
+ except Exception as e:
89
+ logger.error(f"Error updating/creating Gist: {str(e)}")
90
+ raise
91
+
92
+ # Function to start ngrok tunnel
93
+ def start_ngrok_tunnel(port=8000):
94
+ try:
95
+ tunnel = ngrok.connect(port, bind_tls=True)
96
+ public_url = tunnel.public_url
97
+ logger.info(f"ngrok tunnel created at {public_url}")
98
+ return public_url
99
+ except Exception as e:
100
+ logger.error(f"Failed to start ngrok tunnel: {str(e)}")
101
+ raise
102
+
103
+ # Initialize FastAPI app
104
+ app = FastAPI()
105
+
106
+ # Load Whisper model
107
+ model = None
108
+ try:
109
+ model = WhisperModel("large", device="cuda", compute_type="float16")
110
+ logger.info("Whisper model loaded on GPU")
111
+ except Exception as e:
112
+ logger.warning(f"Failed to load Whisper model on GPU: {str(e)}. Falling back to CPU.")
113
+ try:
114
+ model = WhisperModel("large", device="cpu", compute_type="int8")
115
+ logger.info("Whisper model loaded on CPU")
116
+ except Exception as cpu_e:
117
+ logger.error(f"Failed to load Whisper model on CPU: {str(cpu_e)}")
118
+ raise RuntimeError(f"Failed to load Whisper model on both GPU and CPU: {str(cpu_e)}")
119
+
120
+ def clean_transcription(text):
121
+ """Clean the transcription by removing consecutive duplicates"""
122
+ sentences = [s.strip() for s in re.split(r'[.!?]', text) if s.strip()]
123
+ unique_sentences = []
124
+ prev_sentence = None
125
+
126
+ for sentence in sentences:
127
+ if sentence != prev_sentence:
128
+ unique_sentences.append(sentence)
129
+ prev_sentence = sentence
130
+
131
+ processed_sentences = []
132
+ for sentence in unique_sentences:
133
+ words = sentence.split()
134
+ unique_words = []
135
+ prev_word = None
136
+ for word in words:
137
+ if word != prev_word:
138
+ unique_words.append(word)
139
+ prev_word = word
140
+ processed_sentences.append(' '.join(unique_words))
141
+
142
+ cleaned_text = '. '.join(processed_sentences)
143
+ if text.endswith('.'):
144
+ cleaned_text += '.'
145
+ elif text.endswith('?'):
146
+ cleaned_text += '?'
147
+ elif text.endswith('!'):
148
+ cleaned_text += '!'
149
+
150
+ return cleaned_text
151
+
152
+ def translate_to_language(text_or_dict, source_lang, target_lang):
153
+ lang_codes = {
154
+ 'en': 'en', 'hi': 'hi', 'bn': 'bn', 'te': 'te', 'mr': 'mr',
155
+ 'ta': 'ta', 'ur': 'ur', 'gu': 'gu', 'kn': 'kn', 'ml': 'ml',
156
+ 'pa': 'pa', 'or': 'or', 'as': 'as', 'ne': 'ne'
157
+ }
158
+ malayalam_reverse_map = {
159
+ 'ആശയവുമില്ല': 'ഐഡിയയുമില്ല',
160
+ 'എന്ന് ': 'നിന്നും ',
161
+ 'ചെയ്യാൻ': 'ചെയ്യാവോ',
162
+ # ... (keeping the full malayalam_reverse_map as in the original code)
163
+ 'നികുതി': 'ടാക്സ്'
164
+ }
165
+ try:
166
+ if isinstance(text_or_dict, dict):
167
+ text = text_or_dict.get('raw_transcription', '')
168
+ if not text:
169
+ logger.warning("No raw_transcription found in input dictionary")
170
+ return {'translated_text': ''}
171
+ else:
172
+ text = text_or_dict
173
+ if not isinstance(text, str):
174
+ logger.warning(f"Input text_or_dict is not a string or dict: {type(text_or_dict)}")
175
+ text = str(text_or_dict)
176
+ text = text.strip()
177
+ if not text:
178
+ logger.warning("Input text is empty after stripping")
179
+ return {'translated_text': ''}
180
+ text = re.sub(r'[^\w\s.,!?]', '', text)
181
+ if source_lang not in lang_codes or target_lang not in lang_codes:
182
+ logger.error(f"Unsupported language: source={source_lang}, target={target_lang}")
183
+ return {'translated_text': '', 'error': f"Unsupported language: source={source_lang}, target={target_lang}"}
184
+ source_code = lang_codes.get(source_lang, 'en')
185
+ target_code = lang_codes.get(target_lang, 'ml')
186
+ logger.info(f"Translating from {source_lang} ({source_code}) to {target_lang} ({target_code})")
187
+ translator = GoogleTranslator(source=source_code, target=target_code)
188
+ translated_text = translator.translate(text)
189
+ if not translated_text:
190
+ logger.warning("Translation returned empty result")
191
+ return {'translated_text': ''}
192
+ if target_lang == 'ml':
193
+ for incorrect, correct in malayalam_reverse_map.items():
194
+ translated_text = translated_text.replace(incorrect, correct)
195
+ logger.debug(f"Translated text: {translated_text[:100]}...")
196
+ return {'translated_text': translated_text}
197
+ except Exception as e:
198
+ logger.error(f"Translation failed: {str(e)}")
199
+ return {'translated_text': '', 'error': f"Translation failed: {str(e)}"}
200
+
201
+ @app.post("/transcribe")
202
+ async def transcribe(
203
+ audio: UploadFile = File(...),
204
+ model_size: str = Form("small"),
205
+ transcription_language: str = Form("en"),
206
+ target_language: str = Form(None)
207
+ ):
208
+ try:
209
+ audio_data = await audio.read()
210
+ audio_buffer = BytesIO(audio_data)
211
+ audio_segment = AudioSegment.from_file(audio_buffer, format="wav")
212
+ wav_buffer = BytesIO()
213
+ audio_segment.export(wav_buffer, format="wav")
214
+ wav_buffer.seek(0)
215
+ segments, info = model.transcribe(
216
+ wav_buffer,
217
+ language=transcription_language,
218
+ task="transcribe",
219
+ beam_size=1,
220
+ vad_filter=True
221
+ )
222
+
223
+ raw_segments = [segment.text for segment in segments]
224
+ raw_transcription = " ".join(raw_segments)
225
+ raw_transcription = clean_transcription(raw_transcription)
226
+
227
+ response = {
228
+ "raw_transcription": raw_transcription,
229
+ "audio_metadata": {
230
+ "extension": audio.filename.split('.')[-1],
231
+ "duration": len(audio_segment) / 1000,
232
+ "emotion": "unknown"
233
+ },
234
+ "detected_language": info.language,
235
+ "detected_language_name": {
236
+ 'en': 'English', 'hi': 'Hindi', 'bn': 'Bengali', 'te': 'Telugu', 'mr': 'Marathi',
237
+ 'ta': 'Tamil', 'ur': 'Urdu', 'gu': 'Gujarati', 'kn': 'Kannada', 'ml': 'Malayalam',
238
+ 'pa': 'Punjabi', 'or': 'Odia', 'as': 'Assamese', 'ne': 'Nepali'
239
+ }.get(info.language, info.language)
240
+ }
241
+
242
+ if target_language:
243
+ translation_result = translate_to_language(
244
+ {'raw_transcription': raw_transcription},
245
+ source_lang=info.language,
246
+ target_lang=target_language
247
+ )
248
+ response['translated_text'] = translation_result.get('translated_text', '')
249
+ if 'error' in translation_result:
250
+ response['translation_error'] = translation_result['error']
251
+
252
+ return JSONResponse(response)
253
+ except Exception as e:
254
+ logger.error(f"Transcription failed: {str(e)}")
255
+ return JSONResponse({"error": str(e)}, status_code=500)
256
+
257
+ # Function to start the FastAPI server as a subprocess
258
+ def start_server():
259
+ try:
260
+ free_port(8000)
261
+ nest_asyncio.apply()
262
+ ngrok.set_auth_token("2zXsNdquBfJFwENcejztPfNEdan_57vHEJ6RhnFxNoaNxR5cW") # Replace with your token
263
+ public_url = start_ngrok_tunnel(8000)
264
+ st.session_state['public_url'] = public_url
265
+ st.session_state['gist_id'] = update_github_gist(str(public_url))
266
+
267
+ # Start the FastAPI server as a subprocess
268
+ process = subprocess.Popen(
269
+ ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"],
270
+ stdout=subprocess.PIPE,
271
+ stderr=subprocess.PIPE,
272
+ text=True
273
+ )
274
+ st.session_state['server_process'] = process
275
+ st.success(f"Server started successfully! Public URL: {public_url}")
276
+ st.write(f"Updated Gist ID: {st.session_state['gist_id']}")
277
+ return process
278
+ except Exception as e:
279
+ logger.error(f"Failed to start server: {str(e)}")
280
+ st.error("Error: Could not start server. Please check the ngrok dashboard (https://dashboard.ngrok.com/agents) to terminate existing sessions or upgrade your ngrok account.")
281
+ return None
282
+
283
+ # Function to stop the FastAPI server and ngrok tunnel
284
+ def stop_server():
285
+ try:
286
+ # Terminate the server process
287
+ if 'server_process' in st.session_state and st.session_state['server_process']:
288
+ process = st.session_state['server_process']
289
+ process.terminate()
290
+ process.wait(timeout=5)
291
+ logger.info("Server process terminated")
292
+
293
+ # Terminate ngrok tunnel
294
+ ngrok.kill()
295
+ logger.info("ngrok tunnel terminated")
296
+
297
+ # Clear session state
298
+ st.session_state['server_process'] = None
299
+ st.session_state['public_url'] = None
300
+ st.session_state['gist_id'] = None
301
+ st.success("Server and ngrok tunnel stopped successfully")
302
+ except Exception as e:
303
+ logger.error(f"Failed to stop server: {str(e)}")
304
+ st.error(f"Error stopping server: {str(e)}")
305
+
306
+ # Streamlit app
307
+ def main():
308
+ st.title("FastAPI Server Control")
309
+ st.write("Use the buttons below to start or stop the FastAPI server and ngrok tunnel.")
310
+
311
+ # Initialize session state
312
+ if 'server_process' not in st.session_state:
313
+ st.session_state['server_process'] = None
314
+ st.session_state['public_url'] = None
315
+ st.session_state['gist_id'] = None
316
+
317
+ # Start Server Button
318
+ if st.button("Start Server"):
319
+ if st.session_state['server_process'] is None:
320
+ start_server()
321
+ else:
322
+ st.warning("Server is already running!")
323
+
324
+ # Stop Server Button
325
+ if st.button("Stop Server"):
326
+ if st.session_state['server_process'] is not None:
327
+ stop_server()
328
+ else:
329
+ st.warning("No server is running!")
330
+
331
+ # Display current status
332
+ if st.session_state['public_url']:
333
+ st.write(f"**Current Public URL**: {st.session_state['public_url']}")
334
+ st.write(f"**Gist ID**: {st.session_state['gist_id']}")
335
+ else:
336
+ st.write("**Status**: Server is not running")
337
+
338
+ if __name__ == "__main__":
339
+ main()