pari-tts / src /models /rubaistt.py
davronbekdev's picture
Upload folder using huggingface_hub
e077904 verified
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Always use CPU (safer for low-memory GPUs)
device = torch.device("cpu")
# Clear any leftover CUDA cache
torch.cuda.empty_cache()
# Load model and processor (using smaller model recommended)
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
def transcribe_audio(audio_path):
global model, processor
# Load and preprocess audio
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
# Convert stereo to mono if needed
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Process audio
input_features = processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt",
language="uz"
).input_features.to(device)
# Generate transcription (CPU inference)
with torch.no_grad():
predicted_ids = model.generate(input_features)
# Decode
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
# Example usage
if __name__ == "__main__":
audio_file = "some_audio_max_30_sec.wav"
print("Transcribing on CPU, please wait...")
text = transcribe_audio(audio_file)
print(f"Transcription:\n{text}")