adiitya29 commited on
Commit
cf2c908
·
1 Parent(s): 8f2047c

frontend UI created using gradio, fastAPI created, notebooks folder created for fine tuning and evaluation of models

Browse files
gradio_ui.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from app.asr_model import load_model, transcribe_audio
3
+ from app.language_detection import detect_language_from_text
4
+ from app.history import save_to_history, export_history
5
+
6
+ def process_audio(audio_path):
7
+ if audio_path is None:
8
+ return "No audio uploaded.", "Unknown"
9
+
10
+ print(f"\n--- New Request ---")
11
+ print(f"Processing audio: {audio_path}")
12
+
13
+ # Transcribe Speech
14
+ print("Transcribing... (If this is the first time, it is downloading a 400MB model)")
15
+ transcript = transcribe_audio(audio_path)
16
+ print(f"Transcription complete: {transcript[:50]}...")
17
+
18
+ # Detect Language from transcript
19
+ print("Detecting language...")
20
+ lang = detect_language_from_text(transcript)
21
+
22
+ # Save History
23
+ print("Saving to history...")
24
+ save_to_history(audio_path, transcript, lang)
25
+
26
+ print("Done!\n")
27
+ return transcript, lang
28
+
29
+ def export_history_wrapper():
30
+ path = export_history("csv")
31
+ return path if path else None
32
+
33
+ def create_ui():
34
+ with gr.Blocks(title="Multilingual ASR") as demo:
35
+ gr.Markdown("# Multilingual Automatic Speech Recognition")
36
+
37
+ with gr.Tabs():
38
+ with gr.TabItem("Transcribe"):
39
+ gr.Markdown("Upload an audio file to get a text transcription using Wav2Vec.")
40
+
41
+ with gr.Row():
42
+ with gr.Column():
43
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
44
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
45
+
46
+ with gr.Column():
47
+ lang_output = gr.Textbox(label="Detected Language")
48
+ transcript_output = gr.Textbox(label="Transcription", lines=10)
49
+
50
+ transcribe_btn.click(
51
+ fn=process_audio,
52
+ inputs=audio_input,
53
+ outputs=[transcript_output, lang_output]
54
+ )
55
+
56
+ with gr.TabItem("History"):
57
+ gr.Markdown("Download your past transcriptions.")
58
+ download_btn = gr.Button("Prepare History for Download")
59
+ file_output = gr.File(label="Download CSV")
60
+
61
+ download_btn.click(
62
+ fn=export_history_wrapper,
63
+ outputs=file_output
64
+ )
65
+
66
+ return demo
67
+
68
+ if __name__ == "__main__":
69
+ demo = create_ui()
70
+ demo.launch()
main.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import gradio as gr
3
+ from gradio_ui import create_ui
4
+ from app.asr_model import transcribe_audio
5
+ from app.language_detection import detect_language_from_text
6
+ from app.history import save_to_history
7
+ import os
8
+ import tempfile
9
+ import shutil
10
+
11
+ # Initialize FastAPI app
12
+ api = FastAPI(title="Multilingual ASR API", description="REST API for audio transcription", version="1.0.0")
13
+
14
+ @api.post("/api/transcribe")
15
+ async def api_transcribe(audio_file: UploadFile = File(...)):
16
+ """
17
+ REST endpoint to upload an audio file and get its transcription.
18
+ """
19
+ if not audio_file.filename:
20
+ raise HTTPException(status_code=400, detail="No file provided")
21
+
22
+ try:
23
+ # Save uploaded file to a temporary file
24
+ fd, temp_path = tempfile.mkstemp(suffix=os.path.splitext(audio_file.filename)[1])
25
+ with os.fdopen(fd, "wb") as f:
26
+ shutil.copyfileobj(audio_file.file, f)
27
+
28
+ # Run inference
29
+ transcript = transcribe_audio(temp_path)
30
+ lang = detect_language_from_text(transcript)
31
+
32
+ # Save to history
33
+ save_to_history(audio_file.filename, transcript, lang)
34
+
35
+ # Cleanup temp file
36
+ os.remove(temp_path)
37
+
38
+ return {
39
+ "filename": audio_file.filename,
40
+ "language": lang,
41
+ "transcript": transcript
42
+ }
43
+ except Exception as e:
44
+ raise HTTPException(status_code=500, detail=str(e))
45
+
46
+ # Mount Gradio app on root
47
+ demo = create_ui()
48
+ app = gr.mount_gradio_app(api, demo, path="/")
49
+
50
+ if __name__ == "__main__":
51
+ import uvicorn
52
+ # Run the unified app with uvicorn
53
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
notebooks/01_evaluation.ipynb ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Model Evaluation\n",
8
+ "\n",
9
+ "This notebook demonstrates how to evaluate your Wav2Vec2 model on a test dataset using the Word Error Rate (WER) metric."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "source": [
17
+ "!pip install evaluate jiwer datasets"
18
+ ],
19
+ "outputs": []
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "source": [
26
+ "import evaluate\n",
27
+ "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
28
+ "import torch\n",
29
+ "\n",
30
+ "# Load metric\n",
31
+ "wer_metric = evaluate.load(\"wer\")\n",
32
+ "\n",
33
+ "# Load model and processor\n",
34
+ "model_id = \"facebook/wav2vec2-base-960h\"\n",
35
+ "processor = Wav2Vec2Processor.from_pretrained(model_id)\n",
36
+ "model = Wav2Vec2ForCTC.from_pretrained(model_id)"
37
+ ],
38
+ "outputs": []
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "## Compute WER on sample predictions"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "source": [
52
+ "predictions = [\"this is a test\", \"hello world\"]\n",
53
+ "references = [\"this is a test\", \"hello word\"]\n",
54
+ "\n",
55
+ "wer = wer_metric.compute(predictions=predictions, references=references)\n",
56
+ "print(f\"Word Error Rate (WER): {wer}\")"
57
+ ],
58
+ "outputs": []
59
+ }
60
+ ],
61
+ "metadata": {
62
+ "kernelspec": {
63
+ "display_name": "Python 3",
64
+ "language": "python",
65
+ "name": "python3"
66
+ },
67
+ "language_info": {
68
+ "name": "python",
69
+ "version": "3.12"
70
+ }
71
+ },
72
+ "nbformat": 4,
73
+ "nbformat_minor": 4
74
+ }
notebooks/02_finetuning.ipynb ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Model Fine-tuning\n",
8
+ "\n",
9
+ "This notebook provides a skeleton for fine-tuning the Wav2Vec2 model on your custom dataset."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "source": [
17
+ "!pip install datasets transformers accelerate librosa soundfile"
18
+ ],
19
+ "outputs": []
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "source": [
26
+ "from datasets import load_dataset, Audio\n",
27
+ "\n",
28
+ "# Load your dataset here (example uses common_voice, you can replace with your own)\n",
29
+ "# dataset = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"en\", split=\"train\")\n",
30
+ "\n",
31
+ "# Ensure audio is resampled to 16kHz\n",
32
+ "# dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))"
33
+ ],
34
+ "outputs": []
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "source": [
41
+ "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer\n",
42
+ "\n",
43
+ "model_id = \"facebook/wav2vec2-base\"\n",
44
+ "processor = Wav2Vec2Processor.from_pretrained(model_id)\n",
45
+ "model = Wav2Vec2ForCTC.from_pretrained(\n",
46
+ " model_id, \n",
47
+ " ctc_loss_reduction=\"mean\", \n",
48
+ " pad_token_id=processor.tokenizer.pad_token_id\n",
49
+ ")"
50
+ ],
51
+ "outputs": []
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## Training Setup\n",
58
+ "Set up the DataCollator and TrainingArguments here."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "source": [
66
+ "# training_args = TrainingArguments(\n",
67
+ "# output_dir=\"./wav2vec2-finetuned\",\n",
68
+ "# group_by_length=True,\n",
69
+ "# per_device_train_batch_size=16,\n",
70
+ "# evaluation_strategy=\"steps\",\n",
71
+ "# num_train_epochs=10,\n",
72
+ "# fp16=True, # Use false if on MPS without FP16 support\n",
73
+ "# save_steps=500,\n",
74
+ "# eval_steps=500,\n",
75
+ "# logging_steps=500,\n",
76
+ "# learning_rate=1e-4,\n",
77
+ "# warmup_steps=1000,\n",
78
+ "# save_total_limit=2,\n",
79
+ "# )\n",
80
+ "\n",
81
+ "# trainer = Trainer(\n",
82
+ "# model=model,\n",
83
+ "# data_collator=data_collator,\n",
84
+ "# args=training_args,\n",
85
+ "# compute_metrics=compute_metrics,\n",
86
+ "# train_dataset=dataset,\n",
87
+ "# eval_dataset=dataset,\n",
88
+ "# tokenizer=processor.feature_extractor,\n",
89
+ "# )\n",
90
+ "\n",
91
+ "# trainer.train()"
92
+ ],
93
+ "outputs": []
94
+ }
95
+ ],
96
+ "metadata": {
97
+ "kernelspec": {
98
+ "display_name": "Python 3",
99
+ "language": "python",
100
+ "name": "python3"
101
+ },
102
+ "language_info": {
103
+ "name": "python",
104
+ "version": "3.12"
105
+ }
106
+ },
107
+ "nbformat": 4,
108
+ "nbformat_minor": 4
109
+ }