speech / model_engine.py
ahmed-03's picture
Upload 3 files
c559735 verified
import torch
from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
import librosa
class ArabicASREngine:
def __init__(self, model_id="MBZUAI/artst_asr_v3"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing ArTST Model: {model_id} on {self.device}")
# Load processor and model
self.processor = SpeechT5Processor.from_pretrained(model_id)
self.model = SpeechT5ForSpeechToText.from_pretrained(model_id).to(self.device)
def transcribe(self, audio_path):
# 1. Load audio and resample to 16kHz (Requirement 4)
speech, sr = librosa.load(audio_path, sr=16000)
# 2. Pre-process features
input_features = self.processor(
audio=speech,
sampling_rate=16000,
return_tensors="pt"
).input_values.to(self.device)
# 3. Generate IDs
with torch.no_grad():
predicted_ids = self.model.generate(input_features, max_length=200)
# 4. Decode to Arabic Text
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription