Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import librosa | |
| from transformers import Wav2Vec2Processor, AutoModelForCTC | |
| import zipfile | |
| import os | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore | |
| from datetime import datetime | |
| import json | |
| import tempfile | |
| # # Initialize Firebase | |
| # firebase_config = json.loads(os.environ.get('firebase_creds')) | |
| # cred = credentials.Certificate(firebase_config) | |
| # firebase_admin.initialize_app(cred) | |
| # db = firestore.client() | |
| # Load the ASR model and processor | |
| MODEL_NAME = "eleferrand/XLSR_gwad" | |
| processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCTC.from_pretrained(MODEL_NAME) | |
| def transcribe(audio_file): | |
| output = "" | |
| try: | |
| audio, rate = librosa.load(audio_file, sr=16000) | |
| if len(audio)/rate>20: | |
| start=0 | |
| for ind in range(20*rate,len(audio)+20*rate,20*rate): | |
| if ind<len(audio): | |
| end=ind | |
| else: | |
| end=len(audio) | |
| curr = audio[start:ind] | |
| input_values = processor(curr, sampling_rate=16000, return_tensors="pt").input_values | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| transc = transcription.replace("[UNK]", "") | |
| print(transc) | |
| output= output+f"{start/rate} - {end/rate}: {transc}\n" | |
| start=ind | |
| else: | |
| input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| transc = transcription.replace("[UNK]", "") | |
| output=output+f"0 - {len(audio)/rate}: {transc}" | |
| return output | |
| except Exception as e: | |
| return f"處理文件錯誤: {e}" | |
| def transcribe_both(audio_file): | |
| start_time = datetime.now() | |
| transcription = transcribe(audio_file) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| return transcription, transcription, processing_time | |
| def store_correction(original_transcription, corrected_transcription, audio_file, age, native_speaker): | |
| try: | |
| audio_metadata = {} | |
| if audio_file and os.path.exists(audio_file): | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| duration = librosa.get_duration(y=audio, sr=sr) | |
| file_size = os.path.getsize(audio_file) | |
| audio_metadata = {'duration': duration, 'file_size': file_size} | |
| combined_data = { | |
| 'original_text': original_transcription, | |
| 'corrected_text': corrected_transcription, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'audio_metadata': audio_metadata, | |
| 'model_name': MODEL_NAME, | |
| 'user_info': { | |
| 'native_amis_speaker': native_speaker, | |
| 'age': age | |
| } | |
| } | |
| db.collection('transcriptions').add(combined_data) | |
| return "校正保存成功! (Correction saved successfully!)" | |
| except Exception as e: | |
| return f"保存失败: {e} (Error saving correction: {e})" | |
| def prepare_download(audio_file, original_transcription, corrected_transcription): | |
| if audio_file is None: | |
| return None | |
| tmp_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") | |
| tmp_zip.close() | |
| with zipfile.ZipFile(tmp_zip.name, "w") as zf: | |
| if os.path.exists(audio_file): | |
| zf.write(audio_file, arcname="audio.wav") | |
| orig_txt = "original_transcription.txt" | |
| with open(orig_txt, "w", encoding="utf-8") as f: | |
| f.write(original_transcription) | |
| zf.write(orig_txt, arcname="original_transcription.txt") | |
| os.remove(orig_txt) | |
| corr_txt = "corrected_transcription.txt" | |
| with open(corr_txt, "w", encoding="utf-8") as f: | |
| f.write(corrected_transcription) | |
| zf.write(corr_txt, arcname="corrected_transcription.txt") | |
| os.remove(corr_txt) | |
| return tmp_zip.name | |
| def toggle_language(switch): | |
| """Switch UI text between English and Traditional Chinese""" | |
| if switch: | |
| return ( | |
| "阿美語轉錄與修正系統", | |
| "步驟 1:音訊上傳與轉錄", | |
| "步驟 2:審閱與編輯轉錄", | |
| "步驟 3:使用者資訊", | |
| "步驟 4:儲存與下載", | |
| "音訊輸入", "轉錄音訊", | |
| "原始轉錄", "更正轉錄", | |
| "年齡", "以阿美語為母語?", | |
| "儲存更正", "儲存狀態", | |
| "下載 ZIP 檔案" | |
| ) | |
| else: | |
| return ( | |
| "Amis ASR Transcription & Correction System", | |
| "Step 1: Audio Upload & Transcription", | |
| "Step 2: Review & Edit Transcription", | |
| "Step 3: User Information", | |
| "Step 4: Save & Download", | |
| "Audio Input", "Transcribe Audio", | |
| "Original Transcription", "Corrected Transcription", | |
| "Age", "Native Amis Speaker?", | |
| "Save Correction", "Save Status", | |
| "Download ZIP File" | |
| ) | |
| # Interface | |
| # Interface | |
| with gr.Blocks() as demo: | |
| # lang_switch = gr.Checkbox(label="切換到繁體中文 (Switch to Traditional Chinese)") | |
| title = gr.Markdown("Creole ASR Transcription & Correction System") | |
| step1 = gr.Markdown("Step 1: Audio Upload & Transcription") | |
| # Audio input and playback (Original section) | |
| with gr.Row(): | |
| audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") | |
| step2 = gr.Markdown("Step 2: Review & Edit Transcription") | |
| # Transcribe button below the audio input (Added this section to place the button below the playback) | |
| with gr.Row(): # Added this Row to position the button below the audio input | |
| transcribe_button = gr.Button("Transcribe Audio") | |
| original_text = gr.Textbox(label="Transcription", interactive=False, lines=5) | |
| corrected_text = gr.Textbox(label="Corrected Transcription", interactive=True, lines=5) | |
| step3 = gr.Markdown("Step 3: User Information") | |
| with gr.Row(): | |
| age_input = gr.Slider(minimum=0, maximum=100, step=1, label="Age", value=25) | |
| native_speaker_input = gr.Checkbox(label="Native Creole Speaker?", value=True) | |
| step4 = gr.Markdown("Step 4: Save & Download") | |
| with gr.Row(): | |
| save_button = gr.Button("Save Correction") | |
| save_status = gr.Textbox(label="Save Status", interactive=False) | |
| with gr.Row(): | |
| download_button = gr.Button("Download ZIP File") | |
| download_output = gr.File() | |
| # Toggle language dynamically | |
| # lang_switch.change( | |
| # toggle_language, | |
| # inputs=lang_switch, | |
| # outputs=[title, step1, step2, step3, step4, audio_input, transcribe_button, | |
| # original_text, corrected_text, age_input, native_speaker_input, | |
| # save_button, save_status, download_button] | |
| # ) | |
| transcribe_button.click( | |
| transcribe_both, | |
| inputs=audio_input, | |
| outputs=[original_text, corrected_text] | |
| ) | |
| save_button.click( | |
| store_correction, | |
| inputs=[original_text, corrected_text, audio_input, age_input, native_speaker_input], | |
| outputs=save_status | |
| ) | |
| download_button.click( | |
| prepare_download, | |
| inputs=[audio_input, original_text, corrected_text], | |
| outputs=download_output | |
| ) | |
| demo.launch() |