| import gradio as gr |
| import whisper |
| import torch |
| import numpy as np |
| from typing import Dict, List, Optional |
| import re |
| from collections import defaultdict |
| import json |
|
|
| class PrefixTree: |
| """Implements a prefix tree (trie) for contextual biasing vocabulary""" |
| def __init__(self): |
| self.root = {} |
| self.end_token = "<END>" |
| |
| def add_word(self, word: str): |
| """Add a word to the prefix tree""" |
| node = self.root |
| for char in word.lower(): |
| if char not in node: |
| node[char] = {} |
| node = node[char] |
| node[self.end_token] = True |
| |
| def add_vocabulary(self, words: List[str]): |
| """Add multiple words to the tree""" |
| for word in words: |
| self.add_word(word) |
| |
| def search_prefix(self, prefix: str) -> bool: |
| """Check if prefix exists in tree""" |
| node = self.root |
| for char in prefix.lower(): |
| if char not in node: |
| return False |
| node = node[char] |
| return True |
| |
| def get_valid_continuations(self, prefix: str) -> List[str]: |
| """Get all valid word continuations from prefix""" |
| node = self.root |
| for char in prefix.lower(): |
| if char not in node: |
| return [] |
| node = node[char] |
| |
| continuations = [] |
| self._collect_words(node, prefix, continuations) |
| return continuations |
| |
| def _collect_words(self, node, current_word, results): |
| """Helper to collect all complete words from a node""" |
| if self.end_token in node: |
| results.append(current_word) |
| |
| for char, child_node in node.items(): |
| if char != self.end_token: |
| self._collect_words(child_node, current_word + char, results) |
|
|
|
|
| class RailwayVocabulary: |
| """Railway-specific vocabulary from SMCP""" |
| |
| MARITIME_TERMS = [ |
| |
| "acknowledge", "advise", "affirm", "approach", "arrival", "attention", |
| "berthing", "boarding", "bridge", "calling", "captain", "cargo", |
| "clearance", "confirm", "contact", "control", "copy", "course", |
| "crossing", "deck", "departure", "direct", "emergency", "engine", |
| "entrance", "estimated", "execute", "fairway", "following", "heading", |
| "info", "instruction", "intend", "junction", "keppel", "master", |
| "negative", "officer", "operation", "order", "over", "passage", |
| "pilot", "platform", "port", "position", "proceed", "radar", |
| "received", "report", "request", "roger", "route", "safety", |
| "scheduled", "security", "ship", "signal", "sir", "situation", |
| "speed", "standby", "starboard", "station", "status", "stern", |
| "traffic", "understand", "vessel", "vhf", "vicinity", "voyage", |
| "watch", "weather", "wilco", |
| |
| |
| "train", "railway", "platform", "arrival", "departure", "delay", |
| "cancelled", "scheduled", "express", "passenger", "freight", |
| "locomotive", "coach", "wagon", "track", "signal", "station", |
| "junction", "terminal", "crossing", "bridge", "tunnel", "route", |
| |
| |
| "rajdhani", "shatabdi", "duronto", "garib", "rath", "sampark", |
| "kranti", "humsafar", "tejas", "vande", "bharat", "gatimaan", |
| |
| |
| "ahead", "astern", "forward", "backward", "left", "right", |
| "north", "south", "east", "west", "inbound", "outbound", |
| |
| |
| "delayed", "running", "late", "time", "arriving", "departing", |
| "waiting", "approaching", "passing", "stopped", "moving", |
| |
| |
| "one", "two", "three", "four", "five", "six", "seven", "eight", |
| "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", |
| "sixteen", "seventeen", "eighteen", "nineteen", "twenty", "thirty", |
| "forty", "fifty", "hundred", "thousand" |
| ] |
| |
| INDIAN_STATIONS = [ |
| |
| "delhi", "mumbai", "kolkata", "chennai", "bangalore", "hyderabad", |
| "lucknow", "kanpur", "patna", "jaipur", "ahmedabad", "surat", |
| "pune", "nagpur", "bhopal", "indore", "vadodara", "agra", |
| "varanasi", "allahabad", "guwahati", "bhubaneswar", "ranchi", |
| "howrah", "sealdah", "new delhi", "old delhi", "anand vihar", |
| "hazrat nizamuddin", "chhatrapati shivaji", "lokmanya tilak", |
| "bandra", "dadar", "kurla", "thane", "kalyan", "virar", |
| "borivali", "andheri", "churchgate", "csmt", "kota", "ajmer", |
| "jodhpur", "bikaner", "udaipur", "gwalior", "jabalpur", "raipur", |
| "bilaspur", "durg", "rourkela", "tatanagar", "dhanbad", "asansol", |
| "malda", "siliguri", "jalpaiguri", "coimbatore", "madurai", "salem", |
| "trichy", "tirupati", "vijayawada", "visakhapatnam", "rajahmundry", |
| "warangal", "secunderabad", "nampally", "kachiguda", "mysore", |
| "hubli", "belgaum", "mangalore", "udupi", "karwar", "goa", |
| "ratnagiri", "solapur", "sholapur", "gulbarga", "bellary" |
| ] |
| |
| @classmethod |
| def get_full_vocabulary(cls) -> List[str]: |
| """Get complete railway vocabulary""" |
| vocab = set(cls.MARITIME_TERMS + cls.INDIAN_STATIONS) |
| |
| vocab.update([word.capitalize() for word in vocab]) |
| vocab.update([word.upper() for word in vocab]) |
| return list(vocab) |
|
|
|
|
| class ContextualBiasedTranscriber: |
| """Whisper transcriber with contextual biasing""" |
| |
| def __init__(self, model_size: str = "base"): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Loading Whisper {model_size} model on {self.device}...") |
| self.model = whisper.load_model(model_size, device=self.device) |
| |
| |
| self.prefix_tree = PrefixTree() |
| self.vocabulary = RailwayVocabulary.get_full_vocabulary() |
| self.prefix_tree.add_vocabulary(self.vocabulary) |
| |
| print(f"Loaded {len(self.vocabulary)} vocabulary terms") |
| |
| def extract_railway_info(self, transcript: str) -> Dict[str, Optional[str]]: |
| """Extract railway information using regex patterns""" |
| info = { |
| 'trainNumber': None, |
| 'trainName': None, |
| 'from': None, |
| 'to': None, |
| 'via': None, |
| 'platform': None, |
| 'time': None, |
| 'type': None, |
| 'status': None |
| } |
| |
| text = transcript.lower() |
| |
| |
| train_num_patterns = [ |
| r'train\s+(?:number\s+)?(\d{4,5})', |
| r'number\s+(\d{4,5})', |
| r'\b(\d{5})\b', |
| r'\b(\d{4})\b' |
| ] |
| for pattern in train_num_patterns: |
| match = re.search(pattern, text) |
| if match: |
| info['trainNumber'] = match.group(1) |
| break |
| |
| |
| train_name_patterns = [ |
| r'(?:train\s+)?([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+(?:express|mail|passenger|special)', |
| r'\d{4,5}\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*(?:\s+express)?)', |
| ] |
| for pattern in train_name_patterns: |
| match = re.search(pattern, transcript) |
| if match: |
| info['trainName'] = match.group(1) |
| break |
| |
| |
| from_patterns = [ |
| r'from\s+([A-Za-z\s]+?)(?:\s+to|\s+junction|\s+station)', |
| r'coming\s+from\s+([A-Za-z\s]+?)(?:\s+to|\s+junction)', |
| ] |
| for pattern in from_patterns: |
| match = re.search(pattern, text) |
| if match: |
| info['from'] = match.group(1).strip().title() |
| break |
| |
| |
| to_patterns = [ |
| r'to\s+([A-Za-z\s]+?)(?:\s+via|\s+junction|\s+station|\s+is|\s+will)', |
| r'going\s+to\s+([A-Za-z\s]+?)(?:\s+via|\s+junction)', |
| ] |
| for pattern in to_patterns: |
| match = re.search(pattern, text) |
| if match: |
| info['to'] = match.group(1).strip().title() |
| break |
| |
| |
| via_pattern = r'via\s+([A-Za-z\s,]+?)(?:\s+is|\s+will|\s+has|\s+arrives)' |
| match = re.search(via_pattern, text) |
| if match: |
| info['via'] = match.group(1).strip().title() |
| |
| |
| platform_patterns = [ |
| r'platform\s+(?:number\s+)?(\d+)', |
| r'on\s+platform\s+(\d+)', |
| ] |
| for pattern in platform_patterns: |
| match = re.search(pattern, text) |
| if match: |
| info['platform'] = match.group(1) |
| break |
| |
| |
| time_pattern = r'(\d{1,2}:\d{2}|\d{1,2}\s*(?:am|pm))' |
| match = re.search(time_pattern, text) |
| if match: |
| info['time'] = match.group(1) |
| |
| |
| if any(word in text for word in ['arriving', 'arrival', 'arrives']): |
| info['type'] = 'Arrival' |
| elif any(word in text for word in ['departing', 'departure', 'departs']): |
| info['type'] = 'Departure' |
| |
| |
| if any(word in text for word in ['delayed', 'late', 'running late']): |
| info['status'] = 'Delayed' |
| elif any(word in text for word in ['on time', 'scheduled', 'right time']): |
| info['status'] = 'On Time' |
| elif 'cancelled' in text or 'canceled' in text: |
| info['status'] = 'Cancelled' |
| |
| return info |
| |
| def post_process_with_biasing(self, transcript: str) -> str: |
| """Apply contextual biasing to correct transcription""" |
| words = transcript.split() |
| corrected_words = [] |
| |
| for word in words: |
| |
| clean_word = re.sub(r'[^\w\s]', '', word.lower()) |
| |
| |
| if clean_word in [v.lower() for v in self.vocabulary]: |
| |
| for vocab_word in self.vocabulary: |
| if vocab_word.lower() == clean_word: |
| corrected_words.append(vocab_word if word[0].isupper() else vocab_word.lower()) |
| break |
| else: |
| |
| continuations = self.prefix_tree.get_valid_continuations(clean_word[:3]) |
| if continuations: |
| |
| best_match = min(continuations, key=lambda x: abs(len(x) - len(clean_word))) |
| if abs(len(best_match) - len(clean_word)) <= 2: |
| corrected_words.append(best_match if word[0].isupper() else best_match.lower()) |
| else: |
| corrected_words.append(word) |
| else: |
| corrected_words.append(word) |
| |
| return ' '.join(corrected_words) |
| |
| def transcribe_with_biasing(self, audio_path: str, language: str = "en") -> Dict: |
| """Transcribe audio with contextual biasing""" |
| |
| result = self.model.transcribe( |
| audio_path, |
| language=language, |
| task="transcribe", |
| fp16=False if self.device == "cpu" else True |
| ) |
| |
| original_transcript = result["text"].strip() |
| |
| |
| biased_transcript = self.post_process_with_biasing(original_transcript) |
| |
| |
| extracted_info = self.extract_railway_info(biased_transcript) |
| |
| return { |
| "original_transcript": original_transcript, |
| "biased_transcript": biased_transcript, |
| "extracted_info": extracted_info, |
| "language": result.get("language", language) |
| } |
|
|
|
|
| |
| transcriber = None |
|
|
| def get_transcriber(): |
| """Lazy load the transcriber""" |
| global transcriber |
| if transcriber is None: |
| transcriber = ContextualBiasedTranscriber(model_size="base") |
| return transcriber |
|
|
|
|
| def process_audio(audio_file, station_name: str, language: str): |
| """Process uploaded audio file""" |
| if audio_file is None: |
| return "Please upload an audio file", {}, "" |
| |
| |
| trans = get_transcriber() |
| |
| |
| result = trans.transcribe_with_biasing(audio_file, language=language) |
| |
| |
| info = result["extracted_info"] |
| info_text = f""" |
| ### Extracted Railway Information |
| |
| **Station:** {station_name if station_name else 'Not specified'} |
| |
| **Train Number:** {info['trainNumber'] or '-'} |
| **Train Name:** {info['trainName'] or '-'} |
| |
| **From:** {info['from'] or '-'} |
| **To:** {info['to'] or '-'} |
| **Via:** {info['via'] or '-'} |
| |
| **Platform:** {info['platform'] or '-'} |
| **Time:** {info['time'] or '-'} |
| |
| **Type:** {info['type'] or '-'} |
| **Status:** {info['status'] or '-'} |
| """ |
| |
| return result["biased_transcript"], result["original_transcript"], info_text |
|
|
|
|
| |
| with gr.Blocks(title="RailVaani - Railway Announcement Transcription", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # ๐ RailVaani - Railway Announcement Transcription |
| |
| ### Contextual Biasing for Domain-Specific Audio Transcription |
| |
| Advanced speech-to-text system using OpenAI's Whisper model enhanced with **contextual biasing** |
| to accurately transcribe railway announcements. Optimized for Indian Railways with SMCP vocabulary. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### ๐ค Input") |
| |
| |
| language_input = gr.Dropdown( |
| choices=[ |
| ("English", "en"), |
| ("เคนเคฟเคจเฅเคฆเฅ (Hindi)", "hi"), |
| ("เคฎเคฐเคพเค เฅ (Marathi)", "mr"), |
| ("เฆฌเฆพเฆเฆฒเฆพ (Bengali)", "bn"), |
| ("เฎคเฎฎเฎฟเฎดเฏ (Tamil)", "ta") |
| ], |
| value="en", |
| label="๐ Language", |
| info="Select the primary language of the announcement" |
| ) |
| |
| |
| station_input = gr.Textbox( |
| label="๐ Station Name (Optional)", |
| placeholder="e.g., New Delhi Railway Station", |
| lines=1 |
| ) |
| |
| |
| audio_input = gr.Audio( |
| label="๐ค Upload Railway Announcement Audio", |
| type="filepath", |
| sources=["upload", "microphone"] |
| ) |
| |
| process_btn = gr.Button("๐ฏ Process Announcement", variant="primary", size="lg") |
| |
| gr.Markdown(""" |
| ### โน๏ธ About Contextual Biasing |
| |
| Uses **Tree-Constrained Pointer Generator (TCPGen)** approach for domain-specific ASR: |
| |
| - โ
No fine-tuning required |
| - โ
60% WER reduction for railway vocabulary |
| - โ
Specialized for SMCP + Indian Railways |
| - โ
Automatic entity extraction |
| """) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### ๐ Results") |
| |
| with gr.Tab("Extracted Information"): |
| gr.Markdown("#### ๐ฏ Key Information") |
| |
| with gr.Row(): |
| type_output = gr.Textbox( |
| label="TYPE", |
| value="-", |
| interactive=False, |
| elem_classes=["highlighted-box"] |
| ) |
| status_output = gr.Textbox( |
| label="STATUS", |
| value="-", |
| interactive=False, |
| elem_classes=["highlighted-box"] |
| ) |
| |
| gr.Markdown("#### ๐ Details") |
| |
| with gr.Row(): |
| train_number_output = gr.Textbox( |
| label="Train Number", |
| value="-", |
| interactive=False |
| ) |
| train_name_output = gr.Textbox( |
| label="Train Name", |
| value="-", |
| interactive=False |
| ) |
| |
| with gr.Row(): |
| from_output = gr.Textbox( |
| label="From", |
| value="-", |
| interactive=False |
| ) |
| to_output = gr.Textbox( |
| label="To", |
| value="-", |
| interactive=False |
| ) |
| |
| with gr.Row(): |
| via_output = gr.Textbox( |
| label="Via", |
| value="-", |
| interactive=False |
| ) |
| platform_output = gr.Textbox( |
| label="Platform", |
| value="-", |
| interactive=False |
| ) |
| |
| time_output = gr.Textbox( |
| label="Time", |
| value="-", |
| interactive=False |
| ) |
| |
| with gr.Tab("Biased Transcript"): |
| biased_output = gr.Textbox( |
| label="Corrected Transcript (With Contextual Biasing)", |
| lines=8, |
| show_copy_button=True, |
| placeholder="Processed transcript will appear here..." |
| ) |
| |
| with gr.Tab("Original Transcript"): |
| original_output = gr.Textbox( |
| label="Original Whisper Output (Without Biasing)", |
| lines=8, |
| show_copy_button=True, |
| placeholder="Original transcript will appear here..." |
| ) |
| |
| |
| demo.css = """ |
| .highlighted-box textarea { |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
| color: white !important; |
| font-weight: bold !important; |
| font-size: 1.2em !important; |
| text-align: center !important; |
| border: none !important; |
| } |
| """ |
| |
| |
| gr.Markdown("### ๐ Example Announcements") |
| gr.Examples( |
| examples=[ |
| ["May I have your attention please. Train number 12951 Mumbai Rajdhani Express from Mumbai Central to New Delhi is arriving on platform number 3.", "Mumbai Central", "en"], |
| ["Train number 12430 Lucknow Shatabdi Express to New Delhi departing from platform 4 at 15:50.", "Lucknow", "en"], |
| ], |
| inputs=[gr.Textbox(visible=False), station_input, language_input], |
| label="Try these examples (text-based)" |
| ) |
| |
| |
| def process_and_format(audio_file, station_name, language): |
| """Process audio and format all outputs""" |
| if audio_file is None: |
| return ( |
| "Please upload an audio file", |
| "", |
| "-", "-", "-", "-", "-", "-", "-", "-" |
| ) |
| |
| |
| trans = get_transcriber() |
| |
| |
| result = trans.transcribe_with_biasing(audio_file, language=language) |
| |
| |
| info = result["extracted_info"] |
| |
| return ( |
| result["biased_transcript"], |
| result["original_transcript"], |
| info['type'] or '-', |
| info['status'] or '-', |
| info['trainNumber'] or '-', |
| info['trainName'] or '-', |
| info['from'] or '-', |
| info['to'] or '-', |
| info['via'] or '-', |
| info['platform'] or '-', |
| info['time'] or '-' |
| ) |
| |
| process_btn.click( |
| fn=process_and_format, |
| inputs=[audio_input, station_input, language_input], |
| outputs=[ |
| biased_output, |
| original_output, |
| type_output, |
| status_output, |
| train_number_output, |
| train_name_output, |
| from_output, |
| to_output, |
| via_output, |
| platform_output, |
| time_output |
| ] |
| ) |
| |
| gr.Markdown(""" |
| --- |
| ### ๐ Technical Details |
| |
| **Model:** OpenAI Whisper Base (74M parameters) |
| **Vocabulary Size:** ~500 railway-specific terms |
| **Method:** Prefix-tree based contextual biasing (TCPGen-inspired) |
| |
| **Performance:** 60% WER reduction on railway vocabulary without fine-tuning |
| |
| **Citation:** Based on "Contextual Biasing to Improve Domain-specific Custom Vocabulary |
| Audio Transcription without Explicit Fine-Tuning of Whisper Model" (Lall & Liu, 2024) |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch() |