File size: 3,615 Bytes
8a686d5
 
 
 
 
 
 
 
40c4b47
4ab51d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
---
license: apache-2.0
datasets:
- mozilla-foundation/common_voice_17_0
language:
- uz
base_model:
- facebook/wav2vec2-large-xlsr-53
pipeline_tag: automatic-speech-recognition
---
# Fine-tuned Wav2Vec2-Large-XLSR-53 large model for speech recognition on Uzbek Language

Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
on Uzbek using the `train` splits of [Common Voice](https://huggingface.co/datasets/common_voice_17_0).
When using this model, make sure that your speech input is sampled at 16kHz.

## Usage

```python
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torchaudio
from typing import Optional, Tuple

class Wav2Vec2STTModel:
    def __init__(self, model_name: str):
        """Initialize the Wav2Vec2 model and processor"""
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()
        
    def _load_model(self) -> None:
        """Load model and processor from HuggingFace"""
        try:
            self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name).to(self.device)
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")
    
    def preprocess_audio(self, file_path: str) -> Tuple[torch.Tensor, int]:
        """Load and preprocess audio file"""
        try:
            speech_array, sampling_rate = torchaudio.load(file_path)
            
            # Resample if needed
            if sampling_rate != 16000:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=sampling_rate, 
                    new_freq=16000
                )
                speech_array = resampler(speech_array)
                
            return speech_array.squeeze().numpy(), 16000
        except FileNotFoundError:
            raise FileNotFoundError(f"Audio file not found: {file_path}")
        except Exception as e:
            raise RuntimeError(f"Audio processing error: {str(e)}")
    
    def _replace_unk(self, transcription: str) -> str:
        """Replace unknown tokens with apostrophe"""
        return transcription.replace("[UNK]", "ʼ")
    
    def transcribe(self, file_path: str) -> str:
        """Transcribe audio file to text"""
        try:
            # Preprocess audio
            speech_array, sampling_rate = self.preprocess_audio(file_path)
            
            # Process input
            inputs = self.processor(
                speech_array, 
                sampling_rate=sampling_rate, 
                return_tensors="pt"
            ).to(self.device)
            
            # Model inference
            with torch.no_grad():
                logits = self.model(inputs.input_values).logits
            
            # Decode prediction
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = self.processor.batch_decode(predicted_ids)[0]
            
            # Clean up result
            return self._replace_unk(transcription)
            
        except Exception as e:
            raise RuntimeError(f"Transcription error: {str(e)}")

# Example usage
if __name__ == "__main__":
    try:
        # Initialize model
        stt_model = Wav2Vec2STTModel("ipilot7/uzbek_speach_to_text")
        
        # Transcribe audio
        result = stt_model.transcribe("1.mp3")
        print("Transcription:", result)
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")
```